Skip to content

Commit 3fd8c0b

Browse files
committed
Merge SVD_ to Lapack
1 parent f9f16e2 commit 3fd8c0b

File tree

3 files changed

+29
-140
lines changed

3 files changed

+29
-140
lines changed

lax/src/lib.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
//! Singular Value Decomposition
6666
//! -----------------------------
6767
//!
68-
//! - [SVD_] trait provides methods for singular value decomposition for general matrix
68+
//! - [svd] module for singular value decomposition (SVD) for general matrix
6969
//! - [SVDDC_] trait provides methods for singular value decomposition for general matrix
7070
//! with divided-and-conquer algorithm
7171
//! - [LeastSquaresSvdDivideConquer_] trait provides methods
@@ -91,6 +91,7 @@ pub mod eig;
9191
pub mod eigh;
9292
pub mod eigh_generalized;
9393
pub mod qr;
94+
pub mod svd;
9495

9596
mod alloc;
9697
mod cholesky;
@@ -99,7 +100,6 @@ mod opnorm;
99100
mod rcond;
100101
mod solve;
101102
mod solveh;
102-
mod svd;
103103
mod svddc;
104104
mod triangular;
105105
mod tridiagonal;
@@ -111,7 +111,7 @@ pub use self::opnorm::*;
111111
pub use self::rcond::*;
112112
pub use self::solve::*;
113113
pub use self::solveh::*;
114-
pub use self::svd::*;
114+
pub use self::svd::SvdOwned;
115115
pub use self::svddc::*;
116116
pub use self::triangular::*;
117117
pub use self::tridiagonal::*;
@@ -125,7 +125,6 @@ pub type Pivot = Vec<i32>;
125125
/// Trait for primitive types which implements LAPACK subroutines
126126
pub trait Lapack:
127127
OperatorNorm_
128-
+ SVD_
129128
+ SVDDC_
130129
+ Solve_
131130
+ Solveh_
@@ -170,6 +169,9 @@ pub trait Lapack:
170169

171170
/// Execute QR-decomposition at once
172171
fn qr(l: MatrixLayout, a: &mut [Self]) -> Result<Vec<Self>>;
172+
173+
/// Compute singular-value decomposition (SVD)
174+
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result<SvdOwned<Self>>;
173175
}
174176

175177
macro_rules! impl_lapack {
@@ -228,6 +230,17 @@ macro_rules! impl_lapack {
228230
Self::q(l, a, &tau)?;
229231
Ok(r)
230232
}
233+
234+
fn svd(
235+
l: MatrixLayout,
236+
calc_u: bool,
237+
calc_vt: bool,
238+
a: &mut [Self],
239+
) -> Result<SvdOwned<Self>> {
240+
use svd::*;
241+
let work = SvdWork::<$s>::new(l, calc_u, calc_vt)?;
242+
work.eval(a)
243+
}
231244
}
232245
};
233246
}

lax/src/svd.rs

Lines changed: 8 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,17 @@
11
//! Singular-value decomposition
2+
//!
3+
//! LAPACK correspondance
4+
//! ----------------------
5+
//!
6+
//! | f32 | f64 | c32 | c64 |
7+
//! |:-------|:-------|:-------|:-------|
8+
//! | sgesvd | dgesvd | cgesvd | zgesvd |
9+
//!
210
311
use super::{error::*, layout::*, *};
412
use cauchy::*;
513
use num_traits::{ToPrimitive, Zero};
614

7-
/// Result of SVD
8-
pub struct SVDOutput<A: Scalar> {
9-
/// diagonal values
10-
pub s: Vec<A::Real>,
11-
/// Unitary matrix for destination space
12-
pub u: Option<Vec<A>>,
13-
/// Unitary matrix for departure space
14-
pub vt: Option<Vec<A>>,
15-
}
16-
17-
#[cfg_attr(doc, katexit::katexit)]
18-
/// Singular value decomposition
19-
pub trait SVD_: Scalar {
20-
/// Compute singular value decomposition $A = U \Sigma V^T$
21-
///
22-
/// LAPACK correspondance
23-
/// ----------------------
24-
///
25-
/// | f32 | f64 | c32 | c64 |
26-
/// |:-------|:-------|:-------|:-------|
27-
/// | sgesvd | dgesvd | cgesvd | zgesvd |
28-
///
29-
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self])
30-
-> Result<SVDOutput<Self>>;
31-
}
32-
3315
pub struct SvdWork<T: Scalar> {
3416
pub ju: JobSvd,
3517
pub jvt: JobSvd,
@@ -330,109 +312,3 @@ macro_rules! impl_svd_work_r {
330312
}
331313
impl_svd_work_r!(f64, lapack_sys::dgesvd_);
332314
impl_svd_work_r!(f32, lapack_sys::sgesvd_);
333-
334-
macro_rules! impl_svd {
335-
(@real, $scalar:ty, $gesvd:path) => {
336-
impl_svd!(@body, $scalar, $gesvd, );
337-
};
338-
(@complex, $scalar:ty, $gesvd:path) => {
339-
impl_svd!(@body, $scalar, $gesvd, rwork);
340-
};
341-
(@body, $scalar:ty, $gesvd:path, $($rwork_ident:ident),*) => {
342-
impl SVD_ for $scalar {
343-
fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self],) -> Result<SVDOutput<Self>> {
344-
let ju = match l {
345-
MatrixLayout::F { .. } => JobSvd::from_bool(calc_u),
346-
MatrixLayout::C { .. } => JobSvd::from_bool(calc_vt),
347-
};
348-
let jvt = match l {
349-
MatrixLayout::F { .. } => JobSvd::from_bool(calc_vt),
350-
MatrixLayout::C { .. } => JobSvd::from_bool(calc_u),
351-
};
352-
353-
let m = l.lda();
354-
let mut u = match ju {
355-
JobSvd::All => Some(vec_uninit( (m * m) as usize)),
356-
JobSvd::None => None,
357-
_ => unimplemented!("SVD with partial vector output is not supported yet")
358-
};
359-
360-
let n = l.len();
361-
let mut vt = match jvt {
362-
JobSvd::All => Some(vec_uninit( (n * n) as usize)),
363-
JobSvd::None => None,
364-
_ => unimplemented!("SVD with partial vector output is not supported yet")
365-
};
366-
367-
let k = std::cmp::min(m, n);
368-
let mut s = vec_uninit( k as usize);
369-
370-
$(
371-
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(5 * k as usize);
372-
)*
373-
374-
// eval work size
375-
let mut info = 0;
376-
let mut work_size = [Self::zero()];
377-
unsafe {
378-
$gesvd(
379-
ju.as_ptr(),
380-
jvt.as_ptr(),
381-
&m,
382-
&n,
383-
AsPtr::as_mut_ptr(a),
384-
&m,
385-
AsPtr::as_mut_ptr(&mut s),
386-
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
387-
&m,
388-
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
389-
&n,
390-
AsPtr::as_mut_ptr(&mut work_size),
391-
&(-1),
392-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
393-
&mut info,
394-
);
395-
}
396-
info.as_lapack_result()?;
397-
398-
// calc
399-
let lwork = work_size[0].to_usize().unwrap();
400-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
401-
unsafe {
402-
$gesvd(
403-
ju.as_ptr(),
404-
jvt.as_ptr() ,
405-
&m,
406-
&n,
407-
AsPtr::as_mut_ptr(a),
408-
&m,
409-
AsPtr::as_mut_ptr(&mut s),
410-
AsPtr::as_mut_ptr(u.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
411-
&m,
412-
AsPtr::as_mut_ptr(vt.as_mut().map(|x| x.as_mut_slice()).unwrap_or(&mut [])),
413-
&n,
414-
AsPtr::as_mut_ptr(&mut work),
415-
&(lwork as i32),
416-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
417-
&mut info,
418-
);
419-
}
420-
info.as_lapack_result()?;
421-
422-
let s = unsafe { s.assume_init() };
423-
let u = u.map(|v| unsafe { v.assume_init() });
424-
let vt = vt.map(|v| unsafe { v.assume_init() });
425-
426-
match l {
427-
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
428-
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
429-
}
430-
}
431-
}
432-
};
433-
} // impl_svd!
434-
435-
impl_svd!(@real, f64, lapack_sys::dgesvd_);
436-
impl_svd!(@real, f32, lapack_sys::sgesvd_);
437-
impl_svd!(@complex, c64, lapack_sys::zgesvd_);
438-
impl_svd!(@complex, c32, lapack_sys::cgesvd_);

lax/src/svddc.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub trait SVDDC_: Scalar {
1414
/// |:-------|:-------|:-------|:-------|
1515
/// | sgesdd | dgesdd | cgesdd | zgesdd |
1616
///
17-
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SVDOutput<Self>>;
17+
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self]) -> Result<SvdOwned<Self>>;
1818
}
1919

2020
macro_rules! impl_svddc {
@@ -26,7 +26,7 @@ macro_rules! impl_svddc {
2626
};
2727
(@body, $scalar:ty, $gesdd:path, $($rwork_ident:ident),*) => {
2828
impl SVDDC_ for $scalar {
29-
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result<SVDOutput<Self>> {
29+
fn svddc(l: MatrixLayout, jobz: JobSvd, a: &mut [Self],) -> Result<SvdOwned<Self>> {
3030
let m = l.lda();
3131
let n = l.len();
3232
let k = m.min(n);
@@ -112,8 +112,8 @@ macro_rules! impl_svddc {
112112
let vt = vt.map(|v| unsafe { v.assume_init() });
113113

114114
match l {
115-
MatrixLayout::F { .. } => Ok(SVDOutput { s, u, vt }),
116-
MatrixLayout::C { .. } => Ok(SVDOutput { s, u: vt, vt: u }),
115+
MatrixLayout::F { .. } => Ok(SvdOwned { s, u, vt }),
116+
MatrixLayout::C { .. } => Ok(SvdOwned { s, u: vt, vt: u }),
117117
}
118118
}
119119
}

0 commit comments

Comments
 (0)