Skip to content

Commit

Permalink
be more clever about memory consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
mreineck committed Feb 19, 2025
1 parent f54a1dd commit d50a912
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 54 deletions.
3 changes: 2 additions & 1 deletion include/finufft/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ static inline void finufft_fft_cleanup_threads [[maybe_unused]] () {
template<typename TF> struct FINUFFT_PLAN_T;
template<typename TF> std::vector<int> gridsize_for_fft(const FINUFFT_PLAN_T<TF> &p);
template<typename TF>
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint);
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, int ntrans_actual,
bool adjoint);

#endif // FINUFFT_INCLUDE_FINUFFT_FFT_H
14 changes: 7 additions & 7 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
#define FINUFFT_LIKELY(x) (x)
#endif

#include <array>
#include <finufft_errors.h>
#include <memory>
#include <xsimd/xsimd.hpp>

// All indexing in library that potentially can exceed 2^31 uses 64-bit signed.
// This includes all calling arguments (eg M,N) that could be huge someday.
Expand Down Expand Up @@ -186,8 +186,8 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++
std::array<std::vector<TF>, 3> XYZp; // internal primed NU points (x'_j, etc)
std::array<std::vector<TF>, 3> STUp; // internal primed targs (s'_k, etc)
type3params<TF> t3P; // groups together type 3 shift, scale, phase, parameters
std::unique_ptr<FINUFFT_PLAN_T<TF>> innerT2plan; // ptr used for type 2 in step 2 of
// type 3
std::unique_ptr<const FINUFFT_PLAN_T<TF>> innerT2plan; // ptr used for type 2 in step 2
// of type 3

// other internal structs
std::unique_ptr<Finufft_FFT_plan<TF>> fftPlan;
Expand All @@ -196,10 +196,10 @@ template<typename TF> struct FINUFFT_PLAN_T { // the main plan class, fully C++

// Remaining actions (not create/delete) in guru interface are now methods...
int setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF *s, TF *t, TF *u);
int execute(std::complex<TF> *cj, std::complex<TF> *fk, bool adjoint = false) const;
int execute_adjoint(std::complex<TF> *cj, std::complex<TF> *fk) const {
return execute(cj, fk, true);
}
int execute_internal(TC *cj, TC *fk, bool adjoint = false, int ntrans_actual = -1,
TC *aligned_scratch = nullptr, size_t scratch_size = 0) const;
int execute(TC *cj, TC *fk) const { return execute_internal(cj, fk, false); }
int execute_adjoint(TC *cj, TC *fk) const { return execute_internal(cj, fk, true); }
};

void finufft_default_opts_t(finufft_opts *o);
Expand Down
13 changes: 8 additions & 5 deletions src/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ template std::vector<int> gridsize_for_fft<float>(const FINUFFT_PLAN_T<float> &p
template std::vector<int> gridsize_for_fft<double>(const FINUFFT_PLAN_T<double> &p);

template<typename TF>
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint) {
void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, int ntrans_actual,
bool adjoint) {
#ifdef FINUFFT_USE_DUCC0
size_t nthreads = min<size_t>(MY_OMP_GET_MAX_THREADS(), p.opts.nthreads);
const auto ns = gridsize_for_fft(p);
vector<size_t> arrdims, axes;
// FIXME: use thisBatchsize if it is smaller than p.batchSize!
arrdims.push_back(size_t(p.batchSize));
// ntrans_actual may be smaller than batchSize, which we can use
// to our advantage with ducc FFT.
arrdims.push_back(size_t(ntrans_actual));
arrdims.push_back(size_t(ns[0]));
axes.push_back(1);
if (p.dim >= 2) {
Expand Down Expand Up @@ -118,6 +120,7 @@ void do_fft(const FINUFFT_PLAN_T<TF> &p, std::complex<TF> *fwBatch, bool adjoint
#endif
}
template void do_fft<float>(const FINUFFT_PLAN_T<float> &p, std::complex<float> *fwBatch,
bool adjoint);
int ntrans_actual, bool adjoint);
template void do_fft<double>(const FINUFFT_PLAN_T<double> &p,
std::complex<double> *fwBatch, bool adjoint);
std::complex<double> *fwBatch, int ntrans_actual,
bool adjoint);
110 changes: 69 additions & 41 deletions src/finufft_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iomanip>
#include <memory>
#include <vector>
#include <xsimd/xsimd.hpp>

using namespace finufft;
using namespace finufft::utils;
Expand Down Expand Up @@ -929,15 +930,15 @@ int FINUFFT_PLAN_T<TF>::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF
FINUFFT_PLAN_T<TF> *tmpplan;
int ier = finufft_makeplan_t<TF>(2, d, t2nmodes, fftSign, batchSize, tol, &tmpplan,
&t2opts);
innerT2plan.reset(tmpplan);
if (ier > 1) { // if merely warning, still proceed
fprintf(stderr, "[%s t3]: inner type 2 plan creation failed with ier=%d!\n",
__func__, ier);
return ier;
}
ier = innerT2plan->setpts(nk, STUp[0].data(), STUp[1].data(), STUp[2].data(), 0,
nullptr, nullptr,
nullptr); // note nk = # output points (not nj)
ier = tmpplan->setpts(nk, STUp[0].data(), STUp[1].data(), STUp[2].data(), 0, nullptr,
nullptr,
nullptr); // note nk = # output points (not nj)
innerT2plan.reset(tmpplan);
if (ier > 1) {
fprintf(stderr, "[%s t3]: inner type 2 setpts failed, ier=%d!\n", __func__, ier);
return ier;
Expand All @@ -956,8 +957,8 @@ template int FINUFFT_PLAN_T<double>::setpts(BIGINT nj, double *xj, double *yj, d

// EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE
template<typename TF>
int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
bool adjoint) const {
int FINUFFT_PLAN_T<TF>::execute_internal(TC *cj, TC *fk, bool adjoint, int ntrans_actual,
TC *aligned_scratch, size_t scratch_size) const {
/* See ../docs/cguru.doc for current documentation.
For given (stack of) weights cj or coefficients fk, performs NUFFTs with
Expand All @@ -977,50 +978,56 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
CNTime timer;
timer.start();

// if no number of actual transforms has been specified, use the default
if (ntrans_actual <= 0) ntrans_actual = ntrans;

if (type != 3) { // --------------------- TYPE 1,2 EXEC ------------------

double t_sprint = 0.0, t_fft = 0.0, t_deconv = 0.0; // accumulated timing
if (opts.debug)
printf("[%s] start ntrans=%d (%d batches, bsize=%d)...\n", __func__, ntrans, nbatch,
batchSize);
printf("[%s] start ntrans=%d (%d batches, bsize=%d)...\n", __func__, ntrans_actual,
nbatch, batchSize);
// allocate temporary buffers
std::vector<TC, xsimd::aligned_allocator<TC, 64>> fwBatch(nf() * batchSize);
for (int b = 0; b * batchSize < ntrans; b++) { // .....loop b over batches
bool scratch_provided = scratch_size >= size_t(nf() * batchSize);
std::vector<TC, xsimd::aligned_allocator<TC, 64>> fwBatch_(
scratch_provided ? 0 : nf() * batchSize);
TC *fwBatch = scratch_provided ? aligned_scratch : fwBatch_.data();
for (int b = 0; b * batchSize < ntrans_actual; b++) { // .....loop b over batches

// current batch is either batchSize, or possibly truncated if last one
int thisBatchSize = std::min(ntrans - b * batchSize, batchSize);
int bB = b * batchSize; // index of vector, since batchsizes same
std::complex<TF> *cjb = cj + bB * nj; // point to batch of user weights
std::complex<TF> *fkb = fk + bB * N(); // point to batch of user mode coeffs
int thisBatchSize = std::min(ntrans_actual - b * batchSize, batchSize);
int bB = b * batchSize; // index of vector, since batchsizes same
TC *cjb = cj + bB * nj; // point to batch of user weights
TC *fkb = fk + bB * N(); // point to batch of user mode coeffs
if (opts.debug > 1)
printf("[%s] start batch %d (size %d):\n", __func__, b, thisBatchSize);

// STEP 1: (varies by type)
timer.restart();
// usually spread/interp to/from fwBatch (vs spreadinterponly: to/from user grid)
std::complex<TF> *fwBatch_or_fkb = opts.spreadinterponly ? fkb : fwBatch.data();
TC *fwBatch_or_fkb = opts.spreadinterponly ? fkb : fwBatch;
if ((type == 1) != adjoint) { // spread NU pts X, weights cj, to fw grid
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch_or_fkb, cjb, adjoint);
t_sprint += timer.elapsedsec();
if (opts.spreadinterponly) // we're done (skip to next iteration of loop)
continue;
} else if (!opts.spreadinterponly) {
// amplify Fourier coeffs fk into 0-padded fw
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data(), adjoint);
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch, adjoint);
t_deconv += timer.elapsedsec();
}
if (!opts.spreadinterponly) { // Do FFT unless spread/interp only...
// STEP 2: call the FFT on this batch
timer.restart();

do_fft(*this, fwBatch.data(), adjoint);
do_fft(*this, fwBatch, thisBatchSize, adjoint);
t_fft += timer.elapsedsec();
if (opts.debug > 1) printf("\tFFT exec:\t\t%.3g s\n", timer.elapsedsec());
}
// STEP 3: (varies by type)
timer.restart();
if ((type == 1) != adjoint) { // deconvolve (amplify) fw and shuffle to fk
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch.data(), adjoint);
deconvolveBatch<TF>(thisBatchSize, *this, fkb, fwBatch, adjoint);
t_deconv += timer.elapsedsec();
} else { // interpolate unif fw grid to NU target pts
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch_or_fkb, cjb, adjoint);
Expand Down Expand Up @@ -1049,20 +1056,41 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
double t_pre = 0.0, t_spr = 0.0, t_t2 = 0.0,
t_deconv = 0.0; // accumulated timings
if (opts.debug)
printf("[%s t3] start ntrans=%d (%d batches, bsize=%d)...\n", __func__, ntrans,
nbatch, batchSize);
printf("[%s t3] start ntrans=%d (%d batches, bsize=%d)...\n", __func__,
ntrans_actual, nbatch, batchSize);

// allocate temporary buffers
std::vector<TC> CpBatch((adjoint ? nk : nj) * batchSize);
std::vector<TC, xsimd::aligned_allocator<TC, 64>> fwBatch(nf() * batchSize);
// we are trying to be clever here and re-use memory whenever possible
std::vector<TC, xsimd::aligned_allocator<TC, 64>> buf1, buf2, buf3;
TC *CpBatch, *fwBatch, *fwBatch_inner;
if (!adjoint) { // we can combine CpBatch and fwBatch_inner!
buf1.resize(std::max(nj * batchSize, innerT2plan->nf() * innerT2plan->batchSize));
CpBatch = fwBatch_inner = buf1.data();
buf2.resize(nf() * batchSize);
fwBatch = buf2.data();
} else { // we may be able to combine CpBatch and fwBatch!
if (innerT2plan->batchSize >= batchSize) {
buf1.resize(std::max(nk * batchSize, nf() * batchSize));
CpBatch = fwBatch = buf1.data();
buf2.resize(innerT2plan->nf() * innerT2plan->batchSize);
fwBatch_inner = buf2.data();
} else {
buf1.resize(nk * batchSize);
CpBatch = buf1.data();
buf2.resize(nf() * batchSize);
fwBatch = buf2.data();
buf3.resize(innerT2plan->nf() * innerT2plan->batchSize);
fwBatch_inner = buf3.data();
}
}

for (int b = 0; b * batchSize < ntrans; b++) { // .....loop b over batches
for (int b = 0; b * batchSize < ntrans_actual; b++) { // .....loop b over batches

// batching and pointers to this batch, identical to t1,2 above...
int thisBatchSize = std::min(ntrans - b * batchSize, batchSize);
int bB = b * batchSize;
std::complex<TF> *cjb = cj + bB * nj; // batch of input strengths
std::complex<TF> *fkb = fk + bB * nk; // batch of output strengths
int thisBatchSize = std::min(ntrans_actual - b * batchSize, batchSize);
int bB = b * batchSize;
TC *cjb = cj + bB * nj; // batch of input strengths
TC *fkb = fk + bB * nk; // batch of output strengths
if (opts.debug > 1)
printf("[%s t3] start batch %d (size %d):\n", __func__, b, thisBatchSize);

Expand All @@ -1080,18 +1108,16 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,

// STEP 1: spread c'_j batch (x'_j NU pts) into internal fw batch grid...
timer.restart();
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch.data(), CpBatch.data(),
adjoint); // X are primed // FIXME
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch, CpBatch,
adjoint); // X are primed
t_spr += timer.elapsedsec();

// STEP 2: type 2 NUFFT from fw batch to user output fk array batch...
timer.restart();
// illegal possible shrink of ntrans *after* plan for smaller last batch:
// MR FIXME: this breaks immutability!
innerT2plan->ntrans = thisBatchSize; // do not try this at home!
/* (alarming that FFT not shrunk, but safe, because t2's fwBatch array
still the same size, as Andrea explained; just wastes a few flops) */
innerT2plan->execute(fkb, fwBatch.data(), adjoint); // FIXME
innerT2plan->execute_internal(fkb, fwBatch, adjoint, thisBatchSize, fwBatch_inner,
innerT2plan->nf() * innerT2plan->batchSize);
t_t2 += timer.elapsedsec();
// STEP 3: apply deconvolve (precomputed 1/phiHat(targ_k), phasing too)...
timer.restart();
Expand All @@ -1114,13 +1140,13 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,
t_deconv += timer.elapsedsec();
// STEP 1: adjoint type 2 (i.e. type 1) NUFFT from CpBatch to fwBatch...
timer.restart();
// illegal possible shrink of ntrans *after* plan for smaller last batch:
innerT2plan->ntrans = thisBatchSize; // do not try this at home!
innerT2plan->execute(CpBatch.data(), fwBatch.data(), adjoint);
innerT2plan->execute_internal(CpBatch, fwBatch, adjoint, thisBatchSize,
fwBatch_inner,
innerT2plan->nf() * innerT2plan->batchSize);
t_t2 += timer.elapsedsec();
// STEP 2: interpolate fwBatch into user output array ...
timer.restart();
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch.data(), cjb,
spreadinterpSortedBatch<TF>(thisBatchSize, *this, fwBatch, cjb,
adjoint); // X are primed
t_spr += timer.elapsedsec();
// STEP 3: post-phase (possibly) the c_j output strengths (in place) ...
Expand Down Expand Up @@ -1148,10 +1174,12 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk,

return 0;
}
template int FINUFFT_PLAN_T<float>::execute(std::complex<float> *cj,
std::complex<float> *fk, bool adjoint) const;
template int FINUFFT_PLAN_T<double>::execute(
std::complex<double> *cj, std::complex<double> *fk, bool adjoint) const;
template int FINUFFT_PLAN_T<float>::execute_internal(
std::complex<float> *cj, std::complex<float> *fk, bool adjoint, int ntrans_actual,
std::complex<float> *aligned_scratch, size_t scratch_size) const;
template int FINUFFT_PLAN_T<double>::execute_internal(
std::complex<double> *cj, std::complex<double> *fk, bool adjoint, int ntrans_actual,
std::complex<double> *aligned_scratch, size_t scratch_size) const;

// DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
template<typename TF> FINUFFT_PLAN_T<TF>::~FINUFFT_PLAN_T() {
Expand Down

0 comments on commit d50a912

Please sign in to comment.