suanPan
SymmPackMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef SYMMPACKMAT_HPP
31 #define SYMMPACKMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class SymmPackMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'U';
37 
38  static T bin;
39 
40  int solve_trs(Mat<T>&, Mat<T>&&);
41  int solve_trs(Mat<T>&, const Mat<T>&);
42 
43 public:
44  explicit SymmPackMat(uword);
45 
46  unique_ptr<MetaMat<T>> make_copy() override;
47 
48  void unify(uword) override;
49  void nullify(uword) override;
50 
51  const T& operator()(uword, uword) const override;
52  T& at(uword, uword) override;
53 
54  Mat<T> operator*(const Mat<T>&) const override;
55 
56  int direct_solve(Mat<T>&, Mat<T>&&) override;
57  int direct_solve(Mat<T>&, const Mat<T>&) override;
58 };
59 
60 template<sp_d T> T SymmPackMat<T>::bin = 0.;
61 
62 template<sp_d T> SymmPackMat<T>::SymmPackMat(const uword in_size)
63  : DenseMat<T>(in_size, in_size, (in_size + 1) * in_size / 2) {}
64 
65 template<sp_d T> unique_ptr<MetaMat<T>> SymmPackMat<T>::make_copy() { return std::make_unique<SymmPackMat<T>>(*this); }
66 
67 template<sp_d T> void SymmPackMat<T>::unify(const uword K) {
68  nullify(K);
69  access::rw(this->memory[(K * K + 3 * K) / 2]) = 1.;
70 }
71 
72 template<sp_d T> void SymmPackMat<T>::nullify(const uword K) {
73  suanpan_for(0llu, K, [&](const uword I) { access::rw(this->memory[(K * K + K) / 2 + I]) = 0.; });
74  suanpan_for(K, this->n_rows, [&](const uword I) { access::rw(this->memory[(I * I + I) / 2 + K]) = 0.; });
75 
76  this->factored = false;
77 }
78 
79 template<sp_d T> const T& SymmPackMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_col > in_row ? (in_col * in_col + in_col) / 2 + in_row : (in_row * in_row + in_row) / 2 + in_col]; }
80 
81 template<sp_d T> T& SymmPackMat<T>::at(const uword in_row, const uword in_col) {
82  if(in_col < in_row) return bin;
83  this->factored = false;
84  return access::rw(this->memory[(in_col * in_col + in_col) / 2 + in_row]);
85 }
86 
87 template<const char S, const char T, sp_d T1> Mat<T1> spmm(const SymmPackMat<T1>& A, const Mat<T1>& B);
88 
89 template<sp_d T> Mat<T> SymmPackMat<T>::operator*(const Mat<T>& X) const {
90  if(!X.is_colvec()) return spmm<'R', 'N'>(*this, X);
91 
92  auto Y = X;
93 
94  const auto N = static_cast<int>(this->n_rows);
95  constexpr auto INC = 1;
96  T ALPHA = 1.;
97  T BETA = 0.;
98 
99  if(std::is_same_v<T, float>) {
100  using E = float;
101  arma_fortran(arma_sspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.memptr(), &INC, (E*)&BETA, (E*)Y.memptr(), &INC);
102  }
103  else if(std::is_same_v<T, double>) {
104  using E = double;
105  arma_fortran(arma_dspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.memptr(), &INC, (E*)&BETA, (E*)Y.memptr(), &INC);
106  }
107 
108  return Y;
109 }
110 
111 template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
112  if(this->factored) return this->solve_trs(X, B);
113 
114  const auto N = static_cast<int>(this->n_rows);
115  const auto NRHS = static_cast<int>(B.n_cols);
116  const auto LDB = static_cast<int>(B.n_rows);
117  auto INFO = 0;
118 
119  this->factored = true;
120 
121  if(std::is_same_v<T, float>) {
122  using E = float;
123  X = B;
124  arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
125  }
126  else if(Precision::FULL == this->setting.precision) {
127  using E = double;
128  X = B;
129  arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
130  }
131  else {
132  this->s_memory = this->to_float();
133  arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
134  if(0 == INFO) INFO = this->solve_trs(X, B);
135  }
136 
137  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
138 
139  return INFO;
140 }
141 
142 template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
143  const auto N = static_cast<int>(this->n_rows);
144  const auto NRHS = static_cast<int>(B.n_cols);
145  const auto LDB = static_cast<int>(B.n_rows);
146  auto INFO = 0;
147 
148  if(std::is_same_v<T, float>) {
149  using E = float;
150  X = B;
151  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
152  }
153  else if(Precision::FULL == this->setting.precision) {
154  using E = double;
155  X = B;
156  arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
157  }
158  else {
159  X = arma::zeros(B.n_rows, B.n_cols);
160 
161  mat full_residual = B;
162 
163  auto multiplier = norm(full_residual);
164 
165  auto counter = 0u;
166  while(counter++ < this->setting.iterative_refinement) {
167  if(multiplier < this->setting.tolerance) break;
168 
169  auto residual = conv_to<fmat>::from(full_residual / multiplier);
170 
171  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
172  if(0 != INFO) break;
173 
174  const mat incre = multiplier * conv_to<mat>::from(residual);
175 
176  X += incre;
177 
178  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = norm(full_residual -= this->operator*(incre)));
179  }
180  }
181 
182  if(INFO != 0) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
183 
184  return INFO;
185 }
186 
187 template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
188  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
189 
190  const auto N = static_cast<int>(this->n_rows);
191  const auto NRHS = static_cast<int>(B.n_cols);
192  const auto LDB = static_cast<int>(B.n_rows);
193  auto INFO = 0;
194 
195  this->factored = true;
196 
197  if(std::is_same_v<T, float>) {
198  using E = float;
199  arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
200  X = std::move(B);
201  }
202  else if(Precision::FULL == this->setting.precision) {
203  using E = double;
204  arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
205  X = std::move(B);
206  }
207  else {
208  this->s_memory = this->to_float();
209  arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
210  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
211  }
212 
213  if(0 != INFO) suanpan_error("solve() receives error code %u from the base driver, the matrix is probably singular.\n", INFO);
214 
215  return INFO;
216 }
217 
218 template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
219  const auto N = static_cast<int>(this->n_rows);
220  const auto NRHS = static_cast<int>(B.n_cols);
221  const auto LDB = static_cast<int>(B.n_rows);
222  auto INFO = 0;
223 
224  if(std::is_same_v<T, float>) {
225  using E = float;
226  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
227  X = std::move(B);
228  }
229  else if(Precision::FULL == this->setting.precision) {
230  using E = double;
231  arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
232  X = std::move(B);
233  }
234  else {
235  X = arma::zeros(B.n_rows, B.n_cols);
236 
237  auto multiplier = arma::norm(B);
238 
239  auto counter = 0u;
240  while(counter++ < this->setting.iterative_refinement) {
241  if(multiplier < this->setting.tolerance) break;
242 
243  auto residual = conv_to<fmat>::from(B / multiplier);
244 
245  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
246  if(0 != INFO) break;
247 
248  const mat incre = multiplier * conv_to<mat>::from(residual);
249 
250  X += incre;
251 
252  suanpan_debug("mixed precision algorithm multiplier: %.5E.\n", multiplier = norm(B -= this->operator*(incre)));
253  }
254  }
255 
256  if(INFO != 0) suanpan_error("solve() receives error code %u from base driver, the matrix is probably singular.\n", INFO);
257 
258  return INFO;
259 }
260 
261 #endif
262 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: SymmPackMat.hpp:187
void nullify(uword) override
Definition: SymmPackMat.hpp:72
unique_ptr< MetaMat< T > > make_copy() override
Definition: SymmPackMat.hpp:65
Mat< T > operator*(const Mat< T > &) const override
Definition: SymmPackMat.hpp:89
void unify(uword) override
Definition: SymmPackMat.hpp:67
T & at(uword, uword) override
Definition: SymmPackMat.hpp:81
Mat< T1 > spmm(const SymmPackMat< T1 > &A, const Mat< T1 > &B)
Definition: operator_times.hpp:155
SymmPackMat(uword)
Definition: SymmPackMat.hpp:62
const T & operator()(uword, uword) const override
Definition: SymmPackMat.hpp:79
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
void suanpan_error(const char *M,...)
Definition: print.cpp:116
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24