suanPan
Loading...
Searching...
No Matches
FullMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMAT_HPP
31#define FULLMAT_HPP
32
33#include "../DenseMat.hpp"
34
35template<sp_d T> class FullMat : public DenseMat<T> {
36 static constexpr char TRAN = 'N';
37
38 int solve_trs(Mat<T>&, Mat<T>&&);
39
40protected:
42
43 int direct_solve(Mat<T>&, Mat<T>&&) override;
44
45public:
46 FullMat(const uword in_rows, const uword in_cols)
47 : DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
48
49 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<FullMat>(*this); }
50
51 void nullify(const uword K) override {
52 this->factored = false;
53 suanpan::for_each(this->n_rows, [&](const uword I) { this->at(I, K) = T(0); });
54 suanpan::for_each(this->n_cols, [&](const uword I) { this->at(K, I) = T(0); });
55 }
56
57 T operator()(const uword in_row, const uword in_col) const override { return this->memory[in_row + in_col * this->n_rows]; }
58
59 T& at(const uword in_row, const uword in_col) override {
60 this->factored = false;
61 return this->memory[in_row + in_col * this->n_rows];
62 }
63
64 Mat<T> operator*(const Mat<T>&) const override;
65
66 [[nodiscard]] int sign_det() const override {
67 std::function<bool(uword)> neg_diag;
68 if(Precision::FULL == this->setting.precision) neg_diag = [&](const uword i) { return this->memory[i + i * this->n_rows] < 0.; };
69 else neg_diag = [&](const uword i) { return this->s_memory[i + i * this->n_rows] < 0.f; };
70
71 auto det_sign = 1;
72 for(unsigned I = 0; I < this->pivot.n_elem; ++I)
73 if(neg_diag(I) ^ (static_cast<int>(I) + 1 != this->pivot(I))) det_sign = -det_sign;
74 return det_sign;
75 }
76};
77
78template<sp_d T> Mat<T> FullMat<T>::operator*(const Mat<T>& B) const {
79 Mat<T> C(arma::size(B));
80
81 const auto M = static_cast<blas_int>(this->n_rows);
82 const auto N = static_cast<blas_int>(this->n_cols);
83
84 T ALPHA = T(1), BETA = T(0);
85
86 if(1 == B.n_cols) {
87 constexpr blas_int INC = 1;
88
89 if constexpr(std::is_same_v<T, float>) {
90 using E = float;
91 arma_fortran(arma_sgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INC, (E*)&BETA, (E*)C.memptr(), &INC);
92 }
93 else {
94 using E = double;
95 arma_fortran(arma_dgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INC, (E*)&BETA, (E*)C.memptr(), &INC);
96 }
97 }
98 else {
99 const auto K = static_cast<blas_int>(B.n_cols);
100
101 if constexpr(std::is_same_v<T, float>) {
102 using E = float;
103 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
104 }
105 else {
106 using E = double;
107 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
108 }
109 }
110
111 return C;
112}
113
114template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
115 if(this->factored) return this->solve_trs(X, std::move(B));
116
117 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
118
119 blas_int INFO = 0;
120
121 auto N = static_cast<blas_int>(this->n_rows);
122 const auto NRHS = static_cast<blas_int>(B.n_cols);
123 const auto LDB = static_cast<blas_int>(B.n_rows);
124 this->pivot.zeros(N);
125 this->factored = true;
126
127 if constexpr(std::is_same_v<T, float>) {
128 using E = float;
129 arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
130 X = std::move(B);
131 }
132 else if(Precision::FULL == this->setting.precision) {
133 using E = double;
134 arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
135 X = std::move(B);
136 }
137 else {
138 this->s_memory = this->to_float();
139 arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
140 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
141 }
142
143 if(0 != INFO)
144 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
145
146 return INFO;
147}
148
149template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
150 blas_int INFO = 0;
151
152 const auto N = static_cast<blas_int>(this->n_rows);
153 const auto NRHS = static_cast<blas_int>(B.n_cols);
154 const auto LDB = static_cast<blas_int>(B.n_rows);
155
156 if constexpr(std::is_same_v<T, float>) {
157 using E = float;
158 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
159 X = std::move(B);
160 }
161 else if(Precision::FULL == this->setting.precision) {
162 using E = double;
163 arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
164 X = std::move(B);
165 }
166 else
167 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
168 arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
169 return INFO;
170 });
171
172 return INFO;
173}
174
175#endif
176
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
podarray< float > s_memory
Definition DenseMat.hpp:46
podarray< blas_int > pivot
Definition DenseMat.hpp:45
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
A FullMat class that holds matrices.
Definition FullMat.hpp:35
int sign_det() const override
Definition FullMat.hpp:66
FullMat(const uword in_rows, const uword in_cols)
Definition FullMat.hpp:46
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition FullMat.hpp:57
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition FullMat.hpp:59
void nullify(const uword K) override
Definition FullMat.hpp:51
unique_ptr< MetaMat< T > > make_copy() override
Definition FullMat.hpp:49
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
SolverSetting< T > setting
Definition MetaMat.hpp:78
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition FullMat.hpp:114
Mat< T > operator*(const Mat< T > &) const override
Definition FullMat.hpp:78
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
Precision precision
Definition SolverSetting.hpp:32
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
#define suanpan_error(...)
Definition suanPan.h:376