Skip to content

Commit

Permalink
Adding advanced interface (#68)
Browse files Browse the repository at this point in the history
* Adding CPU options struct

* don't export enum values

* getting CPU opts working

* naming issue

* moving options building to python via pydantic

* getting opts implemented

* gpu bugs

* more bugs

* bugz bugz

* adding options tests

* fixing variable name

* moving include of descriptor

* includes

* lib: put cufinufft wrapper function declarations in their own header file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lib: need to declare template specialization in header

---------

Co-authored-by: Lehman Garrison <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent 672247e commit ef69daa
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 126 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build" OFF)

if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found; compiling with GPU support")
enable_language(CUDA)
Expand All @@ -28,7 +30,7 @@ else()
set(FINUFFT_USE_CUDA OFF)
endif()

if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
# TODO(dfm): OpenMP segfaults on my system - can we enable this somehow?
set(FINUFFT_USE_OPENMP OFF)
else()
Expand Down Expand Up @@ -63,6 +65,7 @@ if(FINUFFT_USE_CUDA)
)
pybind11_add_module(jax_finufft_gpu
${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_gpu.cc
${CMAKE_CURRENT_LIST_DIR}/lib/cufinufft_wrapper.cc
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu)
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_INCLUDE_DIRS})
target_include_directories(jax_finufft_gpu PUBLIC ${CUFINUFFT_VENDORED_INCLUDE_DIRS})
Expand Down
55 changes: 12 additions & 43 deletions lib/jax_finufft_gpu.h → lib/cufinufft_wrapper.cc
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
#ifndef _JAX_FINUFFT_GPU_H_
#define _JAX_FINUFFT_GPU_H_
#include "cufinufft_wrapper.h"

#include <complex>

#include "cufinufft.h"

namespace jax_finufft {

template <typename T>
struct plan_type;
namespace gpu {

template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};
void default_opts<float>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}

template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};

template <typename T>
void default_opts(int type, int dim, cufinufft_opts* opts, cudaStream_t stream);
void default_opts<double>(cufinufft_opts* opts) {
cufinufft_default_opts(opts);
}

template <>
void default_opts<float>(int type, int dim, cufinufft_opts* opts, cudaStream_t stream) {
cufinufft_default_opts(opts);
void update_opts<float>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;
}

template <>
void default_opts<double>(int type, int dim, cufinufft_opts* opts, cudaStream_t stream) {
cufinufft_default_opts(opts);
void update_opts<double>(cufinufft_opts* opts, int dim, cudaStream_t stream) {
opts->gpu_stream = stream;

// double precision in 3D blows out shared memory.
Expand All @@ -42,10 +35,6 @@ void default_opts<double>(int type, int dim, cufinufft_opts* opts, cudaStream_t
}
}

template <typename T>
int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, cufinufft_opts* opts);

template <>
int makeplan<float>(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, float eps,
typename plan_type<float>::type* plan, cufinufft_opts* opts) {
Expand All @@ -61,10 +50,6 @@ int makeplan<double>(int type, int dim, const int64_t nmodes[3], int iflag, int
return cufinufft_makeplan(type, dim, tmp_nmodes, iflag, ntr, eps, plan, opts);
}

template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);

template <>
int setpts<float>(typename plan_type<float>::type plan, int64_t M, float* x, float* y, float* z,
int64_t N, float* s, float* t, float* u) {
Expand All @@ -77,9 +62,6 @@ int setpts<double>(typename plan_type<double>::type plan, int64_t M, double* x,
return cufinufft_setpts(plan, M, x, y, z, N, s, t, u);
}

template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);

template <>
int execute<float>(typename plan_type<float>::type plan, std::complex<float>* c,
std::complex<float>* f) {
Expand All @@ -96,9 +78,6 @@ int execute<double>(typename plan_type<double>::type plan, std::complex<double>*
return cufinufft_execute(plan, _c, _f);
}

template <typename T>
void destroy(typename plan_type<T>::type plan);

template <>
void destroy<float>(typename plan_type<float>::type plan) {
cufinufftf_destroy(plan);
Expand All @@ -109,11 +88,6 @@ void destroy<double>(typename plan_type<double>::type plan) {
cufinufft_destroy(plan);
}

template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}

template <>
double* y_index<1, double>(double* y, int64_t index) {
return NULL;
Expand All @@ -124,11 +98,6 @@ float* y_index<1, float>(float* y, int64_t index) {
return NULL;
}

template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}

template <>
double* z_index<3, double>(double* z, int64_t index) {
return &(z[index]);
Expand All @@ -139,6 +108,6 @@ float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

} // namespace jax_finufft
} // namespace gpu

#endif
} // namespace jax_finufft
71 changes: 71 additions & 0 deletions lib/cufinufft_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef _CUFINUFFT_WRAPPER_H_
#define _CUFINUFFT_WRAPPER_H_

#include <complex>

#include "cufinufft.h"

namespace jax_finufft {

namespace gpu {

template <typename T>
struct plan_type;

template <>
struct plan_type<double> {
typedef cufinufft_plan type;
};

template <>
struct plan_type<float> {
typedef cufinufftf_plan type;
};

template <typename T>
void default_opts(cufinufft_opts* opts);

template <typename T>
void update_opts(cufinufft_opts* opts, int dim, cudaStream_t stream);

template <typename T>
int makeplan(int type, int dim, const int64_t nmodes[3], int iflag, int ntr, T eps,
typename plan_type<T>::type* plan, cufinufft_opts* opts);

template <typename T>
int setpts(typename plan_type<T>::type plan, int64_t M, T* x, T* y, T* z, int64_t N, T* s, T* t,
T* u);

template <typename T>
int execute(typename plan_type<T>::type plan, std::complex<T>* c, std::complex<T>* f);

template <typename T>
void destroy(typename plan_type<T>::type plan);

template <int ndim, typename T>
T* y_index(T* y, int64_t index) {
return &(y[index]);
}

template <int ndim, typename T>
T* z_index(T* z, int64_t index) {
return NULL;
}

template <>
double* y_index<1, double>(double* y, int64_t index);

template <>
float* y_index<1, float>(float* y, int64_t index);

template <>
double* z_index<3, double>(double* z, int64_t index);

template <>
float* z_index<3, float>(float* z, int64_t index);

} // namespace gpu

} // namespace jax_finufft

#endif
21 changes: 0 additions & 21 deletions lib/jax_finufft_common.h

This file was deleted.

80 changes: 64 additions & 16 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,32 @@
#include "pybind11_kernel_helpers.h"

using namespace jax_finufft;
using namespace jax_finufft::cpu;
namespace py = pybind11;

namespace {

template <int ndim, typename T>
void run_nufft(int type, void *desc_in, T *x, T *y, T *z, std::complex<T> *c, std::complex<T> *F) {
const NufftDescriptor<T> *descriptor = unpack_descriptor<NufftDescriptor<T>>(
reinterpret_cast<const char *>(desc_in), sizeof(NufftDescriptor<T>));
const descriptor<T> *desc = unpack_descriptor<descriptor<T>>(
reinterpret_cast<const char *>(desc_in), sizeof(descriptor<T>));
int64_t n_k = 1;
for (int d = 0; d < ndim; ++d) n_k *= descriptor->n_k[d];

finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);
for (int d = 0; d < ndim; ++d) n_k *= desc->n_k[d];
finufft_opts opts = desc->opts;

typename plan_type<T>::type plan;
makeplan<T>(type, ndim, const_cast<int64_t *>(descriptor->n_k), descriptor->iflag,
descriptor->n_transf, descriptor->eps, &plan, opts);
for (int64_t index = 0; index < descriptor->n_tot; ++index) {
int64_t i = index * descriptor->n_j;
int64_t j = i * descriptor->n_transf;
int64_t k = index * n_k * descriptor->n_transf;

setpts<T>(plan, descriptor->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0,
NULL, NULL, NULL);
makeplan<T>(type, ndim, const_cast<int64_t *>(desc->n_k), desc->iflag, desc->n_transf, desc->eps,
&plan, &opts);
for (int64_t index = 0; index < desc->n_tot; ++index) {
int64_t i = index * desc->n_j;
int64_t j = i * desc->n_transf;
int64_t k = index * n_k * desc->n_transf;

setpts<T>(plan, desc->n_j, &(x[i]), y_index<ndim, T>(y, i), z_index<ndim, T>(z, i), 0, NULL,
NULL, NULL);
execute<T>(plan, &c[j], &F[k]);
}
destroy<T>(plan);
delete opts;
}

template <int ndim, typename T>
Expand Down Expand Up @@ -68,6 +67,40 @@ void nufft2(void *out, void **in) {
run_nufft<ndim, T>(2, in[0], x, y, z, c, F);
}

template <typename T>
py::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3, finufft_opts opts) {
return pack_descriptor(
descriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}, opts});
}

template <typename T>
finufft_opts *build_opts(bool modeord, bool chkbnds, int debug, int spread_debug, bool showwarn,
int nthreads, int fftw, int spread_sort, bool spread_kerevalmeth,
bool spread_kerpad, double upsampfac, int spread_thread, int maxbatchsize,
int spread_nthr_atomic, int spread_max_sp_size) {
finufft_opts *opts = new finufft_opts;
default_opts<T>(opts);

opts->modeord = int(modeord);
opts->chkbnds = int(chkbnds);
opts->debug = debug;
opts->spread_debug = spread_debug;
opts->showwarn = int(showwarn);
opts->nthreads = nthreads;
opts->fftw = fftw;
opts->spread_sort = spread_sort;
opts->spread_kerevalmeth = int(spread_kerevalmeth);
opts->spread_kerpad = int(spread_kerpad);
opts->upsampfac = upsampfac;
opts->spread_thread = int(spread_thread);
opts->maxbatchsize = maxbatchsize;
opts->spread_nthr_atomic = spread_nthr_atomic;
opts->spread_max_sp_size = spread_max_sp_size;

return opts;
}

pybind11::dict Registrations() {
pybind11::dict dict;

Expand All @@ -92,6 +125,21 @@ PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);

m.attr("FFTW_ESTIMATE") = py::int_(FFTW_ESTIMATE);
m.attr("FFTW_MEASURE") = py::int_(FFTW_MEASURE);
m.attr("FFTW_PATIENT") = py::int_(FFTW_PATIENT);
m.attr("FFTW_EXHAUSTIVE") = py::int_(FFTW_EXHAUSTIVE);
m.attr("FFTW_WISDOM_ONLY") = py::int_(FFTW_WISDOM_ONLY);

py::class_<finufft_opts> opts(m, "FinufftOpts");
opts.def(py::init(&build_opts<double>), py::arg("modeord") = false, py::arg("chkbnds") = true,
py::arg("debug") = 0, py::arg("spread_debug") = 0, py::arg("showwarn") = false,
py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE),
py::arg("spread_sort") = 2, py::arg("spread_kerevalmeth") = true,
py::arg("spread_kerpad") = true, py::arg("upsampfac") = 0.0,
py::arg("spread_thread") = 0, py::arg("maxbatchsize") = 0,
py::arg("spread_nthr_atomic") = -1, py::arg("spread_max_sp_size") = 0);
}

} // namespace
16 changes: 16 additions & 0 deletions lib/jax_finufft_cpu.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#ifndef _JAX_FINUFFT_H_
#define _JAX_FINUFFT_H_

#include <fftw3.h>

#include <complex>

#include "finufft.h"

namespace jax_finufft {

namespace cpu {

template <typename T>
struct plan_type;

Expand Down Expand Up @@ -123,6 +127,18 @@ float* z_index<3, float>(float* z, int64_t index) {
return &(z[index]);
}

template <typename T>
struct descriptor {
T eps;
int iflag;
int64_t n_tot;
int n_transf;
int64_t n_j;
int64_t n_k[3];
finufft_opts opts;
};

} // namespace cpu
} // namespace jax_finufft

#endif
Loading

0 comments on commit ef69daa

Please sign in to comment.