suanPan
BandMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMAT_HPP
31 #define BANDMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandMat final : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  static T bin;
39 
40  const uword l_band;
41  const uword u_band;
42  const uword s_band;
43  const uword m_rows; // memory block layout
44 
45  int solve_trs(Mat<T>&, Mat<T>&&);
46  int solve_trs(Mat<T>&, const Mat<T>&);
47 
48 public:
49  BandMat(uword, uword, uword);
50 
51  unique_ptr<MetaMat<T>> make_copy() override;
52 
53  void unify(uword) override;
54  void nullify(uword) override;
55 
56  const T& operator()(uword, uword) const override;
57  T& at(uword, uword) override;
58 
59  Mat<T> operator*(const Mat<T>&) const override;
60 
61  int direct_solve(Mat<T>&, Mat<T>&&) override;
62  int direct_solve(Mat<T>&, const Mat<T>&) override;
63 };
64 
65 template<sp_d T> T BandMat<T>::bin = 0.;
66 
67 template<sp_d T> BandMat<T>::BandMat(const uword in_size, const uword in_l, const uword in_u)
68  : DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
69  , l_band(in_l)
70  , u_band(in_u)
71  , s_band(in_l + in_u)
72  , m_rows(2 * in_l + in_u + 1) {}
73 
74 template<sp_d T> unique_ptr<MetaMat<T>> BandMat<T>::make_copy() { return std::make_unique<BandMat<T>>(*this); }
75 
76 template<sp_d T> void BandMat<T>::unify(const uword K) {
77  nullify(K);
78  access::rw(this->memory[s_band + K * m_rows]) = 1.;
79 }
80 
81 template<sp_d T> void BandMat<T>::nullify(const uword K) {
82  suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { access::rw(this->memory[I + s_band + K * (m_rows - 1)]) = 0.; });
83  suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { access::rw(this->memory[K + s_band + I * (m_rows - 1)]) = 0.; });
84 
85  this->factored = false;
86 }
87 
88 template<sp_d T> const T& BandMat<T>::operator()(const uword in_row, const uword in_col) const {
89  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
90  return this->memory[in_row + s_band + in_col * (m_rows - 1)];
91 }
92 
93 template<sp_d T> T& BandMat<T>::at(const uword in_row, const uword in_col) {
94  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
95  this->factored = false;
96  return access::rw(this->memory[in_row + s_band + in_col * (m_rows - 1)]);
97 }
98 
99 template<sp_d T> Mat<T> BandMat<T>::operator*(const Mat<T>& X) const {
100  Mat<T> Y(arma::size(X));
101 
102  const auto M = static_cast<int>(this->n_rows);
103  const auto N = static_cast<int>(this->n_cols);
104  const auto KL = static_cast<int>(l_band);
105  const auto KU = static_cast<int>(u_band);
106  const auto LDA = static_cast<int>(m_rows);
107  const auto INC = 1;
108  T ALPHA = 1.;
109  T BETA = 0.;
110 
111  if(std::is_same_v<T, float>) {
112  using E = float;
113  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
114  }
115  else if(std::is_same_v<T, double>) {
116  using E = double;
117  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
118  }
119 
120  return Y;
121 }
122 
123 template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
124  if(this->factored) return this->solve_trs(X, B);
125 
126  suanpan_debug([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
127 
128  auto INFO = 0;
129 
130  auto N = static_cast<int>(this->n_rows);
131  const auto KL = static_cast<int>(l_band);
132  const auto KU = static_cast<int>(u_band);
133  const auto NRHS = static_cast<int>(B.n_cols);
134  const auto LDAB = static_cast<int>(m_rows);
135  const auto LDB = static_cast<int>(B.n_rows);
136  this->pivot.zeros(N);
137  this->factored = true;
138 
139  if(std::is_same_v<T, float>) {
140  using E = float;
141  X = B;
142  arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
143  }
144  else if(Precision::FULL == this->setting.precision) {
145  using E = double;
146  X = B;
147  arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
148  }
149  else {
150  this->s_memory = this->to_float();
151  arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
152  if(0 == INFO) INFO = this->solve_trs(X, B);
153  }
154 
155  if(0 != INFO) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
156 
157  return INFO;
158 }
159 
160 template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
161  auto INFO = 0;
162 
163  const auto N = static_cast<int>(this->n_rows);
164  const auto KL = static_cast<int>(l_band);
165  const auto KU = static_cast<int>(u_band);
166  const auto NRHS = static_cast<int>(B.n_cols);
167  const auto LDAB = static_cast<int>(m_rows);
168  const auto LDB = static_cast<int>(B.n_rows);
169 
170  if(std::is_same_v<T, float>) {
171  using E = float;
172  X = B;
173  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
174  }
175  else if(Precision::FULL == this->setting.precision) {
176  using E = double;
177  X = B;
178  arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
179  }
180  else {
181  X = arma::zeros(B.n_rows, B.n_cols);
182 
183  mat full_residual = B;
184 
185  auto multiplier = norm(full_residual);
186 
187  auto counter = 0u;
188  while(counter++ < this->setting.iterative_refinement) {
189  if(multiplier < this->setting.tolerance) break;
190 
191  auto residual = conv_to<fmat>::from(full_residual / multiplier);
192 
193  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
194  if(0 != INFO) break;
195 
196  const mat incre = multiplier * conv_to<mat>::from(residual);
197 
198  X += incre;
199 
200  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
201  }
202  }
203 
204  if(INFO != 0) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
205 
206  return INFO;
207 }
208 
209 template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
210  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
211 
212  suanpan_debug([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
213 
214  auto INFO = 0;
215 
216  auto N = static_cast<int>(this->n_rows);
217  const auto KL = static_cast<int>(l_band);
218  const auto KU = static_cast<int>(u_band);
219  const auto NRHS = static_cast<int>(B.n_cols);
220  const auto LDAB = static_cast<int>(m_rows);
221  const auto LDB = static_cast<int>(B.n_rows);
222  this->pivot.zeros(N);
223  this->factored = true;
224 
225  if(std::is_same_v<T, float>) {
226  using E = float;
227  arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
228  X = std::move(B);
229  }
230  else if(Precision::FULL == this->setting.precision) {
231  using E = double;
232  arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
233  X = std::move(B);
234  }
235  else {
236  this->s_memory = this->to_float();
237  arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
238  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
239  }
240 
241  if(0 != INFO) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
242 
243  return INFO;
244 }
245 
246 template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
247  auto INFO = 0;
248 
249  const auto N = static_cast<int>(this->n_rows);
250  const auto KL = static_cast<int>(l_band);
251  const auto KU = static_cast<int>(u_band);
252  const auto NRHS = static_cast<int>(B.n_cols);
253  const auto LDAB = static_cast<int>(m_rows);
254  const auto LDB = static_cast<int>(B.n_rows);
255 
256  if(std::is_same_v<T, float>) {
257  using E = float;
258  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
259  X = std::move(B);
260  }
261  else if(Precision::FULL == this->setting.precision) {
262  using E = double;
263  arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
264  X = std::move(B);
265  }
266  else {
267  X = arma::zeros(B.n_rows, B.n_cols);
268 
269  auto multiplier = norm(B);
270 
271  auto counter = 0u;
272  while(counter++ < this->setting.iterative_refinement) {
273  if(multiplier < this->setting.tolerance) break;
274 
275  auto residual = conv_to<fmat>::from(B / multiplier);
276 
277  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
278  if(0 != INFO) break;
279 
280  const mat incre = multiplier * conv_to<mat>::from(residual);
281 
282  X += incre;
283 
284  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(B -= this->operator*(incre)));
285  }
286  }
287 
288  if(INFO != 0) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
289 
290  return INFO;
291 }
292 
293 #endif
294 
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
void unify(uword) override
Definition: BandMat.hpp:76
const T & operator()(uword, uword) const override
Definition: BandMat.hpp:88
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMat.hpp:99
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMat.hpp:209
T & at(uword, uword) override
Definition: BandMat.hpp:93
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMat.hpp:74
BandMat(uword, uword, uword)
Definition: BandMat.hpp:67
void nullify(uword) override
Definition: BandMat.hpp:81
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24