33#include "../DenseMat.hpp"
36 static constexpr auto TRAN =
'N';
42 int solve_trs(Mat<T>&, Mat<T>&&);
55 BandMat(
const uword in_size,
const uword in_l,
const uword in_u)
56 :
DenseMat<T>(in_size, in_size, (2 * in_l + in_u + 1) * in_size)
58 ,
m_rows(2 * in_l + in_u + 1)
62 suanpan_warning(
"The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
65 unique_ptr<MetaMat<T>>
unique_copy()
override {
return std::make_unique<BandMat>(*
this); }
73 T
operator()(
const uword in_row,
const uword in_col)
const override {
74 if(in_row > in_col +
l_band || in_row +
u_band < in_col) [[unlikely]]
76 return this->
memory[in_row + s_band + in_col * (
m_rows - 1)];
79 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
81 return this->
memory[in_row + s_band + in_col * (
m_rows - 1)];
84 T&
at(
const uword in_row,
const uword in_col)
override {
85 if(in_row > in_col +
l_band || in_row +
u_band < in_col) [[unlikely]]
90 Mat<T>
operator*(
const Mat<T>&)
const override;
93 std::function<bool(uword)> neg_diag;
95 else neg_diag = [&](
const uword i) {
return this->
s_memory[s_band + i *
m_rows] < 0.f; };
98 for(
unsigned I = 0; I < this->
pivot.n_elem; ++I)
99 if(neg_diag(I) ^ (
static_cast<int>(I) + 1 != this->
pivot(I))) det_sign = -det_sign;
107 Mat<T> Y(arma::size(X));
109 const auto M =
static_cast<blas_int
>(this->n_rows);
110 const auto N =
static_cast<blas_int
>(this->n_cols);
111 const auto KL =
static_cast<blas_int
>(l_band);
112 const auto KU =
static_cast<blas_int
>(u_band);
113 const auto LDA =
static_cast<blas_int
>(m_rows);
114 static constexpr blas_int INC = 1;
115 static constexpr T ALPHA{1}, BETA{0};
117 if constexpr(std::is_same_v<T, float>)
suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &
M, &
N, &KL, &KU, &ALPHA, this->memptr() + l_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
118 else suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr() + l_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
123 if(this->factored)
return this->solve_trs(X, std::move(B));
125 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
129 auto N =
static_cast<blas_int
>(this->n_rows);
130 const auto KL =
static_cast<blas_int
>(l_band);
131 const auto KU =
static_cast<blas_int
>(u_band);
132 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
133 const auto LDAB =
static_cast<blas_int
>(m_rows);
134 const auto LDB =
static_cast<blas_int
>(B.n_rows);
135 this->pivot.zeros(
N);
136 this->factored =
true;
138 if constexpr(std::is_same_v<T, float>) {
139 arma_fortran(arma_sgbsv)(&
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
143 arma_fortran(arma_dgbsv)(&
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
147 this->s_memory = this->to_float();
148 arma_fortran(arma_sgbtrf)(&
N, &
N, &KL, &KU, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), &INFO);
149 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
153 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
161 const auto N =
static_cast<blas_int
>(this->n_rows);
162 const auto KL =
static_cast<blas_int
>(l_band);
163 const auto KU =
static_cast<blas_int
>(u_band);
164 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
165 const auto LDAB =
static_cast<blas_int
>(m_rows);
166 const auto LDB =
static_cast<blas_int
>(B.n_rows);
168 if constexpr(std::is_same_v<T, float>) {
169 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
173 arma_fortran(arma_dgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
177 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
178 arma_fortran(arma_sgbtrs)(&TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
183 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandMat class that holds matrices.
Definition BandMat.hpp:35
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMat.hpp:79
BandMat(const uword in_size, const uword in_l, const uword in_u)
Definition BandMat.hpp:55
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMat.hpp:84
const uword u_band
Definition BandMat.hpp:48
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMat.hpp:73
unique_ptr< MetaMat< T > > unique_copy() override
Definition BandMat.hpp:65
const uword l_band
Definition BandMat.hpp:47
const uword m_rows
Definition BandMat.hpp:45
int sign_det() const override
Definition BandMat.hpp:92
void nullify(const uword K) override
Definition BandMat.hpp:67
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
podarray< float > s_memory
Definition DenseMat.hpp:46
podarray< blas_int > pivot
Definition DenseMat.hpp:45
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
Precision precision
Definition SolverSetting.hpp:32
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
#define suanpan_warning(...)
Definition suanPan.h:348
#define suanpan_error(...)
Definition suanPan.h:349