Skip to content

Commit ea1f57c

Browse files
committed
Use lapack functions from the scipy cython lapack interface
1 parent 89d6f23 commit ea1f57c

File tree

1 file changed

+90
-49
lines changed

1 file changed

+90
-49
lines changed

varipeps/utils/extensions/svd_ffi.cpp

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
#include "lapack.h"
21
#include "nanobind/nanobind.h"
32
#include "xla/ffi/api/ffi.h"
43

4+
using lapack_int = int;
5+
56
namespace nb = nanobind;
67
using namespace ::xla;
78

8-
nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack");
9-
109
inline constexpr auto LapackIntDtype = ffi::DataType::S32;
1110
static_assert(std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
1211

@@ -30,8 +29,8 @@ static ffi::Error SvdOnlyVtImpl(
3029
MachineType* work, lapack_int const* lwork,
3130
RealType* rwork,
3231
lapack_int* iwork,
33-
lapack_int* info,
34-
size_t strlen),
32+
lapack_int* info
33+
),
3534
void(char const* jobz,
3635
lapack_int const* m, lapack_int const* n,
3736
MachineType* A, lapack_int const* lda,
@@ -40,22 +39,39 @@ static ffi::Error SvdOnlyVtImpl(
4039
MachineType* VT, lapack_int const* ldvt,
4140
MachineType* work, lapack_int const* lwork,
4241
lapack_int* iwork,
43-
lapack_int* info,
44-
size_t strlen)>;
42+
lapack_int* info
43+
)>;
4544

4645
FnSig* fn = nullptr;
4746

48-
if constexpr (dtype == ffi::DataType::F32) {
49-
fn = sgesdd_;
50-
}
51-
if constexpr (dtype == ffi::DataType::F64) {
52-
fn = dgesdd_;
53-
}
54-
if constexpr (dtype == ffi::DataType::C64) {
55-
fn = reinterpret_cast<FnSig*>(cgesdd_);
56-
}
57-
if constexpr (dtype == ffi::DataType::C128) {
58-
fn = reinterpret_cast<FnSig*>(zgesdd_);
47+
try {
48+
PyGILState_STATE state = PyGILState_Ensure();
49+
50+
nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack");
51+
52+
nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__");
53+
54+
auto get_lapack_ptr = [&](const char* name) {
55+
return nb::cast<nb::capsule>(lapack_capi[name]).data();
56+
};
57+
58+
if constexpr (dtype == ffi::DataType::F32) {
59+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("sgesdd"));
60+
}
61+
if constexpr (dtype == ffi::DataType::F64) {
62+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("dgesdd"));
63+
}
64+
if constexpr (dtype == ffi::DataType::C64) {
65+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("cgesdd"));
66+
}
67+
if constexpr (dtype == ffi::DataType::C128) {
68+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("zgesdd"));
69+
}
70+
71+
PyGILState_Release(state);
72+
} catch (const nb::python_error &e) {
73+
std::cerr << e.what() << std::endl;
74+
throw;
5975
}
6076

6177
const auto lapack_int_max = std::numeric_limits<lapack_int>::max();
@@ -96,12 +112,14 @@ static ffi::Error SvdOnlyVtImpl(
96112
fn(&jobz, &x_rows_lapack, &x_cols_lapack, nullptr,
97113
&x_rows_lapack, nullptr, nullptr,
98114
&ldu, nullptr, &x_cols_lapack, &work_size,
99-
&lwork, nullptr, nullptr, info_data, 1);
115+
&lwork, nullptr, nullptr, info_data
116+
);
100117
} else {
101118
fn(&jobz, &x_rows_lapack, &x_cols_lapack, nullptr,
102119
&x_rows_lapack, nullptr, nullptr,
103120
&ldu, nullptr, &x_cols_lapack,
104-
&work_size, &lwork, nullptr, info_data, 1);
121+
&work_size, &lwork, nullptr, info_data
122+
);
105123
}
106124

107125
if (*info_data != 0) {
@@ -131,12 +149,14 @@ static ffi::Error SvdOnlyVtImpl(
131149
fn(&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
132150
&x_rows_lapack, s_data, nullptr,
133151
&ldu, vt_data, &x_cols_lapack, work.get(),
134-
&lwork, rwork.get(), iwork.get(), info_data, 1);
152+
&lwork, rwork.get(), iwork.get(), info_data
153+
);
135154
} else {
136155
fn(&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
137156
&x_rows_lapack, s_data, nullptr,
138157
&ldu, vt_data, &x_cols_lapack,
139-
work.get(), &lwork, iwork.get(), info_data, 1);
158+
work.get(), &lwork, iwork.get(), info_data
159+
);
140160
}
141161

142162
if (*info_data != 0) {
@@ -164,31 +184,48 @@ static ffi::Error SvdOnlyVtQRImpl(
164184
MachineType* VT, lapack_int const* ldvt,
165185
MachineType* work, lapack_int const* lwork,
166186
RealType* rwork,
167-
lapack_int* info,
168-
size_t strlen1, size_t strlen2),
187+
lapack_int* info
188+
),
169189
void(char const* jobu, char const* jobvt,
170190
lapack_int const* m, lapack_int const* n,
171191
MachineType* A, lapack_int const* lda,
172192
RealType* S,
173193
MachineType* U, lapack_int const* ldu,
174194
MachineType* VT, lapack_int const* ldvt,
175195
MachineType* work, lapack_int const* lwork,
176-
lapack_int* info,
177-
size_t strlen1, size_t strlen2)>;
196+
lapack_int* info
197+
)>;
178198

179199
FnSig* fn = nullptr;
180200

181-
if constexpr (dtype == ffi::DataType::F32) {
182-
fn = sgesvd_;
183-
}
184-
if constexpr (dtype == ffi::DataType::F64) {
185-
fn = dgesvd_;
186-
}
187-
if constexpr (dtype == ffi::DataType::C64) {
188-
fn = reinterpret_cast<FnSig*>(cgesvd_);
189-
}
190-
if constexpr (dtype == ffi::DataType::C128) {
191-
fn = reinterpret_cast<FnSig*>(zgesvd_);
201+
try {
202+
PyGILState_STATE state = PyGILState_Ensure();
203+
204+
nb::module_ cython_lapack = nb::module_::import_("scipy.linalg.cython_lapack");
205+
206+
nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__");
207+
208+
auto get_lapack_ptr = [&](const char* name) {
209+
return nb::cast<nb::capsule>(lapack_capi[name]).data();
210+
};
211+
212+
if constexpr (dtype == ffi::DataType::F32) {
213+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("sgesvd"));
214+
}
215+
if constexpr (dtype == ffi::DataType::F64) {
216+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("dgesvd"));
217+
}
218+
if constexpr (dtype == ffi::DataType::C64) {
219+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("cgesvd"));
220+
}
221+
if constexpr (dtype == ffi::DataType::C128) {
222+
fn = reinterpret_cast<FnSig*>(get_lapack_ptr("zgesvd"));
223+
}
224+
225+
PyGILState_Release(state);
226+
} catch (const nb::python_error &e) {
227+
std::cerr << e.what() << std::endl;
228+
throw;
192229
}
193230

194231
const auto lapack_int_max = std::numeric_limits<lapack_int>::max();
@@ -230,12 +267,14 @@ static ffi::Error SvdOnlyVtQRImpl(
230267
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr,
231268
&x_rows_lapack, nullptr, nullptr,
232269
&ldu, nullptr, &x_cols_lapack, &work_size,
233-
&lwork, nullptr, info_data, 1, 1);
270+
&lwork, nullptr, info_data
271+
);
234272
} else {
235273
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr,
236274
&x_rows_lapack, nullptr, nullptr,
237275
&ldu, nullptr, &x_cols_lapack,
238-
&work_size, &lwork, info_data, 1, 1);
276+
&work_size, &lwork, info_data
277+
);
239278
}
240279

241280
if (*info_data != 0) {
@@ -262,12 +301,14 @@ static ffi::Error SvdOnlyVtQRImpl(
262301
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
263302
&x_rows_lapack, s_data, nullptr,
264303
&ldu, nullptr, &x_cols_lapack, work.get(),
265-
&lwork, rwork.get(), info_data, 1, 1);
304+
&lwork, rwork.get(), info_data
305+
);
266306
} else {
267307
fn(&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
268308
&x_rows_lapack, s_data, nullptr,
269309
&ldu, nullptr, &x_cols_lapack,
270-
work.get(), &lwork, info_data, 1, 1);
310+
work.get(), &lwork, info_data
311+
);
271312
}
272313

273314
if (*info_data != 0) {
@@ -333,12 +374,12 @@ nb::capsule EncapsulateFfiCall(T *fn) {
333374
}
334375

335376
NB_MODULE(_svd_only_vt, m) {
336-
m.def("svd_only_vt_f32", []() { return EncapsulateFfiCall(svd_only_vt_f32); });
337-
m.def("svd_only_vt_f64", []() { return EncapsulateFfiCall(svd_only_vt_f64); });
338-
m.def("svd_only_vt_c64", []() { return EncapsulateFfiCall(svd_only_vt_c64); });
339-
m.def("svd_only_vt_c128", []() { return EncapsulateFfiCall(svd_only_vt_c128); });
340-
m.def("svd_only_vt_qr_f32", []() { return EncapsulateFfiCall(svd_only_vt_qr_f32); });
341-
m.def("svd_only_vt_qr_f64", []() { return EncapsulateFfiCall(svd_only_vt_qr_f64); });
342-
m.def("svd_only_vt_qr_c64", []() { return EncapsulateFfiCall(svd_only_vt_qr_c64); });
343-
m.def("svd_only_vt_qr_c128", []() { return EncapsulateFfiCall(svd_only_vt_qr_c128); });
377+
m.def("svd_only_vt_f32", []() { return EncapsulateFfiCall(svd_only_vt_f32); }, nb::call_guard<nb::gil_scoped_release>());
378+
m.def("svd_only_vt_f64", []() { return EncapsulateFfiCall(svd_only_vt_f64); }, nb::call_guard<nb::gil_scoped_release>());
379+
m.def("svd_only_vt_c64", []() { return EncapsulateFfiCall(svd_only_vt_c64); }, nb::call_guard<nb::gil_scoped_release>());
380+
m.def("svd_only_vt_c128", []() { return EncapsulateFfiCall(svd_only_vt_c128); }, nb::call_guard<nb::gil_scoped_release>());
381+
m.def("svd_only_vt_qr_f32", []() { return EncapsulateFfiCall(svd_only_vt_qr_f32); }, nb::call_guard<nb::gil_scoped_release>());
382+
m.def("svd_only_vt_qr_f64", []() { return EncapsulateFfiCall(svd_only_vt_qr_f64); }, nb::call_guard<nb::gil_scoped_release>());
383+
m.def("svd_only_vt_qr_c64", []() { return EncapsulateFfiCall(svd_only_vt_qr_c64); }, nb::call_guard<nb::gil_scoped_release>());
384+
m.def("svd_only_vt_qr_c128", []() { return EncapsulateFfiCall(svd_only_vt_qr_c128); }, nb::call_guard<nb::gil_scoped_release>());
344385
}

0 commit comments

Comments
 (0)