suanPan
MetaMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 #ifndef METAMAT_HPP
30 #define METAMAT_HPP
31 
32 #include "triplet_form.hpp"
33 #include "IterativeSolver.hpp"
34 #include "ILU.hpp"
35 #include "Jacobi.hpp"
36 
37 template<typename T, typename U> concept ArmaContainer = std::is_floating_point_v<U> && (std::is_convertible_v<T, Mat<U>> || std::is_convertible_v<T, SpMat<U>>) ;
38 
39 template<sp_d T> class MetaMat;
40 
41 template<sp_d T> class op_add {
42  friend MetaMat<T>;
43 
44  shared_ptr<MetaMat<T>> mat_a, mat_b;
45 
46 public:
47  explicit op_add(const shared_ptr<MetaMat<T>>& A)
48  : mat_a(A)
49  , mat_b(nullptr) {}
50 
51  op_add(const shared_ptr<MetaMat<T>>& A, const shared_ptr<MetaMat<T>>& B)
52  : mat_a(A)
53  , mat_b(B) {}
54 };
55 
56 template<sp_d T> class op_scale {
57  friend MetaMat<T>;
58 
59  T scalar;
60  op_add<T> bracket;
61 
62 public:
63  op_scale(const T A, const shared_ptr<MetaMat<T>>& B)
64  : scalar(A)
65  , bracket(B) {}
66 
67  op_scale(const T A, op_add<T>&& B)
68  : scalar(A)
69  , bracket(std::forward<op_add<T>>(B)) {}
70 };
71 
72 template<sp_d T> class MetaMat {
73 protected:
74  bool factored = false;
75 
77 
78  virtual int direct_solve(Mat<T>&, const Mat<T>&) = 0;
79 
80  virtual int direct_solve(Mat<T>&, Mat<T>&&) = 0;
81 
82  int direct_solve(Mat<T>& X, const SpMat<T>& B) { return this->direct_solve(X, Mat<T>(B)); }
83 
84  int direct_solve(Mat<T>& X, SpMat<T>&& B) { return this->direct_solve(X, B); }
85 
86  int iterative_solve(Mat<T>&, const Mat<T>&);
87 
88  int iterative_solve(Mat<T>& X, const SpMat<T>& B) { return this->iterative_solve(X, Mat<T>(B)); }
89 
90  template<std::invocable<fmat&> F> int mixed_trs(mat& X, mat&& B, F trs) {
91  auto INFO = 0;
92 
93  X = arma::zeros(size(B));
94 
95  auto multiplier = norm(B);
96 
97  auto counter = 0u;
98  while(counter++ < this->setting.iterative_refinement) {
99  if(multiplier < this->setting.tolerance) break;
100 
101  auto residual = conv_to<fmat>::from(B / multiplier);
102 
103  if(0 != (INFO = trs(residual))) break;
104 
105  const mat incre = multiplier * conv_to<mat>::from(residual);
106 
107  X += incre;
108 
109  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
110  }
111 
112  return INFO;
113  }
114 
115 public:
117 
118  const uword n_rows;
119  const uword n_cols;
120  const uword n_elem;
121 
122  MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
123  : triplet_mat(in_rows, in_cols)
124  , n_rows(in_rows)
125  , n_cols(in_cols)
126  , n_elem(in_elem) {}
127 
128  MetaMat(const MetaMat&) = default;
129  MetaMat(MetaMat&&) noexcept = delete;
130  MetaMat& operator=(const MetaMat&) = delete;
131  MetaMat& operator=(MetaMat&&) noexcept = delete;
132  virtual ~MetaMat() = default;
133 
135 
136  [[nodiscard]] SolverSetting<T>& get_solver_setting() { return setting; }
137 
138  void set_factored(const bool F) { factored = F; }
139 
140  [[nodiscard]] virtual bool is_empty() const = 0;
141  virtual void zeros() = 0;
142 
143  virtual unique_ptr<MetaMat> make_copy() = 0;
144 
145  void unify(const uword K) {
146  this->nullify(K);
147  this->at(K, K) = T(1);
148  }
149 
150  virtual void nullify(uword) = 0;
151 
152  [[nodiscard]] virtual T max() const = 0;
153  [[nodiscard]] virtual Col<T> diag() const = 0;
154 
159  virtual T operator()(uword, uword) const = 0;
164  virtual T& unsafe_at(const uword I, const uword J) { return this->at(I, J); }
165 
170  virtual T& at(uword, uword) = 0;
171 
172  [[nodiscard]] virtual const T* memptr() const = 0;
173  virtual T* memptr() = 0;
174 
175  virtual void scale_accu(T, const shared_ptr<MetaMat>&) = 0;
176  virtual void scale_accu(T, const triplet_form<T, uword>&) = 0;
177 
178  void operator+=(const shared_ptr<MetaMat>& M) { return this->scale_accu(1., M); }
179 
180  void operator-=(const shared_ptr<MetaMat>& M) { return this->scale_accu(-1., M); }
181 
182  void operator+=(const op_scale<T>& M) {
183  const auto& bracket = M.bracket;
184  if(nullptr != bracket.mat_a) this->scale_accu(M.scalar, bracket.mat_a);
185  if(nullptr != bracket.mat_b) this->scale_accu(M.scalar, bracket.mat_b);
186  }
187 
188  void operator-=(const op_scale<T>& M) {
189  const auto& bracket = M.bracket;
190  if(nullptr != bracket.mat_a) this->scale_accu(-M.scalar, bracket.mat_a);
191  if(nullptr != bracket.mat_b) this->scale_accu(-M.scalar, bracket.mat_b);
192  }
193 
194  void operator+=(const triplet_form<T, uword>& M) { return this->scale_accu(1., M); }
195 
196  void operator-=(const triplet_form<T, uword>& M) { return this->scale_accu(-1., M); }
197 
198  virtual Mat<T> operator*(const Mat<T>&) const = 0;
199 
200  virtual void operator*=(T) = 0;
201 
202  template<ArmaContainer<T> C> int solve(Mat<T>& X, C&& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, std::forward<C>(B)) : this->iterative_solve(X, std::forward<C>(B)); }
203 
204  template<ArmaContainer<T> C> Mat<T> solve(C&& B) {
205  Mat<T> X;
206 
207  if(SUANPAN_SUCCESS != this->solve(X, std::forward<C>(B))) throw std::runtime_error("fail to solve the system");
208 
209  return X;
210  }
211 
212  [[nodiscard]] virtual int sign_det() const = 0;
213 
214  void save(const char* name) {
215  if(!to_mat(*this).save(name, raw_ascii))
216  suanpan_error("Cannot save to file \"{}\".\n", name);
217  }
218 
219  virtual void csc_condense() {}
220 
221  virtual void csr_condense() {}
222 
223  [[nodiscard]] Col<T> evaluate(const Col<T>& X) const { return this->operator*(X); }
224 };
225 
226 template<sp_d T> int MetaMat<T>::iterative_solve(Mat<T>& X, const Mat<T>& B) {
227  this->csc_condense();
228 
229  X.zeros(arma::size(B));
230 
231  unique_ptr<Preconditioner<T>> preconditioner;
232  if(PreconditionerType::JACOBI == this->setting.preconditioner_type) preconditioner = std::make_unique<Jacobi<T>>(this->diag());
233 #ifndef SUANPAN_SUPERLUMT
234  else if(PreconditionerType::ILU == this->setting.preconditioner_type) {
235  if(this->triplet_mat.is_empty()) preconditioner = std::make_unique<ILU<T>>(to_triplet_form<T, int>(this));
236  else preconditioner = std::make_unique<ILU<T>>(this->triplet_mat);
237  }
238 #endif
239  else if(PreconditionerType::NONE == this->setting.preconditioner_type) preconditioner = std::make_unique<UnityPreconditioner<T>>();
240 
241  if(SUANPAN_SUCCESS != preconditioner->init()) return SUANPAN_FAIL;
242 
243  this->setting.preconditioner = preconditioner.get();
244 
245  std::atomic_int code = 0;
246 
247  if(IterativeSolver::GMRES == setting.iterative_solver)
248  suanpan::for_each(B.n_cols, [&](const uword I) {
249  Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
250  const Col<T> sub_b(B.colptr(I), B.n_rows);
251  auto col_setting = setting;
252  code += GMRES(this, sub_x, sub_b, col_setting);
253  });
254  else if(IterativeSolver::BICGSTAB == setting.iterative_solver)
255  suanpan::for_each(B.n_cols, [&](const uword I) {
256  Col<T> sub_x(X.colptr(I), X.n_rows, false, true);
257  const Col<T> sub_b(B.colptr(I), B.n_rows);
258  auto col_setting = setting;
259  code += BiCGSTAB(this, sub_x, sub_b, col_setting);
260  });
261  else throw invalid_argument("no proper iterative solver assigned but somehow iterative solving is called");
262 
263  return 0 == code ? SUANPAN_SUCCESS : SUANPAN_FAIL;
264 }
265 
266 template<sp_d T> Mat<T> to_mat(const MetaMat<T>& in_mat) {
267  Mat<T> out_mat(in_mat.n_rows, in_mat.n_cols);
268  for(uword J = 0; J < in_mat.n_cols; ++J) for(uword I = 0; I < in_mat.n_rows; ++I) out_mat(I, J) = in_mat(I, J);
269  return out_mat;
270 }
271 
272 template<sp_d T> Mat<T> to_mat(const shared_ptr<MetaMat<T>>& in_mat) { return to_mat(*in_mat); }
273 
274 template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const triplet_form<data_t, index_t>& in_mat) {
275  Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
276  for(index_t I = 0; I < in_mat.n_elem; ++I) out_mat(in_mat.row(I), in_mat.col(I)) += in_mat.val(I);
277  return out_mat;
278 }
279 
280 template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csr_form<data_t, index_t>& in_mat) {
281  Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
282 
283  index_t c_idx = 1;
284  for(index_t I = 0; I < in_mat.n_elem; ++I) {
285  if(I >= in_mat.row_mem()[c_idx]) ++c_idx;
286  out_mat(c_idx - 1, in_mat.col_mem()[I]) += in_mat.val_mem()[I];
287  }
288 
289  return out_mat;
290 }
291 
292 template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csc_form<data_t, index_t>& in_mat) {
293  Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
294 
295  index_t c_idx = 1;
296  for(index_t I = 0; I < in_mat.n_elem; ++I) {
297  if(I >= in_mat.col_mem()[c_idx]) ++c_idx;
298  out_mat(in_mat.row_mem()[I], c_idx - 1) += in_mat.val_mem()[I];
299  }
300 
301  return out_mat;
302 }
303 
304 template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(MetaMat<data_t>* in_mat) {
305  if(!in_mat->triplet_mat.is_empty()) return triplet_form<data_t, index_t>(in_mat->triplet_mat);
306 
307  const sp_i auto n_rows = index_t(in_mat->n_rows);
308  const sp_i auto n_cols = index_t(in_mat->n_cols);
309  const sp_i auto n_elem = index_t(in_mat->n_elem);
310 
311  triplet_form<data_t, index_t> out_mat(n_rows, n_cols, n_elem);
312  for(index_t J = 0; J < n_cols; ++J) for(index_t I = 0; I < n_rows; ++I) out_mat.at(I, J) = in_mat->operator()(I, J);
313 
314  return out_mat;
315 }
316 
317 template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(const shared_ptr<MetaMat<data_t>>& in_mat) { return to_triplet_form<data_t, index_t>(in_mat.get()); }
318 
319 #endif
320 
A ILU class.
Definition: ILU.hpp:40
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:72
Col< T > evaluate(const Col< T > &X) const
Definition: MetaMat.hpp:223
triplet_form< T, uword > triplet_mat
Definition: MetaMat.hpp:116
MetaMat(const MetaMat &)=default
virtual T & unsafe_at(const uword I, const uword J)
Access element without bound check.
Definition: MetaMat.hpp:164
int direct_solve(Mat< T > &X, SpMat< T > &&B)
Definition: MetaMat.hpp:84
virtual unique_ptr< MetaMat > make_copy()=0
virtual T max() const =0
virtual int sign_det() const =0
const uword n_cols
Definition: MetaMat.hpp:119
void unify(const uword K)
Definition: MetaMat.hpp:145
int solve(Mat< T > &X, C &&B)
Definition: MetaMat.hpp:202
MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition: MetaMat.hpp:122
Mat< T > solve(C &&B)
Definition: MetaMat.hpp:204
virtual T & at(uword, uword)=0
Access element with bound check.
virtual Col< T > diag() const =0
void operator-=(const shared_ptr< MetaMat > &M)
Definition: MetaMat.hpp:180
virtual bool is_empty() const =0
void set_factored(const bool F)
Definition: MetaMat.hpp:138
virtual Mat< T > operator*(const Mat< T > &) const =0
const uword n_rows
Definition: MetaMat.hpp:118
void save(const char *name)
Definition: MetaMat.hpp:214
virtual void scale_accu(T, const shared_ptr< MetaMat > &)=0
virtual const T * memptr() const =0
virtual void scale_accu(T, const triplet_form< T, uword > &)=0
virtual void nullify(uword)=0
void operator-=(const triplet_form< T, uword > &M)
Definition: MetaMat.hpp:196
MetaMat(MetaMat &&) noexcept=delete
void operator+=(const shared_ptr< MetaMat > &M)
Definition: MetaMat.hpp:178
virtual int direct_solve(Mat< T > &, Mat< T > &&)=0
virtual void csc_condense()
Definition: MetaMat.hpp:219
void operator-=(const op_scale< T > &M)
Definition: MetaMat.hpp:188
virtual void csr_condense()
Definition: MetaMat.hpp:221
int iterative_solve(Mat< T > &X, const SpMat< T > &B)
Definition: MetaMat.hpp:88
bool factored
Definition: MetaMat.hpp:74
virtual T * memptr()=0
virtual int direct_solve(Mat< T > &, const Mat< T > &)=0
void set_solver_setting(const SolverSetting< T > &SS)
Definition: MetaMat.hpp:134
virtual T operator()(uword, uword) const =0
Access element (read-only), returns zero if out-of-bound.
SolverSetting< T > & get_solver_setting()
Definition: MetaMat.hpp:136
void operator+=(const triplet_form< T, uword > &M)
Definition: MetaMat.hpp:194
const uword n_elem
Definition: MetaMat.hpp:120
int mixed_trs(mat &X, mat &&B, F trs)
Definition: MetaMat.hpp:90
void operator+=(const op_scale< T > &M)
Definition: MetaMat.hpp:182
SolverSetting< T > setting
Definition: MetaMat.hpp:76
virtual void operator*=(T)=0
virtual void zeros()=0
int direct_solve(Mat< T > &X, const SpMat< T > &B)
Definition: MetaMat.hpp:82
Definition: csc_form.hpp:25
const index_t n_rows
Definition: csc_form.hpp:50
const index_t n_cols
Definition: csc_form.hpp:51
const data_t * val_mem() const
Definition: csc_form.hpp:65
const index_t n_elem
Definition: csc_form.hpp:52
const index_t * col_mem() const
Definition: csc_form.hpp:63
const index_t * row_mem() const
Definition: csc_form.hpp:61
Definition: csr_form.hpp:25
const index_t * col_mem() const
Definition: csr_form.hpp:63
const index_t n_rows
Definition: csr_form.hpp:50
const index_t n_cols
Definition: csr_form.hpp:51
const index_t * row_mem() const
Definition: csr_form.hpp:61
const data_t * val_mem() const
Definition: csr_form.hpp:65
const index_t n_elem
Definition: csr_form.hpp:52
Definition: MetaMat.hpp:41
op_add(const shared_ptr< MetaMat< T >> &A, const shared_ptr< MetaMat< T >> &B)
Definition: MetaMat.hpp:51
op_add(const shared_ptr< MetaMat< T >> &A)
Definition: MetaMat.hpp:47
Definition: MetaMat.hpp:56
op_scale(const T A, const shared_ptr< MetaMat< T >> &B)
Definition: MetaMat.hpp:63
op_scale(const T A, op_add< T > &&B)
Definition: MetaMat.hpp:67
const index_t n_rows
Definition: triplet_form.hpp:128
bool is_empty() const
Definition: triplet_form.hpp:169
data_t & at(index_t, index_t)
Definition: triplet_form.hpp:384
const index_t n_cols
Definition: triplet_form.hpp:129
const index_t n_elem
Definition: triplet_form.hpp:130
index_t col(const index_t I) const
Definition: triplet_form.hpp:161
data_t val(const index_t I) const
Definition: triplet_form.hpp:163
index_t row(const index_t I) const
Definition: triplet_form.hpp:159
triplet_form< data_t, index_t > to_triplet_form(MetaMat< data_t > *in_mat)
Definition: MetaMat.hpp:304
concept ArmaContainer
Definition: MetaMat.hpp:37
Mat< T > to_mat(const MetaMat< T > &in_mat)
Definition: MetaMat.hpp:266
int iterative_solve(Mat< T > &, const Mat< T > &)
Definition: MetaMat.hpp:226
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
double norm(const vec &)
Definition: tensor.cpp:370
unsigned iterative_refinement
Definition: SolverSetting.hpp:44
data_t tolerance
Definition: SolverSetting.hpp:43
IterativeSolver iterative_solver
Definition: SolverSetting.hpp:46
#define suanpan_debug(...)
Definition: suanPan.h:307
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:172
constexpr auto SUANPAN_FAIL
Definition: suanPan.h:173
#define suanpan_error(...)
Definition: suanPan.h:309
concept sp_i
Definition: suanPan.h:331