From b734691edec378a513ab91d719ed5c1f537db3a9 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Sat, 24 Feb 2024 14:10:48 +0100 Subject: [PATCH 1/2] Remove redundant imports --- lax/src/eig.rs | 2 +- lax/src/eigh.rs | 2 +- lax/src/eigh_generalized.rs | 2 +- lax/src/opnorm.rs | 3 +-- lax/src/qr.rs | 2 +- lax/src/rcond.rs | 2 +- lax/src/solve.rs | 2 +- lax/src/solveh.rs | 2 +- lax/src/svddc.rs | 2 +- ndarray-linalg/tests/layout.rs | 1 - 10 files changed, 9 insertions(+), 11 deletions(-) diff --git a/lax/src/eig.rs b/lax/src/eig.rs index 710beb9c..8b6099b6 100644 --- a/lax/src/eig.rs +++ b/lax/src/eig.rs @@ -8,7 +8,7 @@ //! | sgeev | dgeev | cgeev | zgeev | //! -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/eigh.rs b/lax/src/eigh.rs index bb3ca500..afe58729 100644 --- a/lax/src/eigh.rs +++ b/lax/src/eigh.rs @@ -8,7 +8,7 @@ //! | ssyev | dsyev | cheev | zheev | use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::error::*; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/eigh_generalized.rs b/lax/src/eigh_generalized.rs index 5d4d83ca..628c9574 100644 --- a/lax/src/eigh_generalized.rs +++ b/lax/src/eigh_generalized.rs @@ -9,7 +9,7 @@ //! use super::*; -use crate::{error::*, layout::MatrixLayout}; +use crate::error::*; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index 1789f385..92fe892d 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -1,7 +1,6 @@ //! Operator norm -use super::{AsPtr, NormType}; -use crate::{layout::MatrixLayout, *}; +use crate::*; use cauchy::*; pub struct OperatorNormWork { diff --git a/lax/src/qr.rs b/lax/src/qr.rs index f37bd579..99af8140 100644 --- a/lax/src/qr.rs +++ b/lax/src/qr.rs @@ -1,6 +1,6 @@ //! QR decomposition -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/rcond.rs b/lax/src/rcond.rs index 4d4a4c92..ca82022e 100644 --- a/lax/src/rcond.rs +++ b/lax/src/rcond.rs @@ -1,6 +1,6 @@ //! Reciprocal conditional number -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::Zero; diff --git a/lax/src/solve.rs b/lax/src/solve.rs index 63f69983..bdf3ab4b 100644 --- a/lax/src/solve.rs +++ b/lax/src/solve.rs @@ -1,6 +1,6 @@ //! Solve linear equations using LU-decomposition -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index abb75cb8..f22904ce 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -3,7 +3,7 @@ //! [BK]: https://doi.org/10.2307/2005787 //! -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/lax/src/svddc.rs b/lax/src/svddc.rs index c16db4bb..ae775630 100644 --- a/lax/src/svddc.rs +++ b/lax/src/svddc.rs @@ -8,7 +8,7 @@ //! | sgesdd | dgesdd | cgesdd | zgesdd | //! -use crate::{error::*, layout::MatrixLayout, *}; +use crate::{error::*, *}; use cauchy::*; use num_traits::{ToPrimitive, Zero}; diff --git a/ndarray-linalg/tests/layout.rs b/ndarray-linalg/tests/layout.rs index 2c253f5e..623ce2ee 100644 --- a/ndarray-linalg/tests/layout.rs +++ b/ndarray-linalg/tests/layout.rs @@ -1,5 +1,4 @@ use ndarray::*; -use ndarray_linalg::layout::MatrixLayout; use ndarray_linalg::*; #[test] From c22273c815a29bdb1ef9314885530046861bd267 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Sun, 25 Feb 2024 00:47:16 +0100 Subject: [PATCH 2/2] Implement tridiagonal solve for Array1 right hand sides --- ndarray-linalg/src/tridiagonal.rs | 674 +++++++++++----------------- ndarray-linalg/tests/tridiagonal.rs | 29 ++ 2 files changed, 294 insertions(+), 409 deletions(-) diff --git a/ndarray-linalg/src/tridiagonal.rs b/ndarray-linalg/src/tridiagonal.rs index b5aebbf2..4a8ea27d 100644 --- a/ndarray-linalg/src/tridiagonal.rs +++ b/ndarray-linalg/src/tridiagonal.rs @@ -109,437 +109,293 @@ pub trait SolveTridiagonalInplace { ) -> Result<&'a mut ArrayBase>; } -impl SolveTridiagonal for LUFactorizedTridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { - let mut b = replicate(b); - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } -} - -impl SolveTridiagonal for Tridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } -} - -impl SolveTridiagonal for ArrayBase -where - A: Scalar + Lapack, - S: Data, -{ - fn solve_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_t_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_t_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let mut b = replicate(b); - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } - fn solve_h_tridiagonal_into>( - &self, - mut b: ArrayBase, - ) -> Result> { - self.solve_h_tridiagonal_inplace(&mut b)?; - Ok(b) - } -} - -impl SolveTridiagonalInplace for LUFactorizedTridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> - where - Sb: DataMut, - { - A::solve_tridiagonal( - self, - rhs.layout()?, - Transpose::No, - rhs.as_slice_mut().unwrap(), - )?; - Ok(rhs) - } - fn solve_t_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> - where - Sb: DataMut, - { - A::solve_tridiagonal( - self, - rhs.layout()?, - Transpose::Transpose, - rhs.as_slice_mut().unwrap(), - )?; - Ok(rhs) - } - fn solve_h_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> +macro_rules! impl_traits { ($dim: ident, $layout: ident) => { + impl SolveTridiagonal for LUFactorizedTridiagonal where - Sb: DataMut, + A: Scalar + Lapack, { - A::solve_tridiagonal( - self, - rhs.layout()?, - Transpose::Hermite, - rhs.as_slice_mut().unwrap(), - )?; - Ok(rhs) + fn solve_tridiagonal>( + &self, + b: &ArrayBase + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } } -} -impl SolveTridiagonalInplace for Tridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> - where - Sb: DataMut, - { - let f = self.factorize_tridiagonal()?; - f.solve_tridiagonal_inplace(rhs) - } - fn solve_t_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> + impl SolveTridiagonal for Tridiagonal where - Sb: DataMut, + A: Scalar + Lapack, { - let f = self.factorize_tridiagonal()?; - f.solve_t_tridiagonal_inplace(rhs) + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } } - fn solve_h_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> + + impl SolveTridiagonal for ArrayBase where - Sb: DataMut, + A: Scalar + Lapack, + S: Data, { - let f = self.factorize_tridiagonal()?; - f.solve_h_tridiagonal_inplace(rhs) + fn solve_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_t_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_t_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal>( + &self, + b: &ArrayBase, + ) -> Result> { + let mut b = replicate(b); + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } + fn solve_h_tridiagonal_into>( + &self, + mut b: ArrayBase, + ) -> Result> { + self.solve_h_tridiagonal_inplace(&mut b)?; + Ok(b) + } } -} -impl SolveTridiagonalInplace for ArrayBase -where - A: Scalar + Lapack, - S: Data, -{ - fn solve_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> + impl SolveTridiagonalInplace for LUFactorizedTridiagonal where - Sb: DataMut, + A: Scalar + Lapack, { - let f = self.factorize_tridiagonal()?; - f.solve_tridiagonal_inplace(rhs) + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + A::solve_tridiagonal( + self, + $layout!(rhs), + Transpose::No, + rhs.as_slice_mut().unwrap(), + )?; + Ok(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + A::solve_tridiagonal( + self, + $layout!(rhs), + Transpose::Transpose, + rhs.as_slice_mut().unwrap(), + )?; + Ok(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + A::solve_tridiagonal( + self, + $layout!(rhs), + Transpose::Hermite, + rhs.as_slice_mut().unwrap(), + )?; + Ok(rhs) + } } - fn solve_t_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> + + impl SolveTridiagonalInplace for Tridiagonal where - Sb: DataMut, + A: Scalar + Lapack, { - let f = self.factorize_tridiagonal()?; - f.solve_t_tridiagonal_inplace(rhs) + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } } - fn solve_h_tridiagonal_inplace<'a, Sb>( - &self, - rhs: &'a mut ArrayBase, - ) -> Result<&'a mut ArrayBase> + + impl SolveTridiagonalInplace for ArrayBase where - Sb: DataMut, + A: Scalar + Lapack, + S: Data, { - let f = self.factorize_tridiagonal()?; - f.solve_h_tridiagonal_inplace(rhs) - } -} - -impl SolveTridiagonal for LUFactorizedTridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal>(&self, b: &ArrayBase) -> Result> { - let b = b.to_owned(); - self.solve_tridiagonal_into(b) - } - fn solve_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let b = self.solve_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_t_tridiagonal_into(b) - } - fn solve_t_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let b = self.solve_t_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_h_tridiagonal_into(b) - } - fn solve_h_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let b = self.solve_h_tridiagonal_into(b)?; - Ok(flatten(b)) + fn solve_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_tridiagonal_inplace(rhs) + } + fn solve_t_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_t_tridiagonal_inplace(rhs) + } + fn solve_h_tridiagonal_inplace<'a, Sb>( + &self, + rhs: &'a mut ArrayBase, + ) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorize_tridiagonal()?; + f.solve_h_tridiagonal_inplace(rhs) + } } -} +}} -impl SolveTridiagonal for Tridiagonal -where - A: Scalar + Lapack, -{ - fn solve_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_tridiagonal_into(b) - } - fn solve_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_t_tridiagonal_into(b) - } - fn solve_t_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_t_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_h_tridiagonal_into(b) - } - fn solve_h_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_h_tridiagonal_into(b)?; - Ok(flatten(b)) - } -} +macro_rules! layoutIx1 { ($rhs: ident) => { + MatrixLayout::C { row: $rhs.dim() as i32, lda: 1 } +}} +impl_traits!(Ix1, layoutIx1); -impl SolveTridiagonal for ArrayBase -where - A: Scalar + Lapack, - S: Data, -{ - fn solve_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_tridiagonal_into(b) - } - fn solve_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_t_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_t_tridiagonal_into(b) - } - fn solve_t_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_t_tridiagonal_into(b)?; - Ok(flatten(b)) - } - fn solve_h_tridiagonal>( - &self, - b: &ArrayBase, - ) -> Result> { - let b = b.to_owned(); - self.solve_h_tridiagonal_into(b) - } - fn solve_h_tridiagonal_into>( - &self, - b: ArrayBase, - ) -> Result> { - let b = into_col(b); - let f = self.factorize_tridiagonal()?; - let b = f.solve_h_tridiagonal_into(b)?; - Ok(flatten(b)) - } -} +macro_rules! layoutIx2 { ($rhs: ident) => { $rhs.layout()? }} +impl_traits!(Ix2, layoutIx2); /// An interface for computing LU factorizations of tridiagonal matrix refs. pub trait FactorizeTridiagonal { diff --git a/ndarray-linalg/tests/tridiagonal.rs b/ndarray-linalg/tests/tridiagonal.rs index 513d625b..721cd5f7 100644 --- a/ndarray-linalg/tests/tridiagonal.rs +++ b/ndarray-linalg/tests/tridiagonal.rs @@ -42,6 +42,35 @@ fn opnorm_tridiagonal() { assert_aclose!(a.opnorm_fro().unwrap(), t.opnorm_fro().unwrap(), 1e-7); } +#[test] +fn solve_tridiagonal_Ix2_Ix1_f64() { + let a: Array2 = arr2(&[ + [3.0, 2.1, 0.0, 0.0, 0.0], + [3.4, 2.3, -1.0, 0.0, 0.0], + [0.0, 3.6, -5.0, 1.9, 0.0], + [0.0, 0.0, 7.0, -0.9, 8.0], + [0.0, 0.0, 0.0, -6.0, 7.1], + ]); + let mut b: Array1 = arr1(&[ 2.7, -0.5, 2.6, 0.6, 2.7 ]); + let x: Array1 = arr1(&[-4.0, 7.0, 3.0, -4.0, -3.0 ]); + a.solve_tridiagonal_inplace(&mut b).unwrap(); + assert_close_l2!(&x, &b, 1e-7); +} + +#[test] +fn solve_tridiagonal_Ix1_f64() { + let a: Tridiagonal = Tridiagonal { + l: MatrixLayout::C { row: 5, lda: 5 }, + du: vec![ 2.1, -1.0, 1.9, 8.0 ], + d: vec![ 3.0, 2.3, -5.0, -0.9, 7.1 ], + dl: vec![ 3.4, 3.6, 7.0, -6.0 ], + }; + let mut b: Array1 = arr1(&[ 2.7, -0.5, 2.6, 0.6, 2.7 ]); + let x: Array1 = arr1(&[-4.0, 7.0, 3.0, -4.0, -3.0 ]); + a.solve_tridiagonal_inplace(&mut b).unwrap(); + assert_close_l2!(&x, &b, 1e-7); +} + #[test] fn solve_tridiagonal_f64() { // https://www.nag-j.co.jp/lapack/dgttrs.htm