suanPan
BandMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMAT_HPP
31 #define BANDMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandMat : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  static T bin;
39 
40  const uword s_band;
41  const uword m_rows; // memory block layout
42 
43  int solve_trs(Mat<T>&, Mat<T>&&);
44 
45 protected:
46  const uword l_band;
47  const uword u_band;
48 
50 
51  int direct_solve(Mat<T>&, Mat<T>&&) override;
52 
53 public:
54  BandMat(const uword in_size, const uword in_l, const uword in_u)
55  : DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
56  , s_band(in_l + in_u)
57  , m_rows(2 * in_l + in_u + 1)
58  , l_band(in_l)
59  , u_band(in_u) {
60  if(m_rows >= in_size)
61  suanpan_warning("The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
62  }
63 
64  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMat>(*this); }
65 
66  void nullify(const uword K) override {
67  this->factored = false;
68  suanpan::for_each(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + s_band + K * (m_rows - 1)] = T(0); });
69  suanpan::for_each(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + s_band + I * (m_rows - 1)] = T(0); });
70  }
71 
72  T operator()(const uword in_row, const uword in_col) const override {
73  if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
74  return this->memory[in_row + s_band + in_col * (m_rows - 1)];
75  }
76 
77  T& unsafe_at(const uword in_row, const uword in_col) override {
78  this->factored = false;
79  return this->memory[in_row + s_band + in_col * (m_rows - 1)];
80  }
81 
82  T& at(const uword in_row, const uword in_col) override {
83  if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]] return bin = T(0);
84  return this->unsafe_at(in_row, in_col);
85  }
86 
87  Mat<T> operator*(const Mat<T>&) const override;
88 };
89 
90 template<sp_d T> T BandMat<T>::bin = T(0);
91 
92 template<sp_d T> Mat<T> BandMat<T>::operator*(const Mat<T>& X) const {
93  Mat<T> Y(arma::size(X));
94 
95  const auto M = static_cast<int>(this->n_rows);
96  const auto N = static_cast<int>(this->n_cols);
97  const auto KL = static_cast<int>(l_band);
98  const auto KU = static_cast<int>(u_band);
99  const auto LDA = static_cast<int>(m_rows);
100  constexpr auto INC = 1;
101  T ALPHA = T(1);
102  T BETA = T(0);
103 
104  if constexpr(std::is_same_v<T, float>) {
105  using E = float;
106  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
107  }
108  else {
109  using E = double;
110  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + l_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
111  }
112 
113  return Y;
114 }
115 
116 template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
117  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
118 
119  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
120 
121  auto INFO = 0;
122 
123  auto N = static_cast<int>(this->n_rows);
124  const auto KL = static_cast<int>(l_band);
125  const auto KU = static_cast<int>(u_band);
126  const auto NRHS = static_cast<int>(B.n_cols);
127  const auto LDAB = static_cast<int>(m_rows);
128  const auto LDB = static_cast<int>(B.n_rows);
129  this->pivot.zeros(N);
130  this->factored = true;
131 
132  if constexpr(std::is_same_v<T, float>) {
133  using E = float;
134  arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
135  X = std::move(B);
136  }
137  else if(Precision::FULL == this->setting.precision) {
138  using E = double;
139  arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
140  X = std::move(B);
141  }
142  else {
143  this->s_memory = this->to_float();
144  arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
145  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
146  }
147 
148  if(0 != INFO)
149  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
150 
151  return INFO;
152 }
153 
154 template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
155  auto INFO = 0;
156 
157  const auto N = static_cast<int>(this->n_rows);
158  const auto KL = static_cast<int>(l_band);
159  const auto KU = static_cast<int>(u_band);
160  const auto NRHS = static_cast<int>(B.n_cols);
161  const auto LDAB = static_cast<int>(m_rows);
162  const auto LDB = static_cast<int>(B.n_rows);
163 
164  if constexpr(std::is_same_v<T, float>) {
165  using E = float;
166  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
167  X = std::move(B);
168  }
169  else if(Precision::FULL == this->setting.precision) {
170  using E = double;
171  arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, (E*)this->memptr(), &LDAB, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
172  X = std::move(B);
173  }
174  else
175  this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
176  arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
177  return INFO;
178  });
179 
180  if(0 != INFO)
181  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
182 
183  return INFO;
184 }
185 
186 #endif
187 
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition: BandMat.hpp:77
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandMat.hpp:64
BandMat(const uword in_size, const uword in_l, const uword in_u)
Definition: BandMat.hpp:54
const uword u_band
Definition: BandMat.hpp:47
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: BandMat.hpp:82
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandMat.hpp:72
const uword l_band
Definition: BandMat.hpp:46
void nullify(const uword K) override
Definition: BandMat.hpp:66
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
const uword n_cols
Definition: MetaMat.hpp:119
const uword n_rows
Definition: MetaMat.hpp:118
bool factored
Definition: MetaMat.hpp:74
Mat< T > operator*(const Mat< T > &) const override
Definition: BandMat.hpp:92
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandMat.hpp:116
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
#define suanpan_warning(...)
Definition: suanPan.h:308
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309