Skip to content

Commit d162835

Browse files
All routins in src/mklfft.c aquire fsc parameter, defaulting to 1.0
Forward/backward FFT scate are determined as fsc, 1.0/(fsc*n)
1 parent 20b641b commit d162835

File tree

3 files changed

+134
-133
lines changed

3 files changed

+134
-133
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -94,42 +94,42 @@ cdef extern from "src/mklfft.h":
9494
void * hand
9595
int initialized
9696
int _free_dfti_cache(DftiCache *)
97-
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
98-
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
99-
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
100-
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
101-
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
102-
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
103-
104-
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
105-
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
106-
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
107-
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarra, DftiCache*)
108-
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
109-
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
110-
111-
int double_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
112-
int double_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
113-
int float_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
114-
int float_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
115-
116-
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
117-
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
118-
119-
int cdouble_cdouble_mkl_fftnd_in(cnp.ndarray)
120-
int cdouble_cdouble_mkl_ifftnd_in(cnp.ndarray)
121-
int cfloat_cfloat_mkl_fftnd_in(cnp.ndarray)
122-
int cfloat_cfloat_mkl_ifftnd_in(cnp.ndarray)
123-
124-
int cdouble_cdouble_mkl_fftnd_out(cnp.ndarray, cnp.ndarray)
125-
int cdouble_cdouble_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray)
126-
int cfloat_cfloat_mkl_fftnd_out(cnp.ndarray, cnp.ndarray)
127-
int cfloat_cfloat_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray)
128-
129-
int float_cfloat_mkl_fftnd_out(cnp.ndarray, cnp.ndarray)
130-
int double_cdouble_mkl_fftnd_out(cnp.ndarray, cnp.ndarray)
131-
int float_cfloat_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray)
132-
int double_cdouble_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray)
97+
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int, double, DftiCache*)
98+
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int, double, DftiCache*)
99+
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, double, DftiCache*)
100+
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
101+
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, double, DftiCache*)
102+
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
103+
104+
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int, double, DftiCache*)
105+
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int, double, DftiCache*)
106+
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, double, DftiCache*)
107+
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
108+
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, double, DftiCache*)
109+
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
110+
111+
int double_mkl_rfft_in(cnp.ndarray, int, int, double, DftiCache*)
112+
int double_mkl_irfft_in(cnp.ndarray, int, int, double, DftiCache*)
113+
int float_mkl_rfft_in(cnp.ndarray, int, int, double, DftiCache*)
114+
int float_mkl_irfft_in(cnp.ndarray, int, int, double, DftiCache*)
115+
116+
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
117+
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, double, DftiCache*)
118+
119+
int cdouble_cdouble_mkl_fftnd_in(cnp.ndarray, double)
120+
int cdouble_cdouble_mkl_ifftnd_in(cnp.ndarray, double)
121+
int cfloat_cfloat_mkl_fftnd_in(cnp.ndarray, double)
122+
int cfloat_cfloat_mkl_ifftnd_in(cnp.ndarray, double)
123+
124+
int cdouble_cdouble_mkl_fftnd_out(cnp.ndarray, cnp.ndarray, double)
125+
int cdouble_cdouble_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray, double)
126+
int cfloat_cfloat_mkl_fftnd_out(cnp.ndarray, cnp.ndarray, double)
127+
int cfloat_cfloat_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray, double)
128+
129+
int float_cfloat_mkl_fftnd_out(cnp.ndarray, cnp.ndarray, double)
130+
int double_cdouble_mkl_fftnd_out(cnp.ndarray, cnp.ndarray, double)
131+
int float_cfloat_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray, double)
132+
int double_cdouble_mkl_ifftnd_out(cnp.ndarray, cnp.ndarray, double)
133133
char * mkl_dfti_error(int)
134134

135135
# Initialize numpy
@@ -289,7 +289,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
289289
# Float/double inputs are not cast to complex, but are effectively
290290
# treated as complexes with zero imaginary parts.
291291
# All other types are cast to complex double.
292-
def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
292+
def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fsc=1.0):
293293
"""
294294
Uses MKL to perform 1D FFT on the input array x along the given axis.
295295
"""
@@ -333,14 +333,14 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
333333
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
334334
if x_type is cnp.NPY_CDOUBLE:
335335
if dir_ < 0:
336-
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
336+
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_, fsc, _cache)
337337
else:
338-
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
338+
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_, fsc, _cache)
339339
elif x_type is cnp.NPY_CFLOAT:
340340
if dir_ < 0:
341-
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
341+
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_, fsc, _cache)
342342
else:
343-
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
343+
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_, fsc, _cache)
344344
else:
345345
status = 1
346346

@@ -368,32 +368,32 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
368368
if x_type is cnp.NPY_DOUBLE:
369369
if dir_ < 0:
370370
status = double_cdouble_mkl_ifft1d_out(
371-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
371+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, fsc, _cache)
372372
else:
373373
status = double_cdouble_mkl_fft1d_out(
374-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
374+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, fsc, _cache)
375375
elif x_type is cnp.NPY_CDOUBLE:
376376
if dir_ < 0:
377377
status = cdouble_cdouble_mkl_ifft1d_out(
378-
x_arr, n_, <int> axis_, f_arr, _cache)
378+
x_arr, n_, <int> axis_, f_arr, fsc, _cache)
379379
else:
380380
status = cdouble_cdouble_mkl_fft1d_out(
381-
x_arr, n_, <int> axis_, f_arr, _cache)
381+
x_arr, n_, <int> axis_, f_arr, fsc, _cache)
382382
else:
383383
if x_type is cnp.NPY_FLOAT:
384384
if dir_ < 0:
385385
status = float_cfloat_mkl_ifft1d_out(
386-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
386+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, fsc, _cache)
387387
else:
388388
status = float_cfloat_mkl_fft1d_out(
389-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
389+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, fsc, _cache)
390390
elif x_type is cnp.NPY_CFLOAT:
391391
if dir_ < 0:
392392
status = cfloat_cfloat_mkl_ifft1d_out(
393-
x_arr, n_, <int> axis_, f_arr, _cache)
393+
x_arr, n_, <int> axis_, f_arr, fsc, _cache)
394394
else:
395395
status = cfloat_cfloat_mkl_fft1d_out(
396-
x_arr, n_, <int> axis_, f_arr, _cache)
396+
x_arr, n_, <int> axis_, f_arr, fsc, _cache)
397397

398398
if (status):
399399
c_error_msg = mkl_dfti_error(status)
@@ -515,7 +515,7 @@ def _repack_rc_to_rr(x, n, axis):
515515
return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)
516516

517517

518-
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
518+
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
519519
"""
520520
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
521521
@@ -558,9 +558,9 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
558558
_cache_capsule = _tls_dfti_cache_capsule()
559559
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
560560
if x_type is cnp.NPY_DOUBLE:
561-
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
561+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, fsc, _cache)
562562
else:
563-
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
563+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, fsc, _cache)
564564

565565
if (status):
566566
c_error_msg = mkl_dfti_error(status)
@@ -571,7 +571,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
571571
return _rc_to_rr(f_arr, n_, axis_, xnd, x_type)
572572

573573

574-
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
574+
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
575575
"""
576576
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.
577577
@@ -623,11 +623,11 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
623623
if rc_type is cnp.NPY_CFLOAT:
624624
_cache_capsule = _tls_dfti_cache_capsule()
625625
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
626-
status = cfloat_float_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, _cache)
626+
status = cfloat_float_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, fsc, _cache)
627627
elif rc_type is cnp.NPY_CDOUBLE:
628628
_cache_capsule = _tls_dfti_cache_capsule()
629629
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
630-
status = cdouble_double_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, _cache)
630+
status = cdouble_double_mkl_irfft_out(rc_arr, n_, <int> axis_, f_arr, fsc, _cache)
631631
else:
632632
raise ValueError("Internal mkl_fft error occurred: Unrecognized rc_type")
633633

@@ -640,7 +640,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False):
640640

641641

642642
# this routine is functionally equivalent to numpy.fft.rfft
643-
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
643+
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
644644
"""
645645
Uses MKL to perform 1D FFT on the real input array x along the given axis,
646646
producing complex output, but giving only half of the harmonics.
@@ -689,11 +689,11 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
689689
if x_type is cnp.NPY_FLOAT:
690690
_cache_capsule = _tls_dfti_cache_capsule()
691691
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
692-
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
692+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, fsc, _cache)
693693
else:
694694
_cache_capsule = _tls_dfti_cache_capsule()
695695
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
696-
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
696+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, fsc, _cache)
697697

698698
if (status):
699699
c_error_msg = mkl_dfti_error(status)
@@ -718,7 +718,7 @@ cdef int _is_integral(object num):
718718

719719

720720
# this routine is functionally equivalent to numpy.fft.irfft
721-
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
721+
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
722722
"""
723723
Uses MKL to perform 1D FFT on the real input array x along the given axis,
724724
producing complex output, but giving only half of the harmonics.
@@ -774,11 +774,11 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
774774
if x_type is cnp.NPY_CFLOAT:
775775
_cache_capsule = _tls_dfti_cache_capsule()
776776
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
777-
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
777+
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, fsc, _cache)
778778
else:
779779
_cache_capsule = _tls_dfti_cache_capsule()
780780
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
781-
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
781+
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, fsc, _cache)
782782

783783
if (status):
784784
c_error_msg = mkl_dfti_error(status)
@@ -912,7 +912,7 @@ def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False):
912912
return a
913913

914914

915-
def _direct_fftnd(x, overwrite_arg=False, direction=+1):
915+
def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
916916
"""Perform n-dimensional FFT over all axes"""
917917
cdef int err
918918
cdef long n_max = 0
@@ -948,14 +948,14 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1):
948948
if in_place:
949949
if x_type == cnp.NPY_CDOUBLE:
950950
if dir_ == 1:
951-
err = cdouble_cdouble_mkl_fftnd_in(x_arr)
951+
err = cdouble_cdouble_mkl_fftnd_in(x_arr, fsc)
952952
else:
953-
err = cdouble_cdouble_mkl_ifftnd_in(x_arr)
953+
err = cdouble_cdouble_mkl_ifftnd_in(x_arr, fsc)
954954
elif x_type == cnp.NPY_CFLOAT:
955955
if dir_ == 1:
956-
err = cfloat_cfloat_mkl_fftnd_in(x_arr)
956+
err = cfloat_cfloat_mkl_fftnd_in(x_arr, fsc)
957957
else:
958-
err = cfloat_cfloat_mkl_ifftnd_in(x_arr)
958+
err = cfloat_cfloat_mkl_ifftnd_in(x_arr, fsc)
959959
else:
960960
raise ValueError("An input argument x is not complex type array")
961961

@@ -965,24 +965,24 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1):
965965
f_arr = __allocate_result(x_arr, -1, 0, f_type);
966966
if x_type == cnp.NPY_CDOUBLE:
967967
if dir_ == 1:
968-
err = cdouble_cdouble_mkl_fftnd_out(x_arr, f_arr)
968+
err = cdouble_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
969969
else:
970-
err = cdouble_cdouble_mkl_ifftnd_out(x_arr, f_arr)
970+
err = cdouble_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
971971
elif x_type == cnp.NPY_CFLOAT:
972972
if dir_ == 1:
973-
err = cfloat_cfloat_mkl_fftnd_out(x_arr, f_arr)
973+
err = cfloat_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
974974
else:
975-
err = cfloat_cfloat_mkl_ifftnd_out(x_arr, f_arr)
975+
err = cfloat_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)
976976
elif x_type == cnp.NPY_DOUBLE:
977977
if dir_ == 1:
978-
err = double_cdouble_mkl_fftnd_out(x_arr, f_arr)
978+
err = double_cdouble_mkl_fftnd_out(x_arr, f_arr, fsc)
979979
else:
980-
err = double_cdouble_mkl_ifftnd_out(x_arr, f_arr)
980+
err = double_cdouble_mkl_ifftnd_out(x_arr, f_arr, fsc)
981981
elif x_type == cnp.NPY_FLOAT:
982982
if dir_ == 1:
983-
err = float_cfloat_mkl_fftnd_out(x_arr, f_arr)
983+
err = float_cfloat_mkl_fftnd_out(x_arr, f_arr, fsc)
984984
else:
985-
err = float_cfloat_mkl_ifftnd_out(x_arr, f_arr)
985+
err = float_cfloat_mkl_ifftnd_out(x_arr, f_arr, fsc)
986986
else:
987987
raise ValueError("An input argument x is not complex type array")
988988

@@ -1006,7 +1006,7 @@ def _check_shapes_for_direct(xs, shape, axes):
10061006
return True
10071007

10081008

1009-
def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1):
1009+
def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
10101010
if direction not in [-1, +1]:
10111011
raise ValueError("Direction of FFT should +1 or -1")
10121012

@@ -1024,10 +1024,10 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1):
10241024
_direct = False
10251025

10261026
if _direct:
1027-
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction)
1027+
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
10281028
else:
10291029
return _iter_fftnd(x, s=shape, axes=axes,
1030-
overwrite_arg=overwrite_x,
1030+
overwrite_arg=overwrite_x, fsc=fsc,
10311031
function=fft if direction == 1 else ifft)
10321032

10331033

0 commit comments

Comments
 (0)