-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcufinufft_wrapper.h
71 lines (49 loc) · 1.45 KB
/
cufinufft_wrapper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#ifndef _CUFINUFFT_WRAPPER_H_
#define _CUFINUFFT_WRAPPER_H_
#include <complex>
#include "cufinufft.h"
namespace jax_finufft {
namespace gpu {
template <typename T>
struct plan_type;
template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};
template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};
template <typename T>
void default_opts(cufinufft_opts* opts);
template <typename T>
void update_opts(cufinufft_opts* opts, int dim, cudaStream_t stream);
template <typename T>
int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, cufinufft_opts* opts);
template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);
template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);
template <typename T>
void destroy(typename plan_type<T>::type plan);
template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}
template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}
template <>
double* y_index<1, double>(double* y, int64_t index);
template <>
float* y_index<1, float>(float* y, int64_t index);
template <>
double* z_index<3, double>(double* z, int64_t index);
template <>
float* z_index<3, float>(float* z, int64_t index);
} // namespace gpu
} // namespace jax_finufft
#endif