suanPan
FullMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef FULLMAT_HPP
31 #define FULLMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class FullMat : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  int solve_trs(Mat<T>&, Mat<T>&&);
39  int solve_trs(Mat<T>&, const Mat<T>&);
40 
41 public:
42  FullMat(uword, uword);
43 
44  unique_ptr<MetaMat<T>> make_copy() override;
45 
46  void unify(uword) override;
47  void nullify(uword) override;
48 
49  const T& operator()(uword, uword) const override;
50 
51  T& at(uword, uword) override;
52 
53  Mat<T> operator*(const Mat<T>&) const override;
54 
55  int direct_solve(Mat<T>&, Mat<T>&&) override;
56  int direct_solve(Mat<T>&, const Mat<T>&) override;
57 };
58 
59 template<sp_d T> FullMat<T>::FullMat(const uword in_rows, const uword in_cols)
60  : DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
61 
62 template<sp_d T> unique_ptr<MetaMat<T>> FullMat<T>::make_copy() { return std::make_unique<FullMat<T>>(*this); }
63 
64 template<sp_d T> void FullMat<T>::unify(const uword K) {
65  nullify(K);
66  at(K, K) = 1.;
67 }
68 
69 template<sp_d T> void FullMat<T>::nullify(const uword K) {
70  suanpan_for(0llu, this->n_rows, [&](const uword I) { at(I, K) = 0.; });
71  suanpan_for(0llu, this->n_cols, [&](const uword I) { at(K, I) = 0.; });
72 
73  this->factored = false;
74 }
75 
76 template<sp_d T> const T& FullMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_row + in_col * this->n_rows]; }
77 
78 template<sp_d T> T& FullMat<T>::at(const uword in_row, const uword in_col) {
79  this->factored = false;
80  return access::rw(this->memory[in_row + in_col * this->n_rows]);
81 }
82 
83 template<sp_d T> Mat<T> FullMat<T>::operator*(const Mat<T>& B) const {
84  Mat<T> C(arma::size(B));
85 
86  const auto M = static_cast<int>(this->n_rows);
87  const auto N = static_cast<int>(this->n_cols);
88 
89  T ALPHA = 1., BETA = 0.;
90 
91  if(1 == B.n_cols) {
92  constexpr auto INCX = 1, INCY = 1;
93 
94  if(std::is_same_v<T, float>) {
95  using E = float;
96  arma_fortran(arma_sgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
97  }
98  else if(std::is_same_v<T, double>) {
99  using E = double;
100  arma_fortran(arma_dgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
101  }
102  }
103  else {
104  const auto K = static_cast<int>(B.n_cols);
105 
106  if(std::is_same_v<T, float>) {
107  using E = float;
108  arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
109  }
110  else if(std::is_same_v<T, double>) {
111  using E = double;
112  arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
113  }
114  }
115 
116  return C;
117 }
118 
119 template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
120  if(this->factored) return this->solve_trs(X, B);
121 
122  auto N = static_cast<int>(this->n_rows);
123  const auto NRHS = static_cast<int>(B.n_cols);
124  const auto LDB = static_cast<int>(B.n_rows);
125  auto INFO = 0;
126  this->pivot.zeros(N);
127  this->factored = true;
128 
129  if(std::is_same_v<T, float>) {
130  using E = float;
131  X = B;
132  arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
133  }
134  else if(Precision::FULL == this->setting.precision) {
135  using E = double;
136  X = B;
137  arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
138  }
139  else {
140  this->s_memory = this->to_float();
141  arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
142  if(0 == INFO) INFO = this->solve_trs(X, B);
143  }
144 
145  if(0 != INFO)
146  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
147 
148  return INFO;
149 }
150 
151 template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
152  const auto N = static_cast<int>(this->n_rows);
153  const auto NRHS = static_cast<int>(B.n_cols);
154  const auto LDB = static_cast<int>(B.n_rows);
155  auto INFO = 0;
156 
157  if(std::is_same_v<T, float>) {
158  using E = float;
159  X = B;
160  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
161  }
162  else if(Precision::FULL == this->setting.precision) {
163  using E = double;
164  X = B;
165  arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
166  }
167  else {
168  X = arma::zeros(B.n_rows, B.n_cols);
169 
170  mat full_residual = B;
171 
172  auto multiplier = norm(full_residual);
173 
174  auto counter = 0u;
175  while(counter++ < this->setting.iterative_refinement) {
176  if(multiplier < this->setting.tolerance) break;
177 
178  auto residual = conv_to<fmat>::from(full_residual / multiplier);
179 
180  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
181  if(0 != INFO) break;
182 
183  const mat incre = multiplier * conv_to<mat>::from(residual);
184 
185  X += incre;
186 
187  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
188  }
189  }
190 
191  return INFO;
192 }
193 
194 template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
195  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
196 
197  auto N = static_cast<int>(this->n_rows);
198  const auto NRHS = static_cast<int>(B.n_cols);
199  const auto LDB = static_cast<int>(B.n_rows);
200  auto INFO = 0;
201 
202  this->pivot.zeros(N);
203 
204  this->factored = true;
205 
206  if(std::is_same_v<T, float>) {
207  using E = float;
208  arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
209  X = std::move(B);
210  }
211  else if(Precision::FULL == this->setting.precision) {
212  using E = double;
213  arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
214  X = std::move(B);
215  }
216  else {
217  this->s_memory = this->to_float();
218  arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
219  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
220  }
221 
222  if(0 != INFO)
223  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
224 
225  return INFO;
226 }
227 
228 template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
229  const auto N = static_cast<int>(this->n_rows);
230  const auto NRHS = static_cast<int>(B.n_cols);
231  const auto LDB = static_cast<int>(B.n_rows);
232  auto INFO = 0;
233 
234  if(std::is_same_v<T, float>) {
235  using E = float;
236  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
237  X = std::move(B);
238  }
239  else if(Precision::FULL == this->setting.precision) {
240  using E = double;
241  arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
242  X = std::move(B);
243  }
244  else {
245  X = arma::zeros(B.n_rows, B.n_cols);
246 
247  auto multiplier = arma::norm(B);
248 
249  auto counter = 0u;
250  while(counter++ < this->setting.iterative_refinement) {
251  if(multiplier < this->setting.tolerance) break;
252 
253  auto residual = conv_to<fmat>::from(B / multiplier);
254 
255  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
256  if(0 != INFO) break;
257 
258  const mat incre = multiplier * conv_to<mat>::from(residual);
259 
260  X += incre;
261 
262  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
263  }
264  }
265 
266  return INFO;
267 }
268 
269 #endif
270 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: FullMat.hpp:194
T & at(uword, uword) override
Access element with bound check.
Definition: FullMat.hpp:78
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: FullMat.hpp:76
void unify(uword) override
Definition: FullMat.hpp:64
void nullify(uword) override
Definition: FullMat.hpp:69
FullMat(uword, uword)
Definition: FullMat.hpp:59
Mat< T > operator*(const Mat< T > &) const override
Definition: FullMat.hpp:83
unique_ptr< MetaMat< T > > make_copy() override
Definition: FullMat.hpp:62
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27