30 #ifndef FULLMATCUDA_HPP 31 #define FULLMATCUDA_HPP 35 #include <cuda_runtime.h> 36 #include <cusolverDn.h> 40 cusolverDnHandle_t handle =
nullptr;
41 cudaStream_t stream =
nullptr;
46 void* buffer =
nullptr;
61 int solve(Mat<
T>&, Mat<T>&&) override;
62 int solve(Mat<T>&, const Mat<T>&) override;
66 cusolverDnCreate(&handle);
67 cudaStreamCreate(&stream);
68 cusolverDnSetStream(handle, stream);
70 cudaMalloc(&info,
sizeof(
int));
71 cudaMemset(info, 0,
sizeof(
int));
72 cudaMalloc(&ipiv,
sizeof(
int) * this->n_rows);
74 if(
int bufferSize = 0; std::is_same_v<T, float> ||
Precision::MIXED == this->precision) {
75 cudaMalloc(&d_A,
sizeof(
float) * this->n_elem);
76 cusolverDnSgetrf_bufferSize(handle,
int(this->n_rows),
int(this->n_cols), (
float*)d_A,
int(this->n_elem), &bufferSize);
77 cudaMalloc(&buffer,
sizeof(
float) * bufferSize);
80 cudaMalloc(&d_A,
sizeof(
double) * this->n_elem);
81 cusolverDnDgetrf_bufferSize(handle,
int(this->n_rows),
int(this->n_cols), (
double*)d_A,
int(this->n_elem), &bufferSize);
82 cudaMalloc(&buffer,
sizeof(
double) * bufferSize);
87 if(handle) cusolverDnDestroy(handle);
88 if(stream) cudaStreamDestroy(stream);
90 if(info) cudaFree(info);
91 if(d_A) cudaFree(d_A);
92 if(buffer) cudaFree(buffer);
93 if(ipiv) cudaFree(ipiv);
97 :
FullMat<T>(in_rows, in_cols) { acquire(); }
100 :
FullMat<T>(other) { acquire(); }
109 if(std::is_same_v<T, float>) {
112 cudaMemcpyAsync(d_A, this->
memptr(),
sizeof(
float) * this->
n_elem, cudaMemcpyHostToDevice, stream);
113 cusolverDnSgetrf(handle,
int(this->
n_rows),
int(this->
n_cols), (
float*)d_A,
int(this->
n_rows), (
float*)buffer, ipiv, info);
118 const size_t byte_size =
sizeof(float) * B.n_elem;
121 cudaMalloc(&d_x, byte_size);
122 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
123 cusolverDnSgetrs(handle, CUBLAS_OP_N,
int(this->
n_rows),
int(B.n_cols), (
float*)d_A,
int(this->n_rows), ipiv, (
float*)d_x,
int(this->n_rows), info);
125 X.set_size(arma::size(B));
127 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
129 cudaDeviceSynchronize();
131 if(d_x) cudaFree(d_x);
138 cudaMemcpyAsync(d_A, this->
s_memory.memptr(),
sizeof(float) * this->
s_memory.n_elem, cudaMemcpyHostToDevice, stream);
139 cusolverDnSgetrf(handle,
int(this->
n_rows),
int(this->
n_cols), (
float*)d_A,
int(this->
n_rows), (
float*)buffer, ipiv, info);
144 const size_t byte_size =
sizeof(float) * B.n_elem;
147 cudaMalloc(&d_x, byte_size);
149 X = arma::zeros(B.n_rows, B.n_cols);
151 mat full_residual = B;
153 auto multiplier =
norm(full_residual);
159 auto residual = conv_to<fmat>::from(full_residual / multiplier);
161 cudaMemcpyAsync(d_x, residual.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
162 cusolverDnSgetrs(handle, CUBLAS_OP_N,
int(this->
n_rows),
int(B.n_cols), (
float*)d_A,
int(this->n_rows), ipiv, (
float*)d_x,
int(this->n_rows), info);
163 cudaMemcpyAsync(residual.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
165 cudaDeviceSynchronize();
167 const mat incre = multiplier * conv_to<mat>::from(residual);
171 suanpan_debug(
"mixed precision algorithm multiplier: %.5E\n", multiplier =
arma::norm(full_residual -= this->
operator*(incre)));
174 if(d_x) cudaFree(d_x);
179 cudaMemcpyAsync(d_A, this->
memptr(),
sizeof(
double) * this->
n_elem, cudaMemcpyHostToDevice, stream);
180 cusolverDnDgetrf(handle,
int(this->
n_rows),
int(this->
n_cols), (
double*)d_A,
int(this->
n_rows), (
double*)buffer, ipiv, info);
185 const size_t byte_size =
sizeof(double) * B.n_elem;
188 cudaMalloc(&d_x, byte_size);
189 cudaMemcpyAsync(d_x, B.memptr(), byte_size, cudaMemcpyHostToDevice, stream);
190 cusolverDnDgetrs(handle, CUBLAS_OP_N,
int(this->
n_rows),
int(B.n_cols), (
double*)d_A,
int(this->n_rows), ipiv, (
double*)d_x,
int(this->n_rows), info);
192 X.set_size(arma::size(B));
194 cudaMemcpyAsync(X.memptr(), d_x, byte_size, cudaMemcpyDeviceToHost, stream);
196 cudaDeviceSynchronize();
198 if(d_x) cudaFree(d_x);
double norm(const vec &)
Definition: tensorToolbox.cpp:302
podarray< float > s_memory
Definition: DenseMat.hpp:39
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A FullMatCUDA class that holds matrices.
unique_ptr< Material > make_copy(const shared_ptr< Material > &)
Definition: Material.cpp:359
concept sp_d
Definition: suanPan.h:231