18 #ifndef OPERATOR_TIMES_HPP
19 #define OPERATOR_TIMES_HPP
25 if(
nullptr ==
M)
return nullptr;
27 auto N =
M->make_copy();
33 if(
nullptr ==
M)
return nullptr;
35 auto N =
M->make_copy();
40 template<sp_d T>
const shared_ptr<MetaMat<T>>&
operator*=(
const shared_ptr<
MetaMat<T>>& M,
const T value) {
61 M->operator+=(std::forward<unique_ptr<MetaMat<T>>>(
A));
66 if(
nullptr ==
A &&
nullptr == B)
return nullptr;
69 A->operator+=(std::forward<unique_ptr<MetaMat<T>>>(B));
70 return std::forward<unique_ptr<MetaMat<T>>>(
A);
73 return std::forward<unique_ptr<MetaMat<T>>>(B);
78 return std::forward<unique_ptr<MetaMat<T>>>(B);
83 return std::forward<unique_ptr<MetaMat<T>>>(B);
92 M->operator-=(std::forward<unique_ptr<MetaMat<T>>>(
A));
97 if(
nullptr ==
A &&
nullptr == B)
return nullptr;
100 A->operator-=(std::forward<unique_ptr<MetaMat<T>>>(B));
101 return std::forward<unique_ptr<MetaMat<T>>>(
A);
104 return std::forward<unique_ptr<MetaMat<T>>>(B);
109 return std::forward<unique_ptr<MetaMat<T>>>(B);
114 return std::forward<unique_ptr<MetaMat<T>>>(B);
118 if(
nullptr ==
M)
return nullptr;
120 return M->operator*(
A);
124 if(
nullptr ==
M)
return nullptr;
126 return M->operator*(
A);
130 Mat<T> C(
A.n_rows,
A.n_cols);
132 constexpr
auto TRAN =
'N';
134 const auto M =
static_cast<int>(
A.n_rows);
135 const auto N =
static_cast<int>(B.
n_cols);
136 const auto K =
static_cast<int>(
A.n_cols);
143 if(std::is_same_v<T, float>) {
145 arma_fortran(arma_sgemm)(&TRAN, &TRAN, &
M, &
N, &
K, (
E*)&ALPHA, (
E*)
A.memptr(), &LDA, (
E*)B.
memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
147 else if(std::is_same_v<T, double>) {
149 arma_fortran(arma_dgemm)(&TRAN, &TRAN, &
M, &
N, &
K, (
E*)&ALPHA, (
E*)
A.memptr(), &LDA, (
E*)B.
memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
160 constexpr
auto UPLO =
'U';
162 auto M =
static_cast<int>(
A.n_rows);
165 if constexpr(SIDE ==
'L') PT += 1;
166 if constexpr(TRAN ==
'T') PT += 10;
172 N =
static_cast<int>(B.n_cols);
177 N =
static_cast<int>(B.n_rows);
182 N =
static_cast<int>(B.n_rows);
187 N =
static_cast<int>(B.n_cols);
196 const auto LDB =
static_cast<int>(B.n_rows);
199 if(std::is_same_v<T1, float>) {
201 arma_fortran(arma_sspmm)(&SIDE, &UPLO, &TRAN, &
M, &
N, (
E*)
A.memptr(), (
E*)&ALPHA, (
E*)B.memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
203 else if(std::is_same_v<T1, double>) {
205 arma_fortran(arma_dspmm)(&SIDE, &UPLO, &TRAN, &
M, &
N, (
E*)
A.memptr(), (
E*)&ALPHA, (
E*)B.memptr(), &LDB, (
E*)&BETA, (
E*)C.memptr(), &LDC);
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T >> &M, const T value)
Definition: operator_times.hpp:40
unique_ptr< MetaMat< T > > operator-(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:96
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T >> &M)
Definition: operator_times.hpp:24
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:45
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:86
unique_ptr< MetaMat< T > > operator+(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:65