suanPan
BandMatCUDA.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2024 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
29 // ReSharper disable CppCStyleCast
30 #ifndef BANDMATCUDA_HPP
31 #define BANDMATCUDA_HPP
32 
33 #ifdef SUANPAN_CUDA
34 
35 #include "BandMat.hpp"
36 #include <cusolverSp.h>
37 #include <cusparse.h>
38 #include "csr_form.hpp"
39 
40 template<sp_d T> class BandMatCUDA final : public BandMat<T> {
41  cusolverSpHandle_t handle = nullptr;
42  cudaStream_t stream = nullptr;
43  cusparseMatDescr_t descr = nullptr;
44 
45  void* d_val_idx = nullptr;
46  void* d_col_idx = nullptr;
47  void* d_row_ptr = nullptr;
48 
49  triplet_form<float, int> s_mat{static_cast<int>(this->n_rows), static_cast<int>(this->n_cols), static_cast<int>(this->n_elem)};
50 
51  void acquire() {
52  cusolverSpCreate(&handle);
53  cudaStreamCreate(&stream);
54  cusolverSpSetStream(handle, stream);
55  cusparseCreateMatDescr(&descr);
56  cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
57  cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
58  cudaMalloc(&d_row_ptr, sizeof(int) * (this->n_rows + 1));
59  }
60 
61  void release() const {
62  if(handle) cusolverSpDestroy(handle);
63  if(stream) cudaStreamDestroy(stream);
64  if(descr) cusparseDestroyMatDescr(descr);
65  if(d_row_ptr) cudaFree(d_row_ptr);
66  }
67 
68  void device_alloc(csr_form<float, int>&& csr_mat) {
69  const size_t n_val = sizeof(float) * csr_mat.n_elem;
70  const size_t n_col = sizeof(int) * csr_mat.n_elem;
71 
72  cudaMalloc(&d_val_idx, n_val);
73  cudaMalloc(&d_col_idx, n_col);
74 
75  cudaMemcpyAsync(d_val_idx, csr_mat.val_mem(), n_val, cudaMemcpyHostToDevice, stream);
76  cudaMemcpyAsync(d_col_idx, csr_mat.col_mem(), n_col, cudaMemcpyHostToDevice, stream);
77  cudaMemcpyAsync(d_row_ptr, csr_mat.row_mem(), sizeof(int) * (csr_mat.n_rows + 1llu), cudaMemcpyHostToDevice, stream);
78  }
79 
80  void device_dealloc() const {
81  if(d_val_idx) cudaFree(d_val_idx);
82  if(d_col_idx) cudaFree(d_col_idx);
83  }
84 
85 protected:
87 
88  int direct_solve(Mat<T>&, Mat<T>&&) override;
89 
90 public:
91  BandMatCUDA(const uword in_size, const uword in_l, const uword in_u)
92  : BandMat<T>(in_size, in_l, in_u) { acquire(); }
93 
94  BandMatCUDA(const BandMatCUDA& other)
95  : BandMat<T>(other) { acquire(); }
96 
97  BandMatCUDA(BandMatCUDA&&) noexcept = delete;
98  BandMatCUDA& operator=(const BandMatCUDA&) = delete;
99  BandMatCUDA& operator=(BandMatCUDA&&) noexcept = delete;
100 
101  ~BandMatCUDA() override {
102  release();
103  device_dealloc();
104  }
105 
106  unique_ptr<MetaMat<T>> make_copy() override { return std::make_unique<BandMatCUDA>(*this); }
107 };
108 
109 template<sp_d T> int BandMatCUDA<T>::direct_solve(Mat<T>& X, Mat<T>&& B) {
110  suanpan_assert([&] { if(this->n_rows != this->n_cols) throw invalid_argument("requires a square matrix"); });
111 
112  if(!this->factored) {
113  this->factored = true;
114 
115  device_dealloc();
116 
117  s_mat.zeros();
118  for(auto I = 0; I < static_cast<int>(this->n_rows); ++I) for(auto J = std::max(0, I - static_cast<int>(this->u_band)); J <= std::min(static_cast<int>(this->n_rows) - 1, I + static_cast<int>(this->l_band)); ++J) s_mat.at(J, I) = static_cast<float>(this->at(J, I));
119 
120  device_alloc(csr_form<float, int>(s_mat));
121  }
122 
123  const size_t n_rhs = sizeof(float) * B.n_elem;
124 
125  void* d_b = nullptr;
126  void* d_x = nullptr;
127 
128  cudaMalloc(&d_b, n_rhs);
129  cudaMalloc(&d_x, n_rhs);
130 
131  auto INFO = this->mixed_trs(X, std::move(B), [&](fmat& residual) {
132  cudaMemcpyAsync(d_b, residual.memptr(), n_rhs, cudaMemcpyHostToDevice, stream);
133 
134  int singularity;
135 
136  auto code = 0;
137  for(auto I = 0llu; I < residual.n_elem; I += residual.n_rows) code += cusolverSpScsrlsvqr(handle, static_cast<int>(this->n_rows), static_cast<int>(this->s_mat.n_elem), descr, (float*)d_val_idx, (int*)d_row_ptr, (int*)d_col_idx, (float*)d_b + I, static_cast<float>(this->setting.tolerance), 3, (float*)d_x + I, &singularity);
138 
139  cudaMemcpyAsync(residual.memptr(), d_x, n_rhs, cudaMemcpyDeviceToHost, stream);
140 
141  cudaDeviceSynchronize();
142 
143  return code;
144  });
145 
146  if(0 != INFO)
147  suanpan_error("Error code {} received, the matrix is probably singular.\n", INFO);
148 
149  return INFO;
150 }
151 
152 #endif
153 
154 #endif
155 
A BandMatCUDA class that holds matrices.
A BandMat class that holds matrices.
Definition: BandMat.hpp:35
Definition: csr_form.hpp:25
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition: Material.cpp:370
void suanpan_assert(const std::function< void()> &F)
Definition: suanPan.h:296
#define suanpan_error(...)
Definition: suanPan.h:309