-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathkernels.h
34 lines (26 loc) · 1.1 KB
/
kernels.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#ifndef _JAX_FINUFFT_KERNELS_H_
#define _JAX_FINUFFT_KERNELS_H_
#include <cuda_runtime_api.h>
#include <cstddef>
#include <cstdint>
namespace jax_finufft {
template <typename T>
struct descriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
cufinufft_opts opts;
};
void nufft2d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d1(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d2(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft2d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d1f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
void nufft3d2f(cudaStream_t stream, void** buffers, const char* opaque, std::size_t opaque_len);
} // namespace jax_finufft
#endif