Skip to content

Commit 9126810

Browse files
committed
Eig, EigRef
1 parent 43b0480 commit 9126810

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

lax/src/eig.rs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -77,26 +77,28 @@ impl<T: Scalar> EigWork<T> {
7777
}
7878
}
7979

80+
#[derive(Debug, Clone, PartialEq)]
81+
pub struct Eig<T: Scalar> {
82+
pub eigs: Vec<T::Complex>,
83+
pub vr: Option<Vec<T::Complex>>,
84+
pub vl: Option<Vec<T::Complex>>,
85+
}
86+
87+
#[derive(Debug, Clone, PartialEq)]
88+
pub struct EigRef<'work, T: Scalar> {
89+
pub eigs: &'work [T::Complex],
90+
pub vr: Option<&'work [T::Complex]>,
91+
pub vl: Option<&'work [T::Complex]>,
92+
}
93+
8094
pub trait EigWorkImpl: Sized {
8195
type Elem: Scalar;
8296
/// Create new working memory for eigenvalues compution.
8397
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self>;
8498
/// Compute eigenvalues and vectors on this working memory.
85-
fn calc<'work>(
86-
&'work mut self,
87-
a: &mut [Self::Elem],
88-
) -> Result<(
89-
&'work [<Self::Elem as Scalar>::Complex],
90-
Option<&'work [<Self::Elem as Scalar>::Complex]>,
91-
)>;
99+
fn calc<'work>(&'work mut self, a: &mut [Self::Elem]) -> Result<EigRef<'work, Self::Elem>>;
92100
/// Compute eigenvalues and vectors by consuming this working memory.
93-
fn eval(
94-
self,
95-
a: &mut [Self::Elem],
96-
) -> Result<(
97-
Vec<<Self::Elem as Scalar>::Complex>,
98-
Option<Vec<<Self::Elem as Scalar>::Complex>>,
99-
)>;
101+
fn eval(self, a: &mut [Self::Elem]) -> Result<Eig<Self::Elem>>;
100102
}
101103

102104
impl EigWorkImpl for EigWork<c64> {
@@ -157,7 +159,7 @@ impl EigWorkImpl for EigWork<c64> {
157159
})
158160
}
159161

160-
fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result<(&'work [c64], Option<&'work [c64]>)> {
162+
fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result<EigRef<'work, c64>> {
161163
let lwork = self.work.len().to_i32().unwrap();
162164
let mut info = 0;
163165
unsafe {
@@ -193,15 +195,20 @@ impl EigWorkImpl for EigWork<c64> {
193195
value.im = -value.im;
194196
}
195197
}
196-
let v = match (self.vl.as_ref(), self.vr.as_ref()) {
197-
(Some(v), None) | (None, Some(v)) => Some(unsafe { v.slice_assume_init_ref() }),
198-
(None, None) => None,
199-
_ => unreachable!(),
200-
};
201-
Ok((eigs, v))
198+
Ok(EigRef {
199+
eigs,
200+
vl: self
201+
.vl
202+
.as_ref()
203+
.map(|v| unsafe { v.slice_assume_init_ref() }),
204+
vr: self
205+
.vr
206+
.as_ref()
207+
.map(|v| unsafe { v.slice_assume_init_ref() }),
208+
})
202209
}
203210

204-
fn eval(mut self, a: &mut [c64]) -> Result<(Vec<c64>, Option<Vec<c64>>)> {
211+
fn eval(mut self, a: &mut [c64]) -> Result<Eig<c64>> {
205212
let lwork = self.work.len().to_i32().unwrap();
206213
let mut info = 0;
207214
unsafe {
@@ -232,12 +239,11 @@ impl EigWorkImpl for EigWork<c64> {
232239
value.im = -value.im;
233240
}
234241
}
235-
let v = match (self.vl, self.vr) {
236-
(Some(v), None) | (None, Some(v)) => Some(unsafe { v.assume_init() }),
237-
(None, None) => None,
238-
_ => unreachable!(),
239-
};
240-
Ok((eigs, v))
242+
Ok(Eig {
243+
eigs,
244+
vl: self.vl.map(|v| unsafe { v.assume_init() }),
245+
vr: self.vr.map(|v| unsafe { v.assume_init() }),
246+
})
241247
}
242248
}
243249

@@ -301,14 +307,11 @@ impl EigWorkImpl for EigWork<f64> {
301307
})
302308
}
303309

304-
fn calc<'work>(
305-
&'work mut self,
306-
_a: &mut [f64],
307-
) -> Result<(&'work [c64], Option<&'work [c64]>)> {
310+
fn calc<'work>(&'work mut self, _a: &mut [f64]) -> Result<EigRef<'work, f64>> {
308311
todo!()
309312
}
310313

311-
fn eval(mut self, a: &mut [f64]) -> Result<(Vec<c64>, Option<Vec<c64>>)> {
314+
fn eval(mut self, a: &mut [f64]) -> Result<Eig<f64>> {
312315
let lwork = self.work.len().to_i32().unwrap();
313316
let mut info = 0;
314317
unsafe {
@@ -343,12 +346,7 @@ impl EigWorkImpl for EigWork<f64> {
343346
.map(|(&re, &im)| c64::new(re, im))
344347
.collect();
345348

346-
if self.jobvl.is_calc() || self.jobvr.is_calc() {
347-
let eigvecs = reconstruct_eigenvectors(self.jobvl, self.n, &eig_im, vr, vl);
348-
Ok((eigs, Some(eigvecs)))
349-
} else {
350-
Ok((eigs, None))
351-
}
349+
Ok(Eig { eigs, vl, vr })
352350
}
353351
}
354352

0 commit comments

Comments
 (0)