suanPan
BandMatSpike.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMATSPIKE_HPP
31 #define BANDMATSPIKE_HPP
32 
33 #include <feast/spike.h>
34 #include "DenseMat.hpp"
35 
36 template<sp_d T> class BandMatSpike final : public DenseMat<T> {
37  static constexpr char TRAN = 'N';
38 
39  static T bin;
40 
41  const uword l_band;
42  const uword u_band;
43  const uword m_rows; // memory block layout
44 
45  podarray<int> SPIKE = podarray<int>(64);
46  podarray<T> WORK;
47  podarray<float> SWORK;
48 
49  void init_spike();
50 
51  int solve_trs(Mat<T>&, Mat<T>&&);
52  int solve_trs(Mat<T>&, const Mat<T>&);
53 
54 public:
55  BandMatSpike(uword, uword, uword);
56 
57  unique_ptr<MetaMat<T>> make_copy() override;
58 
59  void unify(uword) override;
60  void nullify(uword) override;
61 
62  const T& operator()(uword, uword) const override;
63  T& at(uword, uword) override;
64 
65  Mat<T> operator*(const Mat<T>&) const override;
66 
67  int direct_solve(Mat<T>&, Mat<T>&&) override;
68  int direct_solve(Mat<T>&, const Mat<T>&) override;
69 
70  [[nodiscard]] int sign_det() const override;
71 };
72 
73 template<sp_d T> T BandMatSpike<T>::bin = 0.;
74 
75 template<sp_d T> void BandMatSpike<T>::init_spike() {
76  auto N = static_cast<int>(this->n_rows);
77  auto KLU = static_cast<int>(std::max(l_band, u_band));
78 
79  spikeinit_(SPIKE.memptr(), &N, &KLU);
80 
81  std::is_same_v<T, float> ? sspike_tune_(SPIKE.memptr()) : dspike_tune_(SPIKE.memptr());
82 }
83 
84 template<sp_d T> BandMatSpike<T>::BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
85  : DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
86  , l_band(in_l)
87  , u_band(in_u)
88  , m_rows(in_l + in_u + 1) { init_spike(); }
89 
90 template<sp_d T> unique_ptr<MetaMat<T>> BandMatSpike<T>::make_copy() { return std::make_unique<BandMatSpike<T>>(*this); }
91 
92 template<sp_d T> void BandMatSpike<T>::unify(const uword K) {
93  nullify(K);
94  access::rw(this->memory[u_band + K * m_rows]) = 1.;
95 }
96 
97 template<sp_d T> void BandMatSpike<T>::nullify(const uword K) {
98  suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { access::rw(this->memory[I + u_band + K * (m_rows - 1)]) = 0.; });
99  suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { access::rw(this->memory[K + u_band + I * (m_rows - 1)]) = 0.; });
100 
101  this->factored = false;
102 }
103 
104 template<sp_d T> const T& BandMatSpike<T>::operator()(const uword in_row, const uword in_col) const {
105  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
106  return this->memory[in_row + u_band + in_col * (m_rows - 1)];
107 }
108 
109 template<sp_d T> T& BandMatSpike<T>::at(const uword in_row, const uword in_col) {
110  if(in_row > in_col + l_band || in_row + u_band < in_col) return bin = 0.;
111  this->factored = false;
112  return access::rw(this->memory[in_row + u_band + in_col * (m_rows - 1)]);
113 }
114 
115 template<sp_d T> Mat<T> BandMatSpike<T>::operator*(const Mat<T>& X) const {
116  Mat<T> Y(arma::size(X));
117 
118  const auto M = static_cast<int>(this->n_rows);
119  const auto N = static_cast<int>(this->n_cols);
120  const auto KL = static_cast<int>(l_band);
121  const auto KU = static_cast<int>(u_band);
122  const auto LDA = static_cast<int>(m_rows);
123  const auto INC = 1;
124  T ALPHA = 1.;
125  T BETA = 0.;
126 
127  if(std::is_same_v<T, float>) {
128  using E = float;
129  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
130  }
131  else if(std::is_same_v<T, double>) {
132  using E = double;
133  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
134  }
135 
136  return Y;
137 }
138 
139 template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
140  if(!this->factored) {
141  auto N = static_cast<int>(this->n_rows);
142  auto KL = static_cast<int>(l_band);
143  auto KU = static_cast<int>(u_band);
144  auto LDAB = static_cast<int>(m_rows);
145  const auto KLU = std::max(l_band, u_band);
146  auto INFO = 0;
147 
148  if(std::is_same_v<T, float>) {
149  using E = float;
150  WORK.zeros(KLU * KLU * SPIKE(9));
151  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
152  }
153  else if(Precision::FULL == this->setting.precision) {
154  using E = double;
155  WORK.zeros(KLU * KLU * SPIKE(9));
156  dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
157  }
158  else {
159  this->s_memory = this->to_float();
160  SWORK.zeros(KLU * KLU * SPIKE(9));
161  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
162  }
163 
164  if(INFO != 0) {
165  suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
166  return INFO;
167  }
168 
169  this->factored = true;
170  }
171 
172  return this->solve_trs(X, B);
173 }
174 
175 template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
176  auto N = static_cast<int>(this->n_rows);
177  auto KL = static_cast<int>(l_band);
178  auto KU = static_cast<int>(u_band);
179  auto NRHS = static_cast<int>(B.n_cols);
180  auto LDAB = static_cast<int>(m_rows);
181  auto LDB = static_cast<int>(B.n_rows);
182 
183  if(std::is_same_v<T, float>) {
184  using E = float;
185  X = B;
186  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)X.memptr(), &LDB);
187  }
188  else if(Precision::FULL == this->setting.precision) {
189  using E = double;
190  X = B;
191  dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)X.memptr(), &LDB);
192  }
193  else {
194  X = arma::zeros(B.n_rows, B.n_cols);
195 
196  mat full_residual = B;
197 
198  auto multiplier = norm(full_residual);
199 
200  auto counter = 0u;
201  while(counter++ < this->setting.iterative_refinement) {
202  if(multiplier < this->setting.tolerance) break;
203 
204  auto residual = conv_to<fmat>::from(full_residual / multiplier);
205 
206  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
207 
208  const mat incre = multiplier * conv_to<mat>::from(residual);
209 
210  X += incre;
211 
212  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
213  }
214  }
215 
216  return SUANPAN_SUCCESS;
217 }
218 
219 template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
220  if(!this->factored) {
221  auto N = static_cast<int>(this->n_rows);
222  auto KL = static_cast<int>(l_band);
223  auto KU = static_cast<int>(u_band);
224  auto LDAB = static_cast<int>(m_rows);
225  const auto KLU = std::max(l_band, u_band);
226  auto INFO = 0;
227 
228  if(std::is_same_v<T, float>) {
229  using E = float;
230  WORK.zeros(KLU * KLU * SPIKE(9));
231  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
232  }
233  else if(Precision::FULL == this->setting.precision) {
234  using E = double;
235  WORK.zeros(KLU * KLU * SPIKE(9));
236  dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
237  }
238  else {
239  this->s_memory = this->to_float();
240  SWORK.zeros(KLU * KLU * SPIKE(9));
241  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.mem, &LDAB, SWORK.memptr(), &INFO);
242  }
243 
244  if(INFO != 0) {
245  suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
246  return INFO;
247  }
248 
249  this->factored = true;
250  }
251 
252  return this->solve_trs(X, std::forward<Mat<T>>(B));
253 }
254 
255 template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
256  auto N = static_cast<int>(this->n_rows);
257  auto KL = static_cast<int>(l_band);
258  auto KU = static_cast<int>(u_band);
259  auto NRHS = static_cast<int>(B.n_cols);
260  auto LDAB = static_cast<int>(m_rows);
261  auto LDB = static_cast<int>(B.n_rows);
262 
263  if(std::is_same_v<T, float>) {
264  using E = float;
265  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
266  X = std::move(B);
267  }
268  else if(Precision::FULL == this->setting.precision) {
269  using E = double;
270  dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
271  X = std::move(B);
272  }
273  else {
274  X = arma::zeros(B.n_rows, B.n_cols);
275 
276  auto multiplier = norm(B);
277 
278  auto counter = 0u;
279  while(counter++ < this->setting.iterative_refinement) {
280  if(multiplier < this->setting.tolerance) break;
281 
282  auto residual = conv_to<fmat>::from(B / multiplier);
283 
284  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
285 
286  const mat incre = multiplier * conv_to<mat>::from(residual);
287 
288  X += incre;
289 
290  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(B -= this->operator*(incre)));
291  }
292  }
293 
294  return SUANPAN_SUCCESS;
295 }
296 
297 template<sp_d T> int BandMatSpike<T>::sign_det() const { throw invalid_argument("not supported"); }
298 
299 #endif
300 
A BandMatSpike class that holds matrices.
Definition: BandMatSpike.hpp:36
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
void unify(uword) override
Definition: BandMatSpike.hpp:92
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMatSpike.hpp:115
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMatSpike.hpp:219
int sign_det() const override
Definition: BandMatSpike.hpp:297
const T & operator()(uword, uword) const override
Definition: BandMatSpike.hpp:104
void nullify(uword) override
Definition: BandMatSpike.hpp:97
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMatSpike.hpp:90
BandMatSpike(uword, uword, uword)
Definition: BandMatSpike.hpp:84
T & at(uword, uword) override
Definition: BandMatSpike.hpp:109
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:161
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24