suanPan
FullMatCUDA.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef FULLMATCUDA_HPP
31 #define FULLMATCUDA_HPP
32 
33 #ifdef SUANPAN_CUDA
34 
35 #include <cuda_runtime.h>
36 #include <cusolverDn.h>
37 #include "FullMat.hpp"
38 
39 template<sp_d T> class FullMatCUDA final : public FullMat<T> {
40  cusolverDnHandle_t handle = nullptr;
41  cudaStream_t stream = nullptr;
42 
43  int* info = nullptr;
44  int* ipiv = nullptr;
45  void* d_A = nullptr;
46  void* buffer = nullptr;
47 
48  void acquire();
49  void release() const;
50 
51 public:
52  FullMatCUDA(uword, uword);
53  FullMatCUDA(const FullMatCUDA&);
54  FullMatCUDA(FullMatCUDA&&) noexcept = delete;
55  FullMatCUDA& operator=(const FullMatCUDA&) = delete;
56  FullMatCUDA& operator=(FullMatCUDA&&) noexcept = delete;
57  ~FullMatCUDA() override;
58 
59  unique_ptr<MetaMat<T>> make_copy() override;
60 
61  int direct_solve(Mat<T>&, Mat<T>&&) override;
62  int direct_solve(Mat<T>&, const Mat<T>&) override;
63 };
64 
65 template<sp_d T> void FullMatCUDA<T>::acquire() {
66  cusolverDnCreate(&handle);
67  cudaStreamCreate(&stream);
68  cusolverDnSetStream(handle, stream);
69 
70  cudaMalloc(&info, sizeof(int));
71  cudaMemset(info, 0, sizeof(int));
72  cudaMalloc(&ipiv, sizeof(int) * this->n_rows);
73 
74  if(int bufferSize = 0; std::is_same_v<T, float> || Precision::MIXED == this->setting.precision) {
75  cudaMalloc(&d_A, sizeof(float) * this->n_elem);
76  cusolverDnSgetrf_bufferSize(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_elem), &bufferSize);
77  cudaMalloc(&buffer, sizeof(float) * bufferSize);
78  }
79  else {
80  cudaMalloc(&d_A, sizeof(double) * this->n_elem);
81  cusolverDnDgetrf_bufferSize(handle, int(this->n_rows), int(this->n_cols), (double*)d_A, int(this->n_elem), &bufferSize);
82  cudaMalloc(&buffer, sizeof(double) * bufferSize);
83  }
84 }
85 
86 template<sp_d T> void FullMatCUDA<T>::release() const {
87  if(handle) cusolverDnDestroy(handle);
88  if(stream) cudaStreamDestroy(stream);
89 
90  if(info) cudaFree(info);
91  if(d_A) cudaFree(d_A);
92  if(buffer) cudaFree(buffer);
93  if(ipiv) cudaFree(ipiv);
94 }
95 
96 template<sp_d T> FullMatCUDA<T>::FullMatCUDA(const uword in_rows, const uword in_cols)
97  : FullMat<T>(in_rows, in_cols) { acquire(); }
98 
99 template<sp_d T> FullMatCUDA<T>::FullMatCUDA(const FullMatCUDA& other)
100  : FullMat<T>(other) { acquire(); }
101 
102 template<sp_d T> FullMatCUDA<T>::~FullMatCUDA() { release(); }
103 
104 template<sp_d T> unique_ptr<MetaMat<T>> FullMatCUDA<T>::make_copy() { return make_unique<FullMatCUDA<T>>(*this); }
105 
106 template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, Mat<T>&& B) { return direct_solve(X, B); }
107 
108 template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
109  if(std::is_same_v<T, float>) {
110  // pure float
111  if(!this->factored) {
112  cudaMemcpyAsync(d_A, this->memptr(), sizeof(float) * this->n_elem, cudaMemcpyHostToDevice, stream);
113  cusolverDnSgetrf(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_rows), (float*)buffer, ipiv, info);
114 
115  this->factored = true;
116  }
117 
118  const size_t byte_size = sizeof(float) * B.n_elem;
119 
120  void* d_x = nullptr;
121  cudaMalloc(&d_x, byte_size);
122  cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
123  cusolverDnSgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (float*)d_A, int(this->n_rows), ipiv, (float*)d_x, int(this->n_rows), info);
124 
125  X.set_size(arma::size(B));
126 
127  cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
128 
129  cudaDeviceSynchronize();
130 
131  if(d_x) cudaFree(d_x);
132  }
133  else if(Precision::MIXED == this->setting.precision) {
134  // mixed precision
135  if(!this->factored) {
136  this->s_memory = this->to_float();
137 
138  cudaMemcpyAsync(d_A, this->s_memory.memptr(), sizeof(float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
139  cusolverDnSgetrf(handle, int(this->n_rows), int(this->n_cols), (float*)d_A, int(this->n_rows), (float*)buffer, ipiv, info);
140 
141  this->factored = true;
142  }
143 
144  const size_t byte_size = sizeof(float) * B.n_elem;
145 
146  void* d_x = nullptr;
147  cudaMalloc(&d_x, byte_size);
148 
149  X = arma::zeros(B.n_rows, B.n_cols);
150 
151  mat full_residual = B;
152 
153  auto multiplier = norm(full_residual);
154 
155  auto counter = 0u;
156  while(counter++ < this->setting.iterative_refinement) {
157  if(multiplier < this->setting.tolerance) break;
158 
159  auto residual = conv_to<fmat>::from(full_residual / multiplier);
160 
161  cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
162  cusolverDnSgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (float*)d_A, int(this->n_rows), ipiv, (float*)d_x, int(this->n_rows), info);
163  cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
164 
165  cudaDeviceSynchronize();
166 
167  const mat incre = multiplier * conv_to<mat>::from(residual);
168 
169  X += incre;
170 
171  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
172  }
173 
174  if(d_x) cudaFree(d_x);
175  }
176  else {
177  // pure double
178  if(!this->factored) {
179  cudaMemcpyAsync(d_A, this->memptr(), sizeof(double) * this->n_elem, cudaMemcpyHostToDevice, stream);
180  cusolverDnDgetrf(handle, int(this->n_rows), int(this->n_cols), (double*)d_A, int(this->n_rows), (double*)buffer, ipiv, info);
181 
182  this->factored = true;
183  }
184 
185  const size_t byte_size = sizeof(double) * B.n_elem;
186 
187  void* d_x = nullptr;
188  cudaMalloc(&d_x, byte_size);
189  cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
190  cusolverDnDgetrs(handle, CUBLAS_OP_N, int(this->n_rows), int(B.n_cols), (double*)d_A, int(this->n_rows), ipiv, (double*)d_x, int(this->n_rows), info);
191 
192  X.set_size(arma::size(B));
193 
194  cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
195 
196  cudaDeviceSynchronize();
197 
198  if(d_x) cudaFree(d_x);
199  }
200 
201  return 0;
202 }
203 
204 #endif
205 
206 #endif
207 
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition: Material.cpp:357
void info(const std::string_view format_str, const T &... args)
Definition: suanPan.h:237
double norm(const vec &)
Definition: tensor.cpp:302
concept sp_d
Definition: suanPan.h:318
#define suanpan_debug(...)
Definition: suanPan.h:295