36 static constexpr
char TRAN =
'N';
38 int solve_trs(Mat<T>&, Mat<T>&&);
39 int solve_trs(Mat<T>&,
const Mat<T>&);
44 unique_ptr<MetaMat<T>>
make_copy()
override;
46 void unify(uword)
override;
51 T&
at(uword, uword)
override;
53 Mat<T>
operator*(
const Mat<T>&)
const override;
60 :
DenseMat<
T>(in_rows, in_cols, in_rows * in_cols) {}
70 suanpan_for(0llu, this->n_rows, [&](
const uword I) { at(I,
K) = 0.; });
71 suanpan_for(0llu, this->n_cols, [&](
const uword I) { at(
K, I) = 0.; });
73 this->factored =
false;
76 template<sp_d T>
const T&
FullMat<T>::operator()(
const uword in_row,
const uword in_col)
const {
return this->memory[in_row + in_col * this->n_rows]; }
79 this->factored =
false;
80 return access::rw(this->memory[in_row + in_col * this->n_rows]);
84 Mat<T> C(arma::size(B));
86 const auto M =
static_cast<int>(this->n_rows);
87 const auto N =
static_cast<int>(this->n_cols);
89 T ALPHA = 1., BETA = 0.;
92 constexpr
auto INCX = 1, INCY = 1;
94 if(std::is_same_v<T, float>) {
96 arma_fortran(arma_sgemv)(&TRAN, &
M, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &INCX, (
E*)&BETA, (
E*)C.memptr(), &INCY);
98 else if(std::is_same_v<T, double>) {
100 arma_fortran(arma_dgemv)(&TRAN, &
M, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &INCX, (
E*)&BETA, (
E*)C.memptr(), &INCY);
104 const auto K =
static_cast<int>(B.n_cols);
106 if(std::is_same_v<T, float>) {
108 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &
M, &
K, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &
N, (
E*)&BETA, (
E*)C.memptr(), &
M);
110 else if(std::is_same_v<T, double>) {
112 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &
M, &
K, &
N, (
E*)&ALPHA, (
E*)this->memptr(), &
M, (
E*)B.memptr(), &
N, (
E*)&BETA, (
E*)C.memptr(), &
M);
120 if(this->factored)
return this->solve_trs(X, B);
122 auto N =
static_cast<int>(this->n_rows);
123 const auto NRHS =
static_cast<int>(B.n_cols);
124 const auto LDB =
static_cast<int>(B.n_rows);
126 this->pivot.zeros(
N);
127 this->factored =
true;
129 if(std::is_same_v<T, float>) {
132 arma_fortran(arma_sgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
137 arma_fortran(arma_dgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
140 this->s_memory = this->to_float();
141 arma_fortran(arma_sgetrf)(&
N, &
N, this->s_memory.memptr(), &
N, this->pivot.memptr(), &INFO);
142 if(0 == INFO) INFO = this->solve_trs(X, B);
145 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
151 const auto N =
static_cast<int>(this->n_rows);
152 const auto NRHS =
static_cast<int>(B.n_cols);
153 const auto LDB =
static_cast<int>(B.n_rows);
156 if(std::is_same_v<T, float>) {
159 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
164 arma_fortran(arma_dgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)X.memptr(), &LDB, &INFO);
167 X = arma::zeros(B.n_rows, B.n_cols);
169 mat full_residual = B;
171 auto multiplier =
norm(full_residual);
174 while(counter++ < this->setting.iterative_refinement) {
175 if(multiplier < this->setting.tolerance)
break;
177 auto residual = conv_to<fmat>::from(full_residual / multiplier);
179 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->s_memory.memptr(), &
N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
182 const mat incre = multiplier * conv_to<mat>::from(residual);
186 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
arma::norm(full_residual -= this->
operator*(incre)));
194 if(this->factored)
return this->solve_trs(X, std::forward<Mat<T>>(B));
196 auto N =
static_cast<int>(this->n_rows);
197 const auto NRHS =
static_cast<int>(B.n_cols);
198 const auto LDB =
static_cast<int>(B.n_rows);
201 this->pivot.zeros(
N);
203 this->factored =
true;
205 if(std::is_same_v<T, float>) {
207 arma_fortran(arma_sgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
212 arma_fortran(arma_dgesv)(&
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
216 this->s_memory = this->to_float();
217 arma_fortran(arma_sgetrf)(&
N, &
N, this->s_memory.memptr(), &
N, this->pivot.memptr(), &INFO);
218 if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
221 if(0 != INFO)
suanpan_error(
"solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
227 const auto N =
static_cast<int>(this->n_rows);
228 const auto NRHS =
static_cast<int>(B.n_cols);
229 const auto LDB =
static_cast<int>(B.n_rows);
232 if(std::is_same_v<T, float>) {
234 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
239 arma_fortran(arma_dgetrs)(&TRAN, &
N, &NRHS, (
E*)this->memptr(), &
N, this->pivot.memptr(), (
E*)B.memptr(), &LDB, &INFO);
243 X = arma::zeros(B.n_rows, B.n_cols);
248 while(counter++ < this->setting.iterative_refinement) {
249 if(multiplier < this->setting.tolerance)
break;
251 auto residual = conv_to<fmat>::from(B / multiplier);
253 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->s_memory.memptr(), &
N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
256 const mat incre = multiplier * conv_to<mat>::from(residual);
260 suanpan_debug(
"mixed precision algorithm multiplier: %.5E.\n", multiplier =
arma::norm(B -= this->
operator*(incre)));
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24