30 #ifndef SYMMPACKMAT_HPP
31 #define SYMMPACKMAT_HPP
36 static constexpr
char UPLO =
'U';
40 int solve_trs(Mat<T>&, Mat<T>&&);
41 int solve_trs(Mat<T>&,
const Mat<T>&);
46 unique_ptr<MetaMat<T>>
make_copy()
override;
48 void unify(uword)
override;
52 T&
at(uword, uword)
override;
54 Mat<T>
operator*(
const Mat<T>&)
const override;
63 :
DenseMat<
T>(in_size, in_size, (in_size + 1) * in_size / 2) {}
69 access::rw(this->memory[(
K *
K + 3 *
K) / 2]) = 1.;
73 suanpan_for(0llu,
K, [&](
const uword I) { access::rw(this->memory[(
K *
K +
K) / 2 + I]) = 0.; });
74 suanpan_for(
K, this->n_rows, [&](
const uword I) { access::rw(this->memory[(I * I + I) / 2 +
K]) = 0.; });
76 this->factored =
false;
79 template<sp_d T>
const T&
SymmPackMat<T>::operator()(
const uword in_row,
const uword in_col)
const {
return this->memory[in_col > in_row ? (in_col * in_col + in_col) / 2 + in_row : (in_row * in_row + in_row) / 2 + in_col]; }
82 if(in_col < in_row)
return bin;
83 this->factored =
false;
84 return access::rw(this->memory[(in_col * in_col + in_col) / 2 + in_row]);
87 template<const
char S, const
char T, sp_d T1> Mat<T1>
spmm(
const SymmPackMat<T1>& A,
const Mat<T1>& B);
90 if(!X.is_colvec())
return spmm<'R', 'N'>(*
this, X);
94 const auto N =
static_cast<int>(this->n_rows);
95 constexpr
auto INC = 1;
99 if(std::is_same_v<T, float>) {
101 arma_fortran(arma_sspmv)(&UPLO, &
N, (
E*)&ALPHA, (
E*)this->memptr(), (
E*)X.memptr(), &INC, (
E*)&BETA, (
E*)Y.memptr(), &INC);
103 else if(std::is_same_v<T, double>) {
105 arma_fortran(arma_dspmv)(&UPLO, &
N, (
E*)&ALPHA, (
E*)this->memptr(), (
E*)X.memptr(), &INC, (
E*)&BETA, (
E*)Y.memptr(), &INC);
112 if(this->factored)
return this->solve_trs(X, B);
114 const auto N =
static_cast<int>(this->n_rows);
115 const auto NRHS =
static_cast<int>(B.n_cols);
116 const auto LDB =
static_cast<int>(B.n_rows);
119 this->factored =
true;
121 if(std::is_same_v<T, float>) {
124 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
129 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
132 this->s_memory = this->to_float();
133 arma_fortran(arma_spptrf)(&UPLO, &
N, this->s_memory.memptr(), &INFO);
134 if(0 == INFO) INFO = this->solve_trs(X, B);
137 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
143 const auto N =
static_cast<int>(this->n_rows);
144 const auto NRHS =
static_cast<int>(B.n_cols);
145 const auto LDB =
static_cast<int>(B.n_rows);
148 if(std::is_same_v<T, float>) {
151 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
156 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)X.memptr(), &LDB, &INFO);
159 X = arma::zeros(B.n_rows, B.n_cols);
161 mat full_residual = B;
163 auto multiplier =
norm(full_residual);
166 while(counter++ < this->setting.iterative_refinement) {
167 if(multiplier < this->setting.tolerance)
break;
169 auto residual = conv_to<fmat>::from(full_residual / multiplier);
171 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
174 const mat incre = multiplier * conv_to<mat>::from(residual);
178 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
norm(full_residual -= this->
operator*(incre)));
182 if(INFO != 0)
suanpan_error(
"solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
188 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
190 const auto N =
static_cast<int>(this->n_rows);
191 const auto NRHS =
static_cast<int>(B.n_cols);
192 const auto LDB =
static_cast<int>(B.n_rows);
195 this->factored =
true;
197 if(std::is_same_v<T, float>) {
199 arma_fortran(arma_sppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
204 arma_fortran(arma_dppsv)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
208 this->s_memory = this->to_float();
209 arma_fortran(arma_spptrf)(&UPLO, &
N, this->s_memory.memptr(), &INFO);
210 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
213 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
219 const auto N =
static_cast<int>(this->n_rows);
220 const auto NRHS =
static_cast<int>(B.n_cols);
221 const auto LDB =
static_cast<int>(B.n_rows);
224 if(std::is_same_v<T, float>) {
226 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
231 arma_fortran(arma_dpptrs)(&UPLO, &
N, &NRHS, (
E*)this->memptr(), (
E*)B.memptr(), &LDB, &INFO);
235 X = arma::zeros(B.n_rows, B.n_cols);
240 while(counter++ < this->setting.iterative_refinement) {
241 if(multiplier < this->setting.tolerance)
break;
243 auto residual = conv_to<fmat>::from(B / multiplier);
245 arma_fortran(arma_spptrs)(&UPLO, &
N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
248 const mat incre = multiplier * conv_to<mat>::from(residual);
252 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
norm(B -= this->
operator*(incre)));
256 if(INFO != 0)
suanpan_error(
"solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24