Skip to content

Commit

Permalink
add tests; add overlooked file
Browse files Browse the repository at this point in the history
  • Loading branch information
mreineck committed Feb 18, 2025
1 parent e7bd659 commit fe0dcdc
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
47 changes: 47 additions & 0 deletions python/finufft/test/test_finufft_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ def test_finufft1_plan(dtype, shape, n_pts, output_arg, modeord):

utils.verify_type1(pts, coefs, shape, sig, 1e-6)

# test adjoint type 2
plan = Plan(2, shape, dtype=dtype, modeord=modeord)

plan.setpts(*pts)

if not output_arg:
sig = plan.execute_adjoint(coefs)
else:
sig = np.empty(shape, dtype=dtype)
plan.execute_adjoint(coefs, out=sig)

if modeord == 1:
sig = np.fft.fftshift(sig)

utils.verify_type1(pts, coefs, shape, sig, 1e-6)


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
Expand All @@ -64,6 +80,24 @@ def test_finufft2_plan(dtype, shape, n_pts, output_arg, modeord):

utils.verify_type2(pts, sig, coefs, 1e-6)

# test adjoint type 1
plan = Plan(1, shape, dtype=dtype, modeord=modeord)

plan.setpts(*pts)

if modeord == 1:
_sig = np.fft.ifftshift(sig)
else:
_sig = sig

if not output_arg:
coefs = plan.execute_adjoint(_sig)
else:
coefs = np.empty(n_pts, dtype=dtype)
plan.execute_adjoint(_sig, out=coefs)

utils.verify_type2(pts, sig, coefs, 1e-6)


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("dim", list(set(len(shape) for shape in SHAPES)))
Expand All @@ -86,6 +120,19 @@ def test_finufft3_plan(dtype, dim, n_source_pts, n_target_pts, output_arg):

utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6)

# test adjoint type 3
plan = Plan(3, dim, dtype=dtype, isign=-1)

plan.setpts(*target_pts, *((None,) * (3 - dim)), *source_pts)

if not output_arg:
target_coefs = plan.execute_adjoint(source_coefs)
else:
target_coefs = np.empty(n_target_pts, dtype=dtype)
plan.execute_adjoint(source_coefs, out=target_coefs)

utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6)


def test_finufft_plan_errors():
with pytest.raises(RuntimeError, match="must be single or double"):
Expand Down
15 changes: 8 additions & 7 deletions src/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint
arrdims.push_back(size_t(ns[2]));
axes.push_back(3);
}
bool forward = (p.fftSign < 0) != adjoint;
bool forward = (p.fftSign < 0) != adjoint;
bool spreading = (p.type == 1) != adjoint;
ducc0::vfmav<std::complex<TF>> data(fwBatch, arrdims); // FIXME
#ifdef FINUFFT_NO_DUCC0_TWEAKS
ducc0::c2c(data, data, axes, forward, TF(1), nthreads);
#else
/* For type 1 NUFFTs, only the low-frequency parts of the output fine grid are
going to be used, and for type 2 NUFFTs, the high frequency parts of the
/* When spreading, only the low-frequency parts of the output fine grid are
going to be used, and when interpolating, the high frequency parts of the
input fine grid are zero by definition. This can be used to reduce the
total FFT work for 2D and 3D NUFFTs. One of the FFT axes always has to be
transformed fully (that's why there is no savings for 1D NUFFTs), for the
Expand All @@ -61,13 +62,13 @@ void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint
auto sub1 = ducc0::subarray(data, {{}, {}, {0, y_lo}});
// the next line is analogous to the Python statement "sub2 = data[:, :, y_hi:]"
auto sub2 = ducc0::subarray(data, {{}, {}, {y_hi, ducc0::MAXIDX}});
if (p.type == 1) // spreading, not all parts of the output array are needed
if (spreading) // spreading, not all parts of the output array are needed
// do axis 2 in full
ducc0::c2c(data, data, {2}, forward, TF(1), nthreads);
// do only parts of axis 1
ducc0::c2c(sub1, sub1, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {1}, forward, TF(1), nthreads);
if (p.type == 2) // interpolation, parts of the input array are zero
if (!spreading) // interpolation, parts of the input array are zero
// do axis 2 in full
ducc0::c2c(data, data, {2}, forward, TF(1), nthreads);
}
Expand All @@ -85,7 +86,7 @@ void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint
auto sub4 = ducc0::subarray(sub1, {{}, {}, {y_hi, ducc0::MAXIDX}, {}});
auto sub5 = ducc0::subarray(sub2, {{}, {}, {0, y_lo}, {}});
auto sub6 = ducc0::subarray(sub2, {{}, {}, {y_hi, ducc0::MAXIDX}, {}});
if (p.type == 1) { // spreading, not all parts of the output array are needed
if (spreading) { // spreading, not all parts of the output array are needed
// do axis 3 in full
ducc0::c2c(data, data, {3}, forward, TF(1), nthreads);
// do only parts of axis 2
Expand All @@ -97,7 +98,7 @@ void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint
ducc0::c2c(sub4, sub4, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub5, sub5, {1}, forward, TF(1), nthreads);
ducc0::c2c(sub6, sub6, {1}, forward, TF(1), nthreads);
if (p.type == 2) { // interpolation, parts of the input array are zero
if (!spreading) { // interpolation, parts of the input array are zero
// do only parts of axis 2
ducc0::c2c(sub1, sub1, {2}, forward, TF(1), nthreads);
ducc0::c2c(sub2, sub2, {2}, forward, TF(1), nthreads);
Expand Down

0 comments on commit fe0dcdc

Please sign in to comment.