Skip to content

Commit

Permalink
Merge pull request #6 from wgurecky/dense_orth_q
Browse files Browse the repository at this point in the history
expect orthonormal mat q to be dense
  • Loading branch information
wgurecky authored Apr 4, 2024
2 parents 2b16007 + 7754a52 commit b4667b6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ name = "faer_gmres"
path = 'src/lib.rs'

[dependencies]
thiserror = "1.0"
assert_approx_eq = "1.1.0"
num-traits = "0.2.18"
faer = {version = "0.18.2"}
49 changes: 31 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ use faer::prelude::*;
use faer::sparse::*;
use faer::mat;
use num_traits::Float;
use thiserror::Error;
use std::{error::Error, fmt};

#[derive(Error, Debug)]
#[derive(Debug)]
pub struct GmresError<T>
where
T: faer::RealField + Float
Expand All @@ -48,6 +48,20 @@ pub struct GmresError<T>
msg: String,
}

impl <T> Error for GmresError <T>
where
T: faer::RealField + Float
{}

impl <T> fmt::Display for GmresError<T>
where
T: faer::RealField + Float
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "GmresError")
}
}

pub trait LinOp<T>
where
T: faer::RealField + Float
Expand Down Expand Up @@ -238,31 +252,28 @@ pub fn gmres<'a, T>(
let h_sprs = SparseColMat::<usize, T>::try_new_from_triplets(
h_len, (&hs).len(), &h_triplets).unwrap();

// build full sparse Q matrix
let mut q_triplets = Vec::new();
let mut q_len = 0;
for (c, qvec) in (&qs).into_iter().enumerate() {
q_len = qvec.nrows();
for q_i in 0..q_len {
q_triplets.push((q_i, c, qvec.read(q_i, 0)));

// build full Q matrix
let mut q_out: Mat<T> = faer::Mat::zeros(qs[0].nrows(), qs.len());
for j in 0..q_out.ncols() {
for i in 0..q_out.nrows() {
q_out.write(i, j, qs[j].read(i, 0));
}
}
let q_sprs = SparseColMat::<usize, T>::try_new_from_triplets
(q_len, (&qs).len(), &q_triplets).unwrap();

// compute solution
let h_qr = h_sprs.sp_qr().unwrap();
let y = h_qr.solve(&beta.get(0..k_iters+1, 0..1));

let sol = x.as_ref() + q_sprs * y;
let sol = x.as_ref() + q_out * y;
if error <= threshold {
Ok((sol, error, k_iters))
} else {
Err(GmresError{
cur_x: sol,
error: error,
tol: threshold,
msg: "GMRES did not converge. Error: {:?}. Threshold: {:?}".to_string()}
msg: format!("GMRES did not converge. Error: {:?}. Threshold: {:?}", error, threshold)}
)
}
}
Expand All @@ -276,7 +287,7 @@ pub fn restarted_gmres<'a, T>(
max_iter_outer: usize,
threshold: T,
m: Option<&dyn LinOp<T>>
) -> Result<(Mat<T>, T, usize), String>
) -> Result<(Mat<T>, T, usize), GmresError<T>>
where
T: faer::RealField + Float
{
Expand Down Expand Up @@ -310,10 +321,12 @@ pub fn restarted_gmres<'a, T>(
if error <= threshold {
Ok((res_x, error, tot_iters))
} else {
Err(format!(
"GMRES did not converge. Error: {:?}. Threshold: {:?}",
error, threshold
))
Err(GmresError{
cur_x: res_x,
error: error,
tol: threshold,
msg: format!("GMRES did not converge. Error: {:?}. Threshold: {:?}", error, threshold)}
)
}
}

Expand Down

0 comments on commit b4667b6

Please sign in to comment.