suanPan
IterativeSolver.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2  * Copyright (C) 2017-2022 Theodore Chang
3  *
4  * This program is free software: you can redistribute it and/or modify
5  * it under the terms of the GNU General Public License as published by
6  * the Free Software Foundation, either version 3 of the License, or
7  * (at your option) any later version.
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12  * GNU General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program. If not, see <http://www.gnu.org/licenses/>.
16  ******************************************************************************/
17 
18 #ifndef ITERATIVESOLVER_HPP
19 #define ITERATIVESOLVER_HPP
20 
21 #include <Toolbox/utility.h>
22 #include "SolverSetting.hpp"
23 
24 template<typename T, typename data_t> concept HasEvaluate = requires(T* t, const Col<data_t>& x) { { t->evaluate(x) } -> std::convertible_to<Col<data_t>> ; };
25 
26 template<sp_d data_t, HasEvaluate<data_t> System> int GMRES(const System* system, Col<data_t>& x, const Col<data_t>& b, SolverSetting<data_t>& setting) {
27  constexpr sp_d auto ZERO = data_t(0);
28  constexpr sp_d auto ONE = data_t(1);
29 
30  const auto& conditioner = setting.preconditioner;
31 
32  auto generate_rotation = [](const data_t dx, const data_t dy, data_t& cs, data_t& sn) -> void {
33  if(suanpan::approx_equal(dy, ZERO)) {
34  cs = ONE;
35  sn = ZERO;
36  }
37  else if(std::fabs(dy) > std::fabs(dx)) {
38  const data_t fraction = dx / dy;
39  sn = ONE / std::sqrt(ONE + fraction * fraction);
40  cs = fraction * sn;
41  }
42  else {
43  const data_t fraction = dy / dx;
44  cs = ONE / std::sqrt(ONE + fraction * fraction);
45  sn = fraction * cs;
46  }
47  };
48 
49  auto apply_rotation = [](data_t& dx, data_t& dy, const data_t cs, const data_t sn) -> void {
50  const data_t factor = cs * dx + sn * dy;
51  dy = cs * dy - sn * dx;
52  dx = factor;
53  };
54 
55  if(x.empty()) x = conditioner->apply(b);
56  else x.zeros(arma::size(b));
57 
58  const auto mp = setting.restart + 1;
59 
60  Mat<data_t> hessenberg(mp, setting.restart, fill::zeros);
61 
62  auto counter = 1;
63  data_t beta, residual;
64  Col<data_t> s(mp, fill::none), cs(mp, fill::none), sn(mp, fill::none), r;
65 
66  auto norm_b = arma::norm(conditioner->apply(b));
67  if(suanpan::approx_equal(norm_b, ZERO)) norm_b = ONE;
68 
69  auto stop_criterion = [&] {
70  residual = (beta = arma::norm(r = conditioner->apply(b - system->evaluate(x)))) / norm_b;
71  suanpan_debug("GMRES solver local residual: %.4e.\n", residual);
72  if(residual > setting.tolerance) return SUANPAN_FAIL;
73  setting.tolerance = residual;
74  setting.max_iteration = counter;
75  return SUANPAN_SUCCESS;
76  };
77 
78  if(SUANPAN_SUCCESS == stop_criterion()) return SUANPAN_SUCCESS;
79 
80  Mat<data_t> v(b.n_rows, mp, fill::none);
81 
82  auto update = [&](const int k) -> Col<data_t> {
83  Col<data_t> y = s.head(k + 1llu);
84 
85  for(auto i = k; i >= 0; --i) {
86  y(i) /= hessenberg(i, i);
87  y.head(i) -= hessenberg.col(i).head(i) * y(i);
88  }
89 
90  return v.head_cols(k + 1llu) * y;
91  };
92 
93  while(counter <= setting.max_iteration) {
94  v.col(0) = r / beta;
95  s.zeros();
96  s(0) = beta;
97 
98  for(auto i = 0, j = 1; i < setting.restart && counter <= setting.max_iteration; ++i, ++j, ++counter) {
99  auto w = conditioner->apply(system->evaluate(v.col(i)));
100  for(auto k = 0; k <= i; ++k) w -= (hessenberg(k, i) = arma::dot(w, v.col(k))) * v.col(k);
101  v.col(j) = w / (hessenberg(j, i) = arma::norm(w));
102 
103  for(auto k = 0; k < i; ++k) apply_rotation(hessenberg(k, i), hessenberg(k + 1llu, i), cs(k), sn(k));
104 
105  generate_rotation(hessenberg(i, i), hessenberg(j, i), cs(i), sn(i));
106  apply_rotation(hessenberg(i, i), hessenberg(j, i), cs(i), sn(i));
107  apply_rotation(s(i), s(j), cs(i), sn(i));
108 
109  residual = std::fabs(s(j)) / norm_b;
110  suanpan_debug("GMRES solver local residual: %.4e.\n", residual);
111  if(residual < setting.tolerance) {
112  x += update(i);
113  setting.tolerance = residual;
114  setting.max_iteration = counter;
115  return SUANPAN_SUCCESS;
116  }
117  }
118 
119  x += update(setting.restart - 1);
120  if(SUANPAN_SUCCESS == stop_criterion()) return SUANPAN_SUCCESS;
121  }
122 
123  setting.tolerance = residual;
124  return SUANPAN_FAIL;
125 }
126 
127 template<sp_d data_t, HasEvaluate<data_t> System> int BiCGSTAB(const System* system, Col<data_t>& x, const Col<data_t>& b, SolverSetting<data_t>& setting) {
128  constexpr sp_d auto ZERO = data_t(0);
129  constexpr sp_d auto ONE = data_t(1);
130 
131  const auto& conditioner = setting.preconditioner;
132 
133  data_t norm_b = arma::norm(b);
134  if(suanpan::approx_equal(norm_b, ZERO)) norm_b = ONE;
135 
136  if(x.empty()) x = conditioner->apply(b);
137  else x.zeros(arma::size(b));
138 
139  Col<data_t> r = b - system->evaluate(x);
140  const auto initial_r = r;
141 
142  data_t residual = arma::norm(r) / norm_b;
143  suanpan_debug("BiCGSTAB solver local residual: %.4e.\n", residual);
144  if(residual < setting.tolerance) {
145  setting.tolerance = residual;
146  setting.max_iteration = 0;
147  return 0;
148  }
149 
150  sp_d auto pre_rho = ZERO, alpha = ZERO, omega = ZERO;
151  Col<data_t> v, p;
152 
153  for(auto i = 1; i <= setting.max_iteration; ++i) {
154  const auto rho = arma::dot(initial_r, r);
155  if(suanpan::approx_equal(rho, ZERO)) {
156  setting.tolerance = residual;
157  setting.max_iteration = i;
158  return SUANPAN_FAIL;
159  }
160 
161  if(1 == i) p = r;
162  else p = r + rho / pre_rho * alpha / omega * (p - omega * v);
163 
164  const auto phat = conditioner->apply(p);
165  v = system->evaluate(phat);
166  alpha = rho / arma::dot(initial_r, v);
167  const Col<data_t> s = r - alpha * v;
168 
169  suanpan_debug("BiCGSTAB solver local residual: %.4e.\n", residual = arma::norm(s) / norm_b);
170  if(residual < setting.tolerance) {
171  x += alpha * phat;
172  setting.tolerance = residual;
173  setting.max_iteration = i;
174  return SUANPAN_SUCCESS;
175  }
176 
177  const auto shat = conditioner->apply(s);
178  const Col<data_t> t = system->evaluate(shat);
179  omega = arma::dot(t, s) / arma::dot(t, t);
180  x += alpha * phat + omega * shat;
181  r = s - omega * t;
182 
183  pre_rho = rho;
184 
185  suanpan_debug("BiCGSTAB solver local residual: %.4e.\n", residual = arma::norm(r) / norm_b);
186  if(residual < setting.tolerance) {
187  setting.tolerance = residual;
188  setting.max_iteration = i;
189  return SUANPAN_SUCCESS;
190  }
191 
192  if(suanpan::approx_equal(omega, ZERO)) {
193  setting.tolerance = residual;
194  setting.max_iteration = i;
195  return SUANPAN_FAIL;
196  }
197  }
198 
199  setting.tolerance = residual;
200  return SUANPAN_FAIL;
201 }
202 
203 #endif
int BiCGSTAB(const System *system, Col< data_t > &x, const Col< data_t > &b, SolverSetting< data_t > &setting)
Definition: IterativeSolver.hpp:127
int GMRES(const System *system, Col< data_t > &x, const Col< data_t > &b, SolverSetting< data_t > &setting)
Definition: IterativeSolver.hpp:26
concept HasEvaluate
Definition: IterativeSolver.hpp:24
std::enable_if_t<!std::numeric_limits< T >::is_integer, bool > approx_equal(T x, T y, int ulp=2)
Definition: utility.h:46
double norm(const vec &)
Definition: tensorToolbox.cpp:302
void suanpan_debug(const char *M,...)
Definition: print.cpp:64
Definition: SolverSetting.hpp:40
data_t tolerance
Definition: SolverSetting.hpp:43
int restart
Definition: SolverSetting.hpp:41
int max_iteration
Definition: SolverSetting.hpp:42
Preconditioner< data_t > * preconditioner
Definition: SolverSetting.hpp:48
Definition: TestSolver.h:6
vec evaluate(const vec &x) const
Definition: TestSolver.h:12
concept sp_d
Definition: suanPan.h:227
constexpr auto SUANPAN_SUCCESS
Definition: suanPan.h:161
constexpr auto SUANPAN_FAIL
Definition: suanPan.h:162