suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
SymmPackMat.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2026 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef SYMMPACKMAT_HPP
31#define SYMMPACKMAT_HPP
32
33#include "../DenseMat.hpp"
34
35template<sp_d T> class SymmPackMat final : public DenseMat<T> {
36 static constexpr auto UPLO = 'L';
37
38 static T bin;
39
40 const uword length; // 2n-1
41
42 int solve_trs(Mat<T>&, Mat<T>&&);
43
44protected:
46
47 int direct_solve(Mat<T>&, Mat<T>&&) override;
48
49public:
50 explicit SymmPackMat(const uword in_size)
51 : DenseMat<T>(in_size, in_size, (in_size + 1) * in_size / 2)
52 , length(2 * in_size - 1) {}
53
54 unique_ptr<MetaMat<T>> unique_copy() override { return std::make_unique<SymmPackMat>(*this); }
55
56 void nullify(const uword K) override {
57 this->factored = false;
58 suanpan::for_each(K, [&](const uword I) { this->memory[K + (length - I) * I / 2] = T(0); });
59 const auto t_factor = (length - K) * K / 2;
60 suanpan::for_each(K, this->n_rows, [&](const uword I) { this->memory[I + t_factor] = T(0); });
61 }
62
63 T operator()(const uword in_row, const uword in_col) const override { return this->memory[in_row >= in_col ? in_row + (length - in_col) * in_col / 2 : in_col + (length - in_row) * in_row / 2]; }
64
65 T& unsafe_at(const uword in_row, const uword in_col) override {
66 this->factored = false;
67 return this->memory[in_row + (length - in_col) * in_col / 2];
68 }
69
70 T& at(const uword in_row, const uword in_col) override {
71 if(in_row < in_col) [[unlikely]]
72 return bin = T(0);
73 return this->unsafe_at(in_row, in_col);
74 }
75
76 Mat<T> operator*(const Mat<T>&) const override;
77
78 [[nodiscard]] int sign_det() const override { return 1; }
79};
80
81template<sp_d T> T SymmPackMat<T>::bin = T(0);
82
83template<sp_d T> Mat<T> SymmPackMat<T>::operator*(const Mat<T>& X) const {
84 auto Y = Mat<T>(arma::size(X), fill::none);
85
86 const auto N = static_cast<blas_int>(this->n_rows);
87 static constexpr blas_int INC = 1;
88 static constexpr T ALPHA{1}, BETA{0};
89
90 if constexpr(std::is_same_v<T, float>) suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_sspmv)(&UPLO, &N, &ALPHA, this->memptr(), X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
91 else suanpan::for_each(X.n_cols, [&](const uword I) { arma_fortran(arma_dspmv)(&UPLO, &N, &ALPHA, this->memptr(), X.colptr(I), &INC, &BETA, Y.colptr(I), &INC); });
92
93 return Y;
94}
95
96template<sp_d T> int SymmPackMat<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
97 if(this->factored) return this->solve_trs(X, std::move(B));
98
99 suanpan_assert([&] { if(this->n_rows != this->n_cols) throw std::invalid_argument("requires a square matrix"); });
100
101 blas_int INFO = 0;
102
103 const auto N = static_cast<blas_int>(this->n_rows);
104 const auto NRHS = static_cast<blas_int>(B.n_cols);
105 const auto LDB = static_cast<blas_int>(B.n_rows);
106 this->factored = true;
107
108 if constexpr(std::is_same_v<T, float>) {
109 arma_fortran(arma_sppsv)(&UPLO, &N, &NRHS, this->memptr(), B.memptr(), &LDB, &INFO);
110 X = std::move(B);
111 }
112 else if(Precision::FULL == this->setting.precision) {
113 arma_fortran(arma_dppsv)(&UPLO, &N, &NRHS, this->memptr(), B.memptr(), &LDB, &INFO);
114 X = std::move(B);
115 }
116 else {
117 this->s_memory = this->to_float();
118 arma_fortran(arma_spptrf)(&UPLO, &N, this->s_memory.memptr(), &INFO);
119 if(0 == INFO) INFO = this->solve_trs(X, std::move(B));
120 }
121
122 if(0 != INFO)
123 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
124
125 return INFO;
126}
127
128template<sp_d T> int SymmPackMat<T>::solve_trs(Mat<T>& X, Mat<T>&& B) {
129 blas_int INFO = 0;
130
131 const auto N = static_cast<blas_int>(this->n_rows);
132 const auto NRHS = static_cast<blas_int>(B.n_cols);
133 const auto LDB = static_cast<blas_int>(B.n_rows);
134
135 if constexpr(std::is_same_v<T, float>) {
136 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->memptr(), B.memptr(), &LDB, &INFO);
137 X = std::move(B);
138 }
139 else if(Precision::FULL == this->setting.precision) {
140 arma_fortran(arma_dpptrs)(&UPLO, &N, &NRHS, this->memptr(), B.memptr(), &LDB, &INFO);
141 X = std::move(B);
142 }
143 else
144 this->mixed_trs(X, std::move(B), [&](fmat& residual) {
145 arma_fortran(arma_spptrs)(&UPLO, &N, &NRHS, this->s_memory.memptr(), residual.memptr(), &LDB, &INFO);
146 return INFO;
147 });
148
149 if(0 != INFO)
150 suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
151
152 return INFO;
153}
154
155#endif
156
A DenseMat class that holds matrices.
Definition DenseMat.hpp:39
std::unique_ptr< T[]> memory
Definition DenseMat.hpp:48
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
A SymmPackMat class that holds matrices.
Definition SymmPackMat.hpp:35
T operator()(const uword in_row, const uword in_col) const override
Access element (read-only), returns zero if out-of-bound.
Definition SymmPackMat.hpp:63
unique_ptr< MetaMat< T > > unique_copy() override
Definition SymmPackMat.hpp:54
void nullify(const uword K) override
Definition SymmPackMat.hpp:56
SymmPackMat(const uword in_size)
Definition SymmPackMat.hpp:50
T & unsafe_at(const uword in_row, const uword in_col) override
Access element without bound check.
Definition SymmPackMat.hpp:65
int sign_det() const override
Definition SymmPackMat.hpp:78
T & at(const uword in_row, const uword in_col) override
Access element with bound check.
Definition SymmPackMat.hpp:70
int direct_solve(Mat< T > &, Mat< T > &&) override
Definition SymmPackMat.hpp:96
Mat< T > operator*(const Mat< T > &) const override
Definition SymmPackMat.hpp:83
void for_each(const IT start, const IT end, F &&FN)
Definition utility.h:31
auto suanpan_assert(F &&handler)
Definition suanPan.h:339
#define suanpan_error(...)
Definition suanPan.h:349