suanPan
BandSymmMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDSYMMMAT_HPP
31 #define BANDSYMMMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class BandSymmMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'L';
37 
38  static T bin;
39 
40  const uword band;
41  const uword m_rows; // memory block layout
42 
43  int solve_trs(Mat<T>&, Mat<T>&&);
44 
45 protected:
47 
48  int direct_solve(Mat<T>&, Mat<T>&&) override;
49 
50 public:
51  BandSymmMat(const uword in_size, const uword in_bandwidth)
52  : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
53  , band(in_bandwidth)
54  , m_rows(in_bandwidth + 1) {}
55 
56  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandSymmMat>(*this); }
57 
58  void nullify(const uword K) override {
59  this->factored = false;
60  suanpan::for_each(std::max(band, K) - band, K, [&](const uword I) { this->memory[K - I + I * m_rows] = T(0); });
61  const auto t_factor = K * m_rows - K;
62  suanpan::for_each(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { this->memory[I + t_factor] = T(0); });
63  }
64 
65  T operator()(const uword in_row, const uword in_col) const override {
66  if(in_row > band + in_col || in_col > in_row + band) [[unlikely]] return bin = T(0);
67  return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
68  }
69 
70  T& unsafe_at(const uword in_row, const uword in_col) override {
71  this->factored = false;
72  return this->memory[in_row - in_col + in_col * m_rows];
73  }
74 
75  T& at(const uword in_row, const uword in_col) override {
76  if(in_row > band + in_col || in_row < in_col) [[unlikely]] return bin = T(0);
77  return this->unsafe_at(in_row, in_col);
78  }
79 
80  Mat<T> operator*(const Mat<T>&) const override;
81 };
82 
83 template<sp_d T> T BandSymmMat<T>::bin = T(0);
84 
85 template<sp_d T> Mat<T> BandSymmMat<T>::operator*(const Mat<T>& X) const {
86  Mat<T> Y(arma::size(X));
87 
88  const auto N = static_cast<int>(this->n_cols);
89  const auto K = static_cast<int>(band);
90  const auto LDA = static_cast<int>(m_rows);
91  constexpr auto INC = 1;
92  T ALPHA = T(1);
93  T BETA = T(0);
94 
95  if constexpr(std::is_same_v<T, float>) {
96  using E = float;
97  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
98  }
99  else {
100  using E = double;
101  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
102  }
103 
104  return Y;
105 }
106 
107 template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
108  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
109 
110  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
111 
112  auto INFO = 0;
113 
114  const auto N = static_cast<int>(this->n_rows);
115  const auto KD = static_cast<int>(band);
116  const auto NRHS = static_cast<int>(B.n_cols);
117  const auto LDAB = static_cast<int>(m_rows);
118  const auto LDB = static_cast<int>(B.n_rows);
119  this->factored = true;
120 
121  if constexpr(std::is_same_v<T, float>) {
122  using E = float;
123  arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
124  X = std::move(B);
125  }
126  else if(Precision::FULL == this->setting.precision) {
127  using E = double;
128  arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
129  X = std::move(B);
130  }
131  else {
132  this->s_memory = this->to_float();
133  arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
134  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
135  }
136 
137  if(0 != INFO)
138  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
139 
140  return INFO;
141 }
142 
143 template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
144  auto INFO = 0;
145 
146  const auto N = static_cast<int>(this->n_rows);
147  const auto KD = static_cast<int>(band);
148  const auto NRHS = static_cast<int>(B.n_cols);
149  const auto LDAB = static_cast<int>(m_rows);
150  const auto LDB = static_cast<int>(B.n_rows);
151 
152  if constexpr(std::is_same_v<T, float>) {
153  using E = float;
154  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
155  X = std::move(B);
156  }
157  else if(Precision::FULL == this->setting.precision) {
158  using E = double;
159  arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
160  X = std::move(B);
161  }
162  else
163  this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
164  arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
165  return INFO;
166  });
167 
168  if(0 != INFO)
169  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
170 
171  return INFO;
172 }
173 
174 #endif
175 
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: BandSymmMat.hpp:65
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition: BandSymmMat.hpp:70
BandSymmMat(const uword in_size, const uword in_bandwidth)
Definition: BandSymmMat.hpp:51
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: BandSymmMat.hpp:75
void nullify(const uword K) override
Definition: BandSymmMat.hpp:58
unique_ptr< MetaMat< T > > make_copy() override
Definition: BandSymmMat.hpp:56
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
const uword n_rows
Definition: MetaMat.hpp:118
bool factored
Definition: MetaMat.hpp:74
Mat< T > operator*(const Mat< T > &) const override
Definition: BandSymmMat.hpp:85
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: BandSymmMat.hpp:107
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309