suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
BandMatSpike.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
32
33#include "../DenseMat.hpp"
34
35extern "C" {
38void dspike_gbsv_(la_it*, la_it*, la_it*, la_it*, la_it*, double*, la_it*, double*, la_it*, la_it*);
39void dspike_gbtrf_(la_it*, la_it*, la_it*, la_it*, double*, la_it*, double*, la_it*);
40void dspike_gbtrs_(la_it*, const char*, la_it*, la_it*, la_it*, la_it*, double*, la_it*, double*, double*, la_it*);
42void sspike_gbsv_(la_it*, la_it*, la_it*, la_it*, la_it*, float*, la_it*, float*, la_it*, la_it*);
43void sspike_gbtrf_(la_it*, la_it*, la_it*, la_it*, float*, la_it*, float*, la_it*);
44void sspike_gbtrs_(la_it*, const char*, la_it*, la_it*, la_it*, la_it*, float*, la_it*, float*, float*, la_it*);
45}
46
47template<sp_d T> class BandMatSpike final : public DenseMat<T> {
48 static constexpr auto TRAN = 'N';
49
50 static la_it SPROTO, DPROTO;
51
52 static T bin;
53
54 const uword l_band;
55 const uword u_band;
56 const uword m_rows; // memory block layout
57
58 la_it SPIKE[64]{};
59
60 podarray<T> WORK;
61 podarray<float> SWORK;
62
63 void init_spike() {
64 auto N = static_cast<la_it>(this->n_rows);
65 auto KLU = static_cast<la_it>(std::max(l_band, u_band));
66
67 spikeinit_(SPIKE, &N, &KLU);
68
69 SPIKE[6] = std::is_same_v<T, float> ? SPROTO : DPROTO;
70 SPIKE[4] = SPIKE[6] + SPIKE[6] / 2 + 10;
71 SPIKE[3] = SPIKE[4] / 2;
72 }
73
74 int solve_trs(Mat<T>&, Mat<T>&&);
75
76protected:
78
79 int direct_solve(Mat<T>&, Mat<T>&&) override;
80
81public:
82 BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
83 : DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
84 , l_band(in_l)
85 , u_band(in_u)
86 , m_rows(in_l + in_u + 1) { init_spike(); }
87
89 : DenseMat<T>(other)
90 , l_band(other.l_band)
91 , u_band(other.u_band)
92 , m_rows(other.m_rows) { init_spike(); }
93
97
98 unique_ptr<MetaMat<T>> unique_copy() override { return std::make_unique<BandMatSpike>(*this); }
99
100 void nullify(const uword K) override {
101 this->factored = false;
102 suanpan::for_each(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[I + u_band + K * (m_rows - 1)] = T(0); });
103 suanpan::for_each(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[K + u_band + I * (m_rows - 1)] = T(0); });
104 }
105
106 T operator()(const uword in_row, const uword in_col) const override {
107 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
108 return bin = T(0);
109 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
110 }
111
112 T& unsafe_at(const uword in_row, const uword in_col) override {
113 this->factored = false;
114 return this->memory[in_row + u_band + in_col * (m_rows - 1)];
115 }
116
117 T& at(const uword in_row, const uword in_col) override {
118 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
119 return bin = T(0);
120 return this->unsafe_at(in_row, in_col);
121 }
122
123 Mat<T> operator*(const Mat<T>&) const override;
124};
125
126template<sp_d T> la_it BandMatSpike<T>::SPROTO = [] {
127 la_it PROTO[64]{};
128 sspike_tune_(PROTO);
129 return PROTO[6];
130}();
131
132template<sp_d T> la_it BandMatSpike<T>::DPROTO = [] {
133 la_it PROTO[64]{};
134 dspike_tune_(PROTO);
135 return PROTO[6];
136}();
137
138template<sp_d T> T BandMatSpike<T>::bin = T(0);
139
140template<sp_d T> Mat<T> BandMatSpike<T>::operator*(const Mat<T>& X) const {
141 Mat<T> Y(arma::size(X));
142
143 const auto M = static_cast<blas_int>(this->n_rows);
144 const auto N = static_cast<blas_int>(this->n_cols);
145 const auto KL = static_cast<blas_int>(l_band);
146 const auto KU = static_cast<blas_int>(u_band);
147 const auto LDA = static_cast<blas_int>(m_rows);
148 static constexpr blas_int INC = 1;
149 static constexpr T ALPHA{1}, BETA{0};
150
151 if constexpr(std::is_same_v<T, float>) suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr(), &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
152 else suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr(), &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
153 return Y;
154}
155
156template<sp_d T> int BandMatSpike<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
157 if(!this->factored) {
158 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
159
160 la_it INFO = 0;
161
162 auto N = static_cast<la_it>(this->n_rows);
163 auto KL = static_cast<la_it>(l_band);
164 auto KU = static_cast<la_it>(u_band);
165 auto LDAB = static_cast<la_it>(m_rows);
166 const auto KLU = std::max(l_band, u_band);
167 this->factored = true;
168
169 if constexpr(std::is_same_v<T, float>) {
170 WORK.zeros(KLU * KLU * SPIKE[9]);
171 sspike_gbtrf_(SPIKE, &N, &KL, &KU, this->memptr(), &LDAB, WORK.memptr(), &INFO);
172 }
173 else if(Precision::FULL == this->setting.precision) {
174 WORK.zeros(KLU * KLU * SPIKE[9]);
175 dspike_gbtrf_(SPIKE, &N, &KL, &KU, this->memptr(), &LDAB, WORK.memptr(), &INFO);
176 }
177 else {
178 this->s_memory = this->to_float();
179 SWORK.zeros(KLU * KLU * SPIKE[9]);
180 sspike_gbtrf_(SPIKE, &N, &KL, &KU, this->s_memory.memptr(), &LDAB, SWORK.memptr(), &INFO);
181 }
182
183 if(0 != INFO) {
184 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
185 return INFO;
186 }
187 }
188
189 return this->solve_trs(X, std::move(B));
190}
191
192template<sp_d T> int BandMatSpike<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
193 auto N = static_cast<la_it>(this->n_rows);
194 auto KL = static_cast<la_it>(l_band);
195 auto KU = static_cast<la_it>(u_band);
196 auto NRHS = static_cast<la_it>(B.n_cols);
197 auto LDAB = static_cast<la_it>(m_rows);
198 auto LDB = static_cast<la_it>(B.n_rows);
199
200 if constexpr(std::is_same_v<T, float>) {
201 sspike_gbtrs_(SPIKE, &TRAN, &N, &KL, &KU, &NRHS, this->memptr(), &LDAB, WORK.memptr(), B.memptr(), &LDB);
202 X = std::move(B);
203 }
204 else if(Precision::FULL == this->setting.precision) {
205 dspike_gbtrs_(SPIKE, &TRAN, &N, &KL, &KU, &NRHS, this->memptr(), &LDAB, WORK.memptr(), B.memptr(), &LDB);
206 X = std::move(B);
207 }
208 else
209 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
210 sspike_gbtrs_(SPIKE, &TRAN, &N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
211 return 0;
212 });
213
214 return SUANPAN_SUCCESS;
215}
216
217#endif
218
A BandMatSpike class that holds matrices.
Definition BandMatSpike.hpp:47
BandMatSpike & operator=(const BandMatSpike &)=delete
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatSpike.hpp:117
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatSpike.hpp:82
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatSpike.hpp:106
void nullify(const uword K) override
Definition BandMatSpike.hpp:100
BandMatSpike & operator=(BandMatSpike &&)=delete
unique_ptr< MetaMat< T > > unique_copy() override
Definition BandMatSpike.hpp:98
BandMatSpike(BandMatSpike &&)=delete
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatSpike.hpp:112
BandMatSpike(const BandMatSpike &other)
Definition BandMatSpike.hpp:88
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
void sspike_gbtrs_(la_it *, const char *, la_it *, la_it *, la_it *, la_it *, float *, la_it *, float *, float *, la_it *)
void dspike_gbtrf_(la_it *, la_it *, la_it *, la_it *, double *, la_it *, double *, la_it *)
void spikeinit_(la_it *, la_it *, la_it *)
void dspike_gbsv_(la_it *, la_it *, la_it *, la_it *, la_it *, double *, la_it *, double *, la_it *, la_it *)
std::int32_t la_it
Definition MetaMat.hpp:38
void dspike_gbtrs_(la_it *, const char *, la_it *, la_it *, la_it *, la_it *, double *, la_it *, double *, double *, la_it *)
Mat< T > operator*(const Mat< T > &) const override
Definition BandMatSpike.hpp:140
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandMatSpike.hpp:156
void dspike_tune_(la_it *)
void sspike_tune_(la_it *)
void sspike_gbsv_(la_it *, la_it *, la_it *, la_it *, la_it *, float *, la_it *, float *, la_it *, la_it *)
void sspike_gbtrf_(la_it *, la_it *, la_it *, la_it *, float *, la_it *, float *, la_it *)
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:166
#define suanpan_error(...)
Definition suanPan.h:349