suanPan
FullMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef FULLMAT_HPP
31 #define FULLMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class FullMat : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  int solve_trs(Mat<T>&, Mat<T>&&);
39 
40 protected:
42 
43  int direct_solve(Mat<T>&, Mat<T>&&) override;
44 
45 public:
46  FullMat(const uword in_rows, const uword in_cols)
47  : DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
48 
49  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<FullMat>(*this); }
50 
51  void nullify(const uword K) override {
52  this->factored = false;
53  suanpan::for_each(this->n_rows, [&](const uword I) { at(I, K) = T(0); });
54  suanpan::for_each(this->n_cols, [&](const uword I) { at(K, I) = T(0); });
55  }
56 
57  T operator()(const uword in_row, const uword in_col) const override { return this->memory[in_row + in_col * this->n_rows]; }
58 
59  T& at(const uword in_row, const uword in_col) override {
60  this->factored = false;
61  return this->memory[in_row + in_col * this->n_rows];
62  }
63 
64  Mat<T> operator*(const Mat<T>&) const override;
65 };
66 
67 template<sp_d T> Mat<T> FullMat<T>::operator*(const Mat<T>& B) const {
68  Mat<T> C(arma::size(B));
69 
70  const auto M = static_cast<int>(this->n_rows);
71  const auto N = static_cast<int>(this->n_cols);
72 
73  T ALPHA = T(1), BETA = T(0);
74 
75  if(1 == B.n_cols) {
76  constexpr auto INC = 1;
77 
78  if constexpr(std::is_same_v<T, float>) {
79  using E = float;
80  arma_fortran(arma_sgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INC, (E*)&BETA, (E*)C.memptr(), &INC);
81  }
82  else {
83  using E = double;
84  arma_fortran(arma_dgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INC, (E*)&BETA, (E*)C.memptr(), &INC);
85  }
86  }
87  else {
88  const auto K = static_cast<int>(B.n_cols);
89 
90  if constexpr(std::is_same_v<T, float>) {
91  using E = float;
92  arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
93  }
94  else {
95  using E = double;
96  arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
97  }
98  }
99 
100  return C;
101 }
102 
103 template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
104  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
105 
106  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
107 
108  auto INFO = 0;
109 
110  auto N = static_cast<int>(this->n_rows);
111  const auto NRHS = static_cast<int>(B.n_cols);
112  const auto LDB = static_cast<int>(B.n_rows);
113  this->pivot.zeros(N);
114  this->factored = true;
115 
116  if constexpr(std::is_same_v<T, float>) {
117  using E = float;
118  arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
119  X = std::move(B);
120  }
121  else if(Precision::FULL == this->setting.precision) {
122  using E = double;
123  arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
124  X = std::move(B);
125  }
126  else {
127  this->s_memory = this->to_float();
128  arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
129  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
130  }
131 
132  if(0 != INFO)
133  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
134 
135  return INFO;
136 }
137 
138 template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
139  auto INFO = 0;
140 
141  const auto N = static_cast<int>(this->n_rows);
142  const auto NRHS = static_cast<int>(B.n_cols);
143  const auto LDB = static_cast<int>(B.n_rows);
144 
145  if constexpr(std::is_same_v<T, float>) {
146  using E = float;
147  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
148  X = std::move(B);
149  }
150  else if(Precision::FULL == this->setting.precision) {
151  using E = double;
152  arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
153  X = std::move(B);
154  }
155  else
156  this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
157  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
158  return INFO;
159  });
160 
161  return INFO;
162 }
163 
164 #endif
165 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
FullMat(const uword in_rows, const uword in_cols)
Definition: FullMat.hpp:46
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: FullMat.hpp:57
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: FullMat.hpp:59
void nullify(const uword K) override
Definition: FullMat.hpp:51
unique_ptr< MetaMat< T > > make_copy() override
Definition: FullMat.hpp:49
const uword n_cols
Definition: MetaMat.hpp:119
const uword n_rows
Definition: MetaMat.hpp:118
bool factored
Definition: MetaMat.hpp:74
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: FullMat.hpp:103
Mat< T > operator*(const Mat< T > &) const override
Definition: FullMat.hpp:67
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309