Skip to content

Commit 38c64c9

Browse files
committed
Merge SVDDC_ to Lapack trait
1 parent f30931d commit 38c64c9

File tree

2 files changed

+21
-128
lines changed

2 files changed

+21
-128
lines changed

lax/src/lib.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@
6666
//! -----------------------------
6767
//!
6868
//! - [svd] module for singular value decomposition (SVD) for general matrix
69-
//! - [SVDDC_] trait provides methods for singular value decomposition for general matrix
70-
//! with divided-and-conquer algorithm
69+
//! - [svddc] module for singular value decomposition (SVD) with divided-and-conquer algorithm for general matrix
7170
//! - [LeastSquaresSvdDivideConquer_] trait provides methods
7271
//! for solving least square problem by SVD
7372
//!
@@ -92,6 +91,7 @@ pub mod eigh;
9291
pub mod eigh_generalized;
9392
pub mod qr;
9493
pub mod svd;
94+
pub mod svddc;
9595

9696
mod alloc;
9797
mod cholesky;
@@ -100,7 +100,6 @@ mod opnorm;
100100
mod rcond;
101101
mod solve;
102102
mod solveh;
103-
mod svddc;
104103
mod triangular;
105104
mod tridiagonal;
106105

@@ -112,7 +111,6 @@ pub use self::rcond::*;
112111
pub use self::solve::*;
113112
pub use self::solveh::*;
114113
pub use self::svd::{SvdOwned, SvdRef};
115-
pub use self::svddc::*;
116114
pub use self::triangular::*;
117115
pub use self::tridiagonal::*;
118116

@@ -125,7 +123,6 @@ pub type Pivot = Vec<i32>;
125123
/// Trait for primitive types which implements LAPACK subroutines
126124
pub trait Lapack:
127125
OperatorNorm_
128-
+ SVDDC_
129126
+ Solve_
130127
+ Solveh_
131128
+ Cholesky_
@@ -172,6 +169,9 @@ pub trait Lapack:
172169

173170
/// Compute singular-value decomposition (SVD)
174171
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SvdOwned<Self>>;
172+
173+
/// Compute singular value decomposition (SVD) with divide-and-conquer algorithm
174+
fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
175175
}
176176

177177
macro_rules! impl_lapack {
@@ -241,6 +241,12 @@ macro_rules! impl_lapack {
241241
let work = SvdWork::<$s>::new(l, calc_u, calc_vt)?;
242242
work.eval(a)
243243
}
244+
245+
fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>> {
246+
use svddc::*;
247+
let work = SvdDcWork::<$s>::new(layout, jobz)?;
248+
work.eval(a)
249+
}
244250
}
245251
};
246252
}

lax/src/svddc.rs

Lines changed: 10 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
1+
//! Compute singular value decomposition with divide-and-conquer algorithm
2+
//!
3+
//! LAPACK correspondance
4+
//! ----------------------
5+
//!
6+
//! | f32 | f64 | c32 | c64 |
7+
//! |:-------|:-------|:-------|:-------|
8+
//! | sgesdd | dgesdd | cgesdd | zgesdd |
9+
//!
10+
111
use crate::{error::*, layout::MatrixLayout, *};
212
use cauchy::*;
313
use num_traits::{ToPrimitive, Zero};
414

5-
#[cfg_attr(doc, katexit::katexit)]
6-
/// Singular value decomposition with divide-and-conquer method
7-
pub trait SVDDC_: Scalar {
8-
/// Compute singular value decomposition $A = U \Sigma V^T$
9-
///
10-
/// LAPACK correspondance
11-
/// ----------------------
12-
///
13-
/// | f32 | f64 | c32 | c64 |
14-
/// |:-------|:-------|:-------|:-------|
15-
/// | sgesdd | dgesdd | cgesdd | zgesdd |
16-
///
17-
fn svddc(layout: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
18-
}
19-
2015
pub struct SvdDcWork<T: Scalar> {
2116
pub jobz: JobSvd,
2217
pub layout: MatrixLayout,
@@ -310,111 +305,3 @@ macro_rules! impl_svd_dc_work_r {
310305
}
311306
impl_svd_dc_work_r!(f64, lapack_sys::dgesdd_);
312307
impl_svd_dc_work_r!(f32, lapack_sys::sgesdd_);
313-
314-
macro_rules! impl_svddc {
315-
(@real, $scalar:ty, $gesdd:path) => {
316-
impl_svddc!(@body, $scalar, $gesdd, );
317-
};
318-
(@complex, $scalar:ty, $gesdd:path) => {
319-
impl_svddc!(@body, $scalar, $gesdd, rwork);
320-
};
321-
(@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
322-
impl SVDDC_ for $scalar {
323-
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result<SvdOwned<Self>> {
324-
let m = l.lda();
325-
let n = l.len();
326-
let k = m.min(n);
327-
let mut s = vec_uninit(k as usize);
328-
329-
let (u_col, vt_row) = match jobz {
330-
JobSvd::All | JobSvd::None => (m, n),
331-
JobSvd::Some => (k, k),
332-
};
333-
let (mut u, mut vt) = match jobz {
334-
JobSvd::All => (
335-
Some(vec_uninit((m * m) as usize)),
336-
Some(vec_uninit((n * n) as usize)),
337-
),
338-
JobSvd::Some => (
339-
Some(vec_uninit((m * u_col) as usize)),
340-
Some(vec_uninit((n * vt_row) as usize)),
341-
),
342-
JobSvd::None => (None, None),
343-
};
344-
345-
$( // for complex only
346-
let mx = n.max(m) as usize;
347-
let mn = n.min(m) as usize;
348-
let lrwork = match jobz {
349-
JobSvd::None => 7 * mn,
350-
_ => std::cmp::max(5*mn*mn + 5*mn, 2*mx*mn + 2*mn*mn + mn),
351-
};
352-
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(lrwork);
353-
)*
354-
355-
// eval work size
356-
let mut info = 0;
357-
let mut iwork: Vec<MaybeUninit<i32>> = vec_uninit(8 * k as usize);
358-
let mut work_size = [Self::zero()];
359-
unsafe {
360-
$gesdd(
361-
jobz.as_ptr(),
362-
&m,
363-
&n,
364-
AsPtr::as_mut_ptr(a),
365-
&m,
366-
AsPtr::as_mut_ptr(&mut s),
367-
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
368-
&m,
369-
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
370-
&vt_row,
371-
AsPtr::as_mut_ptr(&mut work_size),
372-
&(-1),
373-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
374-
AsPtr::as_mut_ptr(&mut iwork),
375-
&mut info,
376-
);
377-
}
378-
info.as_lapack_result()?;
379-
380-
// do svd
381-
let lwork = work_size[0].to_usize().unwrap();
382-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
383-
unsafe {
384-
$gesdd(
385-
jobz.as_ptr(),
386-
&m,
387-
&n,
388-
AsPtr::as_mut_ptr(a),
389-
&m,
390-
AsPtr::as_mut_ptr(&mut s),
391-
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
392-
&m,
393-
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
394-
&vt_row,
395-
AsPtr::as_mut_ptr(&mut work),
396-
&(lwork as i32),
397-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
398-
AsPtr::as_mut_ptr(&mut iwork),
399-
&mut info,
400-
);
401-
}
402-
info.as_lapack_result()?;
403-
404-
let s = unsafe { s.assume_init() };
405-
let u = u.map(|v| unsafe { v.assume_init() });
406-
let vt = vt.map(|v| unsafe { v.assume_init() });
407-
408-
match l {
409-
MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
410-
MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
411-
}
412-
}
413-
}
414-
};
415-
}
416-
417-
impl_svddc!(@real, f32, lapack_sys::sgesdd_);
418-
impl_svddc!(@real, f64, lapack_sys::dgesdd_);
419-
impl_svddc!(@complex, c32, lapack_sys::cgesdd_);
420-
impl_svddc!(@complex, c64, lapack_sys::zgesdd_);

0 commit comments

Comments
 (0)