30#ifndef BANDSYMMMAT_HPP
31#define BANDSYMMMAT_HPP
33#include "../DenseMat.hpp"
36 static constexpr char UPLO =
'L';
43 int solve_trs(Mat<T>&, Mat<T>&&);
52 :
DenseMat<
T>(in_size, in_size, (in_bandwidth + 1) * in_size)
54 , m_rows(in_bandwidth + 1) {}
56 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandSymmMat>(*
this); }
61 const auto t_factor =
K * m_rows -
K;
65 T operator()(
const uword in_row,
const uword in_col)
const override {
66 if(in_row > band + in_col || in_col > in_row + band) [[unlikely]]
68 return this->
memory[in_row > in_col ? in_row - in_col + in_col * m_rows : in_col - in_row + in_row * m_rows];
71 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
73 return this->
memory[in_row - in_col + in_col * m_rows];
76 T&
at(
const uword in_row,
const uword in_col)
override {
77 if(in_row > band + in_col || in_row < in_col) [[unlikely]]
82 Mat<T>
operator*(
const Mat<T>&)
const override;
84 [[nodiscard]]
int sign_det()
const override {
return 1; }
90 Mat<T> Y(arma::size(X));
92 const auto N =
static_cast<blas_int
>(this->n_cols);
93 const auto K =
static_cast<blas_int
>(band);
94 const auto LDA =
static_cast<blas_int
>(m_rows);
95 constexpr blas_int INC = 1;
99 if constexpr(std::is_same_v<T, float>) {
101 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
105 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
112 if(this->factored)
return this->solve_trs(X, std::move(B));
114 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
118 const auto N =
static_cast<blas_int
>(this->n_rows);
119 const auto KD =
static_cast<blas_int
>(band);
120 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
121 const auto LDAB =
static_cast<blas_int
>(m_rows);
122 const auto LDB =
static_cast<blas_int
>(B.n_rows);
123 this->factored =
true;
125 if constexpr(std::is_same_v<T, float>) {
127 arma_fortran(arma_spbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
132 arma_fortran(arma_dpbsv)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
136 this->s_memory = this->to_float();
137 arma_fortran(arma_spbtrf)(&UPLO, &
N, &KD, this->s_memory.memptr(), &LDAB, &INFO);
138 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
142 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
150 const auto N =
static_cast<blas_int
>(this->n_rows);
151 const auto KD =
static_cast<blas_int
>(band);
152 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
153 const auto LDAB =
static_cast<blas_int
>(m_rows);
154 const auto LDB =
static_cast<blas_int
>(B.n_rows);
156 if constexpr(std::is_same_v<T, float>) {
158 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
163 arma_fortran(arma_dpbtrs)(&UPLO, &
N, &KD, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)B.memptr(), &LDB, &INFO);
167 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
168 arma_fortran(arma_spbtrs)(&UPLO, &
N, &KD, &NRHS, this->s_memory.memptr(), &LDAB, residual.memptr(), &LDB, &INFO);
173 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
A BandSymmMat class that holds matrices.
Definition BandSymmMat.hpp:35
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandSymmMat.hpp:65
BandSymmMat(const uword in_size, const uword in_bandwidth)
Definition BandSymmMat.hpp:51
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandSymmMat.hpp:76
unique_ptr< MetaMat< T > > make_copy() override
Definition BandSymmMat.hpp:56
int sign_det() const override
Definition BandSymmMat.hpp:84
void nullify(const uword K) override
Definition BandSymmMat.hpp:58
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandSymmMat.hpp:71
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
#define suanpan_error(...)
Definition suanPan.h:376