suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
DenseMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29#ifndef DENSEMAT_HPP
30#define DENSEMAT_HPP
31
32#include "MetaMat.hpp"
33
34template<sp_d T> uword round_up(const uword in_size) {
35 constexpr auto multiple = 64llu / sizeof(T);
36 return (in_size + multiple - 1llu) / multiple * multiple;
37}
38
39template<sp_d T> class DenseMat : public MetaMat<T> {
40protected:
41 using MetaMat<T>::direct_solve;
42
43 int direct_solve(Mat<T>& X, const Mat<T>& B) override { return this->direct_solve(X, Mat<T>(B)); }
44
45 podarray<blas_int> pivot;
46 podarray<float> s_memory; // float storage used in mixed precision algorithm
47
48 std::unique_ptr<T[]> memory = nullptr;
49
50 podarray<float> to_float() {
51 podarray<float> f_memory(this->n_elem);
52 suanpan::for_each(this->n_elem, [&](const uword I) { f_memory(I) = static_cast<float>(memory[I]); });
53 return f_memory;
54 }
55
56public:
57 DenseMat(const uword in_rows, const uword in_cols, const uword in_elem)
58 : MetaMat<T>(in_rows, in_cols, in_elem)
59 , memory(new T[this->n_elem]) {
60 if(in_elem > std::numeric_limits<la_it>::max()) throw std::runtime_error("matrix size exceeds limit, please enable 64-bit indexing");
62 }
63
64 DenseMat(const DenseMat& old_mat)
65 : MetaMat<T>(old_mat)
66 , pivot(old_mat.pivot)
67 , s_memory(old_mat.s_memory)
68 , memory(new T[this->n_elem]) {
69 suanpan::for_each(this->n_elem, [&](const uword I) { memory[I] = old_mat.memory[I]; });
70 }
71
72 DenseMat(DenseMat&&) = delete;
73 DenseMat& operator=(const DenseMat&) = delete;
75 ~DenseMat() override = default;
76
77 [[nodiscard]] bool is_empty() const override { return 0 == this->n_elem; }
78
79 void zeros() override {
80 this->factored = false;
81 arrayops::fill_zeros(memptr(), this->n_elem);
82 }
83
84 [[nodiscard]] T max() const override {
85 T max_value = T(1);
86 for(uword I = 0; I < std::min(this->n_rows, this->n_cols); ++I)
87 if(const auto t_val = this->operator()(I, I); t_val > max_value) max_value = t_val;
88 return max_value;
89 }
90
91 [[nodiscard]] const T* memptr() const override { return memory.get(); }
92
93 T* memptr() override { return memory.get(); }
94
95 void scale_accu(const T scalar, const shared_ptr<MetaMat<T>>& M) override {
96 if(nullptr == M) return;
97 if(!M->triplet_mat.is_empty()) return this->scale_accu(scalar, M->triplet_mat);
98 if(this->n_rows != M->n_rows || this->n_cols != M->n_cols || this->n_elem != M->n_elem) throw std::invalid_argument("size mismatch");
99 if(nullptr == M->memptr()) return;
100 this->factored = false;
101 if(1. == scalar) arrayops::inplace_plus(memptr(), M->memptr(), this->n_elem);
102 else if(-1. == scalar) arrayops::inplace_minus(memptr(), M->memptr(), this->n_elem);
103 else suanpan::for_each(this->n_elem, [&](const uword I) { memptr()[I] += scalar * M->memptr()[I]; });
104 }
105
106 void scale_accu(const T scalar, const triplet_form<T, uword>& M) override {
107 if(this->n_rows != M.n_rows || this->n_cols != M.n_cols) throw std::invalid_argument("size mismatch");
108 this->factored = false;
109 const auto row = M.row_mem();
110 const auto col = M.col_mem();
111 const auto val = M.val_mem();
112 if(1. == scalar)
113 for(auto I = 0llu; I < M.n_elem; ++I) this->at(row[I], col[I]) += val[I];
114 else if(-1. == scalar)
115 for(auto I = 0llu; I < M.n_elem; ++I) this->at(row[I], col[I]) -= val[I];
116 else
117 for(auto I = 0llu; I < M.n_elem; ++I) this->at(row[I], col[I]) += scalar * val[I];
118 }
119
120 void operator*=(const T value) override {
121 this->factored = false;
122 arrayops::inplace_mul(memptr(), value, this->n_elem);
123 }
124
125 void allreduce() override {
126#ifdef SUANPAN_DISTRIBUTED
127 comm_world.allreduce(mpl::plus<T>(), memory.get(), mpl::contiguous_layout<T>{this->n_elem});
128#endif
129 }
130};
131
132#endif
133
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
podarray< float > to_float()
Definition DenseMat.hpp:50
const T * memptr() const override
Definition DenseMat.hpp:91
podarray< float > s_memory
Definition DenseMat.hpp:46
podarray< blas_int > pivot
Definition DenseMat.hpp:45
void zeros() override
Definition DenseMat.hpp:79
DenseMat(const DenseMat &old_mat)
Definition DenseMat.hpp:64
DenseMat(const uword in_rows, const uword in_cols, const uword in_elem)
Definition DenseMat.hpp:57
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
~DenseMat() override=default
void scale_accu(const T scalar, const shared_ptr< MetaMat< T > > &M) override
Definition DenseMat.hpp:95
void operator*=(const T value) override
Definition DenseMat.hpp:120
T max() const override
Definition DenseMat.hpp:84
void scale_accu(const T scalar, const triplet_form< T, uword > &M) override
Definition DenseMat.hpp:106
DenseMat(DenseMat &&)=delete
bool is_empty() const override
Definition DenseMat.hpp:77
void allreduce() override
Definition DenseMat.hpp:125
DenseMat & operator=(const DenseMat &)=delete
DenseMat & operator=(DenseMat &&)=delete
T * memptr() override
Definition DenseMat.hpp:93
int direct_solve(Mat< T > &X, const Mat< T > &B) override
Definition DenseMat.hpp:43
A MetaMat class that holds matrices.
Definition MetaMat.hpp:74
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
const uword n_elem
Definition MetaMat.hpp:117
virtual T & at(uword, uword)=0
Access element with bound check.
Definition triplet_form.hpp:62
const index_t n_rows
Definition triplet_form.hpp:128
uword round_up(const uword in_size)
Definition DenseMat.hpp:34
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31