suanPan
DenseMat.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 #ifndef DENSEMAT_HPP
30 #define DENSEMAT_HPP
31 
32 #include "MetaMat.hpp"
33 
34 template<sp_d T> uword round_up(const uword in_size) {
35  constexpr auto multiple = 64llu / sizeof(T);
36  return (in_size + multiple - 1llu) / multiple * multiple;
37 }
38 
39 template<sp_d T> class DenseMat : public MetaMat<T> {
40 protected:
42 
43  int direct_solve(Mat<T>& X, const Mat<T>& B) override { return this->direct_solve(X, Mat<T>(B)); }
44 
45  podarray<int> pivot;
46  podarray<float> s_memory; // float storage used in mixed precision algorithm
47 
48  std::unique_ptr<T[]> memory = nullptr;
49 
50  podarray<float> to_float() {
51  podarray<float> f_memory(this->n_elem);
52  suanpan_for(0llu, this->n_elem, [&](const uword I) { f_memory(I) = static_cast<float>(memory[I]); });
53  return f_memory;
54  }
55 
56 public:
57  DenseMat(const uword in_rows, const uword in_cols, const uword in_elem)
58  : MetaMat<T>(in_rows, in_cols, in_elem)
59  , memory(std::unique_ptr<T[]>(new T[this->n_elem])) { DenseMat::zeros(); }
60 
61  DenseMat(const DenseMat& old_mat)
62  : MetaMat<T>(old_mat)
63  , pivot(old_mat.pivot)
64  , s_memory(old_mat.s_memory)
65  , memory(std::unique_ptr<T[]>(new T[this->n_elem])) { suanpan_for(0llu, this->n_elem, [&](const uword I) { memory[I] = old_mat.memory[I]; }); }
66 
67  DenseMat(DenseMat&&) noexcept = delete;
68  DenseMat& operator=(const DenseMat&) = delete;
69  DenseMat& operator=(DenseMat&&) noexcept = delete;
70  ~DenseMat() override = default;
71 
72  [[nodiscard]] bool is_empty() const override { return 0 == this->n_elem; }
73 
74  void zeros() override {
75  this->factored = false;
76  arrayops::fill_zeros(memptr(), this->n_elem);
77  }
78 
79  [[nodiscard]] T max() const override {
80  T max_value = T(1);
81  for(uword I = 0; I < std::min(this->n_rows, this->n_cols); ++I) if(const auto t_val = this->operator()(I, I); t_val > max_value) max_value = t_val;
82  return max_value;
83  }
84 
85  [[nodiscard]] Col<T> diag() const override {
86  Col<T> diag_vec(std::min(this->n_rows, this->n_cols), fill::none);
87  suanpan_for(0llu, diag_vec.n_elem, [&](const uword I) { diag_vec(I) = this->operator()(I, I); });
88  return diag_vec;
89  }
90 
91  [[nodiscard]] const T* memptr() const override { return memory.get(); }
92 
93  T* memptr() override { return memory.get(); }
94 
95  void operator+=(const shared_ptr<MetaMat<T>>& M) override {
96  if(nullptr == M) return;
97  if(!M->triplet_mat.is_empty()) return this->operator+=(M->triplet_mat);
98  if(this->n_rows != M->n_rows || this->n_cols != M->n_cols || this->n_elem != M->n_elem) throw invalid_argument("size mismatch");
99  if(nullptr == M->memptr()) return;
100  this->factored = false;
101  arrayops::inplace_plus(memptr(), M->memptr(), this->n_elem);
102  }
103 
104  void operator-=(const shared_ptr<MetaMat<T>>& M) override {
105  if(nullptr == M) return;
106  if(!M->triplet_mat.is_empty()) return this->operator-=(M->triplet_mat);
107  if(this->n_rows != M->n_rows || this->n_cols != M->n_cols || this->n_elem != M->n_elem) throw invalid_argument("size mismatch");
108  if(nullptr == M->memptr()) return;
109  this->factored = false;
110  arrayops::inplace_minus(memptr(), M->memptr(), this->n_elem);
111  }
112 
113  void operator+=(const triplet_form<T, uword>& M) override {
114  if(this->n_rows != M.n_rows || this->n_cols != M.n_cols) throw invalid_argument("size mismatch");
115  this->factored = false;
116  const auto row = M.row_mem();
117  const auto col = M.col_mem();
118  const auto val = M.val_mem();
119  for(uword I = 0llu; I < M.n_elem; ++I) this->at(row[I], col[I]) += val[I];
120  }
121 
122  void operator-=(const triplet_form<T, uword>& M) override {
123  if(this->n_rows != M.n_rows || this->n_cols != M.n_cols) throw invalid_argument("size mismatch");
124  this->factored = false;
125  const auto row = M.row_mem();
126  const auto col = M.col_mem();
127  const auto val = M.val_mem();
128  for(uword I = 0llu; I < M.n_elem; ++I) this->at(row[I], col[I]) -= val[I];
129  }
130 
131  void operator*=(const T value) override {
132  this->factored = false;
133  arrayops::inplace_mul(memptr(), value, this->n_elem);
134  }
135 
136  [[nodiscard]] int sign_det() const override {
137  if(IterativeSolver::NONE != this->setting.iterative_solver) throw invalid_argument("analysis requires the sign of determinant but iterative solver does not support it");
138  auto det_sign = 1;
139  for(unsigned I = 0; I < pivot.n_elem; ++I) if((this->operator()(I, I) < T(0)) ^ (static_cast<int>(I) + 1 != pivot(I))) det_sign = -det_sign;
140  return det_sign;
141  }
142 };
143 
144 #endif
145 
A DenseMat class that holds matrices.
Definition: DenseMat.hpp:39
T * memptr() override
Definition: DenseMat.hpp:93
podarray< float > to_float()
Definition: DenseMat.hpp:50
const T * memptr() const override
Definition: DenseMat.hpp:91
podarray< int > pivot
Definition: DenseMat.hpp:45
podarray< float > s_memory
Definition: DenseMat.hpp:46
void zeros() override
Definition: DenseMat.hpp:74
DenseMat(const DenseMat &old_mat)
Definition: DenseMat.hpp:61
void operator-=(const triplet_form< T, uword > &M) override
Definition: DenseMat.hpp:122
DenseMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition: DenseMat.hpp:57
std::unique_ptr< T[]> memory
Definition: DenseMat.hpp:48
void operator+=(const shared_ptr< MetaMat< T >> &M) override
Definition: DenseMat.hpp:95
void operator*=(const T value) override
Definition: DenseMat.hpp:131
T max() const override
Definition: DenseMat.hpp:79
void operator-=(const shared_ptr< MetaMat< T >> &M) override
Definition: DenseMat.hpp:104
bool is_empty() const override
Definition: DenseMat.hpp:72
int sign_det() const override
Definition: DenseMat.hpp:136
Col< T > diag() const override
Definition: DenseMat.hpp:85
DenseMat(DenseMat &&) noexcept=delete
int direct_solve(Mat< T > &X, const Mat< T > &B) override
Definition: DenseMat.hpp:43
void operator+=(const triplet_form< T, uword > &M) override
Definition: DenseMat.hpp:113
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
const uword n_cols
Definition: MetaMat.hpp:86
virtual T & at(uword, uword)=0
Access element with bound check.
const uword n_rows
Definition: MetaMat.hpp:85
bool factored
Definition: MetaMat.hpp:41
const uword n_elem
Definition: MetaMat.hpp:87
SolverSetting< T > setting
Definition: MetaMat.hpp:43
const index_t n_rows
Definition: triplet_form.hpp:128
uword round_up(const uword in_size)
Definition: DenseMat.hpp:34
IterativeSolver iterative_solver
Definition: SolverSetting.hpp:46
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:27