suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
BandMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMAT_HPP
31#define BANDMAT_HPP
32
33#include "../DenseMat.hpp"
34
35template<sp_d T> class BandMat : public DenseMat<T> {
36 static constexpr auto TRAN = 'N';
37
38 static T bin;
39
40 const uword s_band;
41
42 int solve_trs(Mat<T>&, Mat<T>&&);
43
44protected:
45 const uword m_rows; // memory block layout
46
47 const uword l_band;
48 const uword u_band;
49
51
52 int direct_solve(Mat<T>&, Mat<T>&&) override;
53
54public:
55 BandMat(const uword in_size, const uword in_l, const uword in_u)
56 : DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
57 , s_band(in_l + in_u)
58 , m_rows(2 * in_l + in_u + 1)
59 , l_band(in_l)
60 , u_band(in_u) {
61 if(m_rows >= in_size)
62 suanpan_warning("The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
63 }
64
65 unique_ptr<MetaMat<T>> unique_copy() override { return std::make_unique<BandMat>(*this); }
66
67 void nullify(const uword K) override {
68 this->factored = false;
69 suanpan::for_each(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + s_band + K * (m_rows - 1)] = T(0); });
70 suanpan::for_each(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + s_band + I * (m_rows - 1)] = T(0); });
71 }
72
73 T operator()(const uword in_row, const uword in_col) const override {
74 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
75 return bin = T(0);
76 return this->memory[in_row + s_band + in_col * (m_rows - 1)];
77 }
78
79 T& unsafe_at(const uword in_row, const uword in_col) override {
80 this->factored = false;
81 return this->memory[in_row + s_band + in_col * (m_rows - 1)];
82 }
83
84 T& at(const uword in_row, const uword in_col) override {
85 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
86 return bin = T(0);
87 return this->unsafe_at(in_row, in_col);
88 }
89
90 Mat<T> operator*(const Mat<T>&) const override;
91
92 [[nodiscard]] int sign_det() const override {
93 std::function<bool(uword)> neg_diag;
94 if(Precision::FULL == this->setting.precision) neg_diag = [&](const uword i) { return this->memory[s_band + i * m_rows] < 0.; };
95 else neg_diag = [&](const uword i) { return this->s_memory[s_band + i * m_rows] < 0.f; };
96
97 auto det_sign = 1;
98 for(unsigned I = 0; I < this->pivot.n_elem; ++I)
99 if(neg_diag(I) ^ (static_cast<int>(I) + 1 != this->pivot(I))) det_sign = -det_sign;
100 return det_sign;
101 }
102};
103
104template<sp_d T> T BandMat<T>::bin = T(0);
105
106template<sp_d T> Mat<T> BandMat<T>::operator*(const Mat<T>& X) const {
107 Mat<T> Y(arma::size(X));
108
109 const auto M = static_cast<blas_int>(this->n_rows);
110 const auto N = static_cast<blas_int>(this->n_cols);
111 const auto KL = static_cast<blas_int>(l_band);
112 const auto KU = static_cast<blas_int>(u_band);
113 const auto LDA = static_cast<blas_int>(m_rows);
114 static constexpr blas_int INC = 1;
115 static constexpr T ALPHA{1}, BETA{0};
116
117 if constexpr(std::is_same_v<T, float>) suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr() + l_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
118 else suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr() + l_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
119 return Y;
120}
121
122template<sp_d T> int BandMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
123 if(this->factored) return this->solve_trs(X, std::move(B));
124
125 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
126
127 blas_int INFO = 0;
128
129 auto N = static_cast<blas_int>(this->n_rows);
130 const auto KL = static_cast<blas_int>(l_band);
131 const auto KU = static_cast<blas_int>(u_band);
132 const auto NRHS = static_cast<blas_int>(B.n_cols);
133 const auto LDAB = static_cast<blas_int>(m_rows);
134 const auto LDB = static_cast<blas_int>(B.n_rows);
135 this->pivot.zeros(N);
136 this->factored = true;
137
138 if constexpr(std::is_same_v<T, float>) {
139 arma_fortran(arma_sgbsv)(&N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
140 X = std::move(B);
141 }
142 else if(Precision::FULL == this->setting.precision) {
143 arma_fortran(arma_dgbsv)(&N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
144 X = std::move(B);
145 }
146 else {
147 this->s_memory = this->to_float();
148 arma_fortran(arma_sgbtrf)(&N, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
149 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
150 }
151
152 if(0 != INFO)
153 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
154
155 return INFO;
156}
157
158template<sp_d T> int BandMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
159 blas_int INFO = 0;
160
161 const auto N = static_cast<blas_int>(this->n_rows);
162 const auto KL = static_cast<blas_int>(l_band);
163 const auto KU = static_cast<blas_int>(u_band);
164 const auto NRHS = static_cast<blas_int>(B.n_cols);
165 const auto LDAB = static_cast<blas_int>(m_rows);
166 const auto LDB = static_cast<blas_int>(B.n_rows);
167
168 if constexpr(std::is_same_v<T, float>) {
169 arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
170 X = std::move(B);
171 }
172 else if(Precision::FULL == this->setting.precision) {
173 arma_fortran(arma_dgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
174 X = std::move(B);
175 }
176 else
177 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
178 arma_fortran(arma_sgbtrs)(&TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
179 return INFO;
180 });
181
182 if(0 != INFO)
183 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
184
185 return INFO;
186}
187
188#endif
189
A BandMat class that holds matrices.
Definition BandMat.hpp:35
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMat.hpp:79
BandMat(const uword in_size, const uword in_l, const uword in_u)
Definition BandMat.hpp:55
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMat.hpp:84
const uword u_band
Definition BandMat.hpp:48
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMat.hpp:73
unique_ptr< MetaMat< T > > unique_copy() override
Definition BandMat.hpp:65
const uword l_band
Definition BandMat.hpp:47
const uword m_rows
Definition BandMat.hpp:45
int sign_det() const override
Definition BandMat.hpp:92
void nullify(const uword K) override
Definition BandMat.hpp:67
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
podarray< float > s_memory
Definition DenseMat.hpp:46
podarray< blas_int > pivot
Definition DenseMat.hpp:45
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
SolverSetting< T > setting
Definition MetaMat.hpp:78
Mat< T > operator*(const Mat< T > &) const override
Definition BandMat.hpp:106
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandMat.hpp:122
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
Precision precision
Definition SolverSetting.hpp:32
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
#define suanpan_warning(...)
Definition suanPan.h:348
#define suanpan_error(...)
Definition suanPan.h:349