suanPan
Loading...
Searching...
No Matches
FullMatCUDA.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (C) 2017-2025 Theodore Chang
3 *
4 * This program is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 3 of the License, or
7 * (at your option) any later version.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 *
14 * You should have received a copy of the GNU General Public License
15 * along with this program. If not, see <http://www.gnu.org/licenses/>.
16 ******************************************************************************/
29// ReSharper disable CppCStyleCast
30#ifndef FULLMATCUDA_HPP
31#define FULLMATCUDA_HPP
32
33#include "../cuda_ptr.hpp"
34#include "FullMat.hpp"
35
36#include <cusolverDn.h>
37
38template<sp_d T> class FullMatCUDA final : public FullMat<T> {
39 cusolverDnHandle_t handle = nullptr;
40 cudaStream_t stream = nullptr;
41
42 cuda_ptr info{sizeof(int), 1}, d_ipiv{sizeof(int), static_cast<int>(this->n_rows)}, d_A{}, d_work{};
43
44 void init_config() {
45 cusolverDnCreate(&handle);
46 cudaStreamCreate(&stream);
47 cusolverDnSetStream(handle, stream);
48
49 int work_size = 0;
50 if(std::is_same_v<T, float> || Precision::MIXED == this->setting.precision) {
51 d_A = cuda_ptr(sizeof(float), static_cast<int>(this->n_elem));
52 cusolverDnSgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), d_A.get<float>(), d_A.size, &work_size);
53 d_work = cuda_ptr(sizeof(float), work_size);
54 }
55 else {
56 d_A = cuda_ptr(sizeof(double), static_cast<int>(this->n_elem));
57 cusolverDnDgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), d_A.get<double>(), d_A.size, &work_size);
58 d_work = cuda_ptr(sizeof(double), work_size);
59 }
60
61 this->factored = false;
62 }
63
64 void release() const {
65 if(handle) cusolverDnDestroy(handle);
66 if(stream) cudaStreamDestroy(stream);
67 }
68
69protected:
70 int direct_solve(Mat<T>& X, Mat<T>&& B) override { return this->direct_solve(X, B); }
71
72 int direct_solve(Mat<T>&, const Mat<T>&) override;
73
74public:
75 FullMatCUDA(const uword in_rows, const uword in_cols)
76 : FullMat<T>(in_rows, in_cols) { init_config(); }
77
79 : FullMat<T>(other) { init_config(); }
80
84
85 ~FullMatCUDA() override { release(); }
86
87 unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<FullMatCUDA>(*this); }
88};
89
90template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
91 const auto NROW = static_cast<int>(this->n_rows), NCOL = static_cast<int>(this->n_cols);
92
93 int flag;
94
95 if constexpr(std::is_same_v<T, float>) {
96 // pure float
97 if(!this->factored) {
98 this->factored = true;
99 d_A.copy_from(this->memptr(), stream);
100 cusolverDnSgetrf(handle, NROW, NCOL, d_A.get<float>(), NROW, d_work.get<float>(), d_ipiv.get(), info.get());
101 }
102
103 const cuda_ptr d_x{sizeof(float), static_cast<int>(B.n_elem)};
104 d_x.copy_from(B.memptr(), stream);
105
106 cusolverDnSgetrs(handle, CUBLAS_OP_N, NROW, static_cast<int>(B.n_cols), d_A.get<float>(), NROW, d_ipiv.get(), d_x.get<float>(), NROW, info.get());
107
108 X.set_size(arma::size(B));
109 d_x.copy_to(X.memptr(), stream);
110 }
111 else if(Precision::MIXED == this->setting.precision) {
112 // mixed precision
113 if(!this->factored) {
114 this->factored = true;
115 this->s_memory = this->to_float();
116 d_A.copy_from(this->s_memory.memptr(), stream);
117 cusolverDnSgetrf(handle, NROW, NCOL, d_A.get<float>(), NROW, d_work.get<float>(), d_ipiv.get(), info.get());
118 }
119
120 const cuda_ptr d_x{sizeof(float), static_cast<int>(B.n_elem)};
121
122 X = arma::zeros(B.n_rows, B.n_cols);
123
124 mat full_residual = B;
125
126 std::uint8_t counter{0};
127 while(counter++ < this->setting.iterative_refinement) {
128 const auto multiplier = norm(full_residual);
129 if(multiplier < this->setting.tolerance) break;
130 suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier);
131
132 auto residual = conv_to<fmat>::from(full_residual / multiplier);
133 d_x.copy_from(residual.memptr(), stream);
134
135 cusolverDnSgetrs(handle, CUBLAS_OP_N, NROW, static_cast<int>(B.n_cols), d_A.get<float>(), NROW, d_ipiv.get(), d_x.get<float>(), NROW, info.get());
136
137 d_x.copy_to(residual.memptr(), stream);
138 full_residual = B - this->operator*(X += multiplier * conv_to<mat>::from(residual));
139 }
140 }
141 else {
142 // pure double
143 if(!this->factored) {
144 this->factored = true;
145 d_A.copy_from(this->memptr(), stream);
146 cusolverDnDgetrf(handle, NROW, NCOL, d_A.get<double>(), NROW, d_work.get<double>(), d_ipiv.get(), info.get());
147 }
148
149 const cuda_ptr d_x{sizeof(float), static_cast<int>(B.n_elem)};
150 d_x.copy_from(B.memptr(), stream);
151
152 cusolverDnDgetrs(handle, CUBLAS_OP_N, NROW, static_cast<int>(B.n_cols), d_A.get<double>(), NROW, d_ipiv.get(), d_x.get<double>(), NROW, info.get());
153
154 X.set_size(arma::size(B));
155 d_x.copy_to(X.memptr(), stream);
156 }
157
158 info.copy_to(&flag, stream);
159
160 return flag;
161}
162
163#endif
164
A FullMatCUDA class that holds matrices.
Definition FullMatCUDA.hpp:38
FullMatCUDA(const uword in_rows, const uword in_cols)
Definition FullMatCUDA.hpp:75
~FullMatCUDA() override
Definition FullMatCUDA.hpp:85
FullMatCUDA & operator=(FullMatCUDA &&)=delete
FullMatCUDA & operator=(const FullMatCUDA &)=delete
int direct_solve(Mat< T > &X, Mat< T > &&B) override
Definition FullMatCUDA.hpp:70
FullMatCUDA(FullMatCUDA &&)=delete
FullMatCUDA(const FullMatCUDA &other)
Definition FullMatCUDA.hpp:78
unique_ptr< MetaMat< T > > make_copy() override
Definition FullMatCUDA.hpp:87
A FullMat class that holds matrices.
Definition FullMat.hpp:35
const uword n_cols
Definition MetaMat.hpp:116
const uword n_rows
Definition MetaMat.hpp:115
bool factored
Definition MetaMat.hpp:76
const uword n_elem
Definition MetaMat.hpp:117
SolverSetting< T > setting
Definition MetaMat.hpp:78
Definition cuda_ptr.hpp:32
auto copy_from(const void *src, const cudaStream_t s) const
Definition cuda_ptr.hpp:67
T * get(const unsigned long long offset=0) const
Definition cuda_ptr.hpp:65
int size
Definition cuda_ptr.hpp:40
op_scale< T > operator*(const T value, const shared_ptr< MetaMat< T > > &M)
Definition operator_times.hpp:23
Precision precision
Definition SolverSetting.hpp:32
#define suanpan_debug(...)
Definition suanPan.h:374