suanPan
operator_times.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2023 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef OPERATOR_TIMES_HPP
19 #define OPERATOR_TIMES_HPP
20 
21 #include "FullMat.hpp"
22 #include "SymmPackMat.hpp"
23 
24 template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const unique_ptr<MetaMat<T>>& M) {
25  if(nullptr == M) return nullptr;
26 
27  auto N = M->make_copy();
28  N->operator*=(value);
29  return N;
30 }
31 
32 template<sp_d T> unique_ptr<MetaMat<T>> operator*(const T value, const shared_ptr<MetaMat<T>>& M) {
33  if(nullptr == M) return nullptr;
34 
35  auto N = M->make_copy();
36  N->operator*=(value);
37  return N;
38 }
39 
40 template<sp_d T> const shared_ptr<MetaMat<T>>& operator*=(const shared_ptr<MetaMat<T>>& M, const T value) {
41  M->operator*=(value);
42  return M;
43 }
44 
45 template<sp_d T> const unique_ptr<MetaMat<T>>& operator*=(const unique_ptr<MetaMat<T>>& M, const T value) {
46  M->operator*=(value);
47  return M;
48 }
49 
50 template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
51  M->operator+=(A);
52  return M;
53 }
54 
55 template<sp_d T> const unique_ptr<MetaMat<T>>& operator+=(const unique_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
56  M->operator+=(A);
57  return M;
58 }
59 
60 template<sp_d DT, sp_i IT> const unique_ptr<MetaMat<DT>>& operator+=(const unique_ptr<MetaMat<DT>>& M, const triplet_form<DT, IT>& A) {
61  M->operator+=(A);
62  return M;
63 }
64 
65 template<sp_d T> const shared_ptr<MetaMat<T>>& operator+=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
66  M->operator+=(std::forward<unique_ptr<MetaMat<T>>>(A));
67  return M;
68 }
69 
70 template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
71  if(nullptr == A && nullptr == B) return nullptr;
72 
73  if(nullptr != A) {
74  A->operator+=(std::forward<unique_ptr<MetaMat<T>>>(B));
75  return std::forward<unique_ptr<MetaMat<T>>>(A);
76  }
77 
78  return std::forward<unique_ptr<MetaMat<T>>>(B);
79 }
80 
81 template<sp_d T> unique_ptr<MetaMat<T>> operator+(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
82  B->operator+=(A);
83  return std::forward<unique_ptr<MetaMat<T>>>(B);
84 }
85 
86 template<sp_d T> unique_ptr<MetaMat<T>> operator+(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
87  B->operator+=(A);
88  return std::forward<unique_ptr<MetaMat<T>>>(B);
89 }
90 
91 template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, const shared_ptr<MetaMat<T>>& A) {
92  M->operator-=(A);
93  return M;
94 }
95 
96 template<sp_d T> const shared_ptr<MetaMat<T>>& operator-=(const shared_ptr<MetaMat<T>>& M, unique_ptr<MetaMat<T>>&& A) {
97  M->operator-=(std::forward<unique_ptr<MetaMat<T>>>(A));
98  return M;
99 }
100 
101 template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& A, unique_ptr<MetaMat<T>>&& B) {
102  if(nullptr == A && nullptr == B) return nullptr;
103 
104  if(nullptr != A) {
105  A->operator-=(std::forward<unique_ptr<MetaMat<T>>>(B));
106  return std::forward<unique_ptr<MetaMat<T>>>(A);
107  }
108 
109  return std::forward<unique_ptr<MetaMat<T>>>(B);
110 }
111 
112 template<sp_d T> unique_ptr<MetaMat<T>> operator-(const shared_ptr<MetaMat<T>>& A, unique_ptr<MetaMat<T>>&& B) {
113  B->operator-=(A);
114  return std::forward<unique_ptr<MetaMat<T>>>(B);
115 }
116 
117 template<sp_d T> unique_ptr<MetaMat<T>> operator-(unique_ptr<MetaMat<T>>&& B, const shared_ptr<MetaMat<T>>& A) {
118  B->operator-=(A);
119  return std::forward<unique_ptr<MetaMat<T>>>(B);
120 }
121 
122 template<sp_d T> Mat<T> operator*(const shared_ptr<MetaMat<T>>& M, const Mat<T>& A) {
123  if(nullptr == M) return nullptr;
124 
125  return M->operator*(A);
126 }
127 
128 template<sp_d T> Mat<T> operator*(const unique_ptr<MetaMat<T>>& M, const Mat<T>& A) {
129  if(nullptr == M) return nullptr;
130 
131  return M->operator*(A);
132 }
133 
134 template<sp_d T> Mat<T> operator*(const Mat<T>& A, const FullMat<T>& B) {
135  Mat<T> C(A.n_rows, A.n_cols);
136 
137  constexpr auto TRAN = 'N';
138 
139  const auto M = static_cast<int>(A.n_rows);
140  const auto N = static_cast<int>(B.n_cols);
141  const auto K = static_cast<int>(A.n_cols);
142  T ALPHA = 1.;
143  const auto LDA = M;
144  const auto LDB = K;
145  T BETA = 0.;
146  const auto LDC = M;
147 
148  if(std::is_same_v<T, float>) {
149  using E = float;
150  arma_fortran(arma_sgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
151  }
152  else if(std::is_same_v<T, double>) {
153  using E = double;
154  arma_fortran(arma_dgemm)(&TRAN, &TRAN, &M, &N, &K, (E*)&ALPHA, (E*)A.memptr(), &LDA, (E*)B.memptr(), &LDB, (E*)&BETA, (E*)C.memptr(), &LDC);
155  }
156 
157  return C;
158 }
159 
160 template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, const triplet_form<T, IT>& M) {
161  auto N = M;
162  N *= value;
163  return N;
164 }
165 
166 template<sp_d T, sp_i IT> triplet_form<T, IT> operator*(const T value, triplet_form<T, IT>&& M) {
167  M *= value;
168  return M;
169 }
170 
171 #endif // OPERATOR_TIMES_HPP
A FullMat class that holds matrices.
Definition: FullMat.hpp:35
A MetaMat class that holds matrices.
Definition: MetaMat.hpp:39
const uword n_cols
Definition: MetaMat.hpp:49
Definition: triplet_form.hpp:62
const T * memptr() const override
Definition: DenseMat.hpp:154
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T >> &M, const T value)
Definition: operator_times.hpp:40
unique_ptr< MetaMat< T > > operator-(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:101
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T >> &M)
Definition: operator_times.hpp:24
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:50
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:91
unique_ptr< MetaMat< T > > operator+(unique_ptr< MetaMat< T >> &&A, unique_ptr< MetaMat< T >> &&B)
Definition: operator_times.hpp:70