suanPan
SymmPackMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef SYMMPACKMAT_HPP
31 #define SYMMPACKMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class SymmPackMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'L';
37 
38  static T bin;
39 
40  const uword length; // 2n-1
41 
42  int solve_trs(Mat<T>&, Mat<T>&&);
43  int solve_trs(Mat<T>&, const Mat<T>&);
44 
45 public:
46  explicit SymmPackMat(uword);
47 
48  unique_ptr<MetaMat<T>> make_copy() override;
49 
50  void unify(uword) override;
51  void nullify(uword) override;
52 
53  const T& operator()(uword, uword) const override;
54  T& unsafe_at(uword, uword) override;
55  T& at(uword, uword) override;
56 
57  Mat<T> operator*(const Mat<T>&) const override;
58 
59  int direct_solve(Mat<T>&, Mat<T>&&) override;
60  int direct_solve(Mat<T>&, const Mat<T>&) override;
61 };
62 
63 template<sp_d T> T SymmPackMat<T>::bin = 0.;
64 
65 template<sp_d T> SymmPackMat<T>::SymmPackMat(const uword in_size)
66  : DenseMat<T>(in_size, in_size, (in_size + 1) * in_size / 2)
67  , length(2 * in_size - 1) {}
68 
69 template<sp_d T> unique_ptr<MetaMat<T>> SymmPackMat<T>::make_copy() { return std::make_unique<SymmPackMat<T>>(*this); }
70 
71 template<sp_d T> void SymmPackMat<T>::unify(const uword K) {
72  nullify(K);
73  access::rw(this->memory[(length - K + 2) * K / 2]) = 1.;
74 }
75 
76 template<sp_d T> void SymmPackMat<T>::nullify(const uword K) {
77  suanpan_for(0llu, K, [&](const uword I) { access::rw(this->memory[K + (length - I) * I / 2]) = 0.; });
78  const auto t_factor = (length - K) * K / 2;
79  suanpan_for(K, this->n_rows, [&](const uword I) { access::rw(this->memory[I + t_factor]) = 0.; });
80 
81  this->factored = false;
82 }
83 
84 template<sp_d T> const T& SymmPackMat<T>::operator()(const uword in_row, const uword in_col) const { return this->memory[in_row >= in_col ? in_row + (length - in_col) * in_col / 2 : in_col + (length - in_row) * in_row / 2]; }
85 
86 template<sp_d T> T& SymmPackMat<T>::unsafe_at(const uword in_row, const uword in_col) {
87  this->factored = false;
88  return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
89 }
90 
91 template<sp_d T> T& SymmPackMat<T>::at(const uword in_row, const uword in_col) {
92  if(in_row < in_col) [[unlikely]] return bin;
93  this->factored = false;
94  return access::rw(this->memory[in_row + (length - in_col) * in_col / 2]);
95 }
96 
97 template<sp_d T> Mat<T> SymmPackMat<T>::operator*(const Mat<T>& X) const {
98  auto Y = Mat<T>(arma::size(X), fill::none);
99 
100  const auto N = static_cast<int>(this->n_rows);
101  const auto INC = 1;
102  T ALPHA = 1.;
103  T BETA = 0.;
104 
105  if(std::is_same_v<T, float>) {
106  using E = float;
107  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_sspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
108  }
109  else if(std::is_same_v<T, double>) {
110  using E = double;
111  suanpan_for(0llu, X.n_cols, [&](const uword I) { arma_fortran(arma_dspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
112  }
113 
114  return Y;
115 }
116 
117 template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
118  if(this->factored) return this->solve_trs(X, B);
119 
120  const auto N = static_cast<int>(this->n_rows);
121  const auto NRHS = static_cast<int>(B.n_cols);
122  const auto LDB = static_cast<int>(B.n_rows);
123  auto INFO = 0;
124 
125  this->factored = true;
126 
127  if(std::is_same_v<T, float>) {
128  using E = float;
129  X = B;
130  arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
131  }
132  else if(Precision::FULL == this->setting.precision) {
133  using E = double;
134  X = B;
135  arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
136  }
137  else {
138  this->s_memory = this->to_float();
139  arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
140  if(0 == INFO) INFO = this->solve_trs(X, B);
141  }
142 
143  if(0 != INFO)
144  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
145 
146  return INFO;
147 }
148 
149 template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, const Mat<T>& B) {
150  const auto N = static_cast<int>(this->n_rows);
151  const auto NRHS = static_cast<int>(B.n_cols);
152  const auto LDB = static_cast<int>(B.n_rows);
153  auto INFO = 0;
154 
155  if(std::is_same_v<T, float>) {
156  using E = float;
157  X = B;
158  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
159  }
160  else if(Precision::FULL == this->setting.precision) {
161  using E = double;
162  X = B;
163  arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)X.memptr(), &LDB, &INFO);
164  }
165  else {
166  X = arma::zeros(B.n_rows, B.n_cols);
167 
168  mat full_residual = B;
169 
170  auto multiplier = norm(full_residual);
171 
172  auto counter = 0u;
173  while(counter++ < this->setting.iterative_refinement) {
174  if(multiplier < this->setting.tolerance) break;
175 
176  auto residual = conv_to<fmat>::from(full_residual / multiplier);
177 
178  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
179  if(0 != INFO) break;
180 
181  const mat incre = multiplier * conv_to<mat>::from(residual);
182 
183  X += incre;
184 
185  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
186  }
187  }
188 
189  if(0 != INFO)
190  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
191 
192  return INFO;
193 }
194 
195 template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
196  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
197 
198  const auto N = static_cast<int>(this->n_rows);
199  const auto NRHS = static_cast<int>(B.n_cols);
200  const auto LDB = static_cast<int>(B.n_rows);
201  auto INFO = 0;
202 
203  this->factored = true;
204 
205  if(std::is_same_v<T, float>) {
206  using E = float;
207  arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
208  X = std::move(B);
209  }
210  else if(Precision::FULL == this->setting.precision) {
211  using E = double;
212  arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
213  X = std::move(B);
214  }
215  else {
216  this->s_memory = this->to_float();
217  arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
218  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
219  }
220 
221  if(0 != INFO)
222  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
223 
224  return INFO;
225 }
226 
227 template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
228  const auto N = static_cast<int>(this->n_rows);
229  const auto NRHS = static_cast<int>(B.n_cols);
230  const auto LDB = static_cast<int>(B.n_rows);
231  auto INFO = 0;
232 
233  if(std::is_same_v<T, float>) {
234  using E = float;
235  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
236  X = std::move(B);
237  }
238  else if(Precision::FULL == this->setting.precision) {
239  using E = double;
240  arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
241  X = std::move(B);
242  }
243  else {
244  X = arma::zeros(B.n_rows, B.n_cols);
245 
246  auto multiplier = arma::norm(B);
247 
248  auto counter = 0u;
249  while(counter++ < this->setting.iterative_refinement) {
250  if(multiplier < this->setting.tolerance) break;
251 
252  auto residual = conv_to<fmat>::from(B / multiplier);
253 
254  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
255  if(0 != INFO) break;
256 
257  const mat incre = multiplier * conv_to<mat>::from(residual);
258 
259  X += incre;
260 
261  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(B -= this->operator*(incre)));
262  }
263  }
264 
265  if(0 != INFO)
266  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
267 
268  return INFO;
269 }
270 
271 #endif
272 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:34
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: SymmPackMat.hpp:195
T & unsafe_at(uword, uword) override
Access element without bound check.
Definition: SymmPackMat.hpp:86
void nullify(uword) override
Definition: SymmPackMat.hpp:76
unique_ptr< MetaMat< T > > make_copy() override
Definition: SymmPackMat.hpp:69
Mat< T > operator*(const Mat< T > &) const override
Definition: SymmPackMat.hpp:97
void unify(uword) override
Definition: SymmPackMat.hpp:71
T & at(uword, uword) override
Access element with bound check.
Definition: SymmPackMat.hpp:91
SymmPackMat(uword)
Definition: SymmPackMat.hpp:65
const T & operator()(uword, uword) const override
Access element (read-only), returns zero if out-of-bound.
Definition: SymmPackMat.hpp:84
double norm(const vec &)
Definition: tensor.cpp:302
#define suanpan_debug(...)
Definition: suanPan.h:295
#define suanpan_error(...)
Definition: suanPan.h:297
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27