suanPan
Loading...
Searching...
No Matches
BandSymmMatCluster.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDSYMMMATCLUSTER_HPP
31#define BANDSYMMMATCLUSTER_HPP
32
33#include "../DenseMat.hpp"
34
35#include <ezp/ezp/ppbsv.hpp>
36
37template<sp_d T> class BandSymmMatCluster final : public DenseMat<T> {
38 static constexpr char UPLO = 'L';
39
40 using solver_t = ezp::ppbsv<T, la_it, UPLO>;
41 using indexer_t = typename solver_t::indexer;
42
43 static T bin;
44
45 solver_t solver;
46 indexer_t indexer;
47
48 const uword band;
49
50 int solve_trs(Mat<T>&, Mat<T>&&);
51
52protected:
54
55 int direct_solve(Mat<T>&, Mat<T>&&) override;
56
57public:
58 BandSymmMatCluster(const uword in_size, const uword in_bandwidth)
59 : DenseMat<T>(in_size, in_size, (in_bandwidth + 1) * in_size)
60 , solver()
61 , indexer(in_size, in_bandwidth)
62 , band(in_bandwidth) {}
63
64 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandSymmMatCluster>(*this); }
65
66 void nullify(const uword K) override {
67 this->factored = false;
68 suanpan::for_each(std::max(K, band) - band, K, [&](const uword I) { this->memory[indexer(K, I)] = T(0); });
69 suanpan::for_each(K, std::min(this->n_rows, K + band + 1), [&](const uword I) { this->memory[indexer(I, K)] = T(0); });
70 }
71
72 T operator()(const uword in_row, const uword in_col) const override {
73 const auto pos = indexer(in_row, in_col);
74 if(pos < 0) [[unlikely]]
75 return bin = T(0);
76 return this->memory[pos];
77 }
78
79 T& unsafe_at(const uword in_row, const uword in_col) override {
80 this->factored = false;
81 return this->memory[indexer(in_row, in_col)];
82 }
83
84 T& at(const uword in_row, const uword in_col) override {
85 const auto pos = indexer(in_row, in_col);
86 if(pos < 0) [[unlikely]]
87 return bin = T(0);
88 this->factored = false;
89 return this->memory[pos];
90 }
91
92 Mat<T> operator*(const Mat<T>&) const override;
93};
94
95template<sp_d T> T BandSymmMatCluster<T>::bin = T(0);
96
97template<sp_d T> Mat<T> BandSymmMatCluster<T>::operator*(const Mat<T>& X) const {
98 static constexpr blas_int INC = 1;
99 static constexpr T ALPHA = T(1), BETA = T(0);
100
101 Mat<T> Y(arma::size(X));
102
103 const auto N = static_cast<blas_int>(this->n_cols);
104 const auto K = static_cast<blas_int>(band);
105 const auto LDA = K + 1;
106
107 if constexpr(std::is_same_v<T, float>) {
108 using E = float;
109 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_ssbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
110 }
111 else {
112 using E = double;
113 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dsbmv)(&UPLO, &N, &K, (E*)&ALPHA, (E*)this->memptr(), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
114 }
115
116 return Y;
117}
118
119#pragma GCC diagnostic push
120#pragma GCC diagnostic ignored "-Wnarrowing"
121template<sp_d T> int BandSymmMatCluster<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
122 if(this->factored) return this->solve_trs(X, std::move(B));
123
124 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
125
126 this->factored = true;
127
128 const auto INFO = bcast_from_root(solver.solve({this->n_rows, this->n_cols, this->band, this->memptr()}, {B.n_rows, B.n_cols, B.memptr()}));
129
130 if(0 == INFO) bcast_from_root(X = std::move(B));
131 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
132
133 return INFO;
134}
135
136template<sp_d T> int BandSymmMatCluster<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
137 const auto INFO = bcast_from_root(solver.solve({B.n_rows, B.n_cols, B.memptr()}));
138
139 if(0 == INFO) bcast_from_root(X = std::move(B));
140 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
141
142 return INFO;
143}
144#pragma GCC diagnostic pop
145
146#endif
147
A BandSymmMatCluster class that holds matrices.
Definition BandSymmMatCluster.hpp:37
unique_ptr< MetaMat< T > > make_copy() override
Definition BandSymmMatCluster.hpp:64
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandSymmMatCluster.hpp:72
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandSymmMatCluster.hpp:84
void nullify(const uword K) override
Definition BandSymmMatCluster.hpp:66
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandSymmMatCluster.hpp:79
BandSymmMatCluster(const uword in_size, const uword in_bandwidth)
Definition BandSymmMatCluster.hpp:58
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandSymmMatCluster.hpp:121
Mat< T > operator*(const Mat< T > &) const override
Definition BandSymmMatCluster.hpp:97
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
auto bcast_from_root(T &&object)
Definition suanPan.h:254
#define suanpan_error(...)
Definition suanPan.h:376