suanPan
SparseMatSuperLU.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef SPARSEMATSUPERLU_HPP
31 #define SPARSEMATSUPERLU_HPP
32 
33 #include <superlu-mt/superlu-mt.h>
34 #include "SparseMat.hpp"
35 #include "csc_form.hpp"
36 
37 template<sp_d T> class SparseMatSuperLU final : public SparseMat<T> {
38  SuperMatrix A{}, L{}, U{}, B{};
39 
40 #ifndef SUANPAN_SUPERLUMT
41  superlu_options_t options{};
42 
43  SuperLUStat_t stat{};
44 #else
45  const int ordering_num = 1;
46 
47  Gstat_t stat{};
48 #endif
49 
50  void* t_val = nullptr;
51  int* t_row = nullptr;
52  int* t_col = nullptr;
53 
54  int* perm_r = nullptr;
55  int* perm_c = nullptr;
56 
57  bool allocated = false;
58 
59  template<sp_d ET> void alloc(csc_form<ET, int>&&);
60  void dealloc();
61 
62  template<sp_d ET> void wrap_b(const Mat<ET>&);
63  template<sp_d ET> void tri_solve(int&);
64  template<sp_d ET> void full_solve(int&);
65 
66  int solve_trs(Mat<T>&, Mat<T>&&);
67 
68 protected:
69  int direct_solve(Mat<T>& out_mat, const Mat<T>& in_mat) override { return this->direct_solve(out_mat, Mat<T>(in_mat)); }
70 
71  int direct_solve(Mat<T>&, Mat<T>&&) override;
72 
73 public:
74  SparseMatSuperLU(uword, uword, uword = 0);
76  SparseMatSuperLU(SparseMatSuperLU&&) noexcept = delete;
77  SparseMatSuperLU& operator=(const SparseMatSuperLU&) = delete;
78  SparseMatSuperLU& operator=(SparseMatSuperLU&&) noexcept = delete;
79  ~SparseMatSuperLU() override;
80 
81  void zeros() override;
82 
83  unique_ptr<MetaMat<T>> make_copy() override;
84 };
85 
86 template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::alloc(csc_form<ET, int>&& in) {
87  dealloc();
88 
89  auto t_size = sizeof(ET) * in.n_elem;
90  t_val = superlu_malloc(t_size);
91  memcpy(t_val, (void*)in.val_mem(), t_size);
92 
93  t_size = sizeof(int) * in.n_elem;
94  t_row = (int*)superlu_malloc(t_size);
95  memcpy(t_row, (void*)in.row_mem(), t_size);
96 
97  t_size = sizeof(int) * (in.n_cols + 1llu);
98  t_col = (int*)superlu_malloc(t_size);
99  memcpy(t_col, (void*)in.col_mem(), t_size);
100 
101  if constexpr(std::is_same_v<ET, double>) {
102  using E = double;
103  dCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_D, Mtype_t::SLU_GE);
104  }
105  else {
106  using E = float;
107  sCreate_CompCol_Matrix(&A, in.n_rows, in.n_cols, in.n_elem, (E*)t_val, t_row, t_col, Stype_t::SLU_NC, Dtype_t::SLU_S, Mtype_t::SLU_GE);
108  }
109 
110  perm_r = (int*)superlu_malloc(sizeof(int) * (this->n_rows + 1));
111  perm_c = (int*)superlu_malloc(sizeof(int) * (this->n_cols + 1));
112 
113  allocated = true;
114 }
115 
116 template<sp_d T> void SparseMatSuperLU<T>::dealloc() {
117  if(!allocated) return;
118 
119  Destroy_SuperMatrix_Store(&A);
120 #ifdef SUANPAN_SUPERLUMT
121  Destroy_SuperNode_SCP(&L);
122  Destroy_CompCol_NCP(&U);
123 #else
124  Destroy_SuperNode_Matrix(&L);
125  Destroy_CompCol_Matrix(&U);
126 #endif
127 
128  if(t_val) superlu_free(t_val);
129  if(t_row) superlu_free(t_row);
130  if(t_col) superlu_free(t_col);
131  if(perm_r) superlu_free(perm_r);
132  if(perm_c) superlu_free(perm_c);
133 
134  allocated = false;
135 }
136 
137 template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::wrap_b(const Mat<ET>& in_mat) {
138  if constexpr(std::is_same_v<ET, float>) {
139  using E = float;
140  sCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_S, Mtype_t::SLU_GE);
141  }
142  else {
143  using E = double;
144  dCreate_Dense_Matrix(&B, (int)in_mat.n_rows, (int)in_mat.n_cols, (E*)in_mat.memptr(), (int)in_mat.n_rows, Stype_t::SLU_DN, Dtype_t::SLU_D, Mtype_t::SLU_GE);
145  }
146 }
147 
148 template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::tri_solve(int& flag) {
149 #ifdef SUANPAN_SUPERLUMT
150  if(std::is_same_v<ET, float>) sgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
151  else dgstrs(NOTRANS, &L, &U, perm_c, perm_r, &B, &stat, &flag);
152 #else
153  superlu::gstrs<ET>(options.Trans, &L, &U, perm_c, perm_r, &B, &stat, &flag);
154 #endif
155 
156  Destroy_SuperMatrix_Store(&B);
157 }
158 
159 template<sp_d T> template<sp_d ET> void SparseMatSuperLU<T>::full_solve(int& flag) {
160 #ifdef SUANPAN_SUPERLUMT
161  get_perm_c(ordering_num, &A, perm_c);
162  if(std::is_same_v<ET, float>) psgssv(SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
163  else pdgssv(SUANPAN_NUM_THREADS, &A, perm_c, perm_r, &L, &U, &B, &flag);
164 #else
165  superlu::gssv<ET>(&options, &A, perm_c, perm_r, &L, &U, &B, &stat, &flag);
166 #endif
167 
168  Destroy_SuperMatrix_Store(&B);
169 }
170 
171 template<sp_d T> SparseMatSuperLU<T>::SparseMatSuperLU(const uword in_row, const uword in_col, const uword in_elem)
172  : SparseMat<T>(in_row, in_col, in_elem) {
173 #ifndef SUANPAN_SUPERLUMT
174  set_default_options(&options);
175  options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
176  options.Equil = superlu::yes_no_t::NO;
177 
178  arrayops::fill_zeros(reinterpret_cast<char*>(&stat), sizeof(SuperLUStat_t));
179 
180  StatInit(&stat);
181 #else
182  StatAlloc(static_cast<int>(in_col), SUANPAN_NUM_THREADS, sp_ienv(1), sp_ienv(2), &stat);
183  StatInit(static_cast<int>(in_col), SUANPAN_NUM_THREADS, &stat);
184 #endif
185 }
186 
188  : SparseMat<T>(other) {
189 #ifndef SUANPAN_SUPERLUMT
190  set_default_options(&options);
191  options.IterRefine = std::is_same_v<T, float> ? superlu::IterRefine_t::SLU_SINGLE : superlu::IterRefine_t::SLU_DOUBLE;
192  options.Equil = superlu::yes_no_t::NO;
193 
194  arrayops::fill_zeros(reinterpret_cast<char*>(&stat), sizeof(SuperLUStat_t));
195 
196  StatInit(&stat);
197 #else
198  StatAlloc(static_cast<int>(other.n_cols), SUANPAN_NUM_THREADS, sp_ienv(1), sp_ienv(2), &stat);
199  StatInit(static_cast<int>(other.n_cols), SUANPAN_NUM_THREADS, &stat);
200 #endif
201 }
202 
204  dealloc();
205  StatFree(&stat);
206 }
207 
208 template<sp_d T> void SparseMatSuperLU<T>::zeros() {
210  dealloc();
211 }
212 
213 template<sp_d T> unique_ptr<MetaMat<T>> SparseMatSuperLU<T>::make_copy() { return std::make_unique<SparseMatSuperLU>(*this); }
214 
215 template<sp_d T> int SparseMatSuperLU<T>::direct_solve(Mat<T>& out_mat, Mat<T>&& in_mat) {
216  if(this->factored) return solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
217 
218  this->factored = true;
219 
220  auto flag = 0;
221 
222  if constexpr(std::is_same_v<T, float>) {
223  alloc(csc_form<float, int>(this->triplet_mat));
224 
225  wrap_b(in_mat);
226 
227  full_solve<float>(flag);
228 
229  out_mat = std::move(in_mat);
230  }
231  else if(Precision::FULL == this->setting.precision) {
232  alloc(csc_form<double, int>(this->triplet_mat));
233 
234  wrap_b(in_mat);
235 
236  full_solve<double>(flag);
237 
238  out_mat = std::move(in_mat);
239  }
240  else {
241  alloc(csc_form<float, int>(this->triplet_mat));
242 
243  const fmat f_mat(arma::size(in_mat), fill::none);
244 
245  wrap_b(f_mat);
246 
247  full_solve<float>(flag);
248 
249  if(0 == flag) flag = solve_trs(out_mat, std::forward<Mat<T>>(in_mat));
250  }
251 
252  return flag;
253 }
254 
255 template<sp_d T> int SparseMatSuperLU<T>::solve_trs(Mat<T>& out_mat, Mat<T>&& in_mat) {
256  auto flag = 0;
257 
258  if constexpr(std::is_same_v<T, float>) {
259  wrap_b(in_mat);
260 
261  tri_solve<float>(flag);
262 
263  out_mat = std::move(in_mat);
264  }
265  else if(Precision::FULL == this->setting.precision) {
266  wrap_b(in_mat);
267 
268  tri_solve<double>(flag);
269 
270  out_mat = std::move(in_mat);
271  }
272  else {
273  out_mat.zeros(arma::size(in_mat));
274 
275  auto multiplier = arma::norm(in_mat);
276 
277  auto counter = 0u;
278  while(counter++ < this->setting.iterative_refinement) {
279  if(multiplier < this->setting.tolerance) break;
280 
281  auto residual = conv_to<fmat>::from(in_mat / multiplier);
282 
283  wrap_b(residual);
284 
285  tri_solve<float>(flag);
286 
287  if(0 != flag) break;
288 
289  const mat incre = multiplier * conv_to<mat>::from(residual);
290 
291  out_mat += incre;
292 
293  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(in_mat -= this->operator*(incre)));
294  }
295  }
296 
297  return flag;
298 }
299 #endif
300 
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:72
const uword n_cols
Definition: MetaMat.hpp:119
const uword n_rows
Definition: MetaMat.hpp:118
A SparseMat class that holds matrices.
Definition: SparseMat.hpp:34
void zeros() override
Definition: SparseMat.hpp:46
A SparseMatSuperLU class that holds matrices.
Definition: SparseMatSuperLU.hpp:37
SparseMatSuperLU(SparseMatSuperLU &&) noexcept=delete
int direct_solve(Mat< T > &out_mat, const Mat< T > &in_mat) override
Definition: SparseMatSuperLU.hpp:69
Definition: csc_form.hpp:25
int SUANPAN_NUM_THREADS
Definition: command.cpp:71
~SparseMatSuperLU() override
Definition: SparseMatSuperLU.hpp:203
unique_ptr< MetaMat< T > > make_copy() override
Definition: SparseMatSuperLU.hpp:213
void zeros() override
Definition: SparseMatSuperLU.hpp:208
SparseMatSuperLU(uword, uword, uword=0)
Definition: SparseMatSuperLU.hpp:171
double norm(const vec &)
Definition: tensor.cpp:370
concept sp_d
Definition: suanPan.h:330
#define suanpan_debug(...)
Definition: suanPan.h:307