From fe0dcdc531f8a0e37832f4a2b8ebb25610119687 Mon Sep 17 00:00:00 2001 From: Martin Reinecke Date: Tue, 18 Feb 2025 17:20:14 +0100 Subject: [PATCH] add tests; add overlooked file --- python/finufft/test/test_finufft_plan.py | 47 ++++++++++++++++++++++++ src/fft.cpp | 15 ++++---- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/python/finufft/test/test_finufft_plan.py b/python/finufft/test/test_finufft_plan.py index a062c57e3..bf1ac5009 100644 --- a/python/finufft/test/test_finufft_plan.py +++ b/python/finufft/test/test_finufft_plan.py @@ -38,6 +38,22 @@ def test_finufft1_plan(dtype, shape, n_pts, output_arg, modeord): utils.verify_type1(pts, coefs, shape, sig, 1e-6) + # test adjoint type 2 + plan = Plan(2, shape, dtype=dtype, modeord=modeord) + + plan.setpts(*pts) + + if not output_arg: + sig = plan.execute_adjoint(coefs) + else: + sig = np.empty(shape, dtype=dtype) + plan.execute_adjoint(coefs, out=sig) + + if modeord == 1: + sig = np.fft.fftshift(sig) + + utils.verify_type1(pts, coefs, shape, sig, 1e-6) + @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @@ -64,6 +80,24 @@ def test_finufft2_plan(dtype, shape, n_pts, output_arg, modeord): utils.verify_type2(pts, sig, coefs, 1e-6) + # test adjoint type 1 + plan = Plan(1, shape, dtype=dtype, modeord=modeord) + + plan.setpts(*pts) + + if modeord == 1: + _sig = np.fft.ifftshift(sig) + else: + _sig = sig + + if not output_arg: + coefs = plan.execute_adjoint(_sig) + else: + coefs = np.empty(n_pts, dtype=dtype) + plan.execute_adjoint(_sig, out=coefs) + + utils.verify_type2(pts, sig, coefs, 1e-6) + @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dim", list(set(len(shape) for shape in SHAPES))) @@ -86,6 +120,19 @@ def test_finufft3_plan(dtype, dim, n_source_pts, n_target_pts, output_arg): utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6) + # test adjoint type 3 + plan = Plan(3, dim, dtype=dtype, isign=-1) + + plan.setpts(*target_pts, *((None,) * (3 - dim)), *source_pts) + + if not output_arg: + target_coefs = plan.execute_adjoint(source_coefs) + else: + target_coefs = np.empty(n_target_pts, dtype=dtype) + plan.execute_adjoint(source_coefs, out=target_coefs) + + utils.verify_type3(source_pts, source_coefs, target_pts, target_coefs, 1e-6) + def test_finufft_plan_errors(): with pytest.raises(RuntimeError, match="must be single or double"): diff --git a/src/fft.cpp b/src/fft.cpp index 47f28fcf6..be05a0073 100644 --- a/src/fft.cpp +++ b/src/fft.cpp @@ -36,13 +36,14 @@ void do_fft(const FINUFFT_PLAN_T &p, std::complex *fwBatch, bool adjoint arrdims.push_back(size_t(ns[2])); axes.push_back(3); } - bool forward = (p.fftSign < 0) != adjoint; + bool forward = (p.fftSign < 0) != adjoint; + bool spreading = (p.type == 1) != adjoint; ducc0::vfmav> data(fwBatch, arrdims); // FIXME #ifdef FINUFFT_NO_DUCC0_TWEAKS ducc0::c2c(data, data, axes, forward, TF(1), nthreads); #else - /* For type 1 NUFFTs, only the low-frequency parts of the output fine grid are - going to be used, and for type 2 NUFFTs, the high frequency parts of the + /* When spreading, only the low-frequency parts of the output fine grid are + going to be used, and when interpolating, the high frequency parts of the input fine grid are zero by definition. This can be used to reduce the total FFT work for 2D and 3D NUFFTs. One of the FFT axes always has to be transformed fully (that's why there is no savings for 1D NUFFTs), for the @@ -61,13 +62,13 @@ void do_fft(const FINUFFT_PLAN_T &p, std::complex *fwBatch, bool adjoint auto sub1 = ducc0::subarray(data, {{}, {}, {0, y_lo}}); // the next line is analogous to the Python statement "sub2 = data[:, :, y_hi:]" auto sub2 = ducc0::subarray(data, {{}, {}, {y_hi, ducc0::MAXIDX}}); - if (p.type == 1) // spreading, not all parts of the output array are needed + if (spreading) // spreading, not all parts of the output array are needed // do axis 2 in full ducc0::c2c(data, data, {2}, forward, TF(1), nthreads); // do only parts of axis 1 ducc0::c2c(sub1, sub1, {1}, forward, TF(1), nthreads); ducc0::c2c(sub2, sub2, {1}, forward, TF(1), nthreads); - if (p.type == 2) // interpolation, parts of the input array are zero + if (!spreading) // interpolation, parts of the input array are zero // do axis 2 in full ducc0::c2c(data, data, {2}, forward, TF(1), nthreads); } @@ -85,7 +86,7 @@ void do_fft(const FINUFFT_PLAN_T &p, std::complex *fwBatch, bool adjoint auto sub4 = ducc0::subarray(sub1, {{}, {}, {y_hi, ducc0::MAXIDX}, {}}); auto sub5 = ducc0::subarray(sub2, {{}, {}, {0, y_lo}, {}}); auto sub6 = ducc0::subarray(sub2, {{}, {}, {y_hi, ducc0::MAXIDX}, {}}); - if (p.type == 1) { // spreading, not all parts of the output array are needed + if (spreading) { // spreading, not all parts of the output array are needed // do axis 3 in full ducc0::c2c(data, data, {3}, forward, TF(1), nthreads); // do only parts of axis 2 @@ -97,7 +98,7 @@ void do_fft(const FINUFFT_PLAN_T &p, std::complex *fwBatch, bool adjoint ducc0::c2c(sub4, sub4, {1}, forward, TF(1), nthreads); ducc0::c2c(sub5, sub5, {1}, forward, TF(1), nthreads); ducc0::c2c(sub6, sub6, {1}, forward, TF(1), nthreads); - if (p.type == 2) { // interpolation, parts of the input array are zero + if (!spreading) { // interpolation, parts of the input array are zero // do only parts of axis 2 ducc0::c2c(sub1, sub1, {2}, forward, TF(1), nthreads); ducc0::c2c(sub2, sub2, {2}, forward, TF(1), nthreads);