Skip to content

Commit 44d10a0

Browse files
vtavanaCopilot
andauthored
update handling shape and axes of scipy interface (#181)
* update handling shape and axes of scipy interface * Apply suggestions from code review Co-authored-by: Copilot <[email protected]> * include a few skipped tests --------- Co-authored-by: Copilot <[email protected]>
1 parent 05232e0 commit 44d10a0

File tree

3 files changed

+72
-89
lines changed

3 files changed

+72
-89
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
### Changed
1515
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
1616
* To set `mkl_fft` as the backend for SciPy is only possible through `mkl_fft.interfaces.scipy_fft` [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
17+
* SciPy interface `mkl_fft.interfaces.scipy_fft` uses the same function from SciPy for handling `s` and `axes` for N-D FFTs [gh-181](https://github.com/IntelPython/mkl_fft/pull/181)
1718

1819
## [1.3.14] (04/10/2025)
1920

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 71 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import contextvars
3434
import operator
3535
import os
36+
from numbers import Number
3637

3738
import mkl
3839
import numpy as np
@@ -156,30 +157,65 @@ def _check_plan(plan):
156157
)
157158

158159

159-
def _check_overwrite_x(overwrite_x):
160-
if overwrite_x:
161-
raise NotImplementedError(
162-
"Overwriting the content of `x` is currently not supported"
163-
)
160+
# copied from scipy.fft._pocketfft.helper
161+
# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
162+
def _iterable_of_int(x, name=None):
163+
if isinstance(x, Number):
164+
x = (x,)
164165

166+
try:
167+
x = [operator.index(a) for a in x]
168+
except TypeError as e:
169+
name = name or "value"
170+
raise ValueError(
171+
f"{name} must be a scalar or iterable of integers"
172+
) from e
165173

166-
def _cook_nd_args(x, s=None, axes=None, invreal=False):
167-
if s is None:
168-
shapeless = True
169-
if axes is None:
170-
s = list(x.shape)
171-
else:
172-
s = np.take(x.shape, axes)
174+
return x
175+
176+
177+
# copied and modified from scipy.fft._pocketfft.helper
178+
# https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
179+
def _init_nd_shape_and_axes(x, shape, axes, invreal=False):
180+
noshape = shape is None
181+
noaxes = axes is None
182+
183+
if not noaxes:
184+
axes = _iterable_of_int(axes, "axes")
185+
axes = [a + x.ndim if a < 0 else a for a in axes]
186+
187+
if any(a >= x.ndim or a < 0 for a in axes):
188+
raise ValueError("axes exceeds dimensionality of input")
189+
if len(set(axes)) != len(axes):
190+
raise ValueError("all axes must be unique")
191+
192+
if not noshape:
193+
shape = _iterable_of_int(shape, "shape")
194+
195+
if axes and len(axes) != len(shape):
196+
raise ValueError(
197+
"when given, axes and shape arguments"
198+
" have to be of the same length"
199+
)
200+
if noaxes:
201+
if len(shape) > x.ndim:
202+
raise ValueError("shape requires more axes than are present")
203+
axes = range(x.ndim - len(shape), x.ndim)
204+
205+
shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
206+
elif noaxes:
207+
shape = list(x.shape)
208+
axes = range(x.ndim)
173209
else:
174-
shapeless = False
175-
s = list(s)
176-
if axes is None:
177-
axes = list(range(-len(s), 0))
178-
if len(s) != len(axes):
179-
raise ValueError("Shape and axes have different lengths.")
180-
if invreal and shapeless:
181-
s[-1] = (x.shape[axes[-1]] - 1) * 2
182-
return s, axes
210+
shape = [x.shape[a] for a in axes]
211+
212+
if noshape and invreal:
213+
shape[-1] = (x.shape[axes[-1]] - 1) * 2
214+
215+
if any(s < 1 for s in shape):
216+
raise ValueError(f"invalid number of data points ({shape}) specified")
217+
218+
return tuple(shape), list(axes)
183219

184220

185221
def _validate_input(x):
@@ -301,7 +337,7 @@ def fftn(
301337
"""
302338
_check_plan(plan)
303339
x = _validate_input(x)
304-
s, axes = _cook_nd_args(x, s, axes)
340+
s, axes = _init_nd_shape_and_axes(x, s, axes)
305341
fsc = _compute_fwd_scale(norm, s, x.shape)
306342

307343
with _Workers(workers):
@@ -328,7 +364,7 @@ def ifftn(
328364
"""
329365
_check_plan(plan)
330366
x = _validate_input(x)
331-
s, axes = _cook_nd_args(x, s, axes)
367+
s, axes = _init_nd_shape_and_axes(x, s, axes)
332368
fsc = _compute_fwd_scale(norm, s, x.shape)
333369

334370
with _Workers(workers):
@@ -345,17 +381,13 @@ def rfft(
345381
346382
For full documentation refer to `scipy.fft.rfft`.
347383
348-
Limitation
349-
-----------
350-
The kwarg `overwrite_x` is only supported with its default value.
351-
352384
"""
353385
_check_plan(plan)
354-
_check_overwrite_x(overwrite_x)
355386
x = _validate_input(x)
356387
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
357388

358389
with _Workers(workers):
390+
# Note: overwrite_x is not utilized
359391
return mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc)
360392

361393

@@ -367,17 +399,13 @@ def irfft(
367399
368400
For full documentation refer to `scipy.fft.irfft`.
369401
370-
Limitation
371-
-----------
372-
The kwarg `overwrite_x` is only supported with its default value.
373-
374402
"""
375403
_check_plan(plan)
376-
_check_overwrite_x(overwrite_x)
377404
x = _validate_input(x)
378405
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
379406

380407
with _Workers(workers):
408+
# Note: overwrite_x is not utilized
381409
return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc)
382410

383411

@@ -396,10 +424,6 @@ def rfft2(
396424
397425
For full documentation refer to `scipy.fft.rfft2`.
398426
399-
Limitation
400-
-----------
401-
The kwarg `overwrite_x` is only supported with its default value.
402-
403427
"""
404428
return rfftn(
405429
x,
@@ -427,10 +451,6 @@ def irfft2(
427451
428452
For full documentation refer to `scipy.fft.irfft2`.
429453
430-
Limitation
431-
-----------
432-
The kwarg `overwrite_x` is only supported with its default value.
433-
434454
"""
435455
return irfftn(
436456
x,
@@ -458,18 +478,14 @@ def rfftn(
458478
459479
For full documentation refer to `scipy.fft.rfftn`.
460480
461-
Limitation
462-
-----------
463-
The kwarg `overwrite_x` is only supported with its default value.
464-
465481
"""
466482
_check_plan(plan)
467-
_check_overwrite_x(overwrite_x)
468483
x = _validate_input(x)
469-
s, axes = _cook_nd_args(x, s, axes)
484+
s, axes = _init_nd_shape_and_axes(x, s, axes)
470485
fsc = _compute_fwd_scale(norm, s, x.shape)
471486

472487
with _Workers(workers):
488+
# Note: overwrite_x is not utilized
473489
return mkl_fft.rfftn(x, s, axes, fwd_scale=fsc)
474490

475491

@@ -488,18 +504,14 @@ def irfftn(
488504
489505
For full documentation refer to `scipy.fft.irfftn`.
490506
491-
Limitation
492-
-----------
493-
The kwarg `overwrite_x` is only supported with its default value.
494-
495507
"""
496508
_check_plan(plan)
497-
_check_overwrite_x(overwrite_x)
498509
x = _validate_input(x)
499-
s, axes = _cook_nd_args(x, s, axes, invreal=True)
510+
s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True)
500511
fsc = _compute_fwd_scale(norm, s, x.shape)
501512

502513
with _Workers(workers):
514+
# Note: overwrite_x is not utilized
503515
return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc)
504516

505517

@@ -512,20 +524,16 @@ def hfft(
512524
513525
For full documentation refer to `scipy.fft.hfft`.
514526
515-
Limitation
516-
-----------
517-
The kwarg `overwrite_x` is only supported with its default value.
518-
519527
"""
520528
_check_plan(plan)
521-
_check_overwrite_x(overwrite_x)
522529
x = _validate_input(x)
523530
norm = _swap_direction(norm)
524531
x = np.array(x, copy=True)
525532
np.conjugate(x, out=x)
526533
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
527534

528535
with _Workers(workers):
536+
# Note: overwrite_x is not utilized
529537
return mkl_fft.irfft(x, n=n, axis=axis, fwd_scale=fsc)
530538

531539

@@ -537,18 +545,14 @@ def ihfft(
537545
538546
For full documentation refer to `scipy.fft.ihfft`.
539547
540-
Limitation
541-
-----------
542-
The kwarg `overwrite_x` is only supported with its default value.
543-
544548
"""
545549
_check_plan(plan)
546-
_check_overwrite_x(overwrite_x)
547550
x = _validate_input(x)
548551
norm = _swap_direction(norm)
549552
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
550553

551554
with _Workers(workers):
555+
# Note: overwrite_x is not utilized
552556
result = mkl_fft.rfft(x, n=n, axis=axis, fwd_scale=fsc)
553557

554558
np.conjugate(result, out=result)
@@ -570,10 +574,6 @@ def hfft2(
570574
571575
For full documentation refer to `scipy.fft.hfft2`.
572576
573-
Limitation
574-
-----------
575-
The kwarg `overwrite_x` is only supported with its default value.
576-
577577
"""
578578
return hfftn(
579579
x,
@@ -601,10 +601,6 @@ def ihfft2(
601601
602602
For full documentation refer to `scipy.fft.ihfft2`.
603603
604-
Limitation
605-
-----------
606-
The kwarg `overwrite_x` is only supported with its default value.
607-
608604
"""
609605
return ihfftn(
610606
x,
@@ -633,21 +629,17 @@ def hfftn(
633629
634630
For full documentation refer to `scipy.fft.hfftn`.
635631
636-
Limitation
637-
-----------
638-
The kwarg `overwrite_x` is only supported with its default value.
639-
640632
"""
641633
_check_plan(plan)
642-
_check_overwrite_x(overwrite_x)
643634
x = _validate_input(x)
644635
norm = _swap_direction(norm)
645636
x = np.array(x, copy=True)
646637
np.conjugate(x, out=x)
647-
s, axes = _cook_nd_args(x, s, axes, invreal=True)
638+
s, axes = _init_nd_shape_and_axes(x, s, axes, invreal=True)
648639
fsc = _compute_fwd_scale(norm, s, x.shape)
649640

650641
with _Workers(workers):
642+
# Note: overwrite_x is not utilized
651643
return mkl_fft.irfftn(x, s, axes, fwd_scale=fsc)
652644

653645

@@ -666,19 +658,15 @@ def ihfftn(
666658
667659
For full documentation refer to `scipy.fft.ihfftn`.
668660
669-
Limitation
670-
-----------
671-
The kwarg `overwrite_x` is only supported with its default value.
672-
673661
"""
674662
_check_plan(plan)
675-
_check_overwrite_x(overwrite_x)
676663
x = _validate_input(x)
677664
norm = _swap_direction(norm)
678-
s, axes = _cook_nd_args(x, s, axes)
665+
s, axes = _init_nd_shape_and_axes(x, s, axes)
679666
fsc = _compute_fwd_scale(norm, s, x.shape)
680667

681668
with _Workers(workers):
669+
# Note: overwrite_x is not utilized
682670
result = mkl_fft.rfftn(x, s, axes, fwd_scale=fsc)
683671

684672
np.conjugate(result, out=result)

mkl_fft/tests/third_party/scipy/test_basic.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def test_irfftn(self, xp):
230230
for norm in ["backward", "ortho", "forward"]:
231231
xp_assert_close(fft.irfftn(fft.rfftn(x, norm=norm), norm=norm), x)
232232

233-
@pytest.mark.skip("hfft is not supported")
234233
def test_hfft(self, xp):
235234
x = random(14) + 1j * random(14)
236235
x_herm = np.concatenate((random(1), x, random(1)))
@@ -246,7 +245,6 @@ def test_hfft(self, xp):
246245
)
247246
xp_assert_close(fft.hfft(x_herm, norm="forward"), expect / 30)
248247

249-
@pytest.mark.skip("ihfft is not supported")
250248
def test_ihfft(self, xp):
251249
x = random(14) + 1j * random(14)
252250
x_herm = np.concatenate((random(1), x, random(1)))
@@ -259,14 +257,12 @@ def test_ihfft(self, xp):
259257
fft.ihfft(fft.hfft(x_herm, norm=norm), norm=norm), x_herm
260258
)
261259

262-
@pytest.mark.skip("hfft2 is not supported")
263260
def test_hfft2(self, xp):
264261
x = xp.asarray(random((30, 20)))
265262
xp_assert_close(fft.hfft2(fft.ihfft2(x)), x)
266263
for norm in ["backward", "ortho", "forward"]:
267264
xp_assert_close(fft.hfft2(fft.ihfft2(x, norm=norm), norm=norm), x)
268265

269-
@pytest.mark.skip("ihfft2 is not supported")
270266
def test_ihfft2(self, xp):
271267
x = xp.asarray(random((30, 20)), dtype=xp.float64)
272268
expect = fft.ifft2(xp.asarray(x, dtype=xp.complex128))[:, :11]
@@ -278,14 +274,12 @@ def test_ihfft2(self, xp):
278274
)
279275
xp_assert_close(fft.ihfft2(x, norm="forward"), expect * (30 * 20))
280276

281-
@pytest.mark.skip("hfftn is not supported")
282277
def test_hfftn(self, xp):
283278
x = xp.asarray(random((30, 20, 10)))
284279
xp_assert_close(fft.hfftn(fft.ihfftn(x)), x)
285280
for norm in ["backward", "ortho", "forward"]:
286281
xp_assert_close(fft.hfftn(fft.ihfftn(x, norm=norm), norm=norm), x)
287282

288-
@pytest.mark.skip("ihfftn is not supported")
289283
def test_ihfftn(self, xp):
290284
x = xp.asarray(random((30, 20, 10)), dtype=xp.float64)
291285
expect = fft.ifftn(xp.asarray(x, dtype=xp.complex128))[:, :, :6]

0 commit comments

Comments
 (0)