suanPan
BandMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMAT_HPP
31 #define BANDMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandMat final : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  static T bin;
39 
40  const uword l_band;
41  const uword u_band;
42  const uword s_band;
43  const uword m_rows; // memory block layout
44 
45  int solve_trs(Mat<T>&, Mat<T>&&);
46  int solve_trs(Mat<T>&, const Mat<T>&);
47 
48 public:
49  BandMat(uword, uword, uword);
50 
51  unique_ptr<MetaMat<T>> make_copy() override;
52 
53  void unify(uword) override;
54  void nullify(uword) override;
55 
56  const T& operator()(uword, uword) const override;
57  T& unsafe_at(uword, uword) override;
58  T& at(uword, uword) override;
59 
60  Mat<T> operator*(const Mat<T>&) const override;
61 
62  int direct_solve(Mat<T>&, Mat<T>&&) override;
63  int direct_solve(Mat<T>&, const Mat<T>&) override;
64 };
65 
66 template<sp_d T> T BandMat<T>::bin = 0.;
67 
68 template<sp_d T> BandMat<T>::BandMat(const uword in_size, const uword in_l, const uword in_u)
69  : DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
70  , l_band(in_l)
71  , u_band(in_u)
72  , s_band(in_l + in_u)
73  , m_rows(2 * in_l + in_u + 1) {}
74 
75 template<sp_d T> unique_ptr<MetaMat<T>> BandMat<T>::make_copy() { return std::make_unique<BandMat<T>>(*this); }
76 
77 template<sp_d T> void BandMat<T>::unify(const uword K) {
78  nullify(K);
79  access::rw(this->memory[s_band + K * m_rows]) = 1.;
80 }
81 
82 template<sp_d T> void BandMat<T>::nullify(const uword K) {
83  suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { access::rw(this->memory[I + s_band + K * (m_rows - 1)]) = 0.; });
84  suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { access::rw(this->memory[K + s_band + I * (m_rows - 1)]) = 0.; });
85 
86  this->factored = false;
87 }
88 
89 template<sp_d T> const T& BandMat<T>::operator()(const uword in_row, const uword in_col) const {
90  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
91  return this->memory[in_row + s_band + in_col * (m_rows - 1)];
92 }
93 
94 template<sp_d T> T& BandMat<T>::unsafe_at(const uword in_row, const uword in_col) {
95  this->factored = false;
96  return access::rw(this->memory[in_row + s_band + in_col * (m_rows - 1)]);
97 }
98 
99 template<sp_d T> T& BandMat<T>::at(const uword in_row, const uword in_col) {
100  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
101  this->factored = false;
102  return access::rw(this->memory[in_row + s_band + in_col * (m_rows - 1)]);
103 }
104 
105 template<sp_d T> Mat<T> BandMat<T>::operator*(const Mat<T>& X) const {
106  Mat<T> Y(arma::size(X));
107 
108  const auto M = static_cast<int>(this->n_rows);
109  const auto N = static_cast<int>(this->n_cols);
110  const auto KL = static_cast<int>(l_band);
111  const auto KU = static_cast<int>(u_band);
112  const auto LDA = static_cast<int>(m_rows);
113  const auto INC = 1;
114  T ALPHA = 1.;
115  T BETA = 0.;
116 
117  if(std::is_same_v<T, float>) {
118  using E = float;
119  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
120  }
121  else if(std::is_same_v<T, double>) {
122  using E = double;
123  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
124  }
125 
126  return Y;
127 }
128 
129 template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
130  if(this->factored) return this->solve_trs(X, B);
131 
132  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
133 
134  auto INFO = 0;
135 
136  auto N = static_cast<int>(this->n_rows);
137  const auto KL = static_cast<int>(l_band);
138  const auto KU = static_cast<int>(u_band);
139  const auto NRHS = static_cast<int>(B.n_cols);
140  const auto LDAB = static_cast<int>(m_rows);
141  const auto LDB = static_cast<int>(B.n_rows);
142  this->pivot.zeros(N);
143  this->factored = true;
144 
145  if(std::is_same_v<T, float>) {
146  using E = float;
147  X = B;
148  arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
149  }
150  else if(Precision::FULL == this->setting.precision) {
151  using E = double;
152  X = B;
153  arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
154  }
155  else {
156  this->s_memory = this->to_float();
157  arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
158  if(0 == INFO) INFO = this->solve_trs(X, B);
159  }
160 
161  if(0 != INFO)
162  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
163 
164  return INFO;
165 }
166 
167 template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
168  auto INFO = 0;
169 
170  const auto N = static_cast<int>(this->n_rows);
171  const auto KL = static_cast<int>(l_band);
172  const auto KU = static_cast<int>(u_band);
173  const auto NRHS = static_cast<int>(B.n_cols);
174  const auto LDAB = static_cast<int>(m_rows);
175  const auto LDB = static_cast<int>(B.n_rows);
176 
177  if(std::is_same_v<T, float>) {
178  using E = float;
179  X = B;
180  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
181  }
182  else if(Precision::FULL == this->setting.precision) {
183  using E = double;
184  X = B;
185  arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
186  }
187  else {
188  X = arma::zeros(B.n_rows, B.n_cols);
189 
190  mat full_residual = B;
191 
192  auto multiplier = norm(full_residual);
193 
194  auto counter = 0u;
195  while(counter++ < this->setting.iterative_refinement) {
196  if(multiplier < this->setting.tolerance) break;
197 
198  auto residual = conv_to<fmat>::from(full_residual / multiplier);
199 
200  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
201  if(0 != INFO) break;
202 
203  const mat incre = multiplier * conv_to<mat>::from(residual);
204 
205  X += incre;
206 
207  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
208  }
209  }
210 
211  if(0 != INFO)
212  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
213 
214  return INFO;
215 }
216 
217 template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
218  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
219 
220  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
221 
222  auto INFO = 0;
223 
224  auto N = static_cast<int>(this->n_rows);
225  const auto KL = static_cast<int>(l_band);
226  const auto KU = static_cast<int>(u_band);
227  const auto NRHS = static_cast<int>(B.n_cols);
228  const auto LDAB = static_cast<int>(m_rows);
229  const auto LDB = static_cast<int>(B.n_rows);
230  this->pivot.zeros(N);
231  this->factored = true;
232 
233  if(std::is_same_v<T, float>) {
234  using E = float;
235  arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
236  X = std::move(B);
237  }
238  else if(Precision::FULL == this->setting.precision) {
239  using E = double;
240  arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
241  X = std::move(B);
242  }
243  else {
244  this->s_memory = this->to_float();
245  arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
246  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
247  }
248 
249  if(0 != INFO)
250  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
251 
252  return INFO;
253 }
254 
255 template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
256  auto INFO = 0;
257 
258  const auto N = static_cast<int>(this->n_rows);
259  const auto KL = static_cast<int>(l_band);
260  const auto KU = static_cast<int>(u_band);
261  const auto NRHS = static_cast<int>(B.n_cols);
262  const auto LDAB = static_cast<int>(m_rows);
263  const auto LDB = static_cast<int>(B.n_rows);
264 
265  if(std::is_same_v<T, float>) {
266  using E = float;
267  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
268  X = std::move(B);
269  }
270  else if(Precision::FULL == this->setting.precision) {
271  using E = double;
272  arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
273  X = std::move(B);
274  }
275  else {
276  X = arma::zeros(B.n_rows, B.n_cols);
277 
278  auto multiplier = norm(B);
279 
280  auto counter = 0u;
281  while(counter++ < this->setting.iterative_refinement) {
282  if(multiplier < this->setting.tolerance) break;
283 
284  auto residual = conv_to<fmat>::from(B / multiplier);
285 
286  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
287  if(0 != INFO) break;
288 
289  const mat incre = multiplier * conv_to<mat>::from(residual);
290 
291  X += incre;
292 
293  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
294  }
295  }
296 
297  if(0 != INFO)
298  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
299 
300  return INFO;
301 }
302 
303 #endif
304 
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
void unify(uword) override
Definition: BandMat.hpp:77
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandMat.hpp:89
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMat.hpp:105
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMat.hpp:217
T & at(uword, uword) override
Access element with bound check.
Definition: BandMat.hpp:99
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMat.hpp:75
BandMat(uword, uword, uword)
Definition: BandMat.hpp:68
T & unsafe_at(uword, uword) override
Access element without bound check.
Definition: BandMat.hpp:94
void nullify(uword) override
Definition: BandMat.hpp:82
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:284
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27