30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
33#include "../DenseMat.hpp"
40void dspike_gbtrs_(
la_it*,
const char*,
la_it*,
la_it*,
la_it*,
la_it*,
double*,
la_it*,
double*,
double*,
la_it*);
44void sspike_gbtrs_(
la_it*,
const char*,
la_it*,
la_it*,
la_it*,
la_it*,
float*,
la_it*,
float*,
float*,
la_it*);
48 static constexpr char TRAN =
'N';
50 static la_it SPROTO, DPROTO;
61 podarray<float> SWORK;
65 auto KLU =
static_cast<la_it>(std::max(l_band, u_band));
69 SPIKE[6] = std::is_same_v<T, float> ? SPROTO : DPROTO;
70 SPIKE[4] = SPIKE[6] + SPIKE[6] / 2 + 10;
71 SPIKE[3] = SPIKE[4] / 2;
74 int solve_trs(Mat<T>&, Mat<T>&&);
82 BandMatSpike(
const uword in_size,
const uword in_l,
const uword in_u)
83 :
DenseMat<
T>(in_size, in_size, (in_l + in_u + 1) * in_size)
86 , m_rows(in_l + in_u + 1) { init_spike(); }
90 , l_band(other.l_band)
91 , u_band(other.u_band)
92 , m_rows(other.m_rows) { init_spike(); }
98 unique_ptr<MetaMat<T>>
make_copy()
override {
return std::make_unique<BandMatSpike>(*
this); }
102 suanpan::for_each(std::max(
K, u_band) - u_band, std::min(this->
n_rows, K + l_band + 1), [&](
const uword I) { this->
memory[I + u_band +
K * (m_rows - 1)] =
T(0); });
103 suanpan::for_each(std::max(
K, l_band) - l_band, std::min(this->
n_cols, K + u_band + 1), [&](
const uword I) { this->
memory[K + u_band + I * (m_rows - 1)] =
T(0); });
106 T operator()(
const uword in_row,
const uword in_col)
const override {
107 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
109 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
112 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
114 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
117 T&
at(
const uword in_row,
const uword in_col)
override {
118 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
123 Mat<T>
operator*(
const Mat<T>&)
const override;
141 Mat<T> Y(arma::size(X));
143 const auto M =
static_cast<blas_int
>(this->n_rows);
144 const auto N =
static_cast<blas_int
>(this->n_cols);
145 const auto KL =
static_cast<blas_int
>(l_band);
146 const auto KU =
static_cast<blas_int
>(u_band);
147 const auto LDA =
static_cast<blas_int
>(m_rows);
148 constexpr blas_int INC = 1;
152 if constexpr(std::is_same_v<T, float>) {
154 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
158 suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
165 if(!this->factored) {
166 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
170 auto N =
static_cast<la_it>(this->n_rows);
171 auto KL =
static_cast<la_it>(l_band);
172 auto KU =
static_cast<la_it>(u_band);
173 auto LDAB =
static_cast<la_it>(m_rows);
174 const auto KLU = std::max(l_band, u_band);
175 this->factored =
true;
177 if constexpr(std::is_same_v<T, float>) {
179 WORK.zeros(KLU * KLU *
SPIKE[9]);
184 WORK.zeros(KLU * KLU *
SPIKE[9]);
188 this->s_memory = this->to_float();
189 SWORK.zeros(KLU * KLU *
SPIKE[9]);
194 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
199 return this->solve_trs(X, std::move(B));
203 auto N =
static_cast<la_it>(this->n_rows);
204 auto KL =
static_cast<la_it>(l_band);
205 auto KU =
static_cast<la_it>(u_band);
206 auto NRHS =
static_cast<la_it>(B.n_cols);
207 auto LDAB =
static_cast<la_it>(m_rows);
208 auto LDB =
static_cast<la_it>(B.n_rows);
210 if constexpr(std::is_same_v<T, float>) {
212 sspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)B.memptr(), &LDB);
217 dspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, (
E*)this->memptr(), &LDAB, (
E*)WORK.memptr(), (
E*)B.memptr(), &LDB);
221 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
222 sspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
A BandMatSpike class that holds matrices.
Definition BandMatSpike.hpp:47
BandMatSpike & operator=(const BandMatSpike &)=delete
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatSpike.hpp:117
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatSpike.hpp:82
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatSpike.hpp:106
void nullify(const uword K) override
Definition BandMatSpike.hpp:100
BandMatSpike & operator=(BandMatSpike &&)=delete
BandMatSpike(BandMatSpike &&)=delete
unique_ptr< MetaMat< T > > make_copy() override
Definition BandMatSpike.hpp:98
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatSpike.hpp:112
BandMatSpike(const BandMatSpike &other)
Definition BandMatSpike.hpp:88
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:180
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
#define suanpan_error(...)
Definition suanPan.h:376