-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcufinufft_wrapper.cc
113 lines (90 loc) · 3.14 KB
/
cufinufft_wrapper.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include "cufinufft_wrapper.h"
#include <complex>
#include "cufinufft.h"
namespace jax_finufft {
namespace gpu {
template <>
void default_opts<float>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}
template <>
void default_opts<double>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}
template <>
void update_opts<float>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;
}
template <>
void update_opts<double>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;
// double precision in 3D blows out shared memory.
// Fall back to a slower, non-shared memory algorithm
// https://github.com/flatironinstitute/cufinufft/issues/58
if (dim > 2) {
opts->gpu_method = 1;
}
}
template <>
int makeplan<float>(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, cufinufft_opts* opts) {
int64_t tmp_nmodes[3] = {nmodes[0], nmodes[1],
nmodes[2]}; // TODO: use const in cufinufftf_makeplan API
return cufinufftf_makeplan(type, dim, tmp_nmodes, iflag, ntr, eps, plan, opts);
}
template <>
int makeplan<double>(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, double eps,
typename plan_type<double>::type* plan, cufinufft_opts* opts) {
int64_t tmp_nmodes[3] = {nmodes[0], nmodes[1], nmodes[2]};
return cufinufft_makeplan(type, dim, tmp_nmodes, iflag, ntr, eps, plan, opts);
}
template <>
int setpts<float>(typename plan_type<float>::type plan, int64_t M, float* x, float* y, float* z,
int64_t N, float* s, float* t, float* u) {
return cufinufftf_setpts(plan, M, x, y, z, N, s, t, u);
}
template <>
int setpts<double>(typename plan_type<double>::type plan, int64_t M, double* x, double* y,
double* z, int64_t N, double* s, double* t, double* u) {
return cufinufft_setpts(plan, M, x, y, z, N, s, t, u);
}
template <>
int execute<float>(typename plan_type<float>::type plan, std::complex<float>* c,
std::complex<float>* f) {
cuFloatComplex* _c = reinterpret_cast<cuFloatComplex*>(c);
cuFloatComplex* _f = reinterpret_cast<cuFloatComplex*>(f);
return cufinufftf_execute(plan, _c, _f);
}
template <>
int execute<double>(typename plan_type<double>::type plan, std::complex<double>* c,
std::complex<double>* f) {
cuDoubleComplex* _c = reinterpret_cast<cuDoubleComplex*>(c);
cuDoubleComplex* _f = reinterpret_cast<cuDoubleComplex*>(f);
return cufinufft_execute(plan, _c, _f);
}
template <>
void destroy<float>(typename plan_type<float>::type plan) {
cufinufftf_destroy(plan);
}
template <>
void destroy<double>(typename plan_type<double>::type plan) {
cufinufft_destroy(plan);
}
template <>
double* y_index<1, double>(double* y, int64_t index) {
return NULL;
}
template <>
float* y_index<1, float>(float* y, int64_t index) {
return NULL;
}
template <>
double* z_index<3, double>(double* z, int64_t index) {
return &(z[index]);
}
template <>
float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}
} // namespace gpu
} // namespace jax_finufft