Skip to content

Commit fa42151

Browse files
committed
get rid of overwrite_x kwarg in mkl_fft, instead utilize out kwarg
1 parent 1503f18 commit fa42151

File tree

7 files changed

+124
-125
lines changed

7 files changed

+124
-125
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ While using these interfaces is the easiest way to leverage `mk_fft`, one can al
5353

5454
### complex-to-complex (c2c) transforms:
5555

56-
`fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0, out=None)` - 1D FFT, similar to `scipy.fft.fft`
56+
`fft(x, n=None, axis=-1, fwd_scale=1.0, out=None)` - 1D FFT, similar to `numpy.fft.fft`
5757

58-
`fft2(x, s=None, axes=(-2, -1), overwrite_x=False, fwd_scale=1.0, out=None)` - 2D FFT, similar to `scipy.fft.fft2`
58+
`fft2(x, s=None, axes=(-2, -1), fwd_scale=1.0, out=None)` - 2D FFT, similar to `numpy.fft.fft2`
5959

60-
`fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0, out=None)` - ND FFT, similar to `scipy.fft.fftn`
60+
`fftn(x, s=None, axes=None, fwd_scale=1.0, out=None)` - ND FFT, similar to `numpy.fft.fftn`
6161

6262
and similar inverse FFT (`ifft*`) functions.
6363

mkl_fft/_fft_utils.py

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -262,23 +262,43 @@ def _iter_fftnd(
262262
axes=None,
263263
out=None,
264264
direction=+1,
265-
overwrite_x=False,
266-
scale_function=lambda n, ind: 1.0,
265+
scale_function=lambda ind: 1.0,
267266
):
268267
a = np.asarray(a)
269268
s, axes = _init_nd_shape_and_axes(a, s, axes)
270-
ovwr = overwrite_x
271-
for ii in reversed(range(len(axes))):
269+
270+
# Combine the two, but in reverse, to end with the first axis given.
271+
axes_and_s = list(zip(axes, s))[::-1]
272+
# We try to use in-place calculations where possible, which is
273+
# everywhere except when the size changes after the first FFT.
274+
size_changes = [axis for axis, n in axes_and_s[1:] if a.shape[axis] != n]
275+
276+
# If there are any size changes, we cannot use out
277+
res = None if size_changes else out
278+
for ind, (axis, n) in enumerate(axes_and_s):
279+
if axis in size_changes:
280+
if axis == size_changes[-1]:
281+
# Last size change, so any output should now be OK
282+
# (an error will be raised if not), and if no output is
283+
# required, we want a freshly allocated array of the right size.
284+
res = out
285+
elif res is not None and n < res.shape[axis]:
286+
# For an intermediate step where we return fewer elements, we
287+
# can use a smaller view of the previous array.
288+
res = res[(slice(None),) * axis + (slice(n),)]
289+
else:
290+
# If we need more elements, we cannot use res.
291+
res = None
272292
a = _c2c_fft1d_impl(
273293
a,
274-
n=s[ii],
275-
axis=axes[ii],
276-
overwrite_x=ovwr,
294+
n=n,
295+
axis=axis,
277296
direction=direction,
278-
fsc=scale_function(s[ii], ii),
279-
out=out,
297+
fsc=scale_function(ind),
298+
out=res,
280299
)
281-
ovwr = True
300+
# Default output for next iteration.
301+
res = a
282302
return a
283303

284304

@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360380
x,
361381
s=None,
362382
axes=None,
363-
overwrite_x=False,
364383
direction=+1,
365384
fsc=1.0,
366385
out=None,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385404
if _direct:
386405
return _direct_fftnd(
387406
x,
388-
overwrite_x=overwrite_x,
389407
direction=direction,
390408
fsc=fsc,
391409
out=out,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403421
x,
404422
axes,
405423
_direct_fftnd,
406-
{
407-
"overwrite_x": overwrite_x,
408-
"direction": direction,
409-
"fsc": fsc,
410-
},
424+
{"direction": direction, "fsc": fsc},
411425
res,
412426
)
413427
else:
@@ -418,8 +432,7 @@ def _c2c_fftnd_impl(
418432
axes=axes,
419433
out=out,
420434
direction=direction,
421-
overwrite_x=overwrite_x,
422-
scale_function=lambda n, i: fsc if i == 0 else 1.0,
435+
scale_function=lambda i: fsc if i == 0 else 1.0,
423436
)
424437

425438

@@ -449,16 +462,30 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
449462
ind[la] = ii
450463
tind = tuple(ind)
451464
a_inp = a[tind]
452-
res = out[tind] if out is not None else None
465+
res = out[tind] if out is not None else a_inp
453466
a_res = _c2c_fftnd_impl(
454-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=1, out=res
467+
a_inp, s=ss, axes=aa, direction=1, out=res
455468
)
456469
if a_res is not a_inp:
457470
a[tind] = a_res # copy in place
458471
else:
459472
# a series of 1D c2c FFTs along all axes except last
460-
for ii in range(len(axes) - 2, -1, -1):
461-
a = _c2c_fft1d_impl(a, s[ii], axes[ii], overwrite_x=True)
473+
axes_and_s = list(zip(axes, s))[-2::-1]
474+
size_changes = [
475+
axis for axis, n in axes_and_s[1:] if a.shape[axis] != n
476+
]
477+
res = None if size_changes else out
478+
479+
for axis, n in axes_and_s:
480+
if axis in size_changes:
481+
if axis == size_changes[-1]:
482+
res = out
483+
elif res is not None and n < res.shape[axis]:
484+
res = res[(slice(None),) * axis + (slice(n),)]
485+
else:
486+
res = None
487+
a = _c2c_fft1d_impl(a, n, axis, out=res)
488+
res = a
462489
return a
463490

464491

@@ -472,21 +499,17 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
472499
if len(s) > 1:
473500
if not no_trim:
474501
a = _pad_array(a, s, axes)
475-
ovr_x = True if _datacopied(a, x) else False
476502
len_axes = len(axes)
477503
if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2:
478504
# a series of ND c2c FFTs along last axis
479505
# due to need to write into a, we must copy
480-
if not ovr_x:
481-
a = a.copy()
482-
ovr_x = True
506+
a = a if _datacopied(a, x) else a.copy()
483507
if not np.issubdtype(a.dtype, np.complexfloating):
484508
# complex output will be copied to input, copy is needed
485509
if a.dtype == np.float32:
486510
a = a.astype(np.complex64)
487511
else:
488512
a = a.astype(np.complex128)
489-
ovr_x = True
490513
ss, aa = _remove_axis(s, axes, -1)
491514
ind = [
492515
slice(None, None, 1),
@@ -497,18 +520,27 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
497520
a_inp = a[tind]
498521
# out has real dtype and cannot be used in intermediate steps
499522
a_res = _c2c_fftnd_impl(
500-
a_inp, s=ss, axes=aa, overwrite_x=True, direction=-1
523+
a_inp, s=ss, axes=aa, out=a_inp, direction=-1
501524
)
502525
if a_res is not a_inp:
503526
a[tind] = a_res # copy in place
504527
else:
505528
# a series of 1D c2c FFTs along all axes except last
506-
for ii in range(len(axes) - 1):
507-
# out has real dtype and cannot be used in intermediate steps
508-
a = _c2c_fft1d_impl(
509-
a, s[ii], axes[ii], overwrite_x=ovr_x, direction=-1
510-
)
511-
ovr_x = True
529+
axes_and_s = list(zip(axes, s))[-2::-1]
530+
size_changes = [
531+
axis for axis, n in axes_and_s[1:] if a.shape[axis] != n
532+
]
533+
# out has real dtype cannot be used for intermediate steps
534+
res = None
535+
for axis, n in axes_and_s:
536+
if axis in size_changes:
537+
if res is not None and n < res.shape[axis]:
538+
# pylint: disable=unsubscriptable-object
539+
res = res[(slice(None),) * axis + (slice(n),)]
540+
else:
541+
res = None
542+
a = _c2c_fft1d_impl(a, n, axis, out=res, direction=-1)
543+
res = a
512544
# c2r along last axis
513545
a = _c2r_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=out)
514546
return a

mkl_fft/_mkl_fft.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -45,63 +45,35 @@
4545
]
4646

4747

48-
def fft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
48+
def fft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
4949
return _c2c_fft1d_impl(
50-
x,
51-
n=n,
52-
axis=axis,
53-
out=out,
54-
overwrite_x=overwrite_x,
55-
direction=+1,
56-
fsc=fwd_scale,
50+
x, n=n, axis=axis, out=out, direction=+1, fsc=fwd_scale
5751
)
5852

5953

60-
def ifft(x, n=None, axis=-1, out=None, overwrite_x=False, fwd_scale=1.0):
54+
def ifft(x, n=None, axis=-1, out=None, fwd_scale=1.0):
6155
return _c2c_fft1d_impl(
62-
x,
63-
n=n,
64-
axis=axis,
65-
out=out,
66-
overwrite_x=overwrite_x,
67-
direction=-1,
68-
fsc=fwd_scale,
56+
x, n=n, axis=axis, out=out, direction=-1, fsc=fwd_scale
6957
)
7058

7159

72-
def fft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
73-
return fftn(
74-
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
75-
)
60+
def fft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
61+
return fftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)
7662

7763

78-
def ifft2(x, s=None, axes=(-2, -1), out=None, overwrite_x=False, fwd_scale=1.0):
79-
return ifftn(
80-
x, s=s, axes=axes, out=out, overwrite_x=overwrite_x, fwd_scale=fwd_scale
81-
)
64+
def ifft2(x, s=None, axes=(-2, -1), out=None, fwd_scale=1.0):
65+
return ifftn(x, s=s, axes=axes, out=out, fwd_scale=fwd_scale)
8266

8367

84-
def fftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
68+
def fftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
8569
return _c2c_fftnd_impl(
86-
x,
87-
s=s,
88-
axes=axes,
89-
out=out,
90-
overwrite_x=overwrite_x,
91-
direction=+1,
92-
fsc=fwd_scale,
70+
x, s=s, axes=axes, out=out, direction=+1, fsc=fwd_scale
9371
)
9472

9573

96-
def ifftn(x, s=None, axes=None, out=None, overwrite_x=False, fwd_scale=1.0):
74+
def ifftn(x, s=None, axes=None, out=None, fwd_scale=1.0):
9775
return _c2c_fftnd_impl(
98-
x,
99-
s=s,
100-
axes=axes,
101-
out=out,
102-
overwrite_x=overwrite_x,
103-
direction=-1,
104-
fsc=fwd_scale,
76+
x, s=s, axes=axes, out=out, direction=-1, fsc=fwd_scale
10577
)
10678

10779

0 commit comments

Comments
 (0)