Skip to content

Commit 60946d1

Browse files
committed
EigWorkImpl for f64
1 parent ea9d443 commit 60946d1

File tree

1 file changed

+243
-48
lines changed

1 file changed

+243
-48
lines changed

lax/src/eig.rs

Lines changed: 243 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,179 @@ impl EigWorkImpl for EigWork<c64> {
227227
}
228228
}
229229

230+
impl EigWorkImpl for EigWork<f64> {
231+
type Elem = f64;
232+
233+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self> {
234+
let (n, _) = l.size();
235+
let (jobvl, jobvr) = if calc_v {
236+
match l {
237+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
238+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
239+
}
240+
} else {
241+
(JobEv::None, JobEv::None)
242+
};
243+
let mut eigs_re: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);
244+
let mut eigs_im: Vec<MaybeUninit<f64>> = vec_uninit(n as usize);
245+
246+
let mut vr_l: Option<Vec<MaybeUninit<f64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
247+
let mut vr_r: Option<Vec<MaybeUninit<f64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
248+
let vc_l: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
249+
let vc_r: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
250+
251+
// calc work size
252+
let mut info = 0;
253+
let mut work_size: [f64; 1] = [0.0];
254+
unsafe {
255+
lapack_sys::dgeev_(
256+
jobvl.as_ptr(),
257+
jobvr.as_ptr(),
258+
&n,
259+
std::ptr::null_mut(),
260+
&n,
261+
AsPtr::as_mut_ptr(&mut eigs_re),
262+
AsPtr::as_mut_ptr(&mut eigs_im),
263+
AsPtr::as_mut_ptr(vr_l.as_deref_mut().unwrap_or(&mut [])),
264+
&n,
265+
AsPtr::as_mut_ptr(vr_r.as_deref_mut().unwrap_or(&mut [])),
266+
&n,
267+
AsPtr::as_mut_ptr(&mut work_size),
268+
&(-1),
269+
&mut info,
270+
)
271+
};
272+
info.as_lapack_result()?;
273+
274+
// actual ev
275+
let lwork = work_size[0].to_usize().unwrap();
276+
let work: Vec<MaybeUninit<f64>> = vec_uninit(lwork);
277+
278+
Ok(Self {
279+
n,
280+
jobvr,
281+
jobvl,
282+
eigs: vec_uninit(n as usize),
283+
eigs_re: Some(eigs_re),
284+
eigs_im: Some(eigs_im),
285+
rwork: None,
286+
vr_l,
287+
vr_r,
288+
vc_l,
289+
vc_r,
290+
work,
291+
})
292+
}
293+
294+
fn calc<'work>(&'work mut self, a: &mut [f64]) -> Result<EigRef<'work, f64>> {
295+
let lwork = self.work.len().to_i32().unwrap();
296+
let mut info = 0;
297+
unsafe {
298+
lapack_sys::dgeev_(
299+
self.jobvl.as_ptr(),
300+
self.jobvr.as_ptr(),
301+
&self.n,
302+
AsPtr::as_mut_ptr(a),
303+
&self.n,
304+
AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()),
305+
AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()),
306+
AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])),
307+
&self.n,
308+
AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])),
309+
&self.n,
310+
AsPtr::as_mut_ptr(&mut self.work),
311+
&lwork,
312+
&mut info,
313+
)
314+
};
315+
info.as_lapack_result()?;
316+
317+
let eigs_re: &[f64] = self
318+
.eigs_re
319+
.as_ref()
320+
.map(|e| unsafe { e.slice_assume_init_ref() })
321+
.unwrap();
322+
let eigs_im: &[f64] = self
323+
.eigs_im
324+
.as_ref()
325+
.map(|e| unsafe { e.slice_assume_init_ref() })
326+
.unwrap();
327+
reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs);
328+
329+
if let Some(v) = self.vr_l.as_ref() {
330+
let v = unsafe { v.slice_assume_init_ref() };
331+
reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap());
332+
}
333+
if let Some(v) = self.vr_r.as_ref() {
334+
let v = unsafe { v.slice_assume_init_ref() };
335+
reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap());
336+
}
337+
338+
Ok(EigRef {
339+
eigs: unsafe { self.eigs.slice_assume_init_ref() },
340+
vl: self
341+
.vc_l
342+
.as_ref()
343+
.map(|v| unsafe { v.slice_assume_init_ref() }),
344+
vr: self
345+
.vc_r
346+
.as_ref()
347+
.map(|v| unsafe { v.slice_assume_init_ref() }),
348+
})
349+
}
350+
351+
fn eval(mut self, a: &mut [f64]) -> Result<Eig<f64>> {
352+
let lwork = self.work.len().to_i32().unwrap();
353+
let mut info = 0;
354+
unsafe {
355+
lapack_sys::dgeev_(
356+
self.jobvl.as_ptr(),
357+
self.jobvr.as_ptr(),
358+
&self.n,
359+
AsPtr::as_mut_ptr(a),
360+
&self.n,
361+
AsPtr::as_mut_ptr(self.eigs_re.as_mut().unwrap()),
362+
AsPtr::as_mut_ptr(self.eigs_im.as_mut().unwrap()),
363+
AsPtr::as_mut_ptr(self.vr_l.as_deref_mut().unwrap_or(&mut [])),
364+
&self.n,
365+
AsPtr::as_mut_ptr(self.vr_r.as_deref_mut().unwrap_or(&mut [])),
366+
&self.n,
367+
AsPtr::as_mut_ptr(&mut self.work),
368+
&lwork,
369+
&mut info,
370+
)
371+
};
372+
info.as_lapack_result()?;
373+
374+
let eigs_re: &[f64] = self
375+
.eigs_re
376+
.as_ref()
377+
.map(|e| unsafe { e.slice_assume_init_ref() })
378+
.unwrap();
379+
let eigs_im: &[f64] = self
380+
.eigs_im
381+
.as_ref()
382+
.map(|e| unsafe { e.slice_assume_init_ref() })
383+
.unwrap();
384+
reconstruct_eigs(eigs_re, eigs_im, &mut self.eigs);
385+
386+
if let Some(v) = self.vr_l.as_ref() {
387+
let v = unsafe { v.slice_assume_init_ref() };
388+
reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap());
389+
}
390+
if let Some(v) = self.vr_r.as_ref() {
391+
let v = unsafe { v.slice_assume_init_ref() };
392+
reconstruct_eigenvectors(false, eigs_im, v, self.vc_l.as_mut().unwrap());
393+
}
394+
395+
Ok(Eig {
396+
eigs: unsafe { self.eigs.assume_init() },
397+
vl: self.vc_l.map(|v| unsafe { v.assume_init() }),
398+
vr: self.vc_r.map(|v| unsafe { v.assume_init() }),
399+
})
400+
}
401+
}
402+
230403
macro_rules! impl_eig_complex {
231404
($scalar:ty, $ev:path) => {
232405
impl Eig_ for $scalar {
@@ -429,59 +602,81 @@ macro_rules! impl_eig_real {
429602
.map(|(&re, &im)| Self::complex(re, im))
430603
.collect();
431604

432-
if !calc_v {
433-
return Ok((eigs, Vec::new()));
434-
}
435-
436-
// Reconstruct eigenvectors into complex-array
437-
// --------------------------------------------
438-
//
439-
// From LAPACK API https://software.intel.com/en-us/node/469230
440-
//
441-
// - If the j-th eigenvalue is real,
442-
// - v(j) = VR(:,j), the j-th column of VR.
443-
//
444-
// - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
445-
// - v(j) = VR(:,j) + i*VR(:,j+1)
446-
// - v(j+1) = VR(:,j) - i*VR(:,j+1).
447-
//
448-
// In the C-layout case, we need the conjugates of the left
449-
// eigenvectors, so the signs should be reversed.
450-
451-
let n = n as usize;
452-
let v = vr.or(vl).unwrap();
453-
let mut eigvecs: Vec<MaybeUninit<Self::Complex>> = vec_uninit(n * n);
454-
let mut col = 0;
455-
while col < n {
456-
if eig_im[col] == 0. {
457-
// The corresponding eigenvalue is real.
458-
for row in 0..n {
459-
let re = v[row + col * n];
460-
eigvecs[row + col * n].write(Self::complex(re, 0.));
461-
}
462-
col += 1;
463-
} else {
464-
// This is a complex conjugate pair.
465-
assert!(col + 1 < n);
466-
for row in 0..n {
467-
let re = v[row + col * n];
468-
let mut im = v[row + (col + 1) * n];
469-
if jobvl.is_calc() {
470-
im = -im;
471-
}
472-
eigvecs[row + col * n].write(Self::complex(re, im));
473-
eigvecs[row + (col + 1) * n].write(Self::complex(re, -im));
474-
}
475-
col += 2;
476-
}
605+
if calc_v {
606+
let mut eigvecs = vec_uninit((n * n) as usize);
607+
reconstruct_eigenvectors(
608+
jobvl.is_calc(),
609+
&eig_im,
610+
&vr.or(vl).unwrap(),
611+
&mut eigvecs,
612+
);
613+
Ok((eigs, unsafe { eigvecs.assume_init() }))
614+
} else {
615+
Ok((eigs, Vec::new()))
477616
}
478-
let eigvecs = unsafe { eigvecs.assume_init() };
479-
480-
Ok((eigs, eigvecs))
481617
}
482618
}
483619
};
484620
}
485621

486622
impl_eig_real!(f64, lapack_sys::dgeev_);
487623
impl_eig_real!(f32, lapack_sys::sgeev_);
624+
625+
/// Reconstruct eigenvectors into complex-array
626+
///
627+
/// From LAPACK API https://software.intel.com/en-us/node/469230
628+
///
629+
/// - If the j-th eigenvalue is real,
630+
/// - v(j) = VR(:,j), the j-th column of VR.
631+
///
632+
/// - If the j-th and (j+1)-st eigenvalues form a complex conjugate pair,
633+
/// - v(j) = VR(:,j) + i*VR(:,j+1)
634+
/// - v(j+1) = VR(:,j) - i*VR(:,j+1).
635+
///
636+
/// In the C-layout case, we need the conjugates of the left
637+
/// eigenvectors, so the signs should be reversed.
638+
fn reconstruct_eigenvectors<T: Scalar>(
639+
take_hermite_conjugate: bool,
640+
eig_im: &[T],
641+
vr: &[T],
642+
vc: &mut [MaybeUninit<T::Complex>],
643+
) {
644+
let n = eig_im.len();
645+
assert_eq!(vr.len(), n * n);
646+
assert_eq!(vc.len(), n * n);
647+
648+
let mut col = 0;
649+
while col < n {
650+
if eig_im[col].is_zero() {
651+
// The corresponding eigenvalue is real.
652+
for row in 0..n {
653+
let re = vr[row + col * n];
654+
vc[row + col * n].write(T::complex(re, T::zero()));
655+
}
656+
col += 1;
657+
} else {
658+
// This is a complex conjugate pair.
659+
assert!(col + 1 < n);
660+
for row in 0..n {
661+
let re = vr[row + col * n];
662+
let mut im = vr[row + (col + 1) * n];
663+
if take_hermite_conjugate {
664+
im = -im;
665+
}
666+
vc[row + col * n].write(T::complex(re, im));
667+
vc[row + (col + 1) * n].write(T::complex(re, -im));
668+
}
669+
col += 2;
670+
}
671+
}
672+
}
673+
674+
/// Create complex eigenvalues from real and imaginary parts.
675+
fn reconstruct_eigs<T: Scalar>(re: &[T], im: &[T], eigs: &mut [MaybeUninit<T::Complex>]) {
676+
let n = eigs.len();
677+
assert_eq!(re.len(), n);
678+
assert_eq!(im.len(), n);
679+
for i in 0..n {
680+
eigs[i].write(T::complex(re[i], im[i]));
681+
}
682+
}

0 commit comments

Comments
 (0)