suanPan
Loading...
Searching...
No Matches
BandMatCluster.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef BANDMATCLUSTER_HPP
31#define BANDMATCLUSTER_HPP
32
33#include "../DenseMat.hpp"
34
35#include <ezp/ezp/pgbsv.hpp>
36
37template<sp_d T> class BandMatCluster : public DenseMat<T> {
38 using solver_t = ezp::pgbsv<T, la_it>;
39 using indexer_t = typename solver_t::indexer;
40
41 static T bin;
42
43 solver_t solver;
44 indexer_t indexer;
45
46 int solve_trs(Mat<T>&, Mat<T>&&);
47
48protected:
49 const uword l_band, u_band;
50
52
53 int direct_solve(Mat<T>&, Mat<T>&&) override;
54
55public:
56 BandMatCluster(const uword in_size, const uword in_l, const uword in_u)
57 : DenseMat<T>(in_size, in_size, (2 * (in_l + in_u) + 1) * in_size)
58 , solver()
59 , indexer(in_size, in_l, in_u)
60 , l_band(in_l)
61 , u_band(in_u) {
62 if(2 * (in_l + in_u) + 1 >= in_size)
63 suanpan_warning("The storage requirement for the banded matrix is larger than that of a full matrix, consider using a full/sparse matrix instead.\n");
64 }
65
66 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMatCluster>(*this); }
67
68 void nullify(const uword K) override {
69 this->factored = false;
70 suanpan::for_each(std::max(K, u_band) - u_band, std::min(this->n_rows, K + l_band + 1), [&](const uword I) { this->memory[2 * u_band + l_band + I + 2 * K * (l_band + u_band)] = T(0); });
71 suanpan::for_each(std::max(K, l_band) - l_band, std::min(this->n_cols, K + u_band + 1), [&](const uword I) { this->memory[2 * u_band + l_band + K + 2 * I * (l_band + u_band)] = T(0); });
72 }
73
74 T operator()(const uword in_row, const uword in_col) const override {
75 const auto pos = indexer(in_row, in_col);
76 if(pos < 0) [[unlikely]]
77 return bin = T(0);
78 return this->memory[pos];
79 }
80
81 T& unsafe_at(const uword in_row, const uword in_col) override {
82 this->factored = false;
83 return this->memory[indexer(in_row, in_col)];
84 }
85
86 T& at(const uword in_row, const uword in_col) override {
87 const auto pos = indexer(in_row, in_col);
88 if(pos < 0) [[unlikely]]
89 return bin = T(0);
90 this->factored = false;
91 return this->memory[pos];
92 }
93
94 Mat<T> operator*(const Mat<T>&) const override;
95};
96
97template<sp_d T> T BandMatCluster<T>::bin = T(0);
98
99template<sp_d T> Mat<T> BandMatCluster<T>::operator*(const Mat<T>& X) const {
100 static constexpr char TRAN = 'N';
101 static constexpr blas_int INC = 1;
102 static constexpr T ALPHA = T(1), BETA = T(0);
103
104 Mat<T> Y(arma::size(X));
105
106 const auto s_band = l_band + u_band;
107
108 const auto M = static_cast<blas_int>(this->n_rows);
109 const auto N = static_cast<blas_int>(this->n_cols);
110 const auto KL = static_cast<blas_int>(l_band);
111 const auto KU = static_cast<blas_int>(u_band);
112 const auto LDA = static_cast<blas_int>(2 * s_band + 1);
113
114 if constexpr(std::is_same_v<T, float>) {
115 using E = float;
116 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + s_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
117 }
118 else {
119 using E = double;
120 suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dgbmv)(&TRAN, &M, &N, &KL, &KU, (E*)&ALPHA, (E*)(this->memptr() + s_band), &LDA, (E*)X.colptr(I), &INC, (E*)&BETA, (E*)Y.colptr(I), &INC); });
121 }
122
123 return Y;
124}
125
126#pragma GCC diagnostic push
127#pragma GCC diagnostic ignored "-Wnarrowing"
128template<sp_d T> int BandMatCluster<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
129 if(this->factored) return this->solve_trs(X, std::move(B));
130
131 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
132
133 this->factored = true;
134
135 const auto INFO = bcast_from_root(solver.solve({this->n_rows, this->n_cols, this->l_band, this->u_band, this->memptr()}, {B.n_rows, B.n_cols, B.memptr()}));
136
137 if(0 == INFO) bcast_from_root(X = std::move(B));
138 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
139
140 return INFO;
141}
142
143template<sp_d T> int BandMatCluster<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
144 const auto INFO = bcast_from_root(solver.solve({B.n_rows, B.n_cols, B.memptr()}));
145
146 if(0 == INFO) bcast_from_root(X = std::move(B));
147 else suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
148
149 return INFO;
150}
151#pragma GCC diagnostic pop
152
153#endif
154
A BandMatCluster class that holds matrices.
Definition BandMatCluster.hpp:37
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition BandMatCluster.hpp:74
void nullify(const uword K) override
Definition BandMatCluster.hpp:68
const uword l_band
Definition BandMatCluster.hpp:49
BandMatCluster(const uword in_size, const uword in_l, const uword in_u)
Definition BandMatCluster.hpp:56
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition BandMatCluster.hpp:81
unique_ptr< MetaMat< T > > make_copy() override
Definition BandMatCluster.hpp:66
const uword u_band
Definition BandMatCluster.hpp:49
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition BandMatCluster.hpp:86
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
Mat< T > operator*(const Mat< T > &) const override
Definition BandMatCluster.hpp:99
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition BandMatCluster.hpp:128
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:28
#define suanpan_warning(...)
Definition suanPan.h:375
void suanpan_assert(const std::function< void()> &F)
Definition suanPan.h:363
auto bcast_from_root(T &&object)
Definition suanPan.h:254
#define suanpan_error(...)
Definition suanPan.h:376