Skip to content

Commit f672b07

Browse files
authored
Merge pull request #341 from rust-ndarray/lax-eigh-generalized-work
Merge `Eigh_` into `Lapack` trait, add working memory management
2 parents 33e2dc3 + bef1083 commit f672b07

File tree

3 files changed

+388
-122
lines changed

3 files changed

+388
-122
lines changed

lax/src/eigh.rs

Lines changed: 129 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,190 @@
1+
//! Eigenvalue problem for symmetric/Hermite matricies
2+
//!
3+
//! LAPACK correspondance
4+
//! ----------------------
5+
//!
6+
//! | f32 | f64 | c32 | c64 |
7+
//! |:------|:------|:------|:------|
8+
//! | ssyev | dsyev | cheev | zheev |
9+
110
use super::*;
211
use crate::{error::*, layout::MatrixLayout};
312
use cauchy::*;
413
use num_traits::{ToPrimitive, Zero};
514

6-
#[cfg_attr(doc, katexit::katexit)]
7-
/// Eigenvalue problem for symmetric/hermite matrix
8-
pub trait Eigh_: Scalar {
9-
/// Compute right eigenvalue and eigenvectors $Ax = \lambda x$
10-
///
11-
/// LAPACK correspondance
12-
/// ----------------------
13-
///
14-
/// | f32 | f64 | c32 | c64 |
15-
/// |:------|:------|:------|:------|
16-
/// | ssyev | dsyev | cheev | zheev |
17-
///
18-
fn eigh(
19-
calc_eigenvec: bool,
20-
layout: MatrixLayout,
21-
uplo: UPLO,
22-
a: &mut [Self],
23-
) -> Result<Vec<Self::Real>>;
15+
pub struct EighWork<T: Scalar> {
16+
pub n: i32,
17+
pub jobz: JobEv,
18+
pub eigs: Vec<MaybeUninit<T::Real>>,
19+
pub work: Vec<MaybeUninit<T>>,
20+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
21+
}
2422

25-
/// Compute generalized right eigenvalue and eigenvectors $Ax = \lambda B x$
26-
///
27-
/// LAPACK correspondance
28-
/// ----------------------
29-
///
30-
/// | f32 | f64 | c32 | c64 |
31-
/// |:------|:------|:------|:------|
32-
/// | ssygv | dsygv | chegv | zhegv |
33-
///
34-
fn eigh_generalized(
35-
calc_eigenvec: bool,
36-
layout: MatrixLayout,
37-
uplo: UPLO,
38-
a: &mut [Self],
39-
b: &mut [Self],
40-
) -> Result<Vec<Self::Real>>;
23+
pub trait EighWorkImpl: Sized {
24+
type Elem: Scalar;
25+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self>;
26+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem])
27+
-> Result<&[<Self::Elem as Scalar>::Real]>;
28+
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Vec<<Self::Elem as Scalar>::Real>>;
4129
}
4230

43-
macro_rules! impl_eigh {
44-
(@real, $scalar:ty, $ev:path, $evg:path) => {
45-
impl_eigh!(@body, $scalar, $ev, $evg, );
46-
};
47-
(@complex, $scalar:ty, $ev:path, $evg:path) => {
48-
impl_eigh!(@body, $scalar, $ev, $evg, rwork);
49-
};
50-
(@body, $scalar:ty, $ev:path, $evg:path, $($rwork_ident:ident),*) => {
51-
impl Eigh_ for $scalar {
52-
fn eigh(
53-
calc_v: bool,
54-
layout: MatrixLayout,
55-
uplo: UPLO,
56-
a: &mut [Self],
57-
) -> Result<Vec<Self::Real>> {
31+
macro_rules! impl_eigh_work_c {
32+
($c:ty, $ev:path) => {
33+
impl EighWorkImpl for EighWork<$c> {
34+
type Elem = $c;
35+
36+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
5837
assert_eq!(layout.len(), layout.lda());
5938
let n = layout.len();
60-
let jobz = if calc_v { JobEv::All } else { JobEv::None };
61-
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);
62-
63-
$(
64-
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2 as usize);
65-
)*
66-
67-
// calc work size
39+
let jobz = if calc_eigenvectors {
40+
JobEv::All
41+
} else {
42+
JobEv::None
43+
};
44+
let mut eigs = vec_uninit(n as usize);
45+
let mut rwork = vec_uninit(3 * n as usize - 2 as usize);
6846
let mut info = 0;
69-
let mut work_size = [Self::zero()];
47+
let mut work_size = [Self::Elem::zero()];
7048
unsafe {
7149
$ev(
72-
jobz.as_ptr() ,
73-
uplo.as_ptr(),
50+
jobz.as_ptr(),
51+
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
7452
&n,
75-
AsPtr::as_mut_ptr(a),
53+
std::ptr::null_mut(),
7654
&n,
7755
AsPtr::as_mut_ptr(&mut eigs),
7856
AsPtr::as_mut_ptr(&mut work_size),
7957
&(-1),
80-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
58+
AsPtr::as_mut_ptr(&mut rwork),
8159
&mut info,
8260
);
8361
}
8462
info.as_lapack_result()?;
85-
86-
// actual ev
8763
let lwork = work_size[0].to_usize().unwrap();
88-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
89-
let lwork = lwork as i32;
64+
let work = vec_uninit(lwork);
65+
Ok(EighWork {
66+
n,
67+
eigs,
68+
jobz,
69+
work,
70+
rwork: Some(rwork),
71+
})
72+
}
73+
74+
fn calc(
75+
&mut self,
76+
uplo: UPLO,
77+
a: &mut [Self::Elem],
78+
) -> Result<&[<Self::Elem as Scalar>::Real]> {
79+
let lwork = self.work.len().to_i32().unwrap();
80+
let mut info = 0;
9081
unsafe {
9182
$ev(
92-
jobz.as_ptr(),
83+
self.jobz.as_ptr(),
9384
uplo.as_ptr(),
94-
&n,
85+
&self.n,
9586
AsPtr::as_mut_ptr(a),
96-
&n,
97-
AsPtr::as_mut_ptr(&mut eigs),
98-
AsPtr::as_mut_ptr(&mut work),
87+
&self.n,
88+
AsPtr::as_mut_ptr(&mut self.eigs),
89+
AsPtr::as_mut_ptr(&mut self.work),
9990
&lwork,
100-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
91+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
10192
&mut info,
10293
);
10394
}
10495
info.as_lapack_result()?;
105-
106-
let eigs = unsafe { eigs.assume_init() };
107-
Ok(eigs)
96+
Ok(unsafe { self.eigs.slice_assume_init_ref() })
10897
}
10998

110-
fn eigh_generalized(
111-
calc_v: bool,
112-
layout: MatrixLayout,
99+
fn eval(
100+
mut self,
113101
uplo: UPLO,
114-
a: &mut [Self],
115-
b: &mut [Self],
116-
) -> Result<Vec<Self::Real>> {
117-
assert_eq!(layout.len(), layout.lda());
118-
let n = layout.len();
119-
let jobz = if calc_v { JobEv::All } else { JobEv::None };
120-
let mut eigs: Vec<MaybeUninit<Self::Real>> = vec_uninit(n as usize);
102+
a: &mut [Self::Elem],
103+
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
104+
let _eig = self.calc(uplo, a)?;
105+
Ok(unsafe { self.eigs.assume_init() })
106+
}
107+
}
108+
};
109+
}
110+
impl_eigh_work_c!(c64, lapack_sys::zheev_);
111+
impl_eigh_work_c!(c32, lapack_sys::cheev_);
121112

122-
$(
123-
let mut $rwork_ident: Vec<MaybeUninit<Self::Real>> = vec_uninit(3 * n as usize - 2);
124-
)*
113+
macro_rules! impl_eigh_work_r {
114+
($f:ty, $ev:path) => {
115+
impl EighWorkImpl for EighWork<$f> {
116+
type Elem = $f;
125117

126-
// calc work size
118+
fn new(calc_eigenvectors: bool, layout: MatrixLayout) -> Result<Self> {
119+
assert_eq!(layout.len(), layout.lda());
120+
let n = layout.len();
121+
let jobz = if calc_eigenvectors {
122+
JobEv::All
123+
} else {
124+
JobEv::None
125+
};
126+
let mut eigs = vec_uninit(n as usize);
127127
let mut info = 0;
128-
let mut work_size = [Self::zero()];
128+
let mut work_size = [Self::Elem::zero()];
129129
unsafe {
130-
$evg(
131-
&1, // ITYPE A*x = (lambda)*B*x
130+
$ev(
132131
jobz.as_ptr(),
133-
uplo.as_ptr(),
132+
UPLO::Upper.as_ptr(), // dummy, working memory is not affected by UPLO
134133
&n,
135-
AsPtr::as_mut_ptr(a),
136-
&n,
137-
AsPtr::as_mut_ptr(b),
134+
std::ptr::null_mut(),
138135
&n,
139136
AsPtr::as_mut_ptr(&mut eigs),
140137
AsPtr::as_mut_ptr(&mut work_size),
141138
&(-1),
142-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
143139
&mut info,
144140
);
145141
}
146142
info.as_lapack_result()?;
147-
148-
// actual evg
149143
let lwork = work_size[0].to_usize().unwrap();
150-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
151-
let lwork = lwork as i32;
144+
let work = vec_uninit(lwork);
145+
Ok(EighWork {
146+
n,
147+
eigs,
148+
jobz,
149+
work,
150+
rwork: None,
151+
})
152+
}
153+
154+
fn calc(
155+
&mut self,
156+
uplo: UPLO,
157+
a: &mut [Self::Elem],
158+
) -> Result<&[<Self::Elem as Scalar>::Real]> {
159+
let lwork = self.work.len().to_i32().unwrap();
160+
let mut info = 0;
152161
unsafe {
153-
$evg(
154-
&1, // ITYPE A*x = (lambda)*B*x
155-
jobz.as_ptr(),
162+
$ev(
163+
self.jobz.as_ptr(),
156164
uplo.as_ptr(),
157-
&n,
165+
&self.n,
158166
AsPtr::as_mut_ptr(a),
159-
&n,
160-
AsPtr::as_mut_ptr(b),
161-
&n,
162-
AsPtr::as_mut_ptr(&mut eigs),
163-
AsPtr::as_mut_ptr(&mut work),
167+
&self.n,
168+
AsPtr::as_mut_ptr(&mut self.eigs),
169+
AsPtr::as_mut_ptr(&mut self.work),
164170
&lwork,
165-
$(AsPtr::as_mut_ptr(&mut $rwork_ident),)*
166171
&mut info,
167172
);
168173
}
169174
info.as_lapack_result()?;
170-
let eigs = unsafe { eigs.assume_init() };
171-
Ok(eigs)
175+
Ok(unsafe { self.eigs.slice_assume_init_ref() })
176+
}
177+
178+
fn eval(
179+
mut self,
180+
uplo: UPLO,
181+
a: &mut [Self::Elem],
182+
) -> Result<Vec<<Self::Elem as Scalar>::Real>> {
183+
let _eig = self.calc(uplo, a)?;
184+
Ok(unsafe { self.eigs.assume_init() })
172185
}
173186
}
174187
};
175-
} // impl_eigh!
176-
177-
impl_eigh!(@real, f64, lapack_sys::dsyev_, lapack_sys::dsygv_);
178-
impl_eigh!(@real, f32, lapack_sys::ssyev_, lapack_sys::ssygv_);
179-
impl_eigh!(@complex, c64, lapack_sys::zheev_, lapack_sys::zhegv_);
180-
impl_eigh!(@complex, c32, lapack_sys::cheev_, lapack_sys::chegv_);
188+
}
189+
impl_eigh_work_r!(f64, lapack_sys::dsyev_);
190+
impl_eigh_work_r!(f32, lapack_sys::ssyev_);

0 commit comments

Comments
 (0)