30#ifndef BANDSYMMMATCLUSTER_HPP
31#define BANDSYMMMATCLUSTER_HPP
33#include "../DenseMat.hpp"
35#include <ezp/ezp/ppbsv.hpp>
38 static constexpr char UPLO =
'L';
40 using solver_t = ezp::ppbsv<T, la_it, UPLO>;
41 using indexer_t =
typename solver_t::indexer;
50 int solve_trs(Mat<T>&, Mat<T>&&);
59 :
DenseMat<
T>(in_size, in_size, (in_bandwidth + 1) * in_size)
61 , indexer(in_size, in_bandwidth)
62 , band(in_bandwidth) {}
64 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandSymmMatCluster>(*
this); }
72 T operator()(
const uword in_row,
const uword in_col)
const override {
73 const auto pos = indexer(in_row, in_col);
74 if(pos < 0) [[unlikely]]
79 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
81 return this->
memory[indexer(in_row, in_col)];
84 T&
at(
const uword in_row,
const uword in_col)
override {
85 const auto pos = indexer(in_row, in_col);
86 if(pos < 0) [[unlikely]]
92 Mat<T>
operator*(
const Mat<T>&)
const override;
98 static constexpr blas_int INC = 1;
99 static constexpr T ALPHA =
T(1), BETA =
T(0);
101 Mat<T> Y(arma::size(X));
103 const auto N =
static_cast<blas_int
>(this->n_cols);
104 const auto K =
static_cast<blas_int
>(band);
105 const auto LDA =
K + 1;
107 if constexpr(std::is_same_v<T, float>) {
109 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
113 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
119#pragma GCC diagnostic push
120#pragma GCC diagnostic ignored "-Wnarrowing"
122 if(this->factored)
return this->solve_trs(X, std::move(B));
124 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
126 this->factored =
true;
128 const auto INFO =
bcast_from_root(solver.solve({this->n_rows, this->n_cols, this->band, this->memptr()}, {B.n_rows, B.n_cols, B.memptr()}));
131 else suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
137 const auto INFO =
bcast_from_root(solver.solve({B.n_rows, B.n_cols, B.memptr()}));
140 else suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
144#pragma GCC diagnostic pop
A BandSymmMatCluster class that holds matrices.
Definition BandSymmMatCluster.hpp:37
unique_ptr< MetaMat< T > > make_copy() override
Definition BandSymmMatCluster.hpp:64
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandSymmMatCluster.hpp:72
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandSymmMatCluster.hpp:84
void nullify(const uword K) override
Definition BandSymmMatCluster.hpp:66
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandSymmMatCluster.hpp:79
BandSymmMatCluster(const uword in_size, const uword in_bandwidth)
Definition BandSymmMatCluster.hpp:58
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
auto bcast_from_root(T &&object)
Definition suanPan.h:254
#define suanpan_error(...)
Definition suanPan.h:376