suanPan
operator_times.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef OPERATOR_TIMES_HPP
19 #define OPERATOR_TIMES_HPP
20 
21 #include "FullMat.hpp"
22 #include "SymmPackMat.hpp"
23 
24 template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const unique_ptr<MetaMat<T>>& M) {
25  if(nullptr == M) return nullptr;
26 
27  auto N = M->make_copy();
28  N->operator*=(value);
29  return N;
30 }
31 
32 template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const shared_ptr<MetaMat<T>>& M) {
33  if(nullptr == M) return nullptr;
34 
35  auto N = M->make_copy();
36  N->operator*=(value);
37  return N;
38 }
39 
40 template<sp_d T> const shared_ptr<MetaMat<T>>& operator*=(const shared_ptr<MetaMat<T>>& M, const T value) {
41  M->operator*=(value);
42  return M;
43 }
44 
45 template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
46  M->operator+=(A);
47  return M;
48 }
49 
50 template<sp_d T> const unique_ptr<MetaMat<T>>& operator+=(const unique_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
51  M->operator+=(A);
52  return M;
53 }
54 
55 template<sp_d DT, sp_i IT> const unique_ptr<MetaMat<DT>>& operator+=(const unique_ptr<MetaMat<DT>>& M, const triplet_form<DT, IT>& A) {
56  M->operator+=(A);
57  return M;
58 }
59 
60 template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
61  M->operator+=(std::forward<unique_ptr<MetaMat<T>>>(A));
62  return M;
63 }
64 
65 template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
66  if(nullptr == A && nullptr == B) return nullptr;
67 
68  if(nullptr != A) {
69  A->operator+=(std::forward<unique_ptr<MetaMat<T>>>(B));
70  return std::forward<unique_ptr<MetaMat<T>>>(A);
71  }
72 
73  return std::forward<unique_ptr<MetaMat<T>>>(B);
74 }
75 
76 template<sp_d T> unique_ptr<MetaMat<T>> operator+(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
77  B->operator+=(A);
78  return std::forward<unique_ptr<MetaMat<T>>>(B);
79 }
80 
81 template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
82  B->operator+=(A);
83  return std::forward<unique_ptr<MetaMat<T>>>(B);
84 }
85 
86 template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
87  M->operator-=(A);
88  return M;
89 }
90 
91 template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
92  M->operator-=(std::forward<unique_ptr<MetaMat<T>>>(A));
93  return M;
94 }
95 
96 template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
97  if(nullptr == A && nullptr == B) return nullptr;
98 
99  if(nullptr != A) {
100  A->operator-=(std::forward<unique_ptr<MetaMat<T>>>(B));
101  return std::forward<unique_ptr<MetaMat<T>>>(A);
102  }
103 
104  return std::forward<unique_ptr<MetaMat<T>>>(B);
105 }
106 
107 template<sp_d T> unique_ptr<MetaMat<T>> operator-(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
108  B->operator-=(A);
109  return std::forward<unique_ptr<MetaMat<T>>>(B);
110 }
111 
112 template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
113  B->operator-=(A);
114  return std::forward<unique_ptr<MetaMat<T>>>(B);
115 }
116 
117 template<sp_d T> Mat<T> operator*(const shared_ptr<MetaMat<T>>& M, const Mat<T>& A) {
118  if(nullptr == M) return nullptr;
119 
120  return M->operator*(A);
121 }
122 
123 template<sp_d T> Mat<T> operator*(const unique_ptr<MetaMat<T>>& M, const Mat<T>& A) {
124  if(nullptr == M) return nullptr;
125 
126  return M->operator*(A);
127 }
128 
129 template<sp_d T> Mat<T> operator*(const Mat<T>& A, const FullMat<T>& B) {
130  Mat<T> C(A.n_rows, A.n_cols);
131 
132  constexpr auto TRAN = 'N';
133 
134  const auto M = static_cast<int>(A.n_rows);
135  const auto N = static_cast<int>(B.n_cols);
136  const auto K = static_cast<int>(A.n_cols);
137  T ALPHA = 1.;
138  const auto LDA = M;
139  const auto LDB = K;
140  T BETA = 0.;
141  const auto LDC = M;
142 
143  if(std::is_same_v<T, float>) {
144  using E = float;
145  arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
146  }
147  else if(std::is_same_v<T, double>) {
148  using E = double;
149  arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
150  }
151 
152  return C;
153 }
154 
155 template<char S, const char T, sp_d T1> Mat<T1> spmm(const SymmPackMat<T1>& A, const Mat<T1>& B) {
156  Mat<T1> C;
157 
158  const auto SIDE = S;
159  const auto TRAN = T;
160  constexpr auto UPLO = 'U';
161 
162  auto M = static_cast<int>(A.n_rows);
163 
164  auto PT = 0;
165  if constexpr(SIDE == 'L') PT += 1;
166  if constexpr(TRAN == 'T') PT += 10;
167 
168  int N, LDC;
169 
170  switch(PT) {
171  case 0: // A*B
172  N = static_cast<int>(B.n_cols);
173  C.set_size(M, N);
174  LDC = M;
175  break;
176  case 1: // B*A
177  N = static_cast<int>(B.n_rows);
178  C.set_size(N, M);
179  LDC = N;
180  break;
181  case 10: // A*B**T
182  N = static_cast<int>(B.n_rows);
183  C.set_size(M, N);
184  LDC = M;
185  break;
186  case 11: // B**T*A
187  N = static_cast<int>(B.n_cols);
188  C.set_size(N, M);
189  LDC = N;
190  break;
191  default:
192  break;
193  }
194 
195  T1 ALPHA = 1.;
196  const auto LDB = static_cast<int>(B.n_rows);
197  T1 BETA = 0.;
198 
199  if(std::is_same_v<T1, float>) {
200  using E = float;
201  arma_fortran(arma_sspmm)(&SIDE, &UPLO, &TRAN, &M, &N, (E*)A.memptr(), (E*)&ALPHA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
202  }
203  else if(std::is_same_v<T1, double>) {
204  using E = double;
205  arma_fortran(arma_dspmm)(&SIDE, &UPLO, &TRAN, &M, &N, (E*)A.memptr(), (E*)&ALPHA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
206  }
207 
208  return C;
209 }
210 
211 template<sp_d T> Mat<T> operator*(const Mat<T>& A, const SymmPackMat<T>& B) { return spmm<'L', 'N'>(B, A); }
212 
213 template<sp_d T> Mat<T> operator*(const Op<Mat<T>, op_htrans>& A, const SymmPackMat<T>& B) { return spmm<'L', 'T'>(B, A.m); }
214 
215 template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, const triplet_form<T, IT>& M) {
216  auto N = M;
217  N *= value;
218  return N;
219 }
220 
221 template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, triplet_form<T, IT>&& M) {
222  M *= value;
223  return M;
224 }
225 
226 #endif // OPERATOR_TIMES_HPP
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
const uword n_cols
Definition: MetaMat.hpp:49
A SymmPackMat class that holds matrices.
Definition: SymmPackMat.hpp:35
Definition: triplet_form.hpp:62
const T * memptr() const override
Definition: DenseMat.hpp:149
Mat< T1 > spmm(const SymmPackMat< T1 > &A, const Mat< T1 > &B)
Definition: operator_times.hpp:155
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T >> &M, const T value)
Definition: operator_times.hpp:40
unique_ptr< MetaMat< T > > operator-(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:96
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T >> &M)
Definition: operator_times.hpp:24
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:45
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:86
unique_ptr< MetaMat< T > > operator+(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:65