suanPan
Loading...
Searching...
No Matches
BandSymmMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDSYMMMAT_HPP
31#define BANDSYMMMAT_HPP
32
33#include "../DenseMat.hpp"
34
35template<sp_d T> class BandSymmMat final : public DenseMat<T> {
36 static constexpr char UPLO = 'L';
37
38 static T bin;
39
40 const uword band;
41 const uword m_rows; // memory block layout
42
43 int solve_trs(Mat<T>&, Mat<T>&&);
44
45protected:
47
48 int direct_solve(Mat<T>&, Mat<T>&&) override;
49
50public:
51 BandSymmMat(const uword in_size, const uword in_bandwidth)
52 : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
53 , band(in_bandwidth)
54 , m_rows(in_bandwidth + 1) {}
55
56 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandSymmMat>(*this); }
57
58 void nullify(const uword K) override {
59 this->factored = false;
60 suanpan::for_each(std::max(K, band) - band, K, [&](const uword I) { this->memory[K - I + I * m_rows] = T(0); });
61 const auto t_factor = K * m_rows - K;
62 suanpan::for_each(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { this->memory[I + t_factor] = T(0); });
63 }
64
65 T operator()(const uword in_row, const uword in_col) const override {
66 if(in_row > band + in_col || in_col > in_row + band) [[unlikely]]
67 return bin = T(0);
68 return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
69 }
70
71 T& unsafe_at(const uword in_row, const uword in_col) override {
72 this->factored = false;
73 return this->memory[in_row - in_col + in_col * m_rows];
74 }
75
76 T& at(const uword in_row, const uword in_col) override {
77 if(in_row > band + in_col || in_row < in_col) [[unlikely]]
78 return bin = T(0);
79 return this->unsafe_at(in_row, in_col);
80 }
81
82 Mat<T> operator*(const Mat<T>&) const override;
83
84 [[nodiscard]] int sign_det() const override { return 1; }
85};
86
87template<sp_d T> T BandSymmMat<T>::bin = T(0);
88
89template<sp_d T> Mat<T> BandSymmMat<T>::operator*(const Mat<T>& X) const {
90 Mat<T> Y(arma::size(X));
91
92 const auto N = static_cast<blas_int>(this->n_cols);
93 const auto K = static_cast<blas_int>(band);
94 const auto LDA = static_cast<blas_int>(m_rows);
95 constexpr blas_int INC = 1;
96 T ALPHA = T(1);
97 T BETA = T(0);
98
99 if constexpr(std::is_same_v<T, float>) {
100 using E = float;
101 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
102 }
103 else {
104 using E = double;
105 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
106 }
107
108 return Y;
109}
110
111template<sp_d T> int BandSymmMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
112 if(this->factored) return this->solve_trs(X, std::move(B));
113
114 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
115
116 blas_int INFO = 0;
117
118 const auto N = static_cast<blas_int>(this->n_rows);
119 const auto KD = static_cast<blas_int>(band);
120 const auto NRHS = static_cast<blas_int>(B.n_cols);
121 const auto LDAB = static_cast<blas_int>(m_rows);
122 const auto LDB = static_cast<blas_int>(B.n_rows);
123 this->factored = true;
124
125 if constexpr(std::is_same_v<T, float>) {
126 using E = float;
127 arma_fortran(arma_spbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
128 X = std::move(B);
129 }
130 else if(Precision::FULL == this->setting.precision) {
131 using E = double;
132 arma_fortran(arma_dpbsv)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
133 X = std::move(B);
134 }
135 else {
136 this->s_memory = this->to_float();
137 arma_fortran(arma_spbtrf)(&UPLO, &N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
138 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
139 }
140
141 if(0 != INFO)
142 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
143
144 return INFO;
145}
146
147template<sp_d T> int BandSymmMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
148 blas_int INFO = 0;
149
150 const auto N = static_cast<blas_int>(this->n_rows);
151 const auto KD = static_cast<blas_int>(band);
152 const auto NRHS = static_cast<blas_int>(B.n_cols);
153 const auto LDAB = static_cast<blas_int>(m_rows);
154 const auto LDB = static_cast<blas_int>(B.n_rows);
155
156 if constexpr(std::is_same_v<T, float>) {
157 using E = float;
158 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
159 X = std::move(B);
160 }
161 else if(Precision::FULL == this->setting.precision) {
162 using E = double;
163 arma_fortran(arma_dpbtrs)(&UPLO, &N, &KD, &NRHS, (E*)this->memptr(), &LDAB, (E*)B.memptr(), &LDB, &INFO);
164 X = std::move(B);
165 }
166 else
167 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
168 arma_fortran(arma_spbtrs)(&UPLO, &N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
169 return INFO;
170 });
171
172 if(0 != INFO)
173 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
174
175 return INFO;
176}
177
178#endif
179
A BandSymmMat class that holds matrices.
Definition BandSymmMat.hpp:35
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandSymmMat.hpp:65
BandSymmMat(const uword in_size, const uword in_bandwidth)
Definition BandSymmMat.hpp:51
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandSymmMat.hpp:76
unique_ptr< MetaMat< T > > make_copy() override
Definition BandSymmMat.hpp:56
int sign_det() const override
Definition BandSymmMat.hpp:84
void nullify(const uword K) override
Definition BandSymmMat.hpp:58
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandSymmMat.hpp:71
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
Mat< T > operator*(const Mat< T > &) const override
Definition BandSymmMat.hpp:89
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandSymmMat.hpp:111
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
#define suanpan_error(...)
Definition suanPan.h:376