suanPan
csr_form.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef CSR_FORM_HPP
19 #define CSR_FORM_HPP
20 
21 #include "triplet_form.hpp"
22 
23 template<sp_d data_t, sp_i index_t> class csc_form;
24 
25 template<sp_d data_t, sp_i index_t> class csr_form final {
26  const data_t bin = data_t(0);
27 
28  using index_ptr = std::unique_ptr<index_t[]>;
29  using data_ptr = std::unique_ptr<data_t[]>;
30 
31  index_ptr row_ptr = nullptr; // index storage
32  index_ptr col_idx = nullptr; // index storage
33  data_ptr val_idx = nullptr; // value storage
34 
35  template<sp_d in_dt, sp_i in_it> void copy_to(in_it* const new_row_ptr, in_it* const new_col_idx, in_dt* const new_val_idx) const {
36  suanpan_for(index_t(0), n_rows + 1, [&](const index_t I) { new_row_ptr[I] = in_it(row_ptr[I]); });
37  suanpan_for(index_t(0), n_elem, [&](const index_t I) {
38  new_col_idx[I] = in_it(col_idx[I]);
39  new_val_idx[I] = in_dt(val_idx[I]);
40  });
41  }
42 
43  void init(const index_t in_elem) {
44  row_ptr = std::move(index_ptr(new index_t[n_rows + 1]));
45  col_idx = std::move(index_ptr(new index_t[in_elem]));
46  val_idx = std::move(data_ptr(new data_t[in_elem]));
47  }
48 
49 public:
50  const index_t n_rows = 0;
51  const index_t n_cols = 0;
52  const index_t n_elem = 0;
53 
54  csr_form() = default;
55  csr_form(const csr_form&);
56  csr_form(csr_form&&) noexcept;
57  csr_form& operator=(const csr_form&);
58  csr_form& operator=(csr_form&&) noexcept;
59  ~csr_form() = default;
60 
61  [[nodiscard]] const index_t* row_mem() const { return row_ptr.get(); }
62 
63  [[nodiscard]] const index_t* col_mem() const { return col_idx.get(); }
64 
65  [[nodiscard]] const data_t* val_mem() const { return val_idx.get(); }
66 
67  [[nodiscard]] index_t* row_mem() { return row_ptr.get(); }
68 
69  [[nodiscard]] index_t* col_mem() { return col_idx.get(); }
70 
71  [[nodiscard]] data_t* val_mem() { return val_idx.get(); }
72 
73  [[nodiscard]] data_t max() const {
74  if(0 == n_elem) return data_t(0);
75  return *std::max_element(val_idx.get(), val_idx.get() + n_elem);
76  }
77 
78  void print() const;
79 
80  template<sp_d T2> csr_form<data_t, index_t> operator*(const T2 scalar) const {
81  csr_form<data_t, index_t> copy = *this;
82  return copy *= scalar;
83  }
84 
85  template<sp_d T2> csr_form<data_t, index_t> operator/(const T2 scalar) const {
86  csr_form<data_t, index_t> copy = *this;
87  return copy /= scalar;
88  }
89 
90  template<sp_d T2> csr_form<data_t, index_t>& operator*=(const T2 scalar) {
91  suanpan_for_each(val_idx.get(), val_idx.get() + n_elem, [=](data_t& I) { I *= data_t(scalar); });
92  return *this;
93  }
94 
95  template<sp_d T2> csr_form<data_t, index_t>& operator/=(const T2 scalar) {
96  suanpan_for_each(val_idx.get(), val_idx.get() + n_elem, [=](data_t& I) { I /= data_t(scalar); });
97  return *this;
98  }
99 
100  template<sp_d in_dt, sp_i in_it> explicit csr_form(triplet_form<in_dt, in_it>&, SparseBase = SparseBase::ZERO, bool = false);
101 
102  template<sp_d in_dt, sp_i in_it> csr_form& operator=(triplet_form<in_dt, in_it>&);
103 
104  const data_t& operator()(const index_t in_row, const index_t in_col) const {
105  if(in_row < n_rows && in_col < n_cols) for(auto I = row_ptr[in_row]; I < row_ptr[in_row + 1]; ++I) if(in_col == col_idx[I]) return val_idx[I];
106  return access::rw(bin) = data_t(0);
107  }
108 
109  Mat<data_t> operator*(const Col<data_t>& in_mat) const {
110  Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, 1);
111 
112  suanpan_for(index_t(0), n_rows, [&](const index_t I) { for(auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat(I) += val_idx[J] * in_mat(col_idx[J]); });
113 
114  return out_mat;
115  }
116 
117  Mat<data_t> operator*(const Mat<data_t>& in_mat) const {
118  Mat<data_t> out_mat = arma::zeros<Mat<data_t>>(in_mat.n_rows, in_mat.n_cols);
119 
120  suanpan_for(index_t(0), n_rows, [&](const index_t I) { for(auto J = row_ptr[I]; J < row_ptr[I + 1]; ++J) out_mat.row(I) += val_idx[J] * in_mat.row(col_idx[J]); });
121 
122  return out_mat;
123  }
124 };
125 
126 template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>::csr_form(const csr_form& in_mat)
127  : n_rows{in_mat.n_rows}
128  , n_cols{in_mat.n_cols}
129  , n_elem{in_mat.n_elem} {
130  init(n_elem);
131  in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
132 }
133 
134 template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>::csr_form(csr_form&& in_mat) noexcept
135  : row_ptr{std::move(in_mat.row_ptr)}
136  , col_idx{std::move(in_mat.col_idx)}
137  , val_idx{std::move(in_mat.val_idx)}
138  , n_rows{in_mat.n_rows}
139  , n_cols{in_mat.n_cols}
140  , n_elem{in_mat.n_elem} {}
141 
142 template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(const csr_form& in_mat) {
143  if(this == &in_mat) return *this;
144  access::rw(n_rows) = in_mat.n_rows;
145  access::rw(n_cols) = in_mat.n_cols;
146  access::rw(n_elem) = in_mat.n_elem;
147  init(n_elem);
148  in_mat.copy_to(row_ptr.get(), col_idx.get(), val_idx.get());
149  return *this;
150 }
151 
152 template<sp_d data_t, sp_i index_t> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(csr_form&& in_mat) noexcept {
153  if(this == &in_mat) return *this;
154  access::rw(n_rows) = in_mat.n_rows;
155  access::rw(n_cols) = in_mat.n_cols;
156  access::rw(n_elem) = in_mat.n_elem;
157  row_ptr = std::move(in_mat.row_ptr);
158  col_idx = std::move(in_mat.col_idx);
159  val_idx = std::move(in_mat.val_idx);
160  return *this;
161 }
162 
163 template<sp_d data_t, sp_i index_t> void csr_form<data_t, index_t>::print() const {
164  suanpan_info("A sparse matrix in triplet form with size of %u by %u, the sparsity of %.3f%%.\n", static_cast<unsigned>(n_rows), static_cast<unsigned>(n_cols), 1E2 - static_cast<double>(n_elem) / static_cast<double>(n_rows) / static_cast<double>(n_cols) * 1E2);
165  if(n_elem > index_t(1000)) {
166  suanpan_info("more than 1000 elements exist.\n");
167  return;
168  }
169 
170  index_t c_idx = 1;
171  for(index_t I = 0; I < n_elem; ++I) {
172  if(I >= row_ptr[c_idx]) ++c_idx;
173  suanpan_info("(%3u, %3u) ===> %+.4E\n", static_cast<unsigned>(c_idx) - 1, static_cast<unsigned>(col_idx[I]), val_idx[I]);
174  }
175 }
176 
177 template<sp_d data_t, sp_i index_t> template<sp_d in_dt, sp_i in_it> csr_form<data_t, index_t>::csr_form(triplet_form<in_dt, in_it>& in_mat, const SparseBase base, const bool full)
178  : n_rows(index_t(in_mat.n_rows))
179  , n_cols(index_t(in_mat.n_cols)) {
180  if(full) in_mat.full_csr_condense();
181  else in_mat.csr_condense();
182 
183  init(access::rw(n_elem) = index_t(in_mat.n_elem));
184 
185  const sp_i auto shift = index_t(base);
186 
187  suanpan_for(in_it(0), in_mat.n_elem, [&](const in_it I) {
188  col_idx[I] = index_t(in_mat.col_idx[I]) + shift;
189  val_idx[I] = data_t(in_mat.val_idx[I]);
190  });
191 
192  in_it current_pos = 0, current_row = 0;
193 
194  while(current_pos < in_mat.n_elem)
195  if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
196  else row_ptr[current_row++] = index_t(current_pos) + shift;
197 
198  row_ptr[0] = shift;
199  row_ptr[n_rows] = n_elem + shift;
200 }
201 
202 template<sp_d data_t, sp_i index_t> template<sp_d in_dt, sp_i in_it> csr_form<data_t, index_t>& csr_form<data_t, index_t>::operator=(triplet_form<in_dt, in_it>& in_mat) {
203  in_mat.csr_condense();
204 
205  access::rw(n_rows) = index_t(in_mat.n_rows);
206  access::rw(n_cols) = index_t(in_mat.n_cols);
207 
208  init(access::rw(n_elem) = index_t(in_mat.n_elem));
209 
210  suanpan_for(in_it(0), in_mat.n_elem, [&](const in_it I) {
211  col_idx[I] = index_t(in_mat.col_idx[I]);
212  val_idx[I] = data_t(in_mat.val_idx[I]);
213  });
214 
215  in_it current_pos = 0, current_row = 0;
216 
217  while(current_pos < in_mat.n_elem)
218  if(in_mat.row_idx[current_pos] < current_row) ++current_pos;
219  else row_ptr[current_row++] = index_t(current_pos);
220 
221  row_ptr[0] = index_t(0);
222  row_ptr[n_rows] = n_elem;
223 
224  return *this;
225 }
226 
227 #endif
Definition: csc_form.hpp:25
Definition: csr_form.hpp:25
csr_form< data_t, index_t > operator*(const T2 scalar) const
Definition: csr_form.hpp:80
index_t * col_mem()
Definition: csr_form.hpp:69
data_t max() const
Definition: csr_form.hpp:73
csr_form()=default
index_t * row_mem()
Definition: csr_form.hpp:67
const index_t * col_mem() const
Definition: csr_form.hpp:63
const data_t & operator()(const index_t in_row, const index_t in_col) const
Definition: csr_form.hpp:104
const index_t n_rows
Definition: csr_form.hpp:50
csr_form< data_t, index_t > & operator/=(const T2 scalar)
Definition: csr_form.hpp:95
const index_t n_cols
Definition: csr_form.hpp:51
csr_form< data_t, index_t > operator/(const T2 scalar) const
Definition: csr_form.hpp:85
const index_t * row_mem() const
Definition: csr_form.hpp:61
csr_form & operator=(const csr_form &)
Definition: csr_form.hpp:142
void print() const
Definition: csr_form.hpp:163
data_t * val_mem()
Definition: csr_form.hpp:71
const data_t * val_mem() const
Definition: csr_form.hpp:65
const index_t n_elem
Definition: csr_form.hpp:52
Mat< data_t > operator*(const Mat< data_t > &in_mat) const
Definition: csr_form.hpp:117
csr_form< data_t, index_t > & operator*=(const T2 scalar)
Definition: csr_form.hpp:90
csr_form & operator=(triplet_form< in_dt, in_it > &)
Mat< data_t > operator*(const Col< data_t > &in_mat) const
Definition: csr_form.hpp:109
Definition: triplet_form.hpp:62
const index_t n_rows
Definition: triplet_form.hpp:128
void full_csr_condense()
Definition: triplet_form.hpp:214
void csr_condense()
Definition: triplet_form.hpp:204
const index_t n_cols
Definition: triplet_form.hpp:129
const index_t n_elem
Definition: triplet_form.hpp:130
void suanpan_info(const char *M,...)
Definition: print.cpp:47
#define suanpan_for_each
Definition: suanPan.h:176
concept sp_i
Definition: suanPan.h:228
SparseBase
Definition: triplet_form.hpp:27
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24