Skip to content

Commit 6bde33e

Browse files
vtavanaCopilotantonwolfy
authored
fix an issue for real inputs of irfft (#180)
* fix an issue for real inputs of irfft * Update mkl_fft/tests/third_party/scipy/test_basic.py Co-authored-by: Copilot <[email protected]> * Update mkl_fft/tests/test_fft1d.py Co-authored-by: Anton <[email protected]> --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Anton <[email protected]>
1 parent 44d10a0 commit 6bde33e

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
* To set `mkl_fft` as the backend for SciPy is only possible through `mkl_fft.interfaces.scipy_fft` [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
1717
* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181)
1818

19+
### Fixed
20+
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with an empty axes [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
21+
* Fixed an issue for calling `mkl_fft.interfaces.numpy.fftn` with a zero-size array [gh-139](https://github.com/IntelPython/mkl_fft/pull/139)
22+
* Fixed inconsistency of input and output arrays dtype for `irfft` function [gh-180](https://github.com/IntelPython/mkl_fft/pull/180)
23+
1924
## [1.3.14] (04/10/2025)
2025

2126
resolves gh-152 by adding an explicit `mkl-service` dependency to `mkl-fft` when building the wheel

mkl_fft/_pydfti.pyx

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,12 @@ def _r2c_fft1d_impl(
665665
return f_arr
666666

667667

668-
# this routine is functionally equivalent to numpy.fft.irfft
669668
def _c2r_fft1d_impl(
670669
x, n=None, axis=-1, overwrite_x=False, double fsc=1.0, out=None
671670
):
672671
"""
673-
Uses MKL to perform 1D FFT on the real input array x along the given axis,
674-
producing complex output, but giving only half of the harmonics.
672+
Uses MKL to perform 1D FFT on the real/complex input array x along the
673+
given axis, producing real output.
675674
676675
cf. numpy.fft.irfft
677676
"""
@@ -704,13 +703,13 @@ def _c2r_fft1d_impl(
704703
else:
705704
# we must cast the input and allocate the output,
706705
# so we cast to complex double and operate in place
707-
try:
706+
if x_type is cnp.NPY_FLOAT:
708707
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
709-
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED)
710-
except:
711-
raise ValueError(
712-
"First argument should be a real or "
713-
"a complex sequence of single or double precision"
708+
x_arr, cnp.NPY_CFLOAT, cnp.NPY_BEHAVED
709+
)
710+
else:
711+
x_arr = <cnp.ndarray> cnp.PyArray_FROM_OTF(
712+
x_arr, cnp.NPY_CDOUBLE, cnp.NPY_BEHAVED
714713
)
715714
x_type = cnp.PyArray_TYPE(x_arr)
716715
in_place = 1

mkl_fft/interfaces/_numpy_fft.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,11 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):
295295
"""
296296
norm = _swap_direction(norm)
297297
x = _downcast_float128_array(a)
298-
x = np.array(x, copy=True)
299-
np.conjugate(x, out=x)
300298
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
301299

302300
return _trycall(
303301
mkl_fft.irfft,
304-
(x,),
302+
(np.conjugate(x),),
305303
{"n": n, "axis": axis, "fwd_scale": fsc, "out": out},
306304
)
307305

@@ -317,9 +315,9 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):
317315
x = _downcast_float128_array(a)
318316
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
319317

320-
output = _trycall(
318+
result = _trycall(
321319
mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc, "out": out}
322320
)
323321

324-
np.conjugate(output, out=output)
325-
return output
322+
np.conjugate(result, out=result)
323+
return result

mkl_fft/tests/test_fft1d.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,3 +457,12 @@ def test_irfft_out_strided(axis):
457457
expected = np.fft.irfft(x, axis=axis, out=out)
458458

459459
assert_allclose(result, expected)
460+
461+
462+
@requires_numpy_2
463+
@pytest.mark.parametrize("dt", ["i4", "f4", "f8", "c8", "c16"])
464+
def test_irfft_dtype(dt):
465+
x = np.array(rnd.random((20, 20)), dtype=dt)
466+
result = mkl_fft.irfft(x)
467+
expected = np.fft.irfft(x)
468+
assert_allclose(result, expected, rtol=1e-7, atol=1e-7, strict=True)

mkl_fft/tests/third_party/scipy/test_basic.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414

1515
# pylint: disable=possibly-used-before-assignment
1616
if scipy.__version__ < "1.12":
17-
# scipy from Intel channel is 1.10
18-
pytest.skip(
19-
"This test file needs scipy>=1.12",
20-
allow_module_level=True,
21-
)
17+
# scipy from Intel channel is 1.10 with python 3.9 and 3.10
18+
pytest.skip("This test file needs scipy>=1.12", allow_module_level=True)
2219
elif scipy.__version__ < "1.14":
23-
# For python<=3.9, scipy<1.14 is installed
20+
# For python-3.11 and 3.12, scipy<1.14 is installed from Intel channel
21+
# For python<=3.9, scipy<1.14 is installed from conda channel
2422
# pylint: disable=no-name-in-module
2523
from scipy._lib._array_api import size as xp_size
2624
else:

0 commit comments

Comments
 (0)