Skip to content

Commit 7f742ab

Browse files
committed
WIP: LeastSquaresWork
1 parent ac2f7bc commit 7f742ab

File tree

1 file changed

+91
-1
lines changed

1 file changed

+91
-1
lines changed

lax/src/least_squares.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ pub struct LeastSquaresOwned<A: Scalar> {
1212
pub rank: i32,
1313
}
1414

15+
/// Result of LeastSquares
16+
pub struct LeastSquaresRef<'work, A: Scalar> {
17+
/// singular values
18+
pub singular_values: &'work [A::Real],
19+
/// The rank of the input matrix A
20+
pub rank: i32,
21+
}
22+
1523
#[cfg_attr(doc, katexit::katexit)]
1624
/// Solve least square problem
1725
pub trait LeastSquaresSvdDivideConquer_: Scalar {
@@ -29,7 +37,89 @@ pub trait LeastSquaresSvdDivideConquer_: Scalar {
2937
a: &mut [Self],
3038
b_layout: MatrixLayout,
3139
b: &mut [Self],
32-
) -> Result<LeastSquaresOutput<Self>>;
40+
) -> Result<LeastSquaresOwned<Self>>;
41+
}
42+
43+
pub struct LeastSquaresWork<T: Scalar> {
44+
pub a_layout: MatrixLayout,
45+
pub b_layout: MatrixLayout,
46+
pub singular_values: Vec<MaybeUninit<T::Real>>,
47+
pub work: Vec<MaybeUninit<T>>,
48+
pub iwork: Vec<MaybeUninit<i32>>,
49+
pub rwork: Option<Vec<MaybeUninit<T::Real>>>,
50+
}
51+
52+
pub trait LeastSquaresWorkImpl: Sized {
53+
type Elem: Scalar;
54+
fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self>;
55+
fn calc(&mut self, a: &mut [Self], b: &mut [Self]) -> Result<LeastSquaresRef<Self::Elem>>;
56+
fn eval(self, a: &mut [Self], b: &mut [Self]) -> Result<LeastSquaresOwned<Self::Elem>>;
57+
}
58+
59+
impl LeastSquaresWorkImpl for LeastSquaresWork<c64> {
60+
type Elem = c64;
61+
62+
fn new(a_layout: MatrixLayout, b_layout: MatrixLayout) -> Result<Self> {
63+
let (m, n) = a_layout.size();
64+
let (m_, nrhs) = b_layout.size();
65+
let k = m.min(n);
66+
assert!(m_ >= m);
67+
68+
let rcond = -1.;
69+
let mut singular_values = vec_uninit(k as usize);
70+
let mut rank: i32 = 0;
71+
72+
// eval work size
73+
let mut info = 0;
74+
let mut work_size = [Self::Elem::zero()];
75+
let mut iwork_size = [0];
76+
let mut rwork = [<Self::Elem as Scalar>::Real::zero()];
77+
unsafe {
78+
lapack_sys::zgelsd_(
79+
&m,
80+
&n,
81+
&nrhs,
82+
std::ptr::null_mut(),
83+
&a_layout.lda(),
84+
std::ptr::null_mut(),
85+
&b_layout.lda(),
86+
AsPtr::as_mut_ptr(&mut singular_values),
87+
&rcond,
88+
&mut rank,
89+
AsPtr::as_mut_ptr(&mut work_size),
90+
&(-1),
91+
AsPtr::as_mut_ptr(&mut rwork),
92+
iwork_size.as_mut_ptr(),
93+
&mut info,
94+
)
95+
};
96+
info.as_lapack_result()?;
97+
98+
let lwork = work_size[0].to_usize().unwrap();
99+
let liwork = iwork_size[0].to_usize().unwrap();
100+
let lrwork = rwork[0].to_usize().unwrap();
101+
102+
let work = vec_uninit(lwork);
103+
let iwork = vec_uninit(liwork);
104+
let rwork = vec_uninit(lrwork);
105+
106+
Ok(LeastSquaresWork {
107+
a_layout,
108+
b_layout,
109+
work,
110+
iwork,
111+
rwork: Some(rwork),
112+
singular_values,
113+
})
114+
}
115+
116+
fn calc(&mut self, a: &mut [Self], b: &mut [Self]) -> Result<LeastSquaresRef<Self::Elem>> {
117+
todo!()
118+
}
119+
120+
fn eval(self, a: &mut [Self], b: &mut [Self]) -> Result<LeastSquaresOwned<Self::Elem>> {
121+
todo!()
122+
}
33123
}
34124

35125
macro_rules! impl_least_squares {

0 commit comments

Comments
 (0)