18 #ifndef TRIPLET_FORM_HPP 19 #define TRIPLET_FORM_HPP 24 template<sp_d data_t, sp_i index_t>
class csc_form;
25 template<sp_d data_t, sp_i index_t>
class csr_form;
33 const index_t*
const row_idx;
34 const index_t*
const col_idx;
37 csr_comparator(
const index_t*
const in_row_idx,
const index_t*
const in_col_idx)
39 , col_idx(in_col_idx) {}
41 bool operator()(
const index_t idx_a,
const index_t idx_b)
const {
42 if(row_idx[idx_a] == row_idx[idx_b])
return col_idx[idx_a] < col_idx[idx_b];
43 return row_idx[idx_a] < row_idx[idx_b];
48 const index_t*
const row_idx;
49 const index_t*
const col_idx;
52 csc_comparator(
const index_t*
const in_row_idx,
const index_t*
const in_col_idx)
54 , col_idx(in_col_idx) {}
56 bool operator()(
const index_t idx_a,
const index_t idx_b)
const {
57 if(col_idx[idx_a] == col_idx[idx_b])
return row_idx[idx_a] < row_idx[idx_b];
58 return col_idx[idx_a] < col_idx[idx_b];
63 const data_t bin = data_t(0);
65 using index_ptr = std::unique_ptr<index_t[]>;
66 using data_ptr = std::unique_ptr<data_t[]>;
68 index_ptr row_idx =
nullptr;
69 index_ptr col_idx =
nullptr;
70 data_ptr val_idx =
nullptr;
72 bool csc_sorted =
false;
73 bool csr_sorted =
false;
75 template<sp_d in_dt, sp_i in_it>
void copy_to(
const std::unique_ptr<in_it[]>& new_row_idx,
const std::unique_ptr<in_it[]>& new_col_idx,
const std::unique_ptr<in_dt[]>& new_val_idx,
const index_t
begin,
const index_t row_offset,
const index_t col_offset,
const data_t scalar)
const { copy_to(new_row_idx.get(), new_col_idx.get(), new_val_idx.get(),
begin, row_offset, col_offset, scalar); }
77 template<sp_d in_dt, sp_i in_it>
void copy_to(in_it*
const new_row_idx, in_it*
const new_col_idx, in_dt*
const new_val_idx,
const index_t begin,
const index_t row_offset,
const index_t col_offset,
const data_t scalar)
const {
78 suanpan_for(index_t(0), n_elem, [&](
const index_t I) {
79 new_row_idx[I +
begin] = in_it(row_idx[I] + row_offset);
80 new_col_idx[I +
begin] = in_it(col_idx[I] + col_offset);
81 new_val_idx[I +
begin] = in_dt(scalar * val_idx[I]);
88 void reserve(
const index_t in_elem) {
89 if(in_elem <= n_alloc)
return;
91 access::rw(n_alloc) = index_t(std::pow(2., std::ceil(std::log2(in_elem)) + 1));
93 index_ptr new_row_idx(
new index_t[n_alloc]);
94 index_ptr new_col_idx(
new index_t[n_alloc]);
95 data_ptr new_val_idx(
new data_t[n_alloc]);
97 copy_to(new_row_idx, new_col_idx, new_val_idx, 0, 0, 0, 1);
99 row_idx = std::move(new_row_idx);
100 col_idx = std::move(new_col_idx);
101 val_idx = std::move(new_val_idx);
104 void invalidate_sorting_flag() { csc_sorted = csr_sorted =
false; }
106 void condense(
bool =
false);
108 void populate_diagonal() {
109 const auto t_elem = std::min(n_rows, n_cols);
110 reserve(n_elem + t_elem);
111 suanpan_for(index_t(0), t_elem, [&](
const index_t I) {
112 row_idx[n_elem + I] = I;
113 col_idx[n_elem + I] = I;
114 val_idx[n_elem + I] = data_t(0);
116 access::rw(n_elem) += t_elem;
117 invalidate_sorting_flag();
124 template<sp_d in_dt, sp_i in_it>
friend class csc_form;
125 template<sp_d in_dt, sp_i in_it>
friend class csr_form;
137 triplet_form& operator=(triplet_form&&) noexcept;
138 ~triplet_form() = default;
140 triplet_form(const index_t in_rows, const index_t in_cols, const index_t in_elem = index_t(0))
142 , n_cols(in_cols) {
init(in_elem); }
144 template<sp_d in_dt>
explicit triplet_form(
const SpMat<in_dt>&);
147 [[nodiscard]]
const index_t*
row_mem()
const {
return row_idx.get(); }
149 [[nodiscard]]
const index_t*
col_mem()
const {
return col_idx.get(); }
151 [[nodiscard]]
const data_t*
val_mem()
const {
return val_idx.get(); }
153 [[nodiscard]] index_t*
row_mem() {
return row_idx.get(); }
155 [[nodiscard]] index_t*
col_mem() {
return col_idx.get(); }
157 [[nodiscard]] data_t*
val_mem() {
return val_idx.get(); }
159 [[nodiscard]] index_t
row(
const index_t I)
const {
return row_idx[I]; }
161 [[nodiscard]] index_t
col(
const index_t I)
const {
return col_idx[I]; }
163 [[nodiscard]] data_t
val(
const index_t I)
const {
return val_idx[I]; }
171 [[nodiscard]] data_t
max()
const {
173 return *std::max_element(val_idx.get(), val_idx.get() +
n_elem);
177 access::rw(n_elem) = 0;
178 invalidate_sorting_flag();
181 void init(
const index_t in_elem) {
186 void init(
const index_t in_rows,
const index_t in_cols,
const index_t in_elem) {
187 access::rw(n_rows) = in_rows;
188 access::rw(n_cols) = in_cols;
193 for(index_t I = 0; I <
n_elem; ++I)
if(row == row_idx[I] && col == col_idx[I])
return val_idx[I];
194 return access::rw(bin) = 0.;
197 data_t&
at(index_t, index_t);
226 void assemble(
const Mat<data_t>&,
const Col<uword>&);
229 template<sp_d in_dt, sp_i in_it>
void assemble(
const triplet_form<in_dt, in_it>& in_mat,
const std::vector<index_t>& row_shift,
const std::vector<index_t>& col_shift,
const std::vector<data_t>& scalar) {
230 suanpan_debug([&] {
if(scalar.size() != row_shift.size() || scalar.size() != col_shift.size())
throw invalid_argument(
"size mismatch detected"); });
232 reserve(n_elem + index_t(scalar.size()) * index_t(in_mat.
n_elem));
234 for(
size_t I = 0; I < scalar.size(); ++I)
assemble(in_mat, row_shift[I], col_shift[I], scalar[I]);
238 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
240 for(index_t I = 0; I <
n_elem; ++I) out_mat(row_idx[I]) += val_idx[I] * in_mat(col_idx[I]);
246 Mat<data_t> out_mat(in_mat.n_rows, in_mat.n_cols, fill::zeros);
248 for(index_t I = 0; I <
n_elem; ++I) out_mat.row(row_idx[I]) += val_idx[I] * in_mat.row(col_idx[I]);
260 return copy += in_mat;
265 return copy -= in_mat;
279 auto last_row = row_idx[0], last_col = col_idx[0];
281 sp_i auto current_pos = index_t(0);
282 sp_d auto last_sum = data_t(0);
284 auto populate = [&] {
286 row_idx[current_pos] = last_row;
287 col_idx[current_pos] = last_col;
288 val_idx[current_pos] = last_sum;
290 last_sum = data_t(0);
293 for(index_t I = 0; I <
n_elem; ++I) {
294 if(last_row != row_idx[I] || last_col != col_idx[I]) {
296 last_row = row_idx[I];
297 last_col = col_idx[I];
299 last_sum += val_idx[I];
304 access::rw(n_elem) = current_pos;
308 : csc_sorted{in_mat.csc_sorted}
309 , csr_sorted{in_mat.csr_sorted}
312 init(in_mat.n_alloc);
313 in_mat.copy_to(row_idx, col_idx, val_idx, 0, 0, 0, 1);
314 access::rw(
n_elem) = in_mat.n_elem;
318 : row_idx{std::move(in_mat.row_idx)}
319 , col_idx{std::move(in_mat.col_idx)}
320 , val_idx{std::move(in_mat.val_idx)}
321 , csc_sorted{in_mat.csc_sorted}
322 , csr_sorted{in_mat.csr_sorted}
329 if(
this == &in_mat)
return *
this;
330 csc_sorted = in_mat.csc_sorted;
331 csr_sorted = in_mat.csr_sorted;
335 in_mat.copy_to(row_idx, col_idx, val_idx, 0, 0, 0, 1);
341 if(
this == &in_mat)
return *
this;
342 csc_sorted = in_mat.csc_sorted;
343 csr_sorted = in_mat.csr_sorted;
344 access::rw(
n_rows) = in_mat.n_rows;
345 access::rw(
n_cols) = in_mat.n_cols;
346 access::rw(
n_elem) = in_mat.n_elem;
347 access::rw(
n_alloc) = in_mat.n_alloc;
348 row_idx = std::move(in_mat.row_idx);
349 col_idx = std::move(in_mat.col_idx);
350 val_idx = std::move(in_mat.val_idx);
357 init(index_t(in_mat.n_nonzero));
359 for(
auto I = in_mat.begin(); I != in_mat.end(); ++I)
at(I.row(), I.col()) = *I;
363 : csc_sorted(in_mat.csc_sorted)
364 , csr_sorted(in_mat.csr_sorted)
374 const sp_i auto shift = index_t(base);
376 in_mat.copy_to(row_idx, col_idx, val_idx, 0, shift, shift, 1);
384 invalidate_sorting_flag();
388 return val_idx[access::rw(
n_elem)++] = data_t(0);
392 suanpan_info(
"A sparse matrix in triplet form with size of %u by %u, the density of %.3f%%.\n", static_cast<unsigned>(
n_rows), static_cast<unsigned>(
n_cols), static_cast<double>(
n_elem) / static_cast<double>(
n_rows) / static_cast<double>(
n_cols) * 1E2);
393 if(
n_elem > index_t(1000)) {
394 suanpan_info(
"Not going to print all elements as more than 1000 elements exist.\n");
397 for(index_t I = 0; I <
n_elem; ++I)
suanpan_info(
"(%3u, %3u) ===> %+.10E\n", static_cast<unsigned>(row_idx[I]), static_cast<unsigned>(col_idx[I]), val_idx[I]);
401 if(csr_sorted)
return;
403 std::vector<index_t> index(
n_elem);
404 std::iota(index.begin(), index.end(), index_t(0));
408 index_ptr new_row_idx(
new index_t[
n_alloc]);
409 index_ptr new_col_idx(
new index_t[n_alloc]);
410 data_ptr new_val_idx(
new data_t[n_alloc]);
413 new_row_idx[I] = row_idx[index[I]];
414 new_col_idx[I] = col_idx[index[I]];
415 new_val_idx[I] = val_idx[index[I]];
418 row_idx = std::move(new_row_idx);
419 col_idx = std::move(new_col_idx);
420 val_idx = std::move(new_val_idx);
427 if(csc_sorted)
return;
429 std::vector<index_t> index(
n_elem);
430 std::iota(index.begin(), index.end(), index_t(0));
434 index_ptr new_row_idx(
new index_t[
n_alloc]);
435 index_ptr new_col_idx(
new index_t[n_alloc]);
436 data_ptr new_val_idx(
new data_t[n_alloc]);
439 new_row_idx[I] = row_idx[index[I]];
440 new_col_idx[I] = col_idx[index[I]];
441 new_val_idx[I] = val_idx[index[I]];
444 row_idx = std::move(new_row_idx);
445 col_idx = std::move(new_col_idx);
446 val_idx = std::move(new_val_idx);
453 if(in_mat.empty())
return;
455 invalidate_sorting_flag();
457 const auto t_elem =
n_elem + index_t(in_mat.n_elem);
461 suanpan_for(static_cast<uword>(0), in_mat.n_elem, [&](
const uword I) {
462 row_idx[n_elem + I] = index_t(in_dof(I % in_dof.n_elem));
463 col_idx[n_elem + I] = index_t(in_dof(I / in_dof.n_elem));
464 val_idx[n_elem + I] = in_mat(I);
467 access::rw(
n_elem) = t_elem;
473 invalidate_sorting_flag();
479 in_mat.copy_to(row_idx, col_idx, val_idx,
n_elem, row_shift, col_shift, scalar);
481 access::rw(
n_elem) = t_elem;
494 suanpan_for_each(copy.val_idx.get(), copy.val_idx.get() + copy.
n_elem, [=](data_t& I) { I *= data_t(scalar); });
507 suanpan_for_each(copy.val_idx.get(), copy.val_idx.get() + copy.
n_elem, [=](data_t& I) { I /= data_t(scalar); });
538 invalidate_sorting_flag();
544 in_mat.copy_to(row_idx, col_idx, val_idx,
n_elem, 0, 0, 1);
546 access::rw(
n_elem) = t_elem;
554 invalidate_sorting_flag();
560 in_mat.copy_to(row_idx, col_idx, val_idx,
n_elem, 0, 0, -1);
562 access::rw(
n_elem) = t_elem;
568 auto out_mat = *
this;
570 suanpan_for(index_t(0), out_mat.n_elem, [&](
const index_t I) { out_mat.val_idx[I] *= out_mat.row(I) == out_mat.col(I); });
576 auto out_mat = *
this;
578 suanpan_for(index_t(0), out_mat.n_elem, [&](
const index_t I) { out_mat.val_idx[I] *= out_mat.row(I) < out_mat.col(I); });
584 auto out_mat = *
this;
586 suanpan_for(index_t(0), out_mat.n_elem, [&](
const index_t I) { out_mat.val_idx[I] *= out_mat.col(I) < out_mat.row(I); });
599 return std::forward<triplet_form<data_t, index_t>>(mat_a);
604 return std::forward<triplet_form<data_t, index_t>>(mat_b);
609 return std::forward<triplet_form<data_t, index_t>>(mat_a);
bool operator()(const index_t idx_a, const index_t idx_b) const
Definition: triplet_form.hpp:41
const shared_ptr< MetaMat< T > > & operator*=(const shared_ptr< MetaMat< T >> &M, const T value)
Definition: operator_times.hpp:40
concept sp_i
Definition: suanPan.h:232
std::enable_if_t<!std::numeric_limits< T >::is_integer, bool > approx_equal(T x, T y, int ulp=2)
Definition: utility.h:46
unique_ptr< MetaMat< T > > operator*(const T value, const unique_ptr< MetaMat< T >> &M)
Definition: operator_times.hpp:24
void suanpan_info(const char *M,...)
Definition: print.cpp:47
#define suanpan_for_each
Definition: suanPan.h:180
csr_comparator(const index_t *const in_row_idx, const index_t *const in_col_idx)
Definition: triplet_form.hpp:37
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
Definition: triplet_form.hpp:47
const shared_ptr< MetaMat< T > > & operator-=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:86
void suanpan_for(const IT start, const IT end, F &&FN)
Definition: utility.h:24
#define suanpan_sort
Definition: suanPan.h:179
csc_comparator(const index_t *const in_row_idx, const index_t *const in_col_idx)
Definition: triplet_form.hpp:52
Storage< T >::iterator begin(Storage< T > &S)
Definition: Storage.hpp:200
const shared_ptr< MetaMat< T > > & operator+=(const shared_ptr< MetaMat< T >> &M, const shared_ptr< MetaMat< T >> &A)
Definition: operator_times.hpp:45
Definition: triplet_form.hpp:32
concept sp_d
Definition: suanPan.h:231
bool operator()(const index_t idx_a, const index_t idx_b) const
Definition: triplet_form.hpp:56