suanPan
SymmPackMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef SYMMPACKMAT_HPP
31 #define SYMMPACKMAT_HPP
32 
33 #include "DenseMat.hpp"
34 
35 template<sp_d T> class SymmPackMat final : public DenseMat<T> {
36  static constexpr char UPLO = 'L';
37 
38  static T bin;
39 
40  const uword length; // 2n-1
41 
42  int solve_trs(Mat<T>&, Mat<T>&&);
43 
44 protected:
46 
47  int direct_solve(Mat<T>&, Mat<T>&&) override;
48 
49 public:
50  explicit SymmPackMat(const uword in_size)
51  : DenseMat<T>(in_size, in_size, (in_size + 1) * in_size / 2)
52  , length(2 * in_size - 1) {}
53 
54  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<SymmPackMat>(*this); }
55 
56  void nullify(const uword K) override {
57  this->factored = false;
58  suanpan::for_each(K, [&](const uword I) { this->memory[K + (length - I) * I / 2] = T(0); });
59  const auto t_factor = (length - K) * K / 2;
60  suanpan::for_each(K, this->n_rows, [&](const uword I) { this->memory[I + t_factor] = T(0); });
61  }
62 
63  T operator()(const uword in_row, const uword in_col) const override { return this->memory[in_row >= in_col ? in_row + (length - in_col) * in_col / 2 : in_col + (length - in_row) * in_row / 2]; }
64 
65  T& unsafe_at(const uword in_row, const uword in_col) override {
66  this->factored = false;
67  return this->memory[in_row + (length - in_col) * in_col / 2];
68  }
69 
70  T& at(const uword in_row, const uword in_col) override {
71  if(in_row < in_col) [[unlikely]] return bin = T(0);
72  return this->unsafe_at(in_row, in_col);
73  }
74 
75  Mat<T> operator*(const Mat<T>&) const override;
76 };
77 
78 template<sp_d T> T SymmPackMat<T>::bin = T(0);
79 
80 template<sp_d T> Mat<T> SymmPackMat<T>::operator*(const Mat<T>& X) const {
81  auto Y = Mat<T>(arma::size(X), fill::none);
82 
83  const auto N = static_cast<int>(this->n_rows);
84  constexpr auto INC = 1;
85  T ALPHA = T(1);
86  T BETA = T(0);
87 
88  if constexpr(std::is_same_v<T, float>) {
89  using E = float;
90  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
91  }
92  else {
93  using E = double;
94  suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dspmv)(&UPLO, &N, (E*)&ALPHA, (E*)this->memptr(), (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
95  }
96 
97  return Y;
98 }
99 
100 template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
101  if(this->factored) return this->solve_trs(X, std::forward<Mat<T>>(B));
102 
103  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
104 
105  auto INFO = 0;
106 
107  const auto N = static_cast<int>(this->n_rows);
108  const auto NRHS = static_cast<int>(B.n_cols);
109  const auto LDB = static_cast<int>(B.n_rows);
110  this->factored = true;
111 
112  if constexpr(std::is_same_v<T, float>) {
113  using E = float;
114  arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
115  X = std::move(B);
116  }
117  else if(Precision::FULL == this->setting.precision) {
118  using E = double;
119  arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
120  X = std::move(B);
121  }
122  else {
123  this->s_memory = this->to_float();
124  arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
125  if(0 == INFO) INFO = this->solve_trs(X, std::forward<Mat<T>>(B));
126  }
127 
128  if(0 != INFO)
129  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
130 
131  return INFO;
132 }
133 
134 template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
135  auto INFO = 0;
136 
137  const auto N = static_cast<int>(this->n_rows);
138  const auto NRHS = static_cast<int>(B.n_cols);
139  const auto LDB = static_cast<int>(B.n_rows);
140 
141  if constexpr(std::is_same_v<T, float>) {
142  using E = float;
143  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
144  X = std::move(B);
145  }
146  else if(Precision::FULL == this->setting.precision) {
147  using E = double;
148  arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, (E*)this->memptr(), (E*)B.memptr(), &LDB, &INFO);
149  X = std::move(B);
150  }
151  else
152  this->mixed_trs(X, std::forward<Mat<T>>(B), [&](fmat& residual) {
153  arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
154  return INFO;
155  });
156 
157  if(0 != INFO)
158  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
159 
160  return INFO;
161 }
162 
163 #endif
164 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
const uword n_rows
Definition: MetaMat.hpp:118
bool factored
Definition: MetaMat.hpp:74
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition: SymmPackMat.hpp:63
void nullify(const uword K) override
Definition: SymmPackMat.hpp:56
SymmPackMat(const uword in_size)
Definition: SymmPackMat.hpp:50
unique_ptr< MetaMat< T > > make_copy() override
Definition: SymmPackMat.hpp:54
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition: SymmPackMat.hpp:70
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition: SymmPackMat.hpp:65
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition: SymmPackMat.hpp:100
Mat< T > operator*(const Mat< T > &) const override
Definition: SymmPackMat.hpp:80
void for_each(const IT start, const IT end, F &&FN)
Definition: utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309