suanPan
BandSymmMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDSYMMMAT_HPP
31 #define BANDSYMMMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandSymmMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'L';
37 
38  static T bin;
39 
40  const uword band;
41  const uword m_rows; // memory block layout
42 
43  int solve_trs(Mat<T>&, Mat<T>&&);
44  int solve_trs(Mat<T>&, const Mat<T>&);
45 
46 public:
47  BandSymmMat(uword, uword);
48 
49  unique_ptr<MetaMat<T>> make_copy() override;
50 
51  void unify(uword) override;
52  void nullify(uword) override;
53 
54  const T& operator()(uword, uword) const override;
55  T& at(uword, uword) override;
56 
57  Mat<T> operator*(const Mat<T>&) const override;
58 
59  int direct_solve(Mat<T>&, Mat<T>&&) override;
60  int direct_solve(Mat<T>&, const Mat<T>&) override;
61 };
62 
63 template<sp_d T> T BandSymmMat<T>::bin = 0.;
64 
65 template<sp_d T> BandSymmMat<T>::BandSymmMat(const uword in_size, const uword in_bandwidth)
66  : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
67  , band(in_bandwidth)
68  , m_rows(in_bandwidth + 1) {}
69 
70 template<sp_d T> unique_ptr<MetaMat<T>> BandSymmMat<T>::make_copy() { return std::make_unique<BandSymmMat<T>>(*this); }
71 
72 template<sp_d T> void BandSymmMat<T>::unify(const uword K) {
73  nullify(K);
74  access::rw(this->memory[K * m_rows]) = 1.;
75 }
76 
77 template<sp_d T> void BandSymmMat<T>::nullify(const uword K) {
78  suanpan_for(std::max(band, K) - band, K, [&](const uword I) { access::rw(this->memory[K - I + I * m_rows]) = 0.; });
79  suanpan_for(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { access::rw(this->memory[I - K + K * m_rows]) = 0.; });
80 
81  this->factored = false;
82 }
83 
84 template<sp_d T> const T& BandSymmMat<T>::operator()(const uword in_row, const uword in_col) const {
85  if(in_row > band + in_col) return bin = 0.;
86  return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
87 }
88 
89 template<sp_d T> T& BandSymmMat<T>::at(const uword in_row, const uword in_col) {
90  if(in_row > band + in_col || in_row < in_col) return bin = 0.;
91  this->factored = false;
92  return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
93 }
94 
95 template<sp_d T> Mat<T> BandSymmMat<T>::operator*(const Mat<T>& X) const {
96  Mat<T> Y(arma::size(X));
97 
98  const auto N = static_cast<int>(this->n_cols);
99  const auto K = static_cast<int>(band);
100  const auto LDA = static_cast<int>(m_rows);
101  const auto INC = 1;
102  T ALPHA = 1.;
103  T BETA = 0.;
104 
105  if(std::is_same_v<T, float>) {
106  using E = float;
107  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
108  }
109  else if(std::is_same_v<T, double>) {
110  using E = double;
111  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
112  }
113 
114  return Y;
115 }
116 
117 template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
118  if(this->factored) return this->solve_trs(X, B);
119 
120  const auto N = static_cast<int>(this->n_rows);
121  const auto KD = static_cast<int>(band);
122  const auto NRHS = static_cast<int>(B.n_cols);
123  const auto LDAB = static_cast<int>(m_rows);
124  const auto LDB = static_cast<int>(B.n_rows);
125  auto INFO = 0;
126 
127  this->factored = true;
128 
129  if(std::is_same_v<T, float>) {
130  using E = float;
131  X = B;
132  arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
133  }
134  else if(Precision::FULL == this->setting.precision) {
135  using E = double;
136  X = B;
137  arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
138  }
139  else {
140  this->s_memory = this->to_float();
141  arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
142  if(0 == INFO) INFO = this->solve_trs(X, B);
143  }
144 
145  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
146 
147  return INFO;
148 }
149 
150 template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
151  const auto N = static_cast<int>(this->n_rows);
152  const auto KD = static_cast<int>(band);
153  const auto NRHS = static_cast<int>(B.n_cols);
154  const auto LDAB = static_cast<int>(m_rows);
155  const auto LDB = static_cast<int>(B.n_rows);
156  auto INFO = 0;
157 
158  if(std::is_same_v<T, float>) {
159  using E = float;
160  X = B;
161  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
162  }
163  else if(Precision::FULL == this->setting.precision) {
164  using E = double;
165  X = B;
166  arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)X.memptr(), &LDB, &INFO);
167  }
168  else {
169  X = arma::zeros(B.n_rows, B.n_cols);
170 
171  mat full_residual = B;
172 
173  auto multiplier = norm(full_residual);
174 
175  auto counter = 0u;
176  while(counter++ < this->setting.iterative_refinement) {
177  if(multiplier < this->setting.tolerance) break;
178 
179  auto residual = conv_to<fmat>::from(full_residual / multiplier);
180 
181  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
182  if(0 != INFO) break;
183 
184  const mat incre = multiplier * conv_to<mat>::from(residual);
185 
186  X += incre;
187 
188  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
189  }
190  }
191 
192  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
193 
194  return INFO;
195 }
196 
197 template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
198  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
199 
200  const auto N = static_cast<int>(this->n_rows);
201  const auto KD = static_cast<int>(band);
202  const auto NRHS = static_cast<int>(B.n_cols);
203  const auto LDAB = static_cast<int>(m_rows);
204  const auto LDB = static_cast<int>(B.n_rows);
205  auto INFO = 0;
206 
207  this->factored = true;
208 
209  if(std::is_same_v<T, float>) {
210  using E = float;
211  arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
212  X = std::move(B);
213  }
214  else if(Precision::FULL == this->setting.precision) {
215  using E = double;
216  arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
217  X = std::move(B);
218  }
219  else {
220  this->s_memory = this->to_float();
221  arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
222  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
223  }
224 
225  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
226 
227  return INFO;
228 }
229 
230 template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
231  const auto N = static_cast<int>(this->n_rows);
232  const auto KD = static_cast<int>(band);
233  const auto NRHS = static_cast<int>(B.n_cols);
234  const auto LDAB = static_cast<int>(m_rows);
235  const auto LDB = static_cast<int>(B.n_rows);
236  auto INFO = 0;
237 
238  if(std::is_same_v<T, float>) {
239  using E = float;
240  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
241  X = std::move(B);
242  }
243  else if(Precision::FULL == this->setting.precision) {
244  using E = double;
245  arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
246  X = std::move(B);
247  }
248  else {
249  X = arma::zeros(B.n_rows, B.n_cols);
250 
251  auto multiplier = norm(B);
252 
253  auto counter = 0u;
254  while(counter++ < this->setting.iterative_refinement) {
255  if(multiplier < this->setting.tolerance) break;
256 
257  auto residual = conv_to<fmat>::from(B / multiplier);
258 
259  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
260  if(0 != INFO) break;
261 
262  const mat incre = multiplier * conv_to<mat>::from(residual);
263 
264  X += incre;
265 
266  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(B -= this->operator*(incre)));
267  }
268  }
269 
270  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
271 
272  return INFO;
273 }
274 
275 #endif
276 
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
T & at(uword, uword) override
Definition: BandSymmMat.hpp:89
const T & operator()(uword, uword) const override
Definition: BandSymmMat.hpp:84
void nullify(uword) override
Definition: BandSymmMat.hpp:77
BandSymmMat(uword, uword)
Definition: BandSymmMat.hpp:65
Mat< T > operator*(const Mat< T > &) const override
Definition: BandSymmMat.hpp:95
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandSymmMat.hpp:70
void unify(uword) override
Definition: BandSymmMat.hpp:72
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandSymmMat.hpp:197
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24