33#include "../DenseMat.hpp"
36 static constexpr auto TRAN =
'N';
38 int solve_trs(Mat<T>&, Mat<T>&&);
46 FullMat(
const uword in_rows,
const uword in_cols)
47 :
DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
49 unique_ptr<MetaMat<T>>
unique_copy()
override {
return std::make_unique<FullMat>(*
this); }
57 T
operator()(
const uword in_row,
const uword in_col)
const override {
return this->
memory[in_row + in_col * this->
n_rows]; }
59 T&
at(
const uword in_row,
const uword in_col)
override {
64 Mat<T>
operator*(
const Mat<T>&)
const override;
67 std::function<bool(uword)> neg_diag;
69 else neg_diag = [&](
const uword i) {
return this->
s_memory[i + i * this->
n_rows] < 0.f; };
72 for(
unsigned I = 0; I < this->
pivot.n_elem; ++I)
73 if(neg_diag(I) ^ (
static_cast<int>(I) + 1 != this->
pivot(I))) det_sign = -det_sign;
79 Mat<T> C(arma::size(B));
81 const auto M =
static_cast<blas_int
>(this->n_rows);
82 const auto N =
static_cast<blas_int
>(this->n_cols);
84 static constexpr T ALPHA{1}, BETA{0};
87 static constexpr blas_int INC = 1;
89 if constexpr(std::is_same_v<T, float>) arma_fortran(arma_sgemv)(&TRAN, &
M, &
N, &ALPHA, this->memptr(), &
M, B.memptr(), &INC, &BETA, C.memptr(), &INC);
90 else arma_fortran(arma_dgemv)(&TRAN, &
M, &
N, &ALPHA, this->memptr(), &
M, B.memptr(), &INC, &BETA, C.memptr(), &INC);
93 const auto K =
static_cast<blas_int
>(B.n_cols);
95 if constexpr(std::is_same_v<T, float>) arma_fortran(arma_sgemm)(&TRAN, &TRAN, &
M, &
K, &
N, &ALPHA, this->memptr(), &
M, B.memptr(), &
N, &BETA, C.memptr(), &
M);
96 else arma_fortran(arma_dgemm)(&TRAN, &TRAN, &
M, &
K, &
N, &ALPHA, this->memptr(), &
M, B.memptr(), &
N, &BETA, C.memptr(), &
M);
103 if(this->factored)
return this->solve_trs(X, std::move(B));
105 suanpan_assert([&] {
if(this->n_rows != this->n_cols)
throw std::invalid_argument(
"requires a square matrix"); });
109 auto N =
static_cast<blas_int
>(this->n_rows);
110 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
111 const auto LDB =
static_cast<blas_int
>(B.n_rows);
112 this->pivot.zeros(
N);
113 this->factored =
true;
115 if constexpr(std::is_same_v<T, float>) {
116 arma_fortran(arma_sgesv)(&
N, &NRHS, this->memptr(), &
N, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
120 arma_fortran(arma_dgesv)(&
N, &NRHS, this->memptr(), &
N, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
124 this->s_memory = this->to_float();
125 arma_fortran(arma_sgetrf)(&
N, &
N, this->s_memory.memptr(), &
N, this->pivot.memptr(), &INFO);
126 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
130 suanpan_error(
"Error code {} received, the matrix is probably singular.\n", INFO);
138 const auto N =
static_cast<blas_int
>(this->n_rows);
139 const auto NRHS =
static_cast<blas_int
>(B.n_cols);
140 const auto LDB =
static_cast<blas_int
>(B.n_rows);
142 if constexpr(std::is_same_v<T, float>) {
143 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->memptr(), &
N, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
147 arma_fortran(arma_dgetrs)(&TRAN, &
N, &NRHS, this->memptr(), &
N, this->pivot.memptr(), B.memptr(), &LDB, &INFO);
151 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
152 arma_fortran(arma_sgetrs)(&TRAN, &
N, &NRHS, this->s_memory.memptr(), &
N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
podarray< float > s_memory
Definition DenseMat.hpp:46
podarray< blas_int > pivot
Definition DenseMat.hpp:45
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
A FullMat class that holds matrices.
Definition FullMat.hpp:35
unique_ptr< MetaMat< T > > unique_copy() override
Definition FullMat.hpp:49
int sign_det() const override
Definition FullMat.hpp:66
FullMat(const uword in_rows, const uword in_cols)
Definition FullMat.hpp:46
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition FullMat.hpp:57
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition FullMat.hpp:59
void nullify(const uword K) override
Definition FullMat.hpp:51
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
Precision precision
Definition SolverSetting.hpp:32
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
#define suanpan_error(...)
Definition suanPan.h:349