suanPan
BandSymmMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDSYMMMAT_HPP
31 #define BANDSYMMMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandSymmMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'L';
37 
38  static T bin;
39 
40  const uword band;
41  const uword m_rows; // memory block layout
42 
43  int solve_trs(Mat<T>&, Mat<T>&&);
44  int solve_trs(Mat<T>&, const Mat<T>&);
45 
46 public:
47  BandSymmMat(uword, uword);
48 
49  unique_ptr<MetaMat<T>> make_copy() override;
50 
51  void unify(uword) override;
52  void nullify(uword) override;
53 
54  const T& operator()(uword, uword) const override;
55  T& unsafe_at(uword, uword) override;
56  T& at(uword, uword) override;
57 
58  Mat<T> operator*(const Mat<T>&) const override;
59 
60  int direct_solve(Mat<T>&, Mat<T>&&) override;
61  int direct_solve(Mat<T>&, const Mat<T>&) override;
62 };
63 
64 template<sp_d T> T BandSymmMat<T>::bin = 0.;
65 
66 template<sp_d T> BandSymmMat<T>::BandSymmMat(const uword in_size, const uword in_bandwidth)
67  : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
68  , band(in_bandwidth)
69  , m_rows(in_bandwidth + 1) {}
70 
71 template<sp_d T> unique_ptr<MetaMat<T>> BandSymmMat<T>::make_copy() { return std::make_unique<BandSymmMat<T>>(*this); }
72 
73 template<sp_d T> void BandSymmMat<T>::unify(const uword K) {
74  nullify(K);
75  access::rw(this->memory[K * m_rows]) = 1.;
76 }
77 
78 template<sp_d T> void BandSymmMat<T>::nullify(const uword K) {
79  suanpan_for(std::max(band, K) - band, K, [&](const uword I) { access::rw(this->memory[K - I + I * m_rows]) = 0.; });
80  const auto t_factor = K * m_rows - K;
81  suanpan_for(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
82 
83  this->factored = false;
84 }
85 
86 template<sp_d T> const T& BandSymmMat<T>::operator()(const uword in_row, const uword in_col) const {
87  if(in_row > band + in_col) return bin = 0.;
88  return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
89 }
90 
91 template<sp_d T> T& BandSymmMat<T>::unsafe_at(const uword in_row, const uword in_col) {
92  this->factored = false;
93  return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
94 }
95 
96 template<sp_d T> T& BandSymmMat<T>::at(const uword in_row, const uword in_col) {
97  if(in_row > band + in_col || in_row < in_col) [[unlikely]] return bin = 0.;
98  this->factored = false;
99  return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
100 }
101 
102 template<sp_d T> Mat<T> BandSymmMat<T>::operator*(const Mat<T>& X) const {
103  Mat<T> Y(arma::size(X));
104 
105  const auto N = static_cast<int>(this->n_cols);
106  const auto K = static_cast<int>(band);
107  const auto LDA = static_cast<int>(m_rows);
108  const auto INC = 1;
109  T ALPHA = 1.;
110  T BETA = 0.;
111 
112  if(std::is_same_v<T, float>) {
113  using E = float;
114  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
115  }
116  else if(std::is_same_v<T, double>) {
117  using E = double;
118  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
119  }
120 
121  return Y;
122 }
123 
124 template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
125  if(this->factored) return this->solve_trs(X, B);
126 
127  const auto N = static_cast<int>(this->n_rows);
128  const auto KD = static_cast<int>(band);
129  const auto NRHS = static_cast<int>(B.n_cols);
130  const auto LDAB = static_cast<int>(m_rows);
131  const auto LDB = static_cast<int>(B.n_rows);
132  auto INFO = 0;
133 
134  this->factored = true;
135 
136  if(std::is_same_v<T, float>) {
137  using E = float;
138  X = B;
139  arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
140  }
141  else if(Precision::FULL == this->setting.precision) {
142  using E = double;
143  X = B;
144  arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
145  }
146  else {
147  this->s_memory = this->to_float();
148  arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
149  if(0 == INFO) INFO = this->solve_trs(X, B);
150  }
151 
152  if(0 != INFO)
153  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
154 
155  return INFO;
156 }
157 
158 template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
159  const auto N = static_cast<int>(this->n_rows);
160  const auto KD = static_cast<int>(band);
161  const auto NRHS = static_cast<int>(B.n_cols);
162  const auto LDAB = static_cast<int>(m_rows);
163  const auto LDB = static_cast<int>(B.n_rows);
164  auto INFO = 0;
165 
166  if(std::is_same_v<T, float>) {
167  using E = float;
168  X = B;
169  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
170  }
171  else if(Precision::FULL == this->setting.precision) {
172  using E = double;
173  X = B;
174  arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
175  }
176  else {
177  X = arma::zeros(B.n_rows, B.n_cols);
178 
179  mat full_residual = B;
180 
181  auto multiplier = norm(full_residual);
182 
183  auto counter = 0u;
184  while(counter++ < this->setting.iterative_refinement) {
185  if(multiplier < this->setting.tolerance) break;
186 
187  auto residual = conv_to<fmat>::from(full_residual / multiplier);
188 
189  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
190  if(0 != INFO) break;
191 
192  const mat incre = multiplier * conv_to<mat>::from(residual);
193 
194  X += incre;
195 
196  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
197  }
198  }
199 
200  if(0 != INFO)
201  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
202 
203  return INFO;
204 }
205 
206 template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
207  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
208 
209  const auto N = static_cast<int>(this->n_rows);
210  const auto KD = static_cast<int>(band);
211  const auto NRHS = static_cast<int>(B.n_cols);
212  const auto LDAB = static_cast<int>(m_rows);
213  const auto LDB = static_cast<int>(B.n_rows);
214  auto INFO = 0;
215 
216  this->factored = true;
217 
218  if(std::is_same_v<T, float>) {
219  using E = float;
220  arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
221  X = std::move(B);
222  }
223  else if(Precision::FULL == this->setting.precision) {
224  using E = double;
225  arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
226  X = std::move(B);
227  }
228  else {
229  this->s_memory = this->to_float();
230  arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
231  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
232  }
233 
234  if(0 != INFO)
235  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
236 
237  return INFO;
238 }
239 
240 template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
241  const auto N = static_cast<int>(this->n_rows);
242  const auto KD = static_cast<int>(band);
243  const auto NRHS = static_cast<int>(B.n_cols);
244  const auto LDAB = static_cast<int>(m_rows);
245  const auto LDB = static_cast<int>(B.n_rows);
246  auto INFO = 0;
247 
248  if(std::is_same_v<T, float>) {
249  using E = float;
250  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
251  X = std::move(B);
252  }
253  else if(Precision::FULL == this->setting.precision) {
254  using E = double;
255  arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
256  X = std::move(B);
257  }
258  else {
259  X = arma::zeros(B.n_rows, B.n_cols);
260 
261  auto multiplier = norm(B);
262 
263  auto counter = 0u;
264  while(counter++ < this->setting.iterative_refinement) {
265  if(multiplier < this->setting.tolerance) break;
266 
267  auto residual = conv_to<fmat>::from(B / multiplier);
268 
269  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
270  if(0 != INFO) break;
271 
272  const mat incre = multiplier * conv_to<mat>::from(residual);
273 
274  X += incre;
275 
276  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
277  }
278  }
279 
280  if(0 != INFO)
281  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
282 
283  return INFO;
284 }
285 
286 #endif
287 
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
T & at(uword, uword) override
Access element with bound check.
Definition: BandSymmMat.hpp:96
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandSymmMat.hpp:86
void nullify(uword) override
Definition: BandSymmMat.hpp:78
BandSymmMat(uword, uword)
Definition: BandSymmMat.hpp:66
Mat< T > operator*(const Mat< T > &) const override
Definition: BandSymmMat.hpp:102
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandSymmMat.hpp:71
void unify(uword) override
Definition: BandSymmMat.hpp:73
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandSymmMat.hpp:206
T & unsafe_at(uword, uword) override
Access element without bound check.
Definition: BandSymmMat.hpp:91
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27