30#ifndef BANDMATSPIKE_HPP
31#define BANDMATSPIKE_HPP
33#include "../DenseMat.hpp"
40void dspike_gbtrs_(
la_it*,
const char*,
la_it*,
la_it*,
la_it*,
la_it*,
double*,
la_it*,
double*,
double*,
la_it*);
44void sspike_gbtrs_(
la_it*,
const char*,
la_it*,
la_it*,
la_it*,
la_it*,
float*,
la_it*,
float*,
float*,
la_it*);
48 static constexpr auto TRAN =
'N';
50 static la_it SPROTO, DPROTO;
61 podarray<float> SWORK;
65 auto KLU =
static_cast<la_it>(std::max(l_band, u_band));
69 SPIKE[6] = std::is_same_v<T, float> ? SPROTO : DPROTO;
70 SPIKE[4] = SPIKE[6] + SPIKE[6] / 2 + 10;
71 SPIKE[3] = SPIKE[4] / 2;
74 int solve_trs(Mat<T>&, Mat<T>&&);
82 BandMatSpike(
const uword in_size,
const uword in_l,
const uword in_u)
83 :
DenseMat<T>(in_size, in_size, (in_l + in_u + 1) * in_size)
86 , m_rows(in_l + in_u + 1) { init_spike(); }
90 , l_band(other.l_band)
91 , u_band(other.u_band)
92 , m_rows(other.m_rows) { init_spike(); }
98 unique_ptr<MetaMat<T>>
unique_copy()
override {
return std::make_unique<BandMatSpike>(*
this); }
102 suanpan::for_each(std::max(
K, u_band) - u_band, std::min(this->
n_rows, K + l_band + 1), [&](
const uword I) { this->
memory[I + u_band +
K * (m_rows - 1)] = T(0); });
103 suanpan::for_each(std::max(
K, l_band) - l_band, std::min(this->
n_cols, K + u_band + 1), [&](
const uword I) { this->
memory[K + u_band + I * (m_rows - 1)] = T(0); });
106 T
operator()(
const uword in_row,
const uword in_col)
const override {
107 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
109 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
112 T&
unsafe_at(
const uword in_row,
const uword in_col)
override {
114 return this->
memory[in_row + u_band + in_col * (m_rows - 1)];
117 T&
at(
const uword in_row,
const uword in_col)
override {
118 if(in_row > in_col + l_band || in_row + u_band < in_col) [[unlikely]]
123 Mat<T>
operator*(
const Mat<T>&)
const override;
141 Mat<T> Y(arma::size(X));
143 const auto M =
static_cast<blas_int
>(this->n_rows);
144 const auto N =
static_cast<blas_int
>(this->n_cols);
145 const auto KL =
static_cast<blas_int
>(l_band);
146 const auto KU =
static_cast<blas_int
>(u_band);
147 const auto LDA =
static_cast<blas_int
>(m_rows);
148 static constexpr blas_int INC = 1;
149 static constexpr T ALPHA{1}, BETA{0};
151 if constexpr(std::is_same_v<T, float>)
suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &
M, &
N, &KL, &KU, &ALPHA, this->memptr(), &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
152 else suanpan::for_each(X.n_cols, [&](
const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, &ALPHA, this->memptr(), &LDA, X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
157 if(!this->factored) {
158 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
162 auto N =
static_cast<la_it>(this->n_rows);
163 auto KL =
static_cast<la_it>(l_band);
164 auto KU =
static_cast<la_it>(u_band);
165 auto LDAB =
static_cast<la_it>(m_rows);
166 const auto KLU = std::max(l_band, u_band);
167 this->factored =
true;
169 if constexpr(std::is_same_v<T, float>) {
170 WORK.zeros(KLU * KLU *
SPIKE[9]);
174 WORK.zeros(KLU * KLU *
SPIKE[9]);
178 this->s_memory = this->to_float();
179 SWORK.zeros(KLU * KLU *
SPIKE[9]);
184 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
189 return this->solve_trs(X, std::move(B));
193 auto N =
static_cast<la_it>(this->n_rows);
194 auto KL =
static_cast<la_it>(l_band);
195 auto KU =
static_cast<la_it>(u_band);
196 auto NRHS =
static_cast<la_it>(B.n_cols);
197 auto LDAB =
static_cast<la_it>(m_rows);
198 auto LDB =
static_cast<la_it>(B.n_rows);
200 if constexpr(std::is_same_v<T, float>) {
201 sspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, WORK.memptr(), B.memptr(), &LDB);
205 dspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, this->memptr(), &LDAB, WORK.memptr(), B.memptr(), &LDB);
209 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
210 sspike_gbtrs_(
SPIKE, &TRAN, &
N, &KL, &KU, &NRHS, this->s_memory.memptr(), &LDAB, SWORK.memptr(), residual.memptr(), &LDB);
A BandMatSpike class that holds matrices.
Definition BandMatSpike.hpp:47
BandMatSpike & operator=(const BandMatSpike &)=delete
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatSpike.hpp:117
BandMatSpike(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatSpike.hpp:82
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatSpike.hpp:106
void nullify(const uword K) override
Definition BandMatSpike.hpp:100
BandMatSpike & operator=(BandMatSpike &&)=delete
unique_ptr< MetaMat< T > > unique_copy() override
Definition BandMatSpike.hpp:98
BandMatSpike(BandMatSpike &&)=delete
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatSpike.hpp:112
BandMatSpike(const BandMatSpike &other)
Definition BandMatSpike.hpp:88
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
constexpr auto SUANPAN_SUCCESS
Definition suanPan.h:166
#define suanpan_error(...)
Definition suanPan.h:349