30 #ifndef SYMMPACKMAT_HPP 31 #define SYMMPACKMAT_HPP 36 static constexpr
char UPLO =
'U';
40 int solve_trs(Mat<T>&, Mat<T>&&);
41 int solve_trs(Mat<T>&,
const Mat<T>&);
46 unique_ptr<MetaMat<T>>
make_copy()
override;
48 void unify(uword)
override;
52 T&
at(uword, uword)
override;
56 int solve(Mat<T>&, Mat<T>&&)
override;
57 int solve(Mat<T>&,
const Mat<T>&)
override;
63 :
DenseMat<
T>(in_size, in_size, (in_size + 1) * in_size / 2) {}
69 access::rw(this->
memory[(K * K + 3 * K) / 2]) = 1.;
73 suanpan_for(0llu, K, [&](
const uword I) { access::rw(this->
memory[(K * K + K) / 2 + I]) = 0.; });
79 template<sp_d T>
const T&
SymmPackMat<T>::operator()(
const uword in_row,
const uword in_col)
const {
return this->
memory[in_col > in_row ? (in_col * in_col + in_col) / 2 + in_row : (in_row * in_row + in_row) / 2 + in_col]; }
82 if(in_col < in_row)
return bin;
84 return access::rw(this->
memory[(in_col * in_col + in_col) / 2 + in_row]);
87 template<const
char S, const
char T, sp_d T1> Mat<T1>
spmm(
const SymmPackMat<T1>&
A,
const Mat<T1>& B);
90 if(!X.is_colvec())
return spmm<'R', 'N'>(*
this, X);
94 const auto N =
static_cast<int>(this->
n_rows);
95 constexpr
auto INC = 1;
99 if(std::is_same_v<T, float>) {
101 arma_fortran(arma_sspmv)(&UPLO, &
N, (
E*)&ALPHA, (
E*)this->
memptr(), (
E*)X.memptr(), &INC, (
E*)&BETA, (
E*)Y.memptr(), &INC);
103 else if(std::is_same_v<T, double>) {
105 arma_fortran(arma_dspmv)(&UPLO, &
N, (
E*)&ALPHA, (
E*)this->
memptr(), (
E*)X.memptr(), &INC, (
E*)&BETA, (
E*)Y.memptr(), &INC);
112 if(this->
factored)
return this->solve_trs(X, B);
114 const auto N =
static_cast<int>(this->
n_rows);
115 const auto NRHS =
static_cast<int>(B.n_cols);
116 const auto LDB =
static_cast<int>(B.n_rows);
121 if(std::is_same_v<T, float>) {
124 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)X.memptr(), &LDB, &INFO);
129 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)X.memptr(), &LDB, &INFO);
133 arma_fortran(arma_spptrf)(&UPLO, &
N, this->
s_memory.memptr(), &INFO);
134 if(0 == INFO) INFO = this->solve_trs(X, B);
137 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
143 const auto N =
static_cast<int>(this->
n_rows);
144 const auto NRHS =
static_cast<int>(B.n_cols);
145 const auto LDB =
static_cast<int>(B.n_rows);
148 if(std::is_same_v<T, float>) {
151 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)X.memptr(), &LDB, &INFO);
156 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)X.memptr(), &LDB, &INFO);
159 X = arma::zeros(B.n_rows, B.n_cols);
161 mat full_residual = B;
163 auto multiplier =
norm(full_residual);
169 auto residual = conv_to<fmat>::from(full_residual / multiplier);
171 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->
s_memory.memptr(), residual.memptr(), &LDB, &INFO);
174 const mat incre = multiplier * conv_to<mat>::from(residual);
178 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
norm(full_residual -= this->
operator*(incre)));
182 if(INFO != 0)
suanpan_error(
"solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
188 if(this->
factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
190 const auto N =
static_cast<int>(this->
n_rows);
191 const auto NRHS =
static_cast<int>(B.n_cols);
192 const auto LDB =
static_cast<int>(B.n_rows);
197 if(std::is_same_v<T, float>) {
199 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)B.memptr(), &LDB, &INFO);
204 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)B.memptr(), &LDB, &INFO);
209 arma_fortran(arma_spptrf)(&UPLO, &
N, this->
s_memory.memptr(), &INFO);
210 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
213 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
219 const auto N =
static_cast<int>(this->
n_rows);
220 const auto NRHS =
static_cast<int>(B.n_cols);
221 const auto LDB =
static_cast<int>(B.n_rows);
224 if(std::is_same_v<T, float>) {
226 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)B.memptr(), &LDB, &INFO);
231 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->
memptr(), (
E*)B.memptr(), &LDB, &INFO);
235 X = arma::zeros(B.n_rows, B.n_cols);
243 auto residual = conv_to<fmat>::from(B / multiplier);
245 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->
s_memory.memptr(), residual.memptr(), &LDB, &INFO);
248 const mat incre = multiplier * conv_to<mat>::from(residual);
252 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
norm(B -= this->
operator*(incre)));
256 if(INFO != 0)
suanpan_error(
"solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
void suanpan_error(const char *M,...)
Definition: print.cpp:116
const T *const memory
Definition: DenseMat.hpp:43
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
double norm(const vec &)
Definition: tensorToolbox.cpp:302
podarray< float > s_memory
Definition: DenseMat.hpp:39
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24