suanPan
FullMatCUDA.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef FULLMATCUDA_HPP
31 #define FULLMATCUDA_HPP
32 
33 #ifdef SUANPAN_CUDA
34 
35 #include <cuda_runtime.h>
36 #include <cusolverDn.h>
37 #include "FullMat.hpp"
38 
39 template<sp_d T> class FullMatCUDA final : public FullMat<T> {
40  cusolverDnHandle_t handle = nullptr;
41  cudaStream_t stream = nullptr;
42 
43  int* info = nullptr;
44  int* ipiv = nullptr;
45  void* d_A = nullptr;
46  void* buffer = nullptr;
47 
48  void acquire() {
49  cusolverDnCreate(&handle);
50  cudaStreamCreate(&stream);
51  cusolverDnSetStream(handle, stream);
52 
53  cudaMalloc(&info, sizeof(int));
54  cudaMemset(info, 0, sizeof(int));
55  cudaMalloc(&ipiv, sizeof(int) * this->n_rows);
56 
57  int bufferSize = 0;
58  if constexpr(std::is_same_v<T, float>) {
59  cudaMalloc(&d_A, sizeof(float) * this->n_elem);
60  cusolverDnSgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_elem), &bufferSize);
61  cudaMalloc(&buffer, sizeof(float) * bufferSize);
62  }
63  else if(Precision::MIXED == this->setting.precision) {
64  cudaMalloc(&d_A, sizeof(float) * this->n_elem);
65  cusolverDnSgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_elem), &bufferSize);
66  cudaMalloc(&buffer, sizeof(float) * bufferSize);
67  }
68  else {
69  cudaMalloc(&d_A, sizeof(double) * this->n_elem);
70  cusolverDnDgetrf_bufferSize(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (double*)d_A, static_cast<int>(this->n_elem), &bufferSize);
71  cudaMalloc(&buffer, sizeof(double) * bufferSize);
72  }
73  }
74 
75  void release() const {
76  if(handle) cusolverDnDestroy(handle);
77  if(stream) cudaStreamDestroy(stream);
78 
79  if(info) cudaFree(info);
80  if(d_A) cudaFree(d_A);
81  if(buffer) cudaFree(buffer);
82  if(ipiv) cudaFree(ipiv);
83  }
84 
85 protected:
86  int direct_solve(Mat<T>& X, Mat<T>&& B) override { return this->direct_solve(X, B); }
87 
88  int direct_solve(Mat<T>&, const Mat<T>&) override;
89 
90 public:
91  FullMatCUDA(const uword in_rows, const uword in_cols)
92  : FullMat<T>(in_rows, in_cols) { acquire(); }
93 
94  FullMatCUDA(const FullMatCUDA& other)
95  : FullMat<T>(other) { acquire(); }
96 
97  FullMatCUDA(FullMatCUDA&&) noexcept = delete;
98  FullMatCUDA& operator=(const FullMatCUDA&) = delete;
99  FullMatCUDA& operator=(FullMatCUDA&&) noexcept = delete;
100 
101  ~FullMatCUDA() override { release(); }
102 
103  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<FullMatCUDA>(*this); }
104 };
105 
106 template<sp_d T> int FullMatCUDA<T>::direct_solve(Mat<T>& X, const Mat<T>& B) {
107  if constexpr(std::is_same_v<T, float>) {
108  // pure float
109  if(!this->factored) {
110  this->factored = true;
111  cudaMemcpyAsync(d_A, this->memptr(), sizeof(float) * this->n_elem, cudaMemcpyHostToDevice, stream);
112  cusolverDnSgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_rows), (float*)buffer, ipiv, info);
113  }
114 
115  const size_t byte_size = sizeof(float) * B.n_elem;
116 
117  void* d_x = nullptr;
118  cudaMalloc(&d_x, byte_size);
119  cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
120  cusolverDnSgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (float*)d_A, static_cast<int>(this->n_rows), ipiv, (float*)d_x, static_cast<int>(this->n_rows), info);
121 
122  X.set_size(arma::size(B));
123 
124  cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
125 
126  cudaDeviceSynchronize();
127 
128  if(d_x) cudaFree(d_x);
129  }
130  else if(Precision::MIXED == this->setting.precision) {
131  // mixed precision
132  if(!this->factored) {
133  this->factored = true;
134  this->s_memory = this->to_float();
135  cudaMemcpyAsync(d_A, this->s_memory.memptr(), sizeof(float) * this->s_memory.n_elem, cudaMemcpyHostToDevice, stream);
136  cusolverDnSgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (float*)d_A, static_cast<int>(this->n_rows), (float*)buffer, ipiv, info);
137  }
138 
139  const size_t byte_size = sizeof(float) * B.n_elem;
140 
141  void* d_x = nullptr;
142  cudaMalloc(&d_x, byte_size);
143 
144  X = arma::zeros(B.n_rows, B.n_cols);
145 
146  mat full_residual = B;
147 
148  auto multiplier = norm(full_residual);
149 
150  auto counter = 0u;
151  while(counter++ < this->setting.iterative_refinement) {
152  if(multiplier < this->setting.tolerance) break;
153 
154  auto residual = conv_to<fmat>::from(full_residual / multiplier);
155 
156  cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
157  cusolverDnSgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (float*)d_A, static_cast<int>(this->n_rows), ipiv, (float*)d_x, static_cast<int>(this->n_rows), info);
158  cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
159 
160  cudaDeviceSynchronize();
161 
162  const mat incre = multiplier * conv_to<mat>::from(residual);
163 
164  X += incre;
165 
166  suanpan_debug("Mixed precision algorithm multiplier: {:.5E}.\n", multiplier = arma::norm(full_residual -= this->operator*(incre)));
167  }
168 
169  if(d_x) cudaFree(d_x);
170  }
171  else {
172  // pure double
173  if(!this->factored) {
174  this->factored = true;
175  cudaMemcpyAsync(d_A, this->memptr(), sizeof(double) * this->n_elem, cudaMemcpyHostToDevice, stream);
176  cusolverDnDgetrf(handle, static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), (double*)d_A, static_cast<int>(this->n_rows), (double*)buffer, ipiv, info);
177  }
178 
179  const size_t byte_size = sizeof(double) * B.n_elem;
180 
181  void* d_x = nullptr;
182  cudaMalloc(&d_x, byte_size);
183  cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
184  cusolverDnDgetrs(handle, CUBLAS_OP_N, static_cast<int>(this->n_rows), static_cast<int>(B.n_cols), (double*)d_A, static_cast<int>(this->n_rows), ipiv, (double*)d_x, static_cast<int>(this->n_rows), info);
185 
186  X.set_size(arma::size(B));
187 
188  cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
189 
190  cudaDeviceSynchronize();
191 
192  if(d_x) cudaFree(d_x);
193  }
194 
195  return 0;
196 }
197 
198 #endif
199 
200 #endif
201 
A FullMatCUDA class that holds matrices.
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition: Material.cpp:370
void info(const std::string_view format_str, const T &... args)
Definition: suanPan.h:249
double norm(const vec &)
Definition: tensor.cpp:370
#define suanpan_debug(...)
Definition: suanPan.h:307