1
- #include " lapack.h"
2
1
#include " nanobind/nanobind.h"
3
2
#include " xla/ffi/api/ffi.h"
4
3
4
+ using lapack_int = int ;
5
+
5
6
namespace nb = nanobind;
6
7
using namespace ::xla;
7
8
8
- nb::module_ cython_lapack = nb::module_::import_(" scipy.linalg.cython_lapack" );
9
-
10
9
inline constexpr auto LapackIntDtype = ffi::DataType::S32;
11
10
static_assert (std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
12
11
@@ -30,8 +29,8 @@ static ffi::Error SvdOnlyVtImpl(
30
29
MachineType* work, lapack_int const * lwork,
31
30
RealType* rwork,
32
31
lapack_int* iwork,
33
- lapack_int* info,
34
- size_t strlen ),
32
+ lapack_int* info
33
+ ),
35
34
void (char const * jobz,
36
35
lapack_int const * m, lapack_int const * n,
37
36
MachineType* A, lapack_int const * lda,
@@ -40,22 +39,39 @@ static ffi::Error SvdOnlyVtImpl(
40
39
MachineType* VT, lapack_int const * ldvt,
41
40
MachineType* work, lapack_int const * lwork,
42
41
lapack_int* iwork,
43
- lapack_int* info,
44
- size_t strlen )>;
42
+ lapack_int* info
43
+ )>;
45
44
46
45
FnSig* fn = nullptr ;
47
46
48
- if constexpr (dtype == ffi::DataType::F32) {
49
- fn = sgesdd_;
50
- }
51
- if constexpr (dtype == ffi::DataType::F64) {
52
- fn = dgesdd_;
53
- }
54
- if constexpr (dtype == ffi::DataType::C64) {
55
- fn = reinterpret_cast <FnSig*>(cgesdd_);
56
- }
57
- if constexpr (dtype == ffi::DataType::C128) {
58
- fn = reinterpret_cast <FnSig*>(zgesdd_);
47
+ try {
48
+ PyGILState_STATE state = PyGILState_Ensure ();
49
+
50
+ nb::module_ cython_lapack = nb::module_::import_ (" scipy.linalg.cython_lapack" );
51
+
52
+ nb::dict lapack_capi = cython_lapack.attr (" __pyx_capi__" );
53
+
54
+ auto get_lapack_ptr = [&](const char * name) {
55
+ return nb::cast<nb::capsule>(lapack_capi[name]).data ();
56
+ };
57
+
58
+ if constexpr (dtype == ffi::DataType::F32) {
59
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" sgesdd" ));
60
+ }
61
+ if constexpr (dtype == ffi::DataType::F64) {
62
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" dgesdd" ));
63
+ }
64
+ if constexpr (dtype == ffi::DataType::C64) {
65
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" cgesdd" ));
66
+ }
67
+ if constexpr (dtype == ffi::DataType::C128) {
68
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" zgesdd" ));
69
+ }
70
+
71
+ PyGILState_Release (state);
72
+ } catch (const nb::python_error &e) {
73
+ std::cerr << e.what () << std::endl;
74
+ throw ;
59
75
}
60
76
61
77
const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
@@ -96,12 +112,14 @@ static ffi::Error SvdOnlyVtImpl(
96
112
fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
97
113
&x_rows_lapack, nullptr , nullptr ,
98
114
&ldu, nullptr , &x_cols_lapack, &work_size,
99
- &lwork, nullptr , nullptr , info_data, 1 );
115
+ &lwork, nullptr , nullptr , info_data
116
+ );
100
117
} else {
101
118
fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
102
119
&x_rows_lapack, nullptr , nullptr ,
103
120
&ldu, nullptr , &x_cols_lapack,
104
- &work_size, &lwork, nullptr , info_data, 1 );
121
+ &work_size, &lwork, nullptr , info_data
122
+ );
105
123
}
106
124
107
125
if (*info_data != 0 ) {
@@ -131,12 +149,14 @@ static ffi::Error SvdOnlyVtImpl(
131
149
fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
132
150
&x_rows_lapack, s_data, nullptr ,
133
151
&ldu, vt_data, &x_cols_lapack, work.get (),
134
- &lwork, rwork.get (), iwork.get (), info_data, 1 );
152
+ &lwork, rwork.get (), iwork.get (), info_data
153
+ );
135
154
} else {
136
155
fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
137
156
&x_rows_lapack, s_data, nullptr ,
138
157
&ldu, vt_data, &x_cols_lapack,
139
- work.get (), &lwork, iwork.get (), info_data, 1 );
158
+ work.get (), &lwork, iwork.get (), info_data
159
+ );
140
160
}
141
161
142
162
if (*info_data != 0 ) {
@@ -164,31 +184,48 @@ static ffi::Error SvdOnlyVtQRImpl(
164
184
MachineType* VT, lapack_int const * ldvt,
165
185
MachineType* work, lapack_int const * lwork,
166
186
RealType* rwork,
167
- lapack_int* info,
168
- size_t strlen1, size_t strlen2 ),
187
+ lapack_int* info
188
+ ),
169
189
void (char const * jobu, char const * jobvt,
170
190
lapack_int const * m, lapack_int const * n,
171
191
MachineType* A, lapack_int const * lda,
172
192
RealType* S,
173
193
MachineType* U, lapack_int const * ldu,
174
194
MachineType* VT, lapack_int const * ldvt,
175
195
MachineType* work, lapack_int const * lwork,
176
- lapack_int* info,
177
- size_t strlen1, size_t strlen2 )>;
196
+ lapack_int* info
197
+ )>;
178
198
179
199
FnSig* fn = nullptr ;
180
200
181
- if constexpr (dtype == ffi::DataType::F32) {
182
- fn = sgesvd_;
183
- }
184
- if constexpr (dtype == ffi::DataType::F64) {
185
- fn = dgesvd_;
186
- }
187
- if constexpr (dtype == ffi::DataType::C64) {
188
- fn = reinterpret_cast <FnSig*>(cgesvd_);
189
- }
190
- if constexpr (dtype == ffi::DataType::C128) {
191
- fn = reinterpret_cast <FnSig*>(zgesvd_);
201
+ try {
202
+ PyGILState_STATE state = PyGILState_Ensure ();
203
+
204
+ nb::module_ cython_lapack = nb::module_::import_ (" scipy.linalg.cython_lapack" );
205
+
206
+ nb::dict lapack_capi = cython_lapack.attr (" __pyx_capi__" );
207
+
208
+ auto get_lapack_ptr = [&](const char * name) {
209
+ return nb::cast<nb::capsule>(lapack_capi[name]).data ();
210
+ };
211
+
212
+ if constexpr (dtype == ffi::DataType::F32) {
213
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" sgesvd" ));
214
+ }
215
+ if constexpr (dtype == ffi::DataType::F64) {
216
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" dgesvd" ));
217
+ }
218
+ if constexpr (dtype == ffi::DataType::C64) {
219
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" cgesvd" ));
220
+ }
221
+ if constexpr (dtype == ffi::DataType::C128) {
222
+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" zgesvd" ));
223
+ }
224
+
225
+ PyGILState_Release (state);
226
+ } catch (const nb::python_error &e) {
227
+ std::cerr << e.what () << std::endl;
228
+ throw ;
192
229
}
193
230
194
231
const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
@@ -230,12 +267,14 @@ static ffi::Error SvdOnlyVtQRImpl(
230
267
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
231
268
&x_rows_lapack, nullptr , nullptr ,
232
269
&ldu, nullptr , &x_cols_lapack, &work_size,
233
- &lwork, nullptr , info_data, 1 , 1 );
270
+ &lwork, nullptr , info_data
271
+ );
234
272
} else {
235
273
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
236
274
&x_rows_lapack, nullptr , nullptr ,
237
275
&ldu, nullptr , &x_cols_lapack,
238
- &work_size, &lwork, info_data, 1 , 1 );
276
+ &work_size, &lwork, info_data
277
+ );
239
278
}
240
279
241
280
if (*info_data != 0 ) {
@@ -262,12 +301,14 @@ static ffi::Error SvdOnlyVtQRImpl(
262
301
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
263
302
&x_rows_lapack, s_data, nullptr ,
264
303
&ldu, nullptr , &x_cols_lapack, work.get (),
265
- &lwork, rwork.get (), info_data, 1 , 1 );
304
+ &lwork, rwork.get (), info_data
305
+ );
266
306
} else {
267
307
fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
268
308
&x_rows_lapack, s_data, nullptr ,
269
309
&ldu, nullptr , &x_cols_lapack,
270
- work.get (), &lwork, info_data, 1 , 1 );
310
+ work.get (), &lwork, info_data
311
+ );
271
312
}
272
313
273
314
if (*info_data != 0 ) {
@@ -333,12 +374,12 @@ nb::capsule EncapsulateFfiCall(T *fn) {
333
374
}
334
375
335
376
NB_MODULE (_svd_only_vt, m) {
336
- m.def (" svd_only_vt_f32" , []() { return EncapsulateFfiCall (svd_only_vt_f32); });
337
- m.def (" svd_only_vt_f64" , []() { return EncapsulateFfiCall (svd_only_vt_f64); });
338
- m.def (" svd_only_vt_c64" , []() { return EncapsulateFfiCall (svd_only_vt_c64); });
339
- m.def (" svd_only_vt_c128" , []() { return EncapsulateFfiCall (svd_only_vt_c128); });
340
- m.def (" svd_only_vt_qr_f32" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32); });
341
- m.def (" svd_only_vt_qr_f64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64); });
342
- m.def (" svd_only_vt_qr_c64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64); });
343
- m.def (" svd_only_vt_qr_c128" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128); });
377
+ m.def (" svd_only_vt_f32" , []() { return EncapsulateFfiCall (svd_only_vt_f32); }, nb::call_guard<nb::gil_scoped_release>() );
378
+ m.def (" svd_only_vt_f64" , []() { return EncapsulateFfiCall (svd_only_vt_f64); }, nb::call_guard<nb::gil_scoped_release>() );
379
+ m.def (" svd_only_vt_c64" , []() { return EncapsulateFfiCall (svd_only_vt_c64); }, nb::call_guard<nb::gil_scoped_release>() );
380
+ m.def (" svd_only_vt_c128" , []() { return EncapsulateFfiCall (svd_only_vt_c128); }, nb::call_guard<nb::gil_scoped_release>() );
381
+ m.def (" svd_only_vt_qr_f32" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32); }, nb::call_guard<nb::gil_scoped_release>() );
382
+ m.def (" svd_only_vt_qr_f64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64); }, nb::call_guard<nb::gil_scoped_release>() );
383
+ m.def (" svd_only_vt_qr_c64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64); }, nb::call_guard<nb::gil_scoped_release>() );
384
+ m.def (" svd_only_vt_qr_c128" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128); }, nb::call_guard<nb::gil_scoped_release>() );
344
385
}
0 commit comments