Skip to content

Commit af69789

Browse files
BUG: fixed blunder with use of _remove_axis
The axis to be removed must be -1, not 'la'. This was changed for rfttn_numpy in a prior commit, and I attempted a refactoring, although I evidently missed this one.
1 parent a900c34 commit af69789

File tree

4 files changed

+75
-40
lines changed

4 files changed

+75
-40
lines changed

mkl_fft/_float16_utils.py renamed to mkl_fft/_float_utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,21 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from numpy import half, float32, asarray, ndarray, longdouble, float64
27+
from numpy import (half, float32, asarray, ndarray,
28+
longdouble, float64, longcomplex, complex_)
2829

29-
__all__ = ['__upcast_float16_array']
30+
__all__ = ['__upcast_float16_array', '__downcast_float128_array']
3031

3132
def __upcast_float16_array(x):
33+
"""
34+
Used in _scipy_fft to upcast float16 to float32,
35+
instead of float64, as mkl_fft would do"""
3236
if hasattr(x, "dtype"):
3337
xdt = x.dtype
3438
if xdt == half:
3539
# no half-precision routines, so convert to single precision
3640
return asarray(x, dtype=float32)
37-
if xdt == longdouble and not xdt == float64 :
41+
if xdt == longdouble and not xdt == float64:
3842
raise ValueError("type %s is not supported" % xdt)
3943
if not isinstance(x, ndarray):
4044
__x = asarray(x)
@@ -46,3 +50,24 @@ def __upcast_float16_array(x):
4650
raise ValueError("type %s is not supported" % xdt)
4751
return __x
4852
return x
53+
54+
55+
def __downcast_float128_array(x):
56+
"""
57+
Used in _numpy_fft to unsafely downcast float128/complex256 to
58+
complex128, instead of raising an error"""
59+
if hasattr(x, "dtype"):
60+
xdt = x.dtype
61+
if xdt == longdouble and not xdt == float64:
62+
return asarray(x, dtype=float64)
63+
elif xdt == longcomplex and not xdt == complex_:
64+
return asarray(x, dtype=complex_)
65+
if not isinstance(x, ndarray):
66+
__x = asarray(x)
67+
xdt = __x.dtype
68+
if xdt == longdouble and not xdt == float64:
69+
return asarray(x, dtype=float64)
70+
elif xdt == longcomplex and not xdt == complex_:
71+
return asarray(x, dtype=complex_)
72+
return __x
73+
return x

mkl_fft/_numpy_fft.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161

6262
import numpy
6363
from . import _pydfti as mkl_fft
64+
from . import _float_utils
6465

6566

6667
def _unitary(norm):
@@ -155,7 +156,8 @@ def fft(a, n=None, axis=-1, norm=None):
155156
the `numpy.fft` documentation.
156157
157158
"""
158-
output = mkl_fft.fft(a, n, axis)
159+
x = _float_utils.__downcast_float128_array(a)
160+
output = mkl_fft.fft(x, n, axis)
159161
if _unitary(norm):
160162
output *= 1 / sqrt(output.shape[axis])
161163
return output
@@ -241,7 +243,8 @@ def ifft(a, n=None, axis=-1, norm=None):
241243
242244
"""
243245
unitary = _unitary(norm)
244-
output = mkl_fft.ifft(a, n, axis)
246+
x = _float_utils.__downcast_float128_array(a)
247+
output = mkl_fft.ifft(x, n, axis)
245248
if unitary:
246249
output *= sqrt(output.shape[axis])
247250
return output
@@ -325,10 +328,11 @@ def rfft(a, n=None, axis=-1, norm=None):
325328
326329
"""
327330
unitary = _unitary(norm)
331+
x = _float_utils.__downcast_float128_array(a)
328332
if unitary and n is None:
329-
a = asarray(a)
330-
n = a.shape[axis]
331-
output = mkl_fft.rfft_numpy(a, n=n, axis=axis)
333+
x = asarray(x)
334+
n = x.shape[axis]
335+
output = mkl_fft.rfft_numpy(x, n=n, axis=axis)
332336
if unitary:
333337
output *= 1 / sqrt(n)
334338
return output
@@ -413,7 +417,8 @@ def irfft(a, n=None, axis=-1, norm=None):
413417
specified, and the output array is purely real.
414418
415419
"""
416-
output = mkl_fft.irfft_numpy(a, n=n, axis=axis)
420+
x = _float_utils.__downcast_float128_array(a)
421+
output = mkl_fft.irfft_numpy(x, n=n, axis=axis)
417422
if _unitary(norm):
418423
output *= sqrt(output.shape[axis])
419424
return output
@@ -488,12 +493,12 @@ def hfft(a, n=None, axis=-1, norm=None):
488493
[ 2., -2.]])
489494
490495
"""
491-
# The copy may be required for multithreading.
492-
a = array(a, copy=True, dtype=complex)
496+
x = _float_utils.__downcast_float128_array(a)
497+
x = array(x, copy=True, dtype=complex)
493498
if n is None:
494-
n = (a.shape[axis] - 1) * 2
499+
n = (x.shape[axis] - 1) * 2
495500
unitary = _unitary(norm)
496-
return irfft(conjugate(a), n, axis) * (sqrt(n) if unitary else n)
501+
return irfft(conjugate(x), n, axis) * (sqrt(n) if unitary else n)
497502

498503

499504
def ihfft(a, n=None, axis=-1, norm=None):
@@ -547,11 +552,12 @@ def ihfft(a, n=None, axis=-1, norm=None):
547552
548553
"""
549554
# The copy may be required for multithreading.
550-
a = array(a, copy=True, dtype=float)
555+
x = _float_utils.__downcast_float128_array(a)
556+
x = array(x, copy=True, dtype=float)
551557
if n is None:
552-
n = a.shape[axis]
558+
n = x.shape[axis]
553559
unitary = _unitary(norm)
554-
output = conjugate(rfft(a, n, axis))
560+
output = conjugate(rfft(x, n, axis))
555561
return output * (1 / (sqrt(n) if unitary else n))
556562

557563

@@ -673,7 +679,8 @@ def fftn(a, s=None, axes=None, norm=None):
673679
>>> plt.show()
674680
675681
"""
676-
output = mkl_fft.fftn(a, s, axes)
682+
x = _float_utils.__downcast_float128_array(a)
683+
output = mkl_fft.fftn(x, s, axes)
677684
if _unitary(norm):
678685
output *= 1 / sqrt(_tot_size(output, axes))
679686
return output
@@ -772,7 +779,8 @@ def ifftn(a, s=None, axes=None, norm=None):
772779
773780
"""
774781
unitary = _unitary(norm)
775-
output = mkl_fft.ifftn(a, s, axes)
782+
x = _float_utils.__downcast_float128_array(a)
783+
output = mkl_fft.ifftn(x, s, axes)
776784
if unitary:
777785
output *= sqrt(_tot_size(output, axes))
778786
return output
@@ -863,8 +871,8 @@ def fft2(a, s=None, axes=(-2, -1), norm=None):
863871
0.0 +0.j , 0.0 +0.j ]])
864872
865873
"""
866-
867-
return fftn(a, s=s, axes=axes, norm=norm)
874+
x = _float_utils.__downcast_float128_array(a)
875+
return fftn(x, s=s, axes=axes, norm=norm)
868876

869877

870878
def ifft2(a, s=None, axes=(-2, -1), norm=None):
@@ -949,8 +957,8 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
949957
[ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])
950958
951959
"""
952-
953-
return ifftn(a, s=s, axes=axes, norm=norm)
960+
x = _float_utils.__downcast_float128_array(a)
961+
return ifftn(x, s=s, axes=axes, norm=norm)
954962

955963

956964
def rfftn(a, s=None, axes=None, norm=None):
@@ -1036,11 +1044,12 @@ def rfftn(a, s=None, axes=None, norm=None):
10361044
10371045
"""
10381046
unitary = _unitary(norm)
1047+
x = _float_utils.__downcast_float128_array(a)
10391048
if unitary:
1040-
a = asarray(a)
1041-
s, axes = _cook_nd_args(a, s, axes)
1049+
x = asarray(x)
1050+
s, axes = _cook_nd_args(x, s, axes)
10421051

1043-
output = mkl_fft.rfftn_numpy(a, s, axes)
1052+
output = mkl_fft.rfftn_numpy(x, s, axes)
10441053
if unitary:
10451054
n_tot = prod(asarray(s, dtype=output.dtype))
10461055
output *= 1 / sqrt(n_tot)
@@ -1079,8 +1088,8 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None):
10791088
For more details see `rfftn`.
10801089
10811090
"""
1082-
1083-
return rfftn(a, s, axes, norm)
1091+
x = _float_utils.__downcast_float128_array(a)
1092+
return rfftn(x, s, axes, norm)
10841093

10851094

10861095
def irfftn(a, s=None, axes=None, norm=None):
@@ -1167,7 +1176,8 @@ def irfftn(a, s=None, axes=None, norm=None):
11671176
[ 1., 1.]]])
11681177
11691178
"""
1170-
output = mkl_fft.irfftn_numpy(a, s, axes)
1179+
x = _float_utils.__downcast_float128_array(a)
1180+
output = mkl_fft.irfftn_numpy(x, s, axes)
11711181
if _unitary(norm):
11721182
output *= sqrt(_tot_size(output, axes))
11731183
return output
@@ -1205,6 +1215,6 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None):
12051215
For more details see `irfftn`.
12061216
12071217
"""
1208-
1209-
return irfftn(a, s, axes, norm)
1218+
x = _float_utils.__downcast_float128_array(a)
1219+
return irfftn(x, s, axes, norm)
12101220

mkl_fft/_pydfti.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def irfftn_numpy(x, s=None, axes=None):
971971
if not ovr_x:
972972
a = a.copy()
973973
ovr_x = True
974-
ss, aa = _remove_axis(s, axes, la)
974+
ss, aa = _remove_axis(s, axes, -1)
975975
ind = [slice(None,None,1),] * len(s)
976976
for ii in range(a.shape[la]):
977977
ind[la] = ii

mkl_fft/_scipy_fft.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,46 +25,46 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
from . import _pydfti
28-
from . import _float16_utils
28+
from . import _float_utils
2929

3030
__all__ = ['fft', 'ifft', 'fftn', 'ifftn', 'fft2', 'ifft2', 'rfft', 'irfft']
3131

3232

3333
def fft(a, n=None, axis=-1, overwrite_x=False):
34-
x = _float16_utils.__upcast_float16_array(a)
34+
x = _float_utils.__upcast_float16_array(a)
3535
return _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
3636

3737

3838
def ifft(a, n=None, axis=-1, overwrite_x=False):
39-
x = _float16_utils.__upcast_float16_array(a)
39+
x = _float_utils.__upcast_float16_array(a)
4040
return _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
4141

4242

4343
def fftn(a, shape=None, axes=None, overwrite_x=False):
44-
x = _float16_utils.__upcast_float16_array(a)
44+
x = _float_utils.__upcast_float16_array(a)
4545
return _pydfti.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
4646

4747

4848
def ifftn(a, shape=None, axes=None, overwrite_x=False):
49-
x = _float16_utils.__upcast_float16_array(a)
49+
x = _float_utils.__upcast_float16_array(a)
5050
return _pydfti.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
5151

5252

5353
def fft2(a, shape=None, axes=(-2,-1), overwrite_x=False):
54-
x = _float16_utils.__upcast_float16_array(a)
54+
x = _float_utils.__upcast_float16_array(a)
5555
return _pydfti.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
5656

5757

5858
def ifft2(a, shape=None, axes=(-2,-1), overwrite_x=False):
59-
x = _float16_utils.__upcast_float16_array(a)
59+
x = _float_utils.__upcast_float16_array(a)
6060
return _pydfti.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
6161

6262

6363
def rfft(a, n=None, axis=-1, overwrite_x=False):
64-
x = _float16_utils.__upcast_float16_array(a)
64+
x = _float_utils.__upcast_float16_array(a)
6565
return _pydfti.rfft(a, n=n, axis=axis, overwrite_x=overwrite_x)
6666

6767

6868
def irfft(a, n=None, axis=-1, overwrite_x=False):
69-
x = _float16_utils.__upcast_float16_array(a)
69+
x = _float_utils.__upcast_float16_array(a)
7070
return _pydfti.irfft(a, n=n, axis=axis, overwrite_x=overwrite_x)

0 commit comments

Comments
 (0)