Skip to content

Commit 2715ff5

Browse files
used init_nd_shape_and_axes in _iter_fftnd to pass scipy's tests expecting errors raised with specific text
1 parent 277c932 commit 2715ff5

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,85 @@ def irfft_numpy(x, n=None, axis=-1):
594594

595595
# ============================== ND ====================================== #
596596

597+
# copied from scipy.fftpack.helper
598+
def _init_nd_shape_and_axes(x, shape, axes):
599+
"""Handle shape and axes arguments for n-dimensional transforms.
600+
Returns the shape and axes in a standard form, taking into account negative
601+
values and checking for various potential errors.
602+
Parameters
603+
----------
604+
x : array_like
605+
The input array.
606+
shape : int or array_like of ints or None
607+
The shape of the result. If both `shape` and `axes` (see below) are
608+
None, `shape` is ``x.shape``; if `shape` is None but `axes` is
609+
not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``.
610+
If `shape` is -1, the size of the corresponding dimension of `x` is
611+
used.
612+
axes : int or array_like of ints or None
613+
Axes along which the calculation is computed.
614+
The default is over all axes.
615+
Negative indices are automatically converted to their positive
616+
counterpart.
617+
Returns
618+
-------
619+
shape : array
620+
The shape of the result. It is a 1D integer array.
621+
axes : array
622+
The shape of the result. It is a 1D integer array.
623+
"""
624+
x = np.asarray(x)
625+
noshape = shape is None
626+
noaxes = axes is None
627+
628+
if noaxes:
629+
axes = np.arange(x.ndim, dtype=np.intc)
630+
else:
631+
axes = np.atleast_1d(axes)
632+
633+
if axes.size == 0:
634+
axes = axes.astype(np.intc)
635+
636+
if not axes.ndim == 1:
637+
raise ValueError("when given, axes values must be a scalar or vector")
638+
if not np.issubdtype(axes.dtype, np.integer):
639+
raise ValueError("when given, axes values must be integers")
640+
641+
axes = np.where(axes < 0, axes + x.ndim, axes)
642+
643+
if axes.size != 0 and (axes.max() >= x.ndim or axes.min() < 0):
644+
raise ValueError("axes exceeds dimensionality of input")
645+
if axes.size != 0 and np.unique(axes).shape != axes.shape:
646+
raise ValueError("all axes must be unique")
647+
648+
if not noshape:
649+
shape = np.atleast_1d(shape)
650+
elif np.isscalar(x):
651+
shape = np.array([], dtype=np.intc)
652+
elif noaxes:
653+
shape = np.array(x.shape, dtype=np.intc)
654+
else:
655+
shape = np.take(x.shape, axes)
656+
657+
if shape.size == 0:
658+
shape = shape.astype(np.intc)
659+
660+
if shape.ndim != 1:
661+
raise ValueError("when given, shape values must be a scalar or vector")
662+
if not np.issubdtype(shape.dtype, np.integer):
663+
raise ValueError("when given, shape values must be integers")
664+
if axes.shape != shape.shape:
665+
raise ValueError("when given, axes and shape arguments"
666+
" have to be of the same length")
667+
668+
shape = np.where(shape == -1, np.array(x.shape)[axes], shape)
669+
670+
if shape.size != 0 and (shape < 1).any():
671+
raise ValueError(
672+
"invalid number of data points ({0}) specified".format(shape))
673+
674+
return shape, axes
675+
597676

598677
def _cook_nd_args(a, s=None, axes=None, invreal=0):
599678
if s is None:
@@ -621,7 +700,7 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
621700

622701
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False):
623702
a = np.asarray(a)
624-
s, axes = _cook_nd_args(a, s, axes)
703+
s, axes = _init_nd_shape_and_axes(a, s, axes)
625704
ovwr = overwrite_arg
626705
for ii in reversed(range(len(axes))):
627706
a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr)

0 commit comments

Comments
 (0)