30 #ifndef BANDSYMMMAT_HPP
31 #define BANDSYMMMAT_HPP
36 static constexpr
char UPLO =
'L';
43 int solve_trs(Mat<T>&, Mat<T>&&);
44 int solve_trs(Mat<T>&,
const Mat<T>&);
49 unique_ptr<MetaMat<T>>
make_copy()
override;
51 void unify(uword)
override;
55 T&
at(uword, uword)
override;
57 Mat<T>
operator*(
const Mat<T>&)
const override;
66 :
DenseMat<
T>(in_size, in_size, (in_bandwidth + 1) * in_size)
68 , m_rows(in_bandwidth + 1) {}
74 access::rw(this->memory[
K * m_rows]) = 1.;
78 suanpan_for(std::max(band,
K) - band,
K, [&](
const uword I) { access::rw(this->memory[
K - I + I * m_rows]) = 0.; });
79 suanpan_for(
K, std::min(this->n_rows,
K + band + 1), [&](
const uword I) { access::rw(this->memory[I -
K +
K * m_rows]) = 0.; });
81 this->factored =
false;
85 if(in_row > band + in_col)
return bin = 0.;
86 return this->memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
90 if(in_row > band + in_col || in_row < in_col)
return bin = 0.;
91 this->factored =
false;
92 return access::rw(this->memory[in_row - in_col + in_col * m_rows]);
96 Mat<T> Y(arma::size(X));
98 const auto N =
static_cast<int>(this->n_cols);
99 const auto K =
static_cast<int>(band);
100 const auto LDA =
static_cast<int>(m_rows);
105 if(std::is_same_v<T, float>) {
107 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
109 else if(std::is_same_v<T, double>) {
111 suanpan_for(0llu, X.n_cols, [&](
const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
118 if(this->factored)
return this->solve_trs(X, B);
120 const auto N =
static_cast<int>(this->n_rows);
121 const auto KD =
static_cast<int>(band);
122 const auto NRHS =
static_cast<int>(B.n_cols);
123 const auto LDAB =
static_cast<int>(m_rows);
124 const auto LDB =
static_cast<int>(B.n_rows);
127 this->factored =
true;
129 if(std::is_same_v<T, float>) {
132 arma_fortran(arma_spbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
137 arma_fortran(arma_dpbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
140 this->s_memory = this->to_float();
141 arma_fortran(arma_spbtrf)(&UPLO, &
N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
142 if(0 == INFO) INFO = this->solve_trs(X, B);
145 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
151 const auto N =
static_cast<int>(this->n_rows);
152 const auto KD =
static_cast<int>(band);
153 const auto NRHS =
static_cast<int>(B.n_cols);
154 const auto LDAB =
static_cast<int>(m_rows);
155 const auto LDB =
static_cast<int>(B.n_rows);
158 if(std::is_same_v<T, float>) {
161 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
166 arma_fortran(arma_dpbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)X.memptr(), &LDB, &INFO);
169 X = arma::zeros(B.n_rows, B.n_cols);
171 mat full_residual = B;
173 auto multiplier =
norm(full_residual);
176 while(counter++ < this->setting.iterative_refinement) {
177 if(multiplier < this->setting.tolerance)
break;
179 auto residual = conv_to<fmat>::from(full_residual / multiplier);
181 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
184 const mat incre = multiplier * conv_to<mat>::from(residual);
188 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
arma::norm(full_residual -= this->
operator*(incre)));
192 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
198 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
200 const auto N =
static_cast<int>(this->n_rows);
201 const auto KD =
static_cast<int>(band);
202 const auto NRHS =
static_cast<int>(B.n_cols);
203 const auto LDAB =
static_cast<int>(m_rows);
204 const auto LDB =
static_cast<int>(B.n_rows);
207 this->factored =
true;
209 if(std::is_same_v<T, float>) {
211 arma_fortran(arma_spbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
216 arma_fortran(arma_dpbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
220 this->s_memory = this->to_float();
221 arma_fortran(arma_spbtrf)(&UPLO, &
N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
222 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
225 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
231 const auto N =
static_cast<int>(this->n_rows);
232 const auto KD =
static_cast<int>(band);
233 const auto NRHS =
static_cast<int>(B.n_cols);
234 const auto LDAB =
static_cast<int>(m_rows);
235 const auto LDB =
static_cast<int>(B.n_rows);
238 if(std::is_same_v<T, float>) {
240 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
245 arma_fortran(arma_dpbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
249 X = arma::zeros(B.n_rows, B.n_cols);
251 auto multiplier =
norm(B);
254 while(counter++ < this->setting.iterative_refinement) {
255 if(multiplier < this->setting.tolerance)
break;
257 auto residual = conv_to<fmat>::from(B / multiplier);
259 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
262 const mat incre = multiplier * conv_to<mat>::from(residual);
266 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
arma::norm(B -= this->
operator*(incre)));
270 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
A BandSymmMat class that holds matrices.
Definition: BandSymmMat.hpp:35
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24