suanPan
FullMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef FULLMAT_HPP
31 #define FULLMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class FullMat : public DenseMat<T> {
36  static constexpr char TRAN = 'N';
37 
38  int solve_trs(Mat<T>&, Mat<T>&&);
39  int solve_trs(Mat<T>&, const Mat<T>&);
40 
41 public:
42  FullMat(uword, uword);
43 
44  unique_ptr<MetaMat<T>> make_copy() override;
45 
46  void unify(uword) override;
47  void nullify(uword) override;
48 
49  const T& operator()(uword, uword) const override;
50 
51  T& at(uword, uword) override;
52 
53  Mat<T> operator*(const Mat<T>&) const override;
54 
55  int direct_solve(Mat<T>&, Mat<T>&&) override;
56  int direct_solve(Mat<T>&, const Mat<T>&) override;
57 };
58 
59 template<sp_d T> FullMat<T>::FullMat(const uword in_rows, const uword in_cols)
60  : DenseMat<T>(in_rows, in_cols, in_rows * in_cols) {}
61 
62 template<sp_d T> unique_ptr<MetaMat<T>> FullMat<T>::make_copy() { return std::make_unique<FullMat<T>>(*this); }
63 
64 template<sp_d T> void FullMat<T>::unify(const uword K) {
65  nullify(K);
66  at(K, K) = 1.;
67 }
68 
69 template<sp_d T> void FullMat<T>::nullify(const uword K) {
70  suanpan_for(0llu, this->n_rows, [&](const uword I) { at(I, K) = 0.; });
71  suanpan_for(0llu, this->n_cols, [&](const uword I) { at(K, I) = 0.; });
72 
73  this->factored = false;
74 }
75 
76 template<sp_d T> const T& FullMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_row + in_col * this->n_rows]; }
77 
78 template<sp_d T> T& FullMat<T>::at(const uword in_row, const uword in_col) {
79  this->factored = false;
80  return access::rw(this->memory[in_row + in_col * this->n_rows]);
81 }
82 
83 template<sp_d T> Mat<T> FullMat<T>::operator*(const Mat<T>& B) const {
84  Mat<T> C(arma::size(B));
85 
86  const auto M = static_cast<int>(this->n_rows);
87  const auto N = static_cast<int>(this->n_cols);
88 
89  T ALPHA = 1., BETA = 0.;
90 
91  if(1 == B.n_cols) {
92  constexpr auto INCX = 1, INCY = 1;
93 
94  if(std::is_same_v<T, float>) {
95  using E = float;
96  arma_fortran(arma_sgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
97  }
98  else if(std::is_same_v<T, double>) {
99  using E = double;
100  arma_fortran(arma_dgemv)(&TRAN, &M, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &INCX, (E*)&BETA, (E*)C.memptr(), &INCY);
101  }
102  }
103  else {
104  const auto K = static_cast<int>(B.n_cols);
105 
106  if(std::is_same_v<T, float>) {
107  using E = float;
108  arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
109  }
110  else if(std::is_same_v<T, double>) {
111  using E = double;
112  arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, (E*)&ALPHA, (E*)this->memptr(), &M, (E*)B.memptr(), &N, (E*)&BETA, (E*)C.memptr(), &M);
113  }
114  }
115 
116  return C;
117 }
118 
119 template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
120  if(this->factored) return this->solve_trs(X, B);
121 
122  auto N = static_cast<int>(this->n_rows);
123  const auto NRHS = static_cast<int>(B.n_cols);
124  const auto LDB = static_cast<int>(B.n_rows);
125  auto INFO = 0;
126  this->pivot.zeros(N);
127  this->factored = true;
128 
129  if(std::is_same_v<T, float>) {
130  using E = float;
131  X = B;
132  arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
133  }
134  else if(Precision::FULL == this->setting.precision) {
135  using E = double;
136  X = B;
137  arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
138  }
139  else {
140  this->s_memory = this->to_float();
141  arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
142  if(0 == INFO) INFO = this->solve_trs(X, B);
143  }
144 
145  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
146 
147  return INFO;
148 }
149 
150 template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
151  const auto N = static_cast<int>(this->n_rows);
152  const auto NRHS = static_cast<int>(B.n_cols);
153  const auto LDB = static_cast<int>(B.n_rows);
154  auto INFO = 0;
155 
156  if(std::is_same_v<T, float>) {
157  using E = float;
158  X = B;
159  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
160  }
161  else if(Precision::FULL == this->setting.precision) {
162  using E = double;
163  X = B;
164  arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)X.memptr(), &LDB, &INFO);
165  }
166  else {
167  X = arma::zeros(B.n_rows, B.n_cols);
168 
169  mat full_residual = B;
170 
171  auto multiplier = norm(full_residual);
172 
173  auto counter = 0u;
174  while(counter++ < this->setting.iterative_refinement) {
175  if(multiplier < this->setting.tolerance) break;
176 
177  auto residual = conv_to<fmat>::from(full_residual / multiplier);
178 
179  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
180  if(0 != INFO) break;
181 
182  const mat incre = multiplier * conv_to<mat>::from(residual);
183 
184  X += incre;
185 
186  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
187  }
188  }
189 
190  return INFO;
191 }
192 
193 template<sp_d T> int FullMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
194  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
195 
196  auto N = static_cast<int>(this->n_rows);
197  const auto NRHS = static_cast<int>(B.n_cols);
198  const auto LDB = static_cast<int>(B.n_rows);
199  auto INFO = 0;
200 
201  this->pivot.zeros(N);
202 
203  this->factored = true;
204 
205  if(std::is_same_v<T, float>) {
206  using E = float;
207  arma_fortran(arma_sgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
208  X = std::move(B);
209  }
210  else if(Precision::FULL == this->setting.precision) {
211  using E = double;
212  arma_fortran(arma_dgesv)(&N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
213  X = std::move(B);
214  }
215  else {
216  this->s_memory = this->to_float();
217  arma_fortran(arma_sgetrf)(&N, &N, this->s_memory.memptr(), &N, this->pivot.memptr(), &INFO);
218  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
219  }
220 
221  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
222 
223  return INFO;
224 }
225 
226 template<sp_d T> int FullMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
227  const auto N = static_cast<int>(this->n_rows);
228  const auto NRHS = static_cast<int>(B.n_cols);
229  const auto LDB = static_cast<int>(B.n_rows);
230  auto INFO = 0;
231 
232  if(std::is_same_v<T, float>) {
233  using E = float;
234  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
235  X = std::move(B);
236  }
237  else if(Precision::FULL == this->setting.precision) {
238  using E = double;
239  arma_fortran(arma_dgetrs)(&TRAN, &N, &NRHS, (E*)this->memptr(), &N, this->pivot.memptr(), (E*)B.memptr(), &LDB, &INFO);
240  X = std::move(B);
241  }
242  else {
243  X = arma::zeros(B.n_rows, B.n_cols);
244 
245  auto multiplier = arma::norm(B);
246 
247  auto counter = 0u;
248  while(counter++ < this->setting.iterative_refinement) {
249  if(multiplier < this->setting.tolerance) break;
250 
251  auto residual = conv_to<fmat>::from(B / multiplier);
252 
253  arma_fortran(arma_sgetrs)(&TRAN, &N, &NRHS, this->s_memory.memptr(), &N, this->pivot.memptr(), residual.memptr(), &LDB, &INFO);
254  if(0 != INFO) break;
255 
256  const mat incre = multiplier * conv_to<mat>::from(residual);
257 
258  X += incre;
259 
260  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = arma::norm(B -= this->operator*(incre)));
261  }
262  }
263 
264  return INFO;
265 }
266 
267 #endif
268 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: FullMat.hpp:193
T & at(uword, uword) override
Definition: FullMat.hpp:78
const T & operator()(uword, uword) const override
Definition: FullMat.hpp:76
void unify(uword) override
Definition: FullMat.hpp:64
void nullify(uword) override
Definition: FullMat.hpp:69
FullMat(uword, uword)
Definition: FullMat.hpp:59
Mat< T > operator*(const Mat< T > &) const override
Definition: FullMat.hpp:83
unique_ptr< MetaMat< T > > make_copy() override
Definition: FullMat.hpp:62
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24