-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathjax_finufft_cpu.cc
142 lines (123 loc) · 5.04 KB
/
jax_finufft_cpu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// This file defines the Python interface to the XLA custom call implemented on the CPU.
// It is exposed as a standard nanobind module defining "capsule" objects containing our
// method. For simplicity, we export a separate capsule for each supported dtype.
#include "jax_finufft_cpu.h"
#include "nanobind_kernel_helpers.h"
using namespace jax_finufft;
using namespace jax_finufft::cpu;
namespace nb = nanobind;
namespace {
template <int ndim, typename T>
void run_nufft(int type, void *desc_in, T *x, T *y, T *z, std::complex<T> *c, std::complex<T> *F) {
const descriptor<T> *desc = unpack_descriptor<descriptor<T>>(
reinterpret_cast<const char *>(desc_in), sizeof(descriptor<T>));
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= desc->n_k[d];
finufft_opts opts = desc->opts;
typename plan_type<T>::type plan;
makeplan<T>(type, ndim, const_cast<int64_t *>(desc->n_k), desc->iflag, desc->n_transf, desc->eps,
&plan, &opts);
for (int64_t index = 0; index < desc->n_tot; ++index) {
int64_t i = index * desc->n_j;
int64_t j = i * desc->n_transf;
int64_t k = index * n_k * desc->n_transf;
setpts<T>(plan, desc->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0, NULL,
NULL, NULL);
execute<T>(plan, &c[j], &F[k]);
}
destroy<T>(plan);
}
template <int ndim, typename T>
void nufft1(void *out, void **in) {
std::complex<T> *c = reinterpret_cast<std::complex<T> *>(in[1]);
T *x = reinterpret_cast<T *>(in[2]);
T *y = NULL;
T *z = NULL;
if (ndim > 1) {
y = reinterpret_cast<T *>(in[3]);
}
if (ndim > 2) {
z = reinterpret_cast<T *>(in[4]);
}
std::complex<T> *F = reinterpret_cast<std::complex<T> *>(out);
run_nufft<ndim, T>(1, in[0], x, y, z, c, F);
}
template <int ndim, typename T>
void nufft2(void *out, void **in) {
std::complex<T> *F = reinterpret_cast<std::complex<T> *>(in[1]);
T *x = reinterpret_cast<T *>(in[2]);
T *y = NULL;
T *z = NULL;
if (ndim > 1) {
y = reinterpret_cast<T *>(in[3]);
}
if (ndim > 2) {
z = reinterpret_cast<T *>(in[4]);
}
std::complex<T> *c = reinterpret_cast<std::complex<T> *>(out);
run_nufft<ndim, T>(2, in[0], x, y, z, c, F);
}
template <typename T>
nb::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, finufft_opts opts) {
return pack_descriptor(
descriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts});
}
nb::dict Registrations() {
nb::dict dict;
dict["nufft1d1f"] = encapsulate_function(nufft1<1, float>);
dict["nufft1d2f"] = encapsulate_function(nufft2<1, float>);
dict["nufft2d1f"] = encapsulate_function(nufft1<2, float>);
dict["nufft2d2f"] = encapsulate_function(nufft2<2, float>);
dict["nufft3d1f"] = encapsulate_function(nufft1<3, float>);
dict["nufft3d2f"] = encapsulate_function(nufft2<3, float>);
dict["nufft1d1"] = encapsulate_function(nufft1<1, double>);
dict["nufft1d2"] = encapsulate_function(nufft2<1, double>);
dict["nufft2d1"] = encapsulate_function(nufft1<2, double>);
dict["nufft2d2"] = encapsulate_function(nufft2<2, double>);
dict["nufft3d1"] = encapsulate_function(nufft1<3, double>);
dict["nufft3d2"] = encapsulate_function(nufft2<3, double>);
return dict;
}
NB_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);
m.def("_omp_compile_check", []() {
#ifdef FINUFFT_USE_OPENMP
return true;
#else
return false;
#endif
});
m.attr("FFTW_ESTIMATE") = nb::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = nb::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = nb::int_(FFTW_PATIENT);
m.attr("FFTW_EXHAUSTIVE") = nb::int_(FFTW_EXHAUSTIVE);
m.attr("FFTW_WISDOM_ONLY") = nb::int_(FFTW_WISDOM_ONLY);
nb::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def("__init__",
[](finufft_opts *self, bool modeord, bool chkbnds, int debug, int spread_debug,
bool showwarn, int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
new (self) finufft_opts;
default_opts<double>(self);
self->modeord = int(modeord);
self->chkbnds = int(chkbnds);
self->debug = debug;
self->spread_debug = spread_debug;
self->showwarn = int(showwarn);
self->nthreads = nthreads;
self->fftw = fftw;
self->spread_sort = spread_sort;
self->spread_kerevalmeth = int(spread_kerevalmeth);
self->spread_kerpad = int(spread_kerpad);
self->upsampfac = upsampfac;
self->spread_thread = int(spread_thread);
self->maxbatchsize = maxbatchsize;
self->spread_nthr_atomic = spread_nthr_atomic;
self->spread_max_sp_size = spread_max_sp_size;
});
}
} // namespace