suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
FullMatCluster.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMATCLUSTER_HPP
31#define FULLMATCLUSTER_HPP
32
33#include "../DenseMat.hpp"
34
35#include <ezp/ezp/pgesv.hpp>
36#include <ezp/ezp/pposv.hpp>
37
38template<sp_d T, typename solver_t> class FullMatBaseCluster : public DenseMat<T> {
39 solver_t solver;
40
41 int solve_trs(Mat<T>&, Mat<T>&&);
42
43protected:
45
46 int direct_solve(Mat<T>&, Mat<T>&&) override;
47
48public:
49 FullMatBaseCluster(const uword in_rows, const uword in_cols)
50 : DenseMat<T>(in_rows, in_cols, in_rows * in_cols)
51 , solver() {}
52
53 unique_ptr<MetaMat<T>> unique_copy() override { return std::make_unique<FullMatBaseCluster>(*this); }
54
55 void nullify(const uword K) override {
56 this->factored = false;
57 suanpan::for_each(this->n_rows, [&](const uword I) { this->at(I, K) = T(0); });
58 suanpan::for_each(this->n_cols, [&](const uword I) { this->at(K, I) = T(0); });
59 }
60
61 T operator()(const uword in_row, const uword in_col) const override { return this->memory[in_row + in_col * this->n_rows]; }
62
63 T& at(const uword in_row, const uword in_col) override {
64 this->factored = false;
65 return this->memory[in_row + in_col * this->n_rows];
66 }
67
68 Mat<T> operator*(const Mat<T>&) const override;
69};
70
71template<sp_d T, typename solver_t> Mat<T> FullMatBaseCluster<T, solver_t>::operator*(const Mat<T>& B) const {
72 static constexpr auto TRAN = 'N';
73 static constexpr T ALPHA{1}, BETA{0};
74
75 Mat<T> C(arma::size(B));
76
77 const auto M = static_cast<blas_int>(this->n_rows);
78 const auto N = static_cast<blas_int>(this->n_cols);
79
80 if(1 == B.n_cols) {
81 static constexpr blas_int INC = 1;
82
83 if constexpr(std::is_same_v<T, float>) arma_fortran(arma_sgemv)(&TRAN, &M, &N, &ALPHA, this->memptr(), &M, B.memptr(), &INC, &BETA, C.memptr(), &INC);
84 else arma_fortran(arma_dgemv)(&TRAN, &M, &N, &ALPHA, this->memptr(), &M, B.memptr(), &INC, &BETA, C.memptr(), &INC);
85 }
86 else {
87 const auto K = static_cast<blas_int>(B.n_cols);
88
89 if constexpr(std::is_same_v<T, float>) arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &K, &N, &ALPHA, this->memptr(), &M, B.memptr(), &N, &BETA, C.memptr(), &M);
90 else arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &K, &N, &ALPHA, this->memptr(), &M, B.memptr(), &N, &BETA, C.memptr(), &M);
91 }
92
93 return C;
94}
95
96#pragma GCC diagnostic push
97#pragma GCC diagnostic ignored "-Wnarrowing"
98template<sp_d T, typename solver_t> int FullMatBaseCluster<T, solver_t>::direct_solve(Mat<T>& X, Mat<T>&& B) {
99 if(this->factored) return this->solve_trs(X, std::move(B));
100
101 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
102
103 this->factored = true;
104
105 const auto INFO = bcast_from_root(solver.solve({this->n_rows, this->n_cols, this->memptr()}, {B.n_rows, B.n_cols, B.memptr()}));
106
107 if(0 == INFO) bcast_from_root(X = std::move(B));
108 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
109
110 return INFO;
111}
112
113template<sp_d T, typename solver_t> int FullMatBaseCluster<T, solver_t>::solve_trs(Mat<T>& X, Mat<T>&& B) {
114 const auto INFO = bcast_from_root(solver.solve({B.n_rows, B.n_cols, B.memptr()}));
115
116 if(0 == INFO) bcast_from_root(X = std::move(B));
117 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
118
119 return INFO;
120}
121#pragma GCC diagnostic pop
122
125
126#endif
127
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
Definition FullMatCluster.hpp:38
void nullify(const uword K) override
Definition FullMatCluster.hpp:55
FullMatBaseCluster(const uword in_rows, const uword in_cols)
Definition FullMatCluster.hpp:49
unique_ptr< MetaMat< T > > unique_copy() override
Definition FullMatCluster.hpp:53
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition FullMatCluster.hpp:63
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition FullMatCluster.hpp:61
A FullMatCluster class that holds matrices.
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
FullMatBaseCluster< T, ezp::pposv< T, la_it > > FullSymmMatCluster
Definition FullMatCluster.hpp:124
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition FullMatCluster.hpp:98
Mat< T > operator*(const Mat< T > &) const override
Definition FullMatCluster.hpp:71
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
auto bcast_from_root(T &&object)
Definition suanPan.h:238
#define suanpan_error(...)
Definition suanPan.h:349