suanPan
BandMatSpike.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMATSPIKE_HPP
31 #define BANDMATSPIKE_HPP
32 
33 #include <feast/spike.h>
34 #include "DenseMat.hpp"
35 
36 template<sp_d T> class BandMatSpike final : public DenseMat<T> {
37  static constexpr char TRAN = 'N';
38 
39  static T bin;
40 
41  const uword l_band;
42  const uword u_band;
43  const uword m_rows; // memory block layout
44 
45  podarray<int> SPIKE{64};
46  podarray<T> WORK;
47  podarray<float> SWORK;
48 
49  void init_spike() {
50  auto N = static_cast<int>(this->n_rows);
51  auto KLU = static_cast<int>(std::max(l_band, u_band));
52 
53  spikeinit_(SPIKE.memptr(), &N, &KLU);
54 
55  std::is_same_v<T, float> ? sspike_tune_(SPIKE.memptr()) : dspike_tune_(SPIKE.memptr());
56  }
57 
58  int solve_trs(Mat<T>&, Mat<T>&&);
59 
60 protected:
62 
63  int direct_solve(Mat<T>&, Mat<T>&&) override;
64 
65 public:
66  BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
67  : DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
68  , l_band(in_l)
69  , u_band(in_u)
70  , m_rows(in_l + in_u + 1) { init_spike(); }
71 
72  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMatSpike>(*this); }
73 
74  void nullify(const uword K) override {
75  this->factored = false;
76  suanpan_for(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + u_band + K * (m_rows - 1)] = T(0); });
77  suanpan_for(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + u_band + I * (m_rows - 1)] = T(0); });
78  }
79 
80  T operator()(const uword in_row, const uword in_col) const override {
81  if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
82  return this->memory[in_row + u_band + in_col * (m_rows - 1)];
83  }
84 
85  T& unsafe_at(const uword in_row, const uword in_col) override {
86  this->factored = false;
87  return this->memory[in_row + u_band + in_col * (m_rows - 1)];
88  }
89 
90  T& at(const uword in_row, const uword in_col) override {
91  if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
92  return this->unsafe_at(in_row, in_col);
93  }
94 
95  Mat<T> operator*(const Mat<T>&) const override;
96 
97  [[nodiscard]] int sign_det() const override { throw invalid_argument("not supported"); }
98 };
99 
100 template<sp_d T> T BandMatSpike<T>::bin = T(0);
101 
102 template<sp_d T> Mat<T> BandMatSpike<T>::operator*(const Mat<T>& X) const {
103  Mat<T> Y(arma::size(X));
104 
105  const auto M = static_cast<int>(this->n_rows);
106  const auto N = static_cast<int>(this->n_cols);
107  const auto KL = static_cast<int>(l_band);
108  const auto KU = static_cast<int>(u_band);
109  const auto LDA = static_cast<int>(m_rows);
110  constexpr auto INC = 1;
111  T ALPHA = T(1);
112  T BETA = T(0);
113 
114  if constexpr(std::is_same_v<T, float>) {
115  using E = float;
116  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
117  }
118  else {
119  using E = double;
120  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
121  }
122 
123  return Y;
124 }
125 
126 template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
127  if(!this->factored) {
128  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
129 
130  auto INFO = 0;
131 
132  auto N = static_cast<int>(this->n_rows);
133  auto KL = static_cast<int>(l_band);
134  auto KU = static_cast<int>(u_band);
135  auto LDAB = static_cast<int>(m_rows);
136  const auto KLU = std::max(l_band, u_band);
137  this->factored = true;
138 
139  if constexpr(std::is_same_v<T, float>) {
140  using E = float;
141  WORK.zeros(KLU * KLU * SPIKE(9));
142  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
143  }
144  else if(Precision::FULL == this->setting.precision) {
145  using E = double;
146  WORK.zeros(KLU * KLU * SPIKE(9));
147  dspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), &INFO);
148  }
149  else {
150  this->s_memory = this->to_float();
151  SWORK.zeros(KLU * KLU * SPIKE(9));
152  sspike_gbtrf_(SPIKE.memptr(), &N, &KL, &KU, this->s_memory.memptr(), &LDAB, SWORK.memptr(), &INFO);
153  }
154 
155  if(0 != INFO) {
156  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
157  return INFO;
158  }
159  }
160 
161  return this->solve_trs(X, std::forward<Mat<T>>(B));
162 }
163 
164 template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
165  auto N = static_cast<int>(this->n_rows);
166  auto KL = static_cast<int>(l_band);
167  auto KU = static_cast<int>(u_band);
168  auto NRHS = static_cast<int>(B.n_cols);
169  auto LDAB = static_cast<int>(m_rows);
170  auto LDB = static_cast<int>(B.n_rows);
171 
172  if constexpr(std::is_same_v<T, float>) {
173  using E = float;
174  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
175  X = std::move(B);
176  }
177  else if(Precision::FULL == this->setting.precision) {
178  using E = double;
179  dspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, (E*)WORK.memptr(), (E*)B.memptr(), &LDB);
180  X = std::move(B);
181  }
182  else
183  this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
184  sspike_gbtrs_(SPIKE.memptr(), &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
185  return 0;
186  });
187 
188  return SUANPAN_SUCCESS;
189 }
190 
191 #endif
192 
A BandMatSpike class that holds matrices.
Definition: BandMatSpike.hpp:36
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition: BandMatSpike.hpp:66
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandMatSpike.hpp:80
int sign_det() const override
Definition: BandMatSpike.hpp:97
void nullify(const uword K) override
Definition: BandMatSpike.hpp:74
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: BandMatSpike.hpp:90
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMatSpike.hpp:72
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition: BandMatSpike.hpp:85
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
const uword n_cols
Definition: MetaMat.hpp:86
const uword n_rows
Definition: MetaMat.hpp:85
bool factored
Definition: MetaMat.hpp:41
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMatSpike.hpp:102
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMatSpike.hpp:126
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:172
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27