@@ -262,23 +262,43 @@ def _iter_fftnd(
262
262
axes = None ,
263
263
out = None ,
264
264
direction = + 1 ,
265
- overwrite_x = False ,
266
- scale_function = lambda n , ind : 1.0 ,
265
+ scale_function = lambda ind : 1.0 ,
267
266
):
268
267
a = np .asarray (a )
269
268
s , axes = _init_nd_shape_and_axes (a , s , axes )
270
- ovwr = overwrite_x
271
- for ii in reversed (range (len (axes ))):
269
+
270
+ # Combine the two, but in reverse, to end with the first axis given.
271
+ axes_and_s = list (zip (axes , s ))[::- 1 ]
272
+ # We try to use in-place calculations where possible, which is
273
+ # everywhere except when the size changes after the first FFT.
274
+ size_changes = [axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n ]
275
+
276
+ # If there are any size changes, we cannot use out
277
+ res = None if size_changes else out
278
+ for ind , (axis , n ) in enumerate (axes_and_s ):
279
+ if axis in size_changes :
280
+ if axis == size_changes [- 1 ]:
281
+ # Last size change, so any output should now be OK
282
+ # (an error will be raised if not), and if no output is
283
+ # required, we want a freshly allocated array of the right size.
284
+ res = out
285
+ elif res is not None and n < res .shape [axis ]:
286
+ # For an intermediate step where we return fewer elements, we
287
+ # can use a smaller view of the previous array.
288
+ res = res [(slice (None ),) * axis + (slice (n ),)]
289
+ else :
290
+ # If we need more elements, we cannot use res.
291
+ res = None
272
292
a = _c2c_fft1d_impl (
273
293
a ,
274
- n = s [ii ],
275
- axis = axes [ii ],
276
- overwrite_x = ovwr ,
294
+ n = n ,
295
+ axis = axis ,
277
296
direction = direction ,
278
- fsc = scale_function (s [ ii ], ii ),
279
- out = out ,
297
+ fsc = scale_function (ind ),
298
+ out = res ,
280
299
)
281
- ovwr = True
300
+ # Default output for next iteration.
301
+ res = a
282
302
return a
283
303
284
304
@@ -360,7 +380,6 @@ def _c2c_fftnd_impl(
360
380
x ,
361
381
s = None ,
362
382
axes = None ,
363
- overwrite_x = False ,
364
383
direction = + 1 ,
365
384
fsc = 1.0 ,
366
385
out = None ,
@@ -385,7 +404,6 @@ def _c2c_fftnd_impl(
385
404
if _direct :
386
405
return _direct_fftnd (
387
406
x ,
388
- overwrite_x = overwrite_x ,
389
407
direction = direction ,
390
408
fsc = fsc ,
391
409
out = out ,
@@ -403,11 +421,7 @@ def _c2c_fftnd_impl(
403
421
x ,
404
422
axes ,
405
423
_direct_fftnd ,
406
- {
407
- "overwrite_x" : overwrite_x ,
408
- "direction" : direction ,
409
- "fsc" : fsc ,
410
- },
424
+ {"direction" : direction , "fsc" : fsc },
411
425
res ,
412
426
)
413
427
else :
@@ -418,8 +432,7 @@ def _c2c_fftnd_impl(
418
432
axes = axes ,
419
433
out = out ,
420
434
direction = direction ,
421
- overwrite_x = overwrite_x ,
422
- scale_function = lambda n , i : fsc if i == 0 else 1.0 ,
435
+ scale_function = lambda i : fsc if i == 0 else 1.0 ,
423
436
)
424
437
425
438
@@ -449,16 +462,30 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
449
462
ind [la ] = ii
450
463
tind = tuple (ind )
451
464
a_inp = a [tind ]
452
- res = out [tind ] if out is not None else None
465
+ res = out [tind ] if out is not None else a_inp
453
466
a_res = _c2c_fftnd_impl (
454
- a_inp , s = ss , axes = aa , overwrite_x = True , direction = 1 , out = res
467
+ a_inp , s = ss , axes = aa , direction = 1 , out = res
455
468
)
456
469
if a_res is not a_inp :
457
470
a [tind ] = a_res # copy in place
458
471
else :
459
472
# a series of 1D c2c FFTs along all axes except last
460
- for ii in range (len (axes ) - 2 , - 1 , - 1 ):
461
- a = _c2c_fft1d_impl (a , s [ii ], axes [ii ], overwrite_x = True )
473
+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
474
+ size_changes = [
475
+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
476
+ ]
477
+ res = None if size_changes else out
478
+
479
+ for axis , n in axes_and_s :
480
+ if axis in size_changes :
481
+ if axis == size_changes [- 1 ]:
482
+ res = out
483
+ elif res is not None and n < res .shape [axis ]:
484
+ res = res [(slice (None ),) * axis + (slice (n ),)]
485
+ else :
486
+ res = None
487
+ a = _c2c_fft1d_impl (a , n , axis , out = res )
488
+ res = a
462
489
return a
463
490
464
491
@@ -472,21 +499,17 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
472
499
if len (s ) > 1 :
473
500
if not no_trim :
474
501
a = _pad_array (a , s , axes )
475
- ovr_x = True if _datacopied (a , x ) else False
476
502
len_axes = len (axes )
477
503
if len (set (axes )) == len_axes and len_axes == a .ndim and len_axes > 2 :
478
504
# a series of ND c2c FFTs along last axis
479
505
# due to need to write into a, we must copy
480
- if not ovr_x :
481
- a = a .copy ()
482
- ovr_x = True
506
+ a = a if _datacopied (a , x ) else a .copy ()
483
507
if not np .issubdtype (a .dtype , np .complexfloating ):
484
508
# complex output will be copied to input, copy is needed
485
509
if a .dtype == np .float32 :
486
510
a = a .astype (np .complex64 )
487
511
else :
488
512
a = a .astype (np .complex128 )
489
- ovr_x = True
490
513
ss , aa = _remove_axis (s , axes , - 1 )
491
514
ind = [
492
515
slice (None , None , 1 ),
@@ -497,18 +520,27 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None):
497
520
a_inp = a [tind ]
498
521
# out has real dtype and cannot be used in intermediate steps
499
522
a_res = _c2c_fftnd_impl (
500
- a_inp , s = ss , axes = aa , overwrite_x = True , direction = - 1
523
+ a_inp , s = ss , axes = aa , out = a_inp , direction = - 1
501
524
)
502
525
if a_res is not a_inp :
503
526
a [tind ] = a_res # copy in place
504
527
else :
505
528
# a series of 1D c2c FFTs along all axes except last
506
- for ii in range (len (axes ) - 1 ):
507
- # out has real dtype and cannot be used in intermediate steps
508
- a = _c2c_fft1d_impl (
509
- a , s [ii ], axes [ii ], overwrite_x = ovr_x , direction = - 1
510
- )
511
- ovr_x = True
529
+ axes_and_s = list (zip (axes , s ))[- 2 ::- 1 ]
530
+ size_changes = [
531
+ axis for axis , n in axes_and_s [1 :] if a .shape [axis ] != n
532
+ ]
533
+ # out has real dtype cannot be used for intermediate steps
534
+ res = None
535
+ for axis , n in axes_and_s :
536
+ if axis in size_changes :
537
+ if res is not None and n < res .shape [axis ]:
538
+ # pylint: disable=unsubscriptable-object
539
+ res = res [(slice (None ),) * axis + (slice (n ),)]
540
+ else :
541
+ res = None
542
+ a = _c2c_fft1d_impl (a , n , axis , out = res , direction = - 1 )
543
+ res = a
512
544
# c2r along last axis
513
545
a = _c2r_fft1d_impl (a , n = s [- 1 ], axis = la , fsc = fsc , out = out )
514
546
return a
0 commit comments