33
33
import contextvars
34
34
import operator
35
35
import os
36
+ from numbers import Number
36
37
37
38
import mkl
38
39
import numpy as np
@@ -156,30 +157,65 @@ def _check_plan(plan):
156
157
)
157
158
158
159
159
- def _check_overwrite_x ( overwrite_x ):
160
- if overwrite_x :
161
- raise NotImplementedError (
162
- "Overwriting the content of `x` is currently not supported"
163
- )
160
+ # copied from scipy.fft._pocketfft.helper
161
+ # https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
162
+ def _iterable_of_int ( x , name = None ):
163
+ if isinstance ( x , Number ):
164
+ x = ( x , )
164
165
166
+ try :
167
+ x = [operator .index (a ) for a in x ]
168
+ except TypeError as e :
169
+ name = name or "value"
170
+ raise ValueError (
171
+ f"{ name } must be a scalar or iterable of integers"
172
+ ) from e
165
173
166
- def _cook_nd_args (x , s = None , axes = None , invreal = False ):
167
- if s is None :
168
- shapeless = True
169
- if axes is None :
170
- s = list (x .shape )
171
- else :
172
- s = np .take (x .shape , axes )
174
+ return x
175
+
176
+
177
+ # copied and modified from scipy.fft._pocketfft.helper
178
+ # https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py
179
+ def _init_nd_shape_and_axes (x , shape , axes , invreal = False ):
180
+ noshape = shape is None
181
+ noaxes = axes is None
182
+
183
+ if not noaxes :
184
+ axes = _iterable_of_int (axes , "axes" )
185
+ axes = [a + x .ndim if a < 0 else a for a in axes ]
186
+
187
+ if any (a >= x .ndim or a < 0 for a in axes ):
188
+ raise ValueError ("axes exceeds dimensionality of input" )
189
+ if len (set (axes )) != len (axes ):
190
+ raise ValueError ("all axes must be unique" )
191
+
192
+ if not noshape :
193
+ shape = _iterable_of_int (shape , "shape" )
194
+
195
+ if axes and len (axes ) != len (shape ):
196
+ raise ValueError (
197
+ "when given, axes and shape arguments"
198
+ " have to be of the same length"
199
+ )
200
+ if noaxes :
201
+ if len (shape ) > x .ndim :
202
+ raise ValueError ("shape requires more axes than are present" )
203
+ axes = range (x .ndim - len (shape ), x .ndim )
204
+
205
+ shape = [x .shape [a ] if s == - 1 else s for s , a in zip (shape , axes )]
206
+ elif noaxes :
207
+ shape = list (x .shape )
208
+ axes = range (x .ndim )
173
209
else :
174
- shapeless = False
175
- s = list ( s )
176
- if axes is None :
177
- axes = list ( range ( - len ( s ), 0 ))
178
- if len ( s ) != len ( axes ):
179
- raise ValueError ( "Shape and axes have different lengths." )
180
- if invreal and shapeless :
181
- s [ - 1 ] = ( x . shape [ axes [ - 1 ]] - 1 ) * 2
182
- return s , axes
210
+ shape = [ x . shape [ a ] for a in axes ]
211
+
212
+ if noshape and invreal :
213
+ shape [ - 1 ] = ( x . shape [ axes [ - 1 ]] - 1 ) * 2
214
+
215
+ if any ( s < 1 for s in shape ):
216
+ raise ValueError ( f"invalid number of data points ( { shape } ) specified" )
217
+
218
+ return tuple ( shape ), list ( axes )
183
219
184
220
185
221
def _validate_input (x ):
@@ -301,7 +337,7 @@ def fftn(
301
337
"""
302
338
_check_plan (plan )
303
339
x = _validate_input (x )
304
- s , axes = _cook_nd_args (x , s , axes )
340
+ s , axes = _init_nd_shape_and_axes (x , s , axes )
305
341
fsc = _compute_fwd_scale (norm , s , x .shape )
306
342
307
343
with _Workers (workers ):
@@ -328,7 +364,7 @@ def ifftn(
328
364
"""
329
365
_check_plan (plan )
330
366
x = _validate_input (x )
331
- s , axes = _cook_nd_args (x , s , axes )
367
+ s , axes = _init_nd_shape_and_axes (x , s , axes )
332
368
fsc = _compute_fwd_scale (norm , s , x .shape )
333
369
334
370
with _Workers (workers ):
@@ -345,17 +381,13 @@ def rfft(
345
381
346
382
For full documentation refer to `scipy.fft.rfft`.
347
383
348
- Limitation
349
- -----------
350
- The kwarg `overwrite_x` is only supported with its default value.
351
-
352
384
"""
353
385
_check_plan (plan )
354
- _check_overwrite_x (overwrite_x )
355
386
x = _validate_input (x )
356
387
fsc = _compute_fwd_scale (norm , n , x .shape [axis ])
357
388
358
389
with _Workers (workers ):
390
+ # Note: overwrite_x is not utilized
359
391
return mkl_fft .rfft (x , n = n , axis = axis , fwd_scale = fsc )
360
392
361
393
@@ -367,17 +399,13 @@ def irfft(
367
399
368
400
For full documentation refer to `scipy.fft.irfft`.
369
401
370
- Limitation
371
- -----------
372
- The kwarg `overwrite_x` is only supported with its default value.
373
-
374
402
"""
375
403
_check_plan (plan )
376
- _check_overwrite_x (overwrite_x )
377
404
x = _validate_input (x )
378
405
fsc = _compute_fwd_scale (norm , n , 2 * (x .shape [axis ] - 1 ))
379
406
380
407
with _Workers (workers ):
408
+ # Note: overwrite_x is not utilized
381
409
return mkl_fft .irfft (x , n = n , axis = axis , fwd_scale = fsc )
382
410
383
411
@@ -396,10 +424,6 @@ def rfft2(
396
424
397
425
For full documentation refer to `scipy.fft.rfft2`.
398
426
399
- Limitation
400
- -----------
401
- The kwarg `overwrite_x` is only supported with its default value.
402
-
403
427
"""
404
428
return rfftn (
405
429
x ,
@@ -427,10 +451,6 @@ def irfft2(
427
451
428
452
For full documentation refer to `scipy.fft.irfft2`.
429
453
430
- Limitation
431
- -----------
432
- The kwarg `overwrite_x` is only supported with its default value.
433
-
434
454
"""
435
455
return irfftn (
436
456
x ,
@@ -458,18 +478,14 @@ def rfftn(
458
478
459
479
For full documentation refer to `scipy.fft.rfftn`.
460
480
461
- Limitation
462
- -----------
463
- The kwarg `overwrite_x` is only supported with its default value.
464
-
465
481
"""
466
482
_check_plan (plan )
467
- _check_overwrite_x (overwrite_x )
468
483
x = _validate_input (x )
469
- s , axes = _cook_nd_args (x , s , axes )
484
+ s , axes = _init_nd_shape_and_axes (x , s , axes )
470
485
fsc = _compute_fwd_scale (norm , s , x .shape )
471
486
472
487
with _Workers (workers ):
488
+ # Note: overwrite_x is not utilized
473
489
return mkl_fft .rfftn (x , s , axes , fwd_scale = fsc )
474
490
475
491
@@ -488,18 +504,14 @@ def irfftn(
488
504
489
505
For full documentation refer to `scipy.fft.irfftn`.
490
506
491
- Limitation
492
- -----------
493
- The kwarg `overwrite_x` is only supported with its default value.
494
-
495
507
"""
496
508
_check_plan (plan )
497
- _check_overwrite_x (overwrite_x )
498
509
x = _validate_input (x )
499
- s , axes = _cook_nd_args (x , s , axes , invreal = True )
510
+ s , axes = _init_nd_shape_and_axes (x , s , axes , invreal = True )
500
511
fsc = _compute_fwd_scale (norm , s , x .shape )
501
512
502
513
with _Workers (workers ):
514
+ # Note: overwrite_x is not utilized
503
515
return mkl_fft .irfftn (x , s , axes , fwd_scale = fsc )
504
516
505
517
@@ -512,20 +524,16 @@ def hfft(
512
524
513
525
For full documentation refer to `scipy.fft.hfft`.
514
526
515
- Limitation
516
- -----------
517
- The kwarg `overwrite_x` is only supported with its default value.
518
-
519
527
"""
520
528
_check_plan (plan )
521
- _check_overwrite_x (overwrite_x )
522
529
x = _validate_input (x )
523
530
norm = _swap_direction (norm )
524
531
x = np .array (x , copy = True )
525
532
np .conjugate (x , out = x )
526
533
fsc = _compute_fwd_scale (norm , n , 2 * (x .shape [axis ] - 1 ))
527
534
528
535
with _Workers (workers ):
536
+ # Note: overwrite_x is not utilized
529
537
return mkl_fft .irfft (x , n = n , axis = axis , fwd_scale = fsc )
530
538
531
539
@@ -537,18 +545,14 @@ def ihfft(
537
545
538
546
For full documentation refer to `scipy.fft.ihfft`.
539
547
540
- Limitation
541
- -----------
542
- The kwarg `overwrite_x` is only supported with its default value.
543
-
544
548
"""
545
549
_check_plan (plan )
546
- _check_overwrite_x (overwrite_x )
547
550
x = _validate_input (x )
548
551
norm = _swap_direction (norm )
549
552
fsc = _compute_fwd_scale (norm , n , x .shape [axis ])
550
553
551
554
with _Workers (workers ):
555
+ # Note: overwrite_x is not utilized
552
556
result = mkl_fft .rfft (x , n = n , axis = axis , fwd_scale = fsc )
553
557
554
558
np .conjugate (result , out = result )
@@ -570,10 +574,6 @@ def hfft2(
570
574
571
575
For full documentation refer to `scipy.fft.hfft2`.
572
576
573
- Limitation
574
- -----------
575
- The kwarg `overwrite_x` is only supported with its default value.
576
-
577
577
"""
578
578
return hfftn (
579
579
x ,
@@ -601,10 +601,6 @@ def ihfft2(
601
601
602
602
For full documentation refer to `scipy.fft.ihfft2`.
603
603
604
- Limitation
605
- -----------
606
- The kwarg `overwrite_x` is only supported with its default value.
607
-
608
604
"""
609
605
return ihfftn (
610
606
x ,
@@ -633,21 +629,17 @@ def hfftn(
633
629
634
630
For full documentation refer to `scipy.fft.hfftn`.
635
631
636
- Limitation
637
- -----------
638
- The kwarg `overwrite_x` is only supported with its default value.
639
-
640
632
"""
641
633
_check_plan (plan )
642
- _check_overwrite_x (overwrite_x )
643
634
x = _validate_input (x )
644
635
norm = _swap_direction (norm )
645
636
x = np .array (x , copy = True )
646
637
np .conjugate (x , out = x )
647
- s , axes = _cook_nd_args (x , s , axes , invreal = True )
638
+ s , axes = _init_nd_shape_and_axes (x , s , axes , invreal = True )
648
639
fsc = _compute_fwd_scale (norm , s , x .shape )
649
640
650
641
with _Workers (workers ):
642
+ # Note: overwrite_x is not utilized
651
643
return mkl_fft .irfftn (x , s , axes , fwd_scale = fsc )
652
644
653
645
@@ -666,19 +658,15 @@ def ihfftn(
666
658
667
659
For full documentation refer to `scipy.fft.ihfftn`.
668
660
669
- Limitation
670
- -----------
671
- The kwarg `overwrite_x` is only supported with its default value.
672
-
673
661
"""
674
662
_check_plan (plan )
675
- _check_overwrite_x (overwrite_x )
676
663
x = _validate_input (x )
677
664
norm = _swap_direction (norm )
678
- s , axes = _cook_nd_args (x , s , axes )
665
+ s , axes = _init_nd_shape_and_axes (x , s , axes )
679
666
fsc = _compute_fwd_scale (norm , s , x .shape )
680
667
681
668
with _Workers (workers ):
669
+ # Note: overwrite_x is not utilized
682
670
result = mkl_fft .rfftn (x , s , axes , fwd_scale = fsc )
683
671
684
672
np .conjugate (result , out = result )
0 commit comments