suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
MetaMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29#ifndef METAMAT_HPP
30#define METAMAT_HPP
31
32#include "SolverSetting.hpp"
33#include "triplet_form.hpp"
34
35#ifdef SUANSPAN_64BIT_INT
36using la_it = std::int64_t;
37#else
38using la_it = std::int32_t;
39#endif
40
41template<sp_d T> class MetaMat;
42
43template<sp_d T> class op_add {
44 friend MetaMat<T>;
45
46 shared_ptr<MetaMat<T>> mat_a, mat_b;
47
48public:
49 explicit op_add(const shared_ptr<MetaMat<T>>& A)
50 : mat_a(A)
51 , mat_b(nullptr) {}
52
53 op_add(const shared_ptr<MetaMat<T>>& A, const shared_ptr<MetaMat<T>>& B)
54 : mat_a(A)
55 , mat_b(B) {}
56};
57
58template<sp_d T> class op_scale {
59 friend MetaMat<T>;
60
61 T scalar;
62 op_add<T> bracket;
63
64public:
65 op_scale(const T A, const shared_ptr<MetaMat<T>>& B)
66 : scalar(A)
67 , bracket(B) {}
68
69 op_scale(const T A, op_add<T>&& B)
70 : scalar(A)
71 , bracket(std::move(B)) {}
72};
73
74template<sp_d T> class MetaMat {
75protected:
76 bool factored = false;
77
79
80 virtual int direct_solve(Mat<T>&, const Mat<T>&) = 0;
81
82 virtual int direct_solve(Mat<T>&, Mat<T>&&) = 0;
83
84 int direct_solve(Mat<T>& X, const SpMat<T>& B) { return this->direct_solve(X, Mat<T>(B)); }
85
86 int direct_solve(Mat<T>& X, SpMat<T>&& B) { return this->direct_solve(X, B); }
87
88 template<std::invocable<fmat&> F> int mixed_trs(mat& X, mat&& B, F&& trs) {
89 auto INFO = 0;
90
91 X = arma::zeros(size(B));
92
93 std::uint8_t counter{0};
94 while(counter++ < this->setting.iterative_refinement) {
95 const auto multiplier = norm(B);
96 if(multiplier < this->setting.tolerance) break;
97 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier);
98
99 auto residual = conv_to<fmat>::from(B / multiplier);
100
101 if(0 != (INFO = trs(residual))) break;
102
103 const mat incre = multiplier * conv_to<mat>::from(residual);
104
105 X += incre;
106 B -= this->operator*(incre);
107 }
108
109 return INFO;
110 }
111
112public:
114
115 const uword n_rows;
116 const uword n_cols;
117 const uword n_elem;
118
119 MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
120 : triplet_mat(in_rows, in_cols)
121 , n_rows(in_rows)
122 , n_cols(in_cols)
123 , n_elem(in_elem) {}
124
125 MetaMat(const MetaMat&) = default;
126 MetaMat(MetaMat&&) = delete;
127 MetaMat& operator=(const MetaMat&) = delete;
129 virtual ~MetaMat() = default;
130
132
133 [[nodiscard]] SolverSetting<T>& get_solver_setting() { return setting; }
134
135 void set_factored(const bool F) { factored = F; }
136
137 [[nodiscard]] virtual bool is_empty() const = 0;
138 virtual void zeros() = 0;
139
140 virtual unique_ptr<MetaMat> unique_copy() = 0;
141
142 void unify(const uword K) {
143 this->nullify(K);
144 this->at(K, K) = T(1);
145 }
146
147 virtual void nullify(uword) = 0;
148
149 void unify(const uvec& list) {
150 for(const auto index : list) unify(index);
151 }
152
153 void nullify(const uvec& list) {
154 for(const auto index : list) nullify(index);
155 }
156
157 [[nodiscard]] virtual T max() const = 0;
158
163 virtual T operator()(uword, uword) const = 0;
168 virtual T& unsafe_at(const uword I, const uword J) { return this->at(I, J); }
169
174 virtual T& at(uword, uword) = 0;
175
176 [[nodiscard]] virtual const T* memptr() const = 0;
177 virtual T* memptr() = 0;
178
179 virtual void scale_accu(T, const shared_ptr<MetaMat>&) = 0;
180 virtual void scale_accu(T, const triplet_form<T, uword>&) = 0;
181
182 void operator+=(const shared_ptr<MetaMat>& M) { return this->scale_accu(1., M); }
183
184 void operator-=(const shared_ptr<MetaMat>& M) { return this->scale_accu(-1., M); }
185
186 void operator+=(const op_scale<T>& M) {
187 const auto& bracket = M.bracket;
188 if(nullptr != bracket.mat_a) this->scale_accu(M.scalar, bracket.mat_a);
189 if(nullptr != bracket.mat_b) this->scale_accu(M.scalar, bracket.mat_b);
190 }
191
192 void operator-=(const op_scale<T>& M) {
193 const auto& bracket = M.bracket;
194 if(nullptr != bracket.mat_a) this->scale_accu(-M.scalar, bracket.mat_a);
195 if(nullptr != bracket.mat_b) this->scale_accu(-M.scalar, bracket.mat_b);
196 }
197
198 void operator+=(const triplet_form<T, uword>& M) { return this->scale_accu(1., M); }
199
200 void operator-=(const triplet_form<T, uword>& M) { return this->scale_accu(-1., M); }
201
202 virtual Mat<T> operator*(const Mat<T>&) const = 0;
203
204 virtual void operator*=(T) = 0;
205
206 template<typename C> requires is_arma_mat<T, C>
207 int solve(Mat<T>& X, C&& B) { return this->direct_solve(X, std::forward<C>(B)); }
208
209 template<typename C> requires is_arma_mat<T, C>
210 Mat<T> solve(C&& B) {
211 Mat<T> X;
212
213 if(SUANPAN_SUCCESS != this->solve(X, std::forward<C>(B))) throw std::runtime_error("fail to solve the system");
214
215 return X;
216 }
217
218 [[nodiscard]] virtual int sign_det() const { throw std::runtime_error("not supported"); }
219
220 virtual void allreduce() = 0;
221
222 void save(const char* name) {
223 if(!to_mat(*this).save(name, raw_ascii))
224 suanpan_error("Cannot save to file \"{}\".\n", name);
225 }
226
227 virtual void csc_condense() {}
228
229 virtual void csr_condense() {}
230};
231
232template<sp_d T> Mat<T> to_mat(const MetaMat<T>& in_mat) {
233 Mat<T> out_mat(in_mat.n_rows, in_mat.n_cols);
234 for(uword J = 0; J < in_mat.n_cols; ++J)
235 for(uword I = 0; I < in_mat.n_rows; ++I) out_mat(I, J) = in_mat(I, J);
236 return out_mat;
237}
238
239template<sp_d T> Mat<T> to_mat(const shared_ptr<MetaMat<T>>& in_mat) { return to_mat(*in_mat); }
240
241template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const triplet_form<data_t, index_t>& in_mat) {
242 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
243 for(index_t I = 0; I < in_mat.n_elem; ++I) out_mat(in_mat.row(I), in_mat.col(I)) += in_mat.val(I);
244 return out_mat;
245}
246
247template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csr_form<data_t, index_t>& in_mat) {
248 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
249
250 index_t c_idx = 1;
251 for(index_t I = 0; I < in_mat.n_elem; ++I) {
252 if(I >= in_mat.row_mem()[c_idx]) ++c_idx;
253 out_mat(c_idx - 1, in_mat.col_mem()[I]) += in_mat.val_mem()[I];
254 }
255
256 return out_mat;
257}
258
259template<sp_d data_t, sp_i index_t> Mat<data_t> to_mat(const csc_form<data_t, index_t>& in_mat) {
260 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
261
262 index_t c_idx = 1;
263 for(index_t I = 0; I < in_mat.n_elem; ++I) {
264 if(I >= in_mat.col_mem()[c_idx]) ++c_idx;
265 out_mat(in_mat.row_mem()[I], c_idx - 1) += in_mat.val_mem()[I];
266 }
267
268 return out_mat;
269}
270
271template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(MetaMat<data_t>* in_mat) {
272 if(!in_mat->triplet_mat.is_empty()) return triplet_form<data_t, index_t>(in_mat->triplet_mat);
273
274 const sp_i auto n_rows = index_t(in_mat->n_rows);
275 const sp_i auto n_cols = index_t(in_mat->n_cols);
276 const sp_i auto n_elem = index_t(in_mat->n_elem);
277
278 triplet_form<data_t, index_t> out_mat(n_rows, n_cols, n_elem);
279 for(index_t J = 0; J < n_cols; ++J)
280 for(index_t I = 0; I < n_rows; ++I) out_mat.at(I, J) = in_mat->operator()(I, J);
281
282 return out_mat;
283}
284
285template<sp_d data_t, sp_i index_t> triplet_form<data_t, index_t> to_triplet_form(const shared_ptr<MetaMat<data_t>>& in_mat) { return to_triplet_form<data_t, index_t>(in_mat.get()); }
286
287#endif
288
A MetaMat class that holds matrices.
Definition MetaMat.hpp:74
Mat< T > solve(C &&B)
Definition MetaMat.hpp:210
virtual unique_ptr< MetaMat > unique_copy()=0
triplet_form< T, uword > triplet_mat
Definition MetaMat.hpp:113
MetaMat(const MetaMat &)=default
MetaMat & operator=(const MetaMat &)=delete
int direct_solve(Mat< T > &X, SpMat< T > &&B)
Definition MetaMat.hpp:86
virtual T max() const =0
const uword n_cols
Definition MetaMat.hpp:116
void unify(const uword K)
Definition MetaMat.hpp:142
virtual int sign_det() const
Definition MetaMat.hpp:218
MetaMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition MetaMat.hpp:119
virtual const T * memptr() const =0
int solve(Mat< T > &X, C &&B)
Definition MetaMat.hpp:207
void operator-=(const shared_ptr< MetaMat > &M)
Definition MetaMat.hpp:184
virtual bool is_empty() const =0
void set_factored(const bool F)
Definition MetaMat.hpp:135
int mixed_trs(mat &X, mat &&B, F &&trs)
Definition MetaMat.hpp:88
const uword n_rows
Definition MetaMat.hpp:115
void save(const char *name)
Definition MetaMat.hpp:222
virtual void scale_accu(T, const shared_ptr< MetaMat > &)=0
virtual T & unsafe_at(const uword I, const uword J)
Access element without bound check.
Definition MetaMat.hpp:168
SolverSetting< T > & get_solver_setting()
Definition MetaMat.hpp:133
virtual T * memptr()=0
virtual void scale_accu(T, const triplet_form< T, uword > &)=0
virtual void nullify(uword)=0
virtual ~MetaMat()=default
void operator-=(const triplet_form< T, uword > &M)
Definition MetaMat.hpp:200
void operator+=(const shared_ptr< MetaMat > &M)
Definition MetaMat.hpp:182
virtual int direct_solve(Mat< T > &, Mat< T > &&)=0
virtual void csc_condense()
Definition MetaMat.hpp:227
void operator-=(const op_scale< T > &M)
Definition MetaMat.hpp:192
virtual void csr_condense()
Definition MetaMat.hpp:229
bool factored
Definition MetaMat.hpp:76
MetaMat & operator=(MetaMat &&)=delete
virtual int direct_solve(Mat< T > &, const Mat< T > &)=0
void nullify(const uvec &list)
Definition MetaMat.hpp:153
void set_solver_setting(const SolverSetting< T > &SS)
Definition MetaMat.hpp:131
virtual T operator()(uword, uword) const =0
Access element (read-only), returns zero if out-of-bound.
void unify(const uvec &list)
Definition MetaMat.hpp:149
void operator+=(const triplet_form< T, uword > &M)
Definition MetaMat.hpp:198
virtual Mat< T > operator*(const Mat< T > &) const =0
const uword n_elem
Definition MetaMat.hpp:117
void operator+=(const op_scale< T > &M)
Definition MetaMat.hpp:186
MetaMat(MetaMat &&)=delete
SolverSetting< T > setting
Definition MetaMat.hpp:78
virtual void allreduce()=0
virtual void operator*=(T)=0
virtual T & at(uword, uword)=0
Access element with bound check.
virtual void zeros()=0
int direct_solve(Mat< T > &X, const SpMat< T > &B)
Definition MetaMat.hpp:84
Definition csc_form.hpp:25
const index_t n_rows
Definition csc_form.hpp:50
const index_t n_cols
Definition csc_form.hpp:51
const index_t * col_mem() const
Definition csc_form.hpp:63
const index_t * row_mem() const
Definition csc_form.hpp:61
const data_t * val_mem() const
Definition csc_form.hpp:65
const index_t n_elem
Definition csc_form.hpp:52
Definition csr_form.hpp:25
const index_t * row_mem() const
Definition csr_form.hpp:61
const index_t n_rows
Definition csr_form.hpp:50
const index_t n_cols
Definition csr_form.hpp:51
const data_t * val_mem() const
Definition csr_form.hpp:65
const index_t * col_mem() const
Definition csr_form.hpp:63
const index_t n_elem
Definition csr_form.hpp:52
Definition MetaMat.hpp:43
op_add(const shared_ptr< MetaMat< T > > &A, const shared_ptr< MetaMat< T > > &B)
Definition MetaMat.hpp:53
op_add(const shared_ptr< MetaMat< T > > &A)
Definition MetaMat.hpp:49
Definition MetaMat.hpp:58
op_scale(const T A, const shared_ptr< MetaMat< T > > &B)
Definition MetaMat.hpp:65
op_scale(const T A, op_add< T > &&B)
Definition MetaMat.hpp:69
Definition triplet_form.hpp:62
const index_t n_rows
Definition triplet_form.hpp:128
bool is_empty() const
Definition triplet_form.hpp:172
data_t & at(index_t, index_t)
Definition triplet_form.hpp:415
const index_t n_cols
Definition triplet_form.hpp:129
const index_t n_elem
Definition triplet_form.hpp:130
index_t col(const index_t I) const
Definition triplet_form.hpp:164
data_t val(const index_t I) const
Definition triplet_form.hpp:166
index_t row(const index_t I) const
Definition triplet_form.hpp:162
Definition suanPan.h:360
Definition suanPan.h:358
std::int32_t la_it
Definition MetaMat.hpp:38
Mat< T > to_mat(const MetaMat< T > &in_mat)
Definition MetaMat.hpp:232
triplet_form< data_t, index_t > to_triplet_form(MetaMat< data_t > *in_mat)
Definition MetaMat.hpp:271
Definition SolverSetting.hpp:28
data_t tolerance
Definition SolverSetting.hpp:30
std::uint8_t iterative_refinement
Definition SolverSetting.hpp:31
#define suanpan_debug(...)
Definition suanPan.h:347
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:166
#define suanpan_error(...)
Definition suanPan.h:349