Skip to content

Commit

Permalink
Merge pull request #2 from wgurecky/precon_gmres
Browse files Browse the repository at this point in the history
Precon gmres
  • Loading branch information
wgurecky authored Mar 13, 2024
2 parents f2c0c3f + 84089bf commit d9dc63c
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 11 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ name = "faer_gmres"
path = 'src/lib.rs'

[dependencies]
thiserror = "1.0"
assert_approx_eq = "1.1.0"
num-traits = "0.2.18"
faer = {version = "0.18.2"}
24 changes: 23 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,33 @@ Example use:
[0.0],
];

let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8).unwrap();
// the final None arg means do not apply left preconditioning
let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8, None).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);

## Preconditioned GMRES:

A preconditioner can be supplied:

// continued from above...
use faer_gmres::{JacobiPreconLinOp, LinOp};
let jacobi_pre = JacobiPreconLinOp::new(a_test.as_ref());
let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8, Some(&jacobi_pre)).unwrap();

## Restarted GMRES:

A restarted GMRES routine is provided:

use faer_gmres::restarted_gmres;
let max_inner = 30;
let max_outer = 50;
let (res_x, err, iters) = restarted_gmres(
a_test.as_ref(), b.as_ref(), x0.as_ref(), max_inner, max_outer, 1e-8, None).unwrap();

This will repeatedly call the inner GMRES routine, using the previous outer iteration's solution as the inital guess for the next outer solve. The current implementation of restarted GMRES in this package can reduce the memory requirements needed, but slow convergence.

TODO
====

Expand Down
246 changes: 236 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Basic GMRES implementation from the wiki:
// https://en.wikipedia.org/wiki/Generalized_minimal_residual_method
//
// Includes restarted GMRES implementation for reduced memory requirements.
//
// Uses the Faer library for sparse matricies and sparse solver.
//
// Specifically the givens_rotation, apply_givens_rotation and part of the
Expand Down Expand Up @@ -33,6 +35,56 @@ use faer::prelude::*;
use faer::sparse::*;
use faer::mat;
use num_traits::Float;
use thiserror::Error;

#[derive(Error, Debug)]
pub struct GmresError<T>
where
T: faer::RealField + Float
{
cur_x: Mat<T>,
error: T,
tol: T,
msg: String,
}

pub trait LinOp<T>
where
T: faer::RealField + Float
{
fn apply_linop_to_vec(&self, target: MatMut<T>);
}

#[derive(Clone)]
pub struct JacobiPreconLinOp<'a, T>
where
T: faer::RealField + Float
{
m: SparseColMatRef<'a, usize, T>,
}
impl <'a, T> LinOp<T> for JacobiPreconLinOp<'a, T>
where
T: faer::RealField + Float + faer::SimpleEntity
{
fn apply_linop_to_vec(&self, mut target: MatMut<T>) {
for i in 0..self.m.nrows()
{
let v = target.read(i, 0);
target.write(i, 0,
v * (T::from(1.0).unwrap() / self.m[(i, i)] ));
}
}
}
impl <'a, T> JacobiPreconLinOp <'a, T>
where
T: faer::RealField + Float
{
pub fn new(m_in: SparseColMatRef<'a, usize, T>) -> Self {
Self {
m: m_in,
}
}
}


/// Calculate the givens rotation matrix
Expand Down Expand Up @@ -67,7 +119,20 @@ fn apply_givens_rotation<T>(h: &mut Vec<T>, cs: &mut Vec<T>, sn: &mut Vec<T>, k:
}

/// Arnoldi decomposition for sparse matrices
fn arnoldi<T>(a: SparseColMatRef<usize, T>, q: &Vec<Mat<T>>, k: usize) -> (Vec<T>, Mat<T>)
///
/// # Arguments
/// * `a`- The sparse matrix used to build the krylov subspace by forming [k, Ak, A^2k, A^3k...]
/// * `q`- Vector of all prior krylov column vecs
/// * `k`- Current iteration
/// * `m`- An optional preconditioner that is applied to the original system such that
/// the new krylov subspace built is [M^{-1}k, M^{-1}Ak, M^{-1}A^2k, ...].
/// If None, no preconditioner is applied.
fn arnoldi<'a, T>(
a: SparseColMatRef<'a, usize, T>,
q: &Vec<Mat<T>>,
k: usize,
m: Option<&dyn LinOp<T>>
) -> (Vec<T>, Mat<T>)
where
T: faer::RealField + Float
{
Expand All @@ -79,6 +144,13 @@ fn arnoldi<T>(a: SparseColMatRef<usize, T>, q: &Vec<Mat<T>>, k: usize) -> (Vec<T
let mut qv: Mat<T> = faer::Mat::zeros(q_col.nrows(), 1);
linalg::matmul::sparse_dense_matmul(
qv.as_mut(), a.as_ref(), q_col.as_ref(), None, T::from(1.0).unwrap(), faer::get_global_parallelism());

// Apply left preconditioner if supplied
match m {
Some(m) => m.apply_linop_to_vec(qv.as_mut()),
_ => {}
}

let mut h = Vec::with_capacity(k + 2);
for i in 0..=k {
let qci: MatRef<T> = q[i].as_ref();
Expand All @@ -94,18 +166,23 @@ fn arnoldi<T>(a: SparseColMatRef<usize, T>, q: &Vec<Mat<T>>, k: usize) -> (Vec<T


/// Generalized minimal residual method
pub fn gmres<T>(
a: SparseColMatRef<usize, T>,
pub fn gmres<'a, T>(
a: SparseColMatRef<'a, usize, T>,
b: MatRef<T>,
x: MatRef<T>,
max_iter: usize,
threshold: T,
) -> Result<(Mat<T>, T, usize), String>
m: Option<&dyn LinOp<T>>
) -> Result<(Mat<T>, T, usize), GmresError<T>>
where
T: faer::RealField + Float
{
// compute initial residual
let r = b - a * x.as_ref();
let mut r = b - a * x.as_ref();
match &m {
Some(m) => (&m).apply_linop_to_vec(r.as_mut()),
_ => {}
}

let b_norm = b.norm_l2();
let r_norm = r.norm_l2();
Expand All @@ -127,7 +204,7 @@ pub fn gmres<T>(

let mut k_iters = 0;
for k in 0..max_iter {
let (mut hk, qk) = arnoldi(a, &qs, k);
let (mut hk, qk) = arnoldi(a, &qs, k, m);
apply_givens_rotation(&mut hk, &mut cs, &mut sn, k);
hs.push(hk);
qs.push(qk);
Expand Down Expand Up @@ -174,8 +251,61 @@ pub fn gmres<T>(
let h_qr = h_sprs.sp_qr().unwrap();
let y = h_qr.solve(&beta.get(0..k_iters+1, 0..1));

let sol = x.as_ref() + q_sprs * y;
if error <= threshold {
Ok((x.as_ref() + q_sprs * y, error, k_iters))
Ok((sol, error, k_iters))
} else {
Err(GmresError{
cur_x: sol,
error: error,
tol: threshold,
msg: "GMRES did not converge. Error: {:?}. Threshold: {:?}".to_string()}
)
}
}

/// Restarted Generalized minimal residual method
pub fn restarted_gmres<'a, T>(
a: SparseColMatRef<'a, usize, T>,
b: MatRef<T>,
x: MatRef<T>,
max_iter_inner: usize,
max_iter_outer: usize,
threshold: T,
m: Option<&dyn LinOp<T>>
) -> Result<(Mat<T>, T, usize), String>
where
T: faer::RealField + Float
{
let mut res_x = x.to_owned();
let mut error = T::from(1e20).unwrap();
let mut tot_iters = 0;
let mut iters = 0;
for _ko in 0..max_iter_outer {
let res = gmres(
a.as_ref(), b.as_ref(), res_x.as_ref(),
max_iter_inner, threshold, m);
match res {
// done
Ok(res) => {
(res_x, error, iters) = res;
tot_iters += iters;
break;
}
// failed to converge move to next outer iter
// store current solution for next outer iter
Err(res) => {
res_x = res.cur_x;
error = res.error;
tot_iters += max_iter_inner;
}
}
if error <= threshold {
break;
}
}
if error <= threshold {
Ok((res_x, error, tot_iters))
} else {
Err(format!(
"GMRES did not converge. Error: {:?}. Threshold: {:?}",
Expand Down Expand Up @@ -217,7 +347,49 @@ mod test_faer_gmres {
[0.0],
];

let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8).unwrap();
let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8, None).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
assert!(err < 1e-4);
assert!(iters < 10);

// expect result for x to be [2,1,2/3]
assert_approx_eq!(res_x.read(0, 0), 2.0, 1e-12);
assert_approx_eq!(res_x.read(1, 0), 1.0, 1e-12);
assert_approx_eq!(res_x.read(2, 0), 2.0/3.0, 1e-12);
}

#[test]
fn test_gmres_1b() {
let a_test_triplets = vec![
(0, 0, 1.0),
(1, 1, 2.0),
(2, 2, 3.0),
];
let a_test = SparseColMat::<usize, f64>::try_new_from_triplets(
3, 3,
&a_test_triplets).unwrap();

// rhs
let b = faer::mat![
[2.0],
[2.0],
[2.0],
];

// initia sol guess
let x0 = faer::mat![
[0.0],
[0.0],
[0.0],
];

// preconditioner
let jacobi_pre = JacobiPreconLinOp::new(a_test.as_ref());

let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 10, 1e-8,
Some(&jacobi_pre)).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
Expand Down Expand Up @@ -268,7 +440,7 @@ mod test_faer_gmres {
[0.0],
];

let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 100, 1e-6).unwrap();
let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 100, 1e-6, None).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
Expand Down Expand Up @@ -321,7 +493,7 @@ mod test_faer_gmres {
[0.0],
];

let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 100, 1e-6).unwrap();
let (res_x, err, iters) = gmres(a_test.as_ref(), b.as_ref(), x0.as_ref(), 100, 1e-6, None).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
Expand All @@ -336,6 +508,60 @@ mod test_faer_gmres {
assert_approx_eq!(res_x.read(4, 0), 0.292447, 1e-4);
}


#[test]
fn test_restarted_gmres_4() {
let a: Mat<f32> = faer::mat![
[0.888641, 0.477151, 0.764081, 0.244348, 0.662542],
[0.695741, 0.991383, 0.800932, 0.089616, 0.250400],
[0.149974, 0.584978, 0.937576, 0.870798, 0.990016],
[0.429292, 0.459984, 0.056629, 0.567589, 0.048561],
[0.454428, 0.253192, 0.173598, 0.321640, 0.632031],
];

let mut a_test_triplets = vec![];
for i in 0..a.nrows() {
for j in 0..a.ncols() {
a_test_triplets.push((i, j, a.read(i, j)));
}
}
let a_test = SparseColMat::<usize, f32>::try_new_from_triplets(
5, 5,
&a_test_triplets).unwrap();

// rhs
let b: Mat<f32> = faer::mat![
[0.104594],
[0.437549],
[0.040264],
[0.298842],
[0.254451]
];

// initia sol guess
let x0: Mat<f32> = faer::mat![
[0.0],
[0.0],
[0.0],
[0.0],
[0.0],
];

let (res_x, err, iters) = restarted_gmres(
a_test.as_ref(), b.as_ref(), x0.as_ref(), 3, 30,
1e-6, None).unwrap();
println!("Result x: {:?}", res_x);
println!("Error x: {:?}", err);
println!("Iters : {:?}", iters);
assert!(err < 1e-4);
assert!(iters < 100);
assert_approx_eq!(res_x.read(0, 0), 0.037919, 1e-4);
assert_approx_eq!(res_x.read(1, 0), 0.888551, 1e-4);
assert_approx_eq!(res_x.read(2, 0), -0.657575, 1e-4);
assert_approx_eq!(res_x.read(3, 0), -0.181680, 1e-4);
assert_approx_eq!(res_x.read(4, 0), 0.292447, 1e-4);
}

#[test]
fn test_arnoldi() {
}
Expand Down

0 comments on commit d9dc63c

Please sign in to comment.