Skip to content

Commit 85d4f60

Browse files
committed
EigWorkImpl for c64
1 parent 7e5cf1c commit 85d4f60

File tree

1 file changed

+164
-0
lines changed

1 file changed

+164
-0
lines changed

lax/src/eig.rs

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,170 @@ pub struct EigWork<T: Scalar> {
5757
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
5858
}
5959

60+
pub trait EigWorkImpl: Sized {
61+
type Elem: Scalar;
62+
/// Create new working memory for eigenvalues compution.
63+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self>;
64+
/// Compute eigenvalues and vectors on this working memory.
65+
fn calc<'work>(
66+
&'work mut self,
67+
a: &mut [Self::Elem],
68+
) -> Result<(
69+
&'work [<Self::Elem as Scalar>::Complex],
70+
Option<&'work [<Self::Elem as Scalar>::Complex]>,
71+
)>;
72+
/// Compute eigenvalues and vectors by consuming this working memory.
73+
fn eval(
74+
self,
75+
a: &mut [Self::Elem],
76+
) -> Result<(
77+
Vec<<Self::Elem as Scalar>::Complex>,
78+
Option<Vec<<Self::Elem as Scalar>::Complex>>,
79+
)>;
80+
}
81+
82+
impl EigWorkImpl for EigWork<c64> {
83+
type Elem = c64;
84+
85+
fn new(calc_v: bool, l: MatrixLayout) -> Result<Self> {
86+
let (n, _) = l.size();
87+
let (jobvl, jobvr) = if calc_v {
88+
match l {
89+
MatrixLayout::C { .. } => (JobEv::All, JobEv::None),
90+
MatrixLayout::F { .. } => (JobEv::None, JobEv::All),
91+
}
92+
} else {
93+
(JobEv::None, JobEv::None)
94+
};
95+
let mut eigs: Vec<MaybeUninit<c64>> = vec_uninit(n as usize);
96+
let mut rwork: Vec<MaybeUninit<f64>> = vec_uninit(2 * n as usize);
97+
98+
let mut vl: Option<Vec<MaybeUninit<c64>>> = jobvl.then(|| vec_uninit((n * n) as usize));
99+
let mut vr: Option<Vec<MaybeUninit<c64>>> = jobvr.then(|| vec_uninit((n * n) as usize));
100+
101+
// calc work size
102+
let mut info = 0;
103+
let mut work_size = [c64::zero()];
104+
unsafe {
105+
lapack_sys::zgeev_(
106+
jobvl.as_ptr(),
107+
jobvr.as_ptr(),
108+
&n,
109+
std::ptr::null_mut(),
110+
&n,
111+
AsPtr::as_mut_ptr(&mut eigs),
112+
AsPtr::as_mut_ptr(vl.as_deref_mut().unwrap_or(&mut [])),
113+
&n,
114+
AsPtr::as_mut_ptr(vr.as_deref_mut().unwrap_or(&mut [])),
115+
&n,
116+
AsPtr::as_mut_ptr(&mut work_size),
117+
&(-1),
118+
AsPtr::as_mut_ptr(&mut rwork),
119+
&mut info,
120+
)
121+
};
122+
info.as_lapack_result()?;
123+
124+
let lwork = work_size[0].to_usize().unwrap();
125+
let work: Vec<MaybeUninit<c64>> = vec_uninit(lwork);
126+
Ok(Self {
127+
n,
128+
jobvl,
129+
jobvr,
130+
eigs: Some(eigs),
131+
eigs_re: None,
132+
eigs_im: None,
133+
rwork: Some(rwork),
134+
vl,
135+
vr,
136+
work,
137+
})
138+
}
139+
140+
fn calc<'work>(&'work mut self, a: &mut [c64]) -> Result<(&'work [c64], Option<&'work [c64]>)> {
141+
let lwork = self.work.len().to_i32().unwrap();
142+
let mut info = 0;
143+
unsafe {
144+
lapack_sys::zgeev_(
145+
self.jobvl.as_ptr(),
146+
self.jobvr.as_ptr(),
147+
&self.n,
148+
AsPtr::as_mut_ptr(a),
149+
&self.n,
150+
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
151+
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
152+
&self.n,
153+
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
154+
&self.n,
155+
AsPtr::as_mut_ptr(&mut self.work),
156+
&lwork,
157+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
158+
&mut info,
159+
)
160+
};
161+
info.as_lapack_result()?;
162+
163+
let eigs = self
164+
.eigs
165+
.as_ref()
166+
.map(|v| unsafe { v.slice_assume_init_ref() })
167+
.unwrap();
168+
169+
// Hermite conjugate
170+
if let Some(vl) = self.vl.as_mut() {
171+
for value in vl {
172+
let value = unsafe { value.assume_init_mut() };
173+
value.im = -value.im;
174+
}
175+
}
176+
let v = match (self.vl.as_ref(), self.vr.as_ref()) {
177+
(Some(v), None) | (None, Some(v)) => Some(unsafe { v.slice_assume_init_ref() }),
178+
(None, None) => None,
179+
_ => unreachable!(),
180+
};
181+
Ok((eigs, v))
182+
}
183+
184+
fn eval(mut self, a: &mut [c64]) -> Result<(Vec<c64>, Option<Vec<c64>>)> {
185+
let lwork = self.work.len().to_i32().unwrap();
186+
let mut info = 0;
187+
unsafe {
188+
lapack_sys::zgeev_(
189+
self.jobvl.as_ptr(),
190+
self.jobvr.as_ptr(),
191+
&self.n,
192+
AsPtr::as_mut_ptr(a),
193+
&self.n,
194+
AsPtr::as_mut_ptr(self.eigs.as_mut().unwrap()),
195+
AsPtr::as_mut_ptr(self.vl.as_deref_mut().unwrap_or(&mut [])),
196+
&self.n,
197+
AsPtr::as_mut_ptr(self.vr.as_deref_mut().unwrap_or(&mut [])),
198+
&self.n,
199+
AsPtr::as_mut_ptr(&mut self.work),
200+
&lwork,
201+
AsPtr::as_mut_ptr(self.rwork.as_mut().unwrap()),
202+
&mut info,
203+
)
204+
};
205+
info.as_lapack_result()?;
206+
let eigs = self.eigs.map(|v| unsafe { v.assume_init() }).unwrap();
207+
208+
// Hermite conjugate
209+
if let Some(vl) = self.vl.as_mut() {
210+
for value in vl {
211+
let value = unsafe { value.assume_init_mut() };
212+
value.im = -value.im;
213+
}
214+
}
215+
let v = match (self.vl, self.vr) {
216+
(Some(v), None) | (None, Some(v)) => Some(unsafe { v.assume_init() }),
217+
(None, None) => None,
218+
_ => unreachable!(),
219+
};
220+
Ok((eigs, v))
221+
}
222+
}
223+
60224
macro_rules! impl_eig_complex {
61225
($scalar:ty, $ev:path) => {
62226
impl Eig_ for $scalar {

0 commit comments

Comments
 (0)