39 cusolverDnHandle_t handle =
nullptr;
40 cudaStream_t stream =
nullptr;
42 cuda_ptr info{
sizeof(int), 1}, d_ipiv{
sizeof(int),
static_cast<int>(this->
n_rows)}, d_A{}, d_work{};
45 cusolverDnCreate(&handle);
46 cudaStreamCreate(&stream);
47 cusolverDnSetStream(handle, stream);
52 cusolverDnSgetrf_bufferSize(handle,
static_cast<int>(this->
n_rows),
static_cast<int>(this->
n_cols), d_A.
get<
float>(), d_A.
size, &work_size);
53 d_work =
cuda_ptr(
sizeof(
float), work_size);
57 cusolverDnDgetrf_bufferSize(handle,
static_cast<int>(this->
n_rows),
static_cast<int>(this->
n_cols), d_A.
get<
double>(), d_A.
size, &work_size);
58 d_work =
cuda_ptr(
sizeof(
double), work_size);
64 void release()
const {
65 if(handle) cusolverDnDestroy(handle);
66 if(stream) cudaStreamDestroy(stream);
76 :
FullMat<T>(in_rows, in_cols) { init_config(); }
79 :
FullMat<T>(other) { init_config(); }
87 unique_ptr<MetaMat<T>>
unique_copy()
override {
return std::make_unique<FullMatCUDA>(*
this); }
91 const auto NROW =
static_cast<int>(this->n_rows), NCOL =
static_cast<int>(this->n_cols);
95 if constexpr(std::is_same_v<T, float>) {
98 this->factored =
true;
99 d_A.copy_from(this->memptr(), stream);
100 cusolverDnSgetrf(handle, NROW, NCOL, d_A.get<
float>(), NROW, d_work.get<
float>(), d_ipiv.get(), info.get());
103 const cuda_ptr d_x{
sizeof(float),
static_cast<int>(B.n_elem)};
106 cusolverDnSgetrs(handle, CUBLAS_OP_N, NROW,
static_cast<int>(B.n_cols), d_A.get<
float>(), NROW, d_ipiv.get(), d_x.get<
float>(), NROW, info.get());
108 X.set_size(arma::size(B));
109 d_x.copy_to(X.memptr(), stream);
113 if(!this->factored) {
114 this->factored =
true;
115 this->s_memory = this->to_float();
116 d_A.copy_from(this->s_memory.memptr(), stream);
117 cusolverDnSgetrf(handle, NROW, NCOL, d_A.get<
float>(), NROW, d_work.get<
float>(), d_ipiv.get(), info.get());
120 const cuda_ptr d_x{
sizeof(float),
static_cast<int>(B.n_elem)};
122 X = arma::zeros(B.n_rows, B.n_cols);
124 mat full_residual = B;
126 std::uint8_t counter{0};
127 while(counter++ < this->setting.iterative_refinement) {
128 const auto multiplier = norm(full_residual);
129 if(multiplier < this->setting.tolerance)
break;
130 suanpan_debug(
"Mixed precision algorithm multiplier: {:.5E}.\n", multiplier);
132 auto residual = conv_to<fmat>::from(full_residual / multiplier);
133 d_x.copy_from(residual.memptr(), stream);
135 cusolverDnSgetrs(handle, CUBLAS_OP_N, NROW,
static_cast<int>(B.n_cols), d_A.get<
float>(), NROW, d_ipiv.get(), d_x.get<
float>(), NROW, info.get());
137 d_x.copy_to(residual.memptr(), stream);
138 full_residual = B - this->
operator*(X += multiplier * conv_to<mat>::from(residual));
143 if(!this->factored) {
144 this->factored =
true;
145 d_A.copy_from(this->memptr(), stream);
146 cusolverDnDgetrf(handle, NROW, NCOL, d_A.get<
double>(), NROW, d_work.get<
double>(), d_ipiv.get(), info.get());
149 const cuda_ptr d_x{
sizeof(float),
static_cast<int>(B.n_elem)};
152 cusolverDnDgetrs(handle, CUBLAS_OP_N, NROW,
static_cast<int>(B.n_cols), d_A.get<
double>(), NROW, d_ipiv.get(), d_x.get<
double>(), NROW, info.get());
154 X.set_size(arma::size(B));
155 d_x.copy_to(X.memptr(), stream);
158 info.copy_to(&flag, stream);