30#ifndef BANDMATCLUSTER_HPP
31#define BANDMATCLUSTER_HPP
33#include "../DenseMat.hpp"
35#include <ezp/ezp/pgbsv.hpp>
38 using solver_t = ezp::pgbsv<T, la_it>;
39 using indexer_t = solver_t::indexer;
46 int solve_trs(Mat<T>&, Mat<T>&&);
57 :
DenseMat<T>(in_size, in_size, (2 * (in_l + in_u) + 1) * in_size)
59 , indexer(in_size, in_l, in_u)
62 if(2 * (in_l + in_u) + 1 >= in_size)
63 suanpan_warning(
"The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
66 unique_ptr<MetaMat<T>>
unique_copy()
override {
return std::make_unique<BandMatCluster>(*
this); }
74 T
operator()(
const uword in_row,
const uword in_col)
const override {
75 const auto pos = indexer(in_row, in_col);
76 if(pos < 0) [[unlikely]]
81 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
83 return this->
memory[indexer(in_row, in_col)];
86 T&
at(
const uword in_row,
const uword in_col)
override {
87 const auto pos = indexer(in_row, in_col);
88 if(pos < 0) [[unlikely]]
94 Mat<T>
operator*(
const Mat<T>&)
const override;
100 static constexpr auto TRAN =
'N';
101 static constexpr blas_int INC = 1;
102 static constexpr T ALPHA{1}, BETA{0};
104 Mat<T> Y(arma::size(X));
106 const auto s_band = l_band + u_band;
108 const auto M =
static_cast<blas_int
>(this->n_rows);
109 const auto N =
static_cast<blas_int
>(this->n_cols);
110 const auto KL =
static_cast<blas_int
>(l_band);
111 const auto KU =
static_cast<blas_int
>(u_band);
112 const auto LDA =
static_cast<blas_int
>(2 * s_band + 1);
114 if constexpr(std::is_same_v<T, float>)
suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &
M, &
N, &KL, &KU, &ALPHA, this->memptr() + s_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
115 else suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr() + s_band, &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
120#pragma GCC diagnostic push
121#pragma GCC diagnostic ignored "-Wnarrowing"
123 if(this->factored)
return this->solve_trs(X, std::move(B));
125 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
127 this->factored =
true;
129 const auto INFO =
bcast_from_root(solver.solve({this->n_rows, this->n_cols, this->l_band, this->u_band, this->memptr()}, {B.n_rows, B.n_cols, B.memptr()}));
132 else suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
138 const auto INFO =
bcast_from_root(solver.solve({B.n_rows, B.n_cols, B.memptr()}));
141 else suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
145#pragma GCC diagnostic pop
A BandMatCluster class that holds matrices.
Definition BandMatCluster.hpp:37
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatCluster.hpp:74
void nullify(const uword K) override
Definition BandMatCluster.hpp:68
const uword l_band
Definition BandMatCluster.hpp:49
BandMatCluster(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatCluster.hpp:56
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatCluster.hpp:81
unique_ptr< MetaMat< T > > unique_copy() override
Definition BandMatCluster.hpp:66
const uword u_band
Definition BandMatCluster.hpp:49
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatCluster.hpp:86
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
#define suanpan_warning(...)
Definition suanPan.h:348
auto bcast_from_root(T &&object)
Definition suanPan.h:238
#define suanpan_error(...)
Definition suanPan.h:349