-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathkernels.cc.cu
122 lines (97 loc) · 3.93 KB
/
kernels.cc.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "cufinufft_wrapper.h"
#include "kernel_helpers.h"
#include "kernels.h"
using namespace jax_finufft::gpu;
namespace jax_finufft {
void ThrowIfError(cudaError_t error) {
if (error != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(error));
}
}
template <int ndim, typename T>
void run_nufft(int type, const descriptor<T> *descriptor, T *x, T *y, T *z, std::complex<T> *c,
std::complex<T> *F, cudaStream_t stream) {
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];
cufinufft_opts opts = descriptor->opts;
update_opts<T>(&opts, ndim, stream);
typename plan_type<T>::type plan;
makeplan<T>(type, ndim, descriptor->n_k, descriptor->iflag, descriptor->n_transf,
descriptor->eps, &plan, &opts);
for (int64_t index = 0; index < descriptor->n_tot; ++index) {
int64_t i = index * descriptor->n_j;
int64_t j = index * descriptor->n_j * descriptor->n_transf;
int64_t k = index * n_k * descriptor->n_transf;
setpts<T>(plan, descriptor->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0,
NULL, NULL, NULL);
execute<T>(plan, &c[j], &F[k]);
}
// Don't free resources like the cuFFT plan until the stream is done.
cudaStreamSynchronize(stream);
destroy<T>(plan);
}
template <int ndim, typename T>
void nufft1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
const descriptor<T> *desc = unpack_descriptor<descriptor<T>>(opaque, opaque_len);
std::complex<T> *c = reinterpret_cast<std::complex<T> *>(buffers[0]);
T *x = reinterpret_cast<T *>(buffers[1]);
T *y = NULL;
T *z = NULL;
int out_dim = 2;
if (ndim > 1) {
y = reinterpret_cast<T *>(buffers[2]);
out_dim = 3;
}
if (ndim > 2) {
z = reinterpret_cast<T *>(buffers[3]);
out_dim = 4;
}
std::complex<T> *F = reinterpret_cast<std::complex<T> *>(buffers[out_dim]);
run_nufft<ndim, T>(1, desc, x, y, z, c, F, stream);
ThrowIfError(cudaGetLastError());
}
template <int ndim, typename T>
void nufft2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
const descriptor<T> *desc = unpack_descriptor<descriptor<T>>(opaque, opaque_len);
std::complex<T> *F = reinterpret_cast<std::complex<T> *>(buffers[0]);
T *x = reinterpret_cast<T *>(buffers[1]);
T *y = NULL;
T *z = NULL;
int out_dim = 2;
if (ndim > 1) {
y = reinterpret_cast<T *>(buffers[2]);
out_dim = 3;
}
if (ndim > 2) {
z = reinterpret_cast<T *>(buffers[3]);
out_dim = 4;
}
std::complex<T> *c = reinterpret_cast<std::complex<T> *>(buffers[out_dim]);
run_nufft<ndim, T>(2, desc, x, y, z, c, F, stream);
ThrowIfError(cudaGetLastError());
}
void nufft2d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, double>(stream, buffers, opaque, opaque_len);
}
void nufft2d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<2, double>(stream, buffers, opaque, opaque_len);
}
void nufft3d1(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<3, double>(stream, buffers, opaque, opaque_len);
}
void nufft3d2(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<3, double>(stream, buffers, opaque, opaque_len);
}
void nufft2d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<2, float>(stream, buffers, opaque, opaque_len);
}
void nufft2d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<2, float>(stream, buffers, opaque, opaque_len);
}
void nufft3d1f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft1<3, float>(stream, buffers, opaque, opaque_len);
}
void nufft3d2f(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) {
nufft2<3, float>(stream, buffers, opaque, opaque_len);
}
} // namespace jax_finufft