@@ -594,6 +594,85 @@ def irfft_numpy(x, n=None, axis=-1):
594
594
595
595
# ============================== ND ====================================== #
596
596
597
+ # copied from scipy.fftpack.helper
598
+ def _init_nd_shape_and_axes (x , shape , axes ):
599
+ """Handle shape and axes arguments for n-dimensional transforms.
600
+ Returns the shape and axes in a standard form, taking into account negative
601
+ values and checking for various potential errors.
602
+ Parameters
603
+ ----------
604
+ x : array_like
605
+ The input array.
606
+ shape : int or array_like of ints or None
607
+ The shape of the result. If both `shape` and `axes` (see below) are
608
+ None, `shape` is ``x.shape``; if `shape` is None but `axes` is
609
+ not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
610
+ If `shape` is -1, the size of the corresponding dimension of `x` is
611
+ used.
612
+ axes : int or array_like of ints or None
613
+ Axes along which the calculation is computed.
614
+ The default is over all axes.
615
+ Negative indices are automatically converted to their positive
616
+ counterpart.
617
+ Returns
618
+ -------
619
+ shape : array
620
+ The shape of the result. It is a 1D integer array.
621
+ axes : array
622
+ The shape of the result. It is a 1D integer array.
623
+ """
624
+ x = np .asarray (x )
625
+ noshape = shape is None
626
+ noaxes = axes is None
627
+
628
+ if noaxes :
629
+ axes = np .arange (x .ndim , dtype = np .intc )
630
+ else :
631
+ axes = np .atleast_1d (axes )
632
+
633
+ if axes .size == 0 :
634
+ axes = axes .astype (np .intc )
635
+
636
+ if not axes .ndim == 1 :
637
+ raise ValueError ("when given, axes values must be a scalar or vector" )
638
+ if not np .issubdtype (axes .dtype , np .integer ):
639
+ raise ValueError ("when given, axes values must be integers" )
640
+
641
+ axes = np .where (axes < 0 , axes + x .ndim , axes )
642
+
643
+ if axes .size != 0 and (axes .max () >= x .ndim or axes .min () < 0 ):
644
+ raise ValueError ("axes exceeds dimensionality of input" )
645
+ if axes .size != 0 and np .unique (axes ).shape != axes .shape :
646
+ raise ValueError ("all axes must be unique" )
647
+
648
+ if not noshape :
649
+ shape = np .atleast_1d (shape )
650
+ elif np .isscalar (x ):
651
+ shape = np .array ([], dtype = np .intc )
652
+ elif noaxes :
653
+ shape = np .array (x .shape , dtype = np .intc )
654
+ else :
655
+ shape = np .take (x .shape , axes )
656
+
657
+ if shape .size == 0 :
658
+ shape = shape .astype (np .intc )
659
+
660
+ if shape .ndim != 1 :
661
+ raise ValueError ("when given, shape values must be a scalar or vector" )
662
+ if not np .issubdtype (shape .dtype , np .integer ):
663
+ raise ValueError ("when given, shape values must be integers" )
664
+ if axes .shape != shape .shape :
665
+ raise ValueError ("when given, axes and shape arguments"
666
+ " have to be of the same length" )
667
+
668
+ shape = np .where (shape == - 1 , np .array (x .shape )[axes ], shape )
669
+
670
+ if shape .size != 0 and (shape < 1 ).any ():
671
+ raise ValueError (
672
+ "invalid number of data points ({0}) specified" .format (shape ))
673
+
674
+ return shape , axes
675
+
597
676
598
677
def _cook_nd_args (a , s = None , axes = None , invreal = 0 ):
599
678
if s is None :
@@ -621,7 +700,7 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
621
700
622
701
def _iter_fftnd (a , s = None , axes = None , function = fft , overwrite_arg = False ):
623
702
a = np .asarray (a )
624
- s , axes = _cook_nd_args (a , s , axes )
703
+ s , axes = _init_nd_shape_and_axes (a , s , axes )
625
704
ovwr = overwrite_arg
626
705
for ii in reversed (range (len (axes ))):
627
706
a = function (a , n = s [ii ], axis = axes [ii ], overwrite_x = ovwr )
0 commit comments