Skip to content

Commit ce9e0d1

Browse files
added support for forward_scale parameter, and tests
1 parent d162835 commit ce9e0d1

File tree

2 files changed

+88
-36
lines changed

2 files changed

+88
-36
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ cdef int _datacopied(cnp.ndarray arr, object orig):
151151
return 1 if (arr_obj.base is None) else 0
152152

153153

154-
def fft(x, n=None, axis=-1, overwrite_x=False):
155-
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1)
154+
def fft(x, n=None, axis=-1, overwrite_x=False, forward_scale=1.0):
155+
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1, fsc=forward_scale)
156156

157157

158-
def ifft(x, n=None, axis=-1, overwrite_x=False):
159-
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1)
158+
def ifft(x, n=None, axis=-1, overwrite_x=False, forward_scale=1.0):
159+
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1, fsc=forward_scale)
160160

161161

162162
cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int realQ):
@@ -403,14 +403,14 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
403403
return f_arr
404404

405405

406-
def rfft(x, n=None, axis=-1, overwrite_x=False):
406+
def rfft(x, n=None, axis=-1, overwrite_x=False, forward_scale=1.0):
407407
"""Packed real-valued harmonics of FFT of a real sequence x"""
408-
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x)
408+
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=forward_scale)
409409

410410

411-
def irfft(x, n=None, axis=-1, overwrite_x=False):
411+
def irfft(x, n=None, axis=-1, overwrite_x=False, forward_scale=1.0):
412412
"""Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
413-
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x)
413+
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=forward_scale)
414414

415415

416416
cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type):
@@ -788,12 +788,12 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
788788
return f_arr
789789

790790

791-
def rfft_numpy(x, n=None, axis=-1):
792-
return _rc_fft1d_impl(x, n=n, axis=axis)
791+
def rfft_numpy(x, n=None, axis=-1, forward_scale=1.0):
792+
return _rc_fft1d_impl(x, n=n, axis=axis, fsc=forward_scale)
793793

794794

795-
def irfft_numpy(x, n=None, axis=-1):
796-
return _rc_ifft1d_impl(x, n=n, axis=axis)
795+
def irfft_numpy(x, n=None, axis=-1, forward_scale=1.0):
796+
return _rc_ifft1d_impl(x, n=n, axis=axis, fsc=forward_scale)
797797

798798

799799
# ============================== ND ====================================== #
@@ -902,12 +902,12 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
902902
return s, axes
903903

904904

905-
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False):
905+
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_function=lambda n: 1.0):
906906
a = np.asarray(a)
907907
s, axes = _init_nd_shape_and_axes(a, s, axes)
908908
ovwr = overwrite_arg
909909
for ii in reversed(range(len(axes))):
910-
a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr)
910+
a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr, forward_scale=scale_function(s[ii]))
911911
ovwr = True
912912
return a
913913

@@ -1026,33 +1026,34 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
10261026
if _direct:
10271027
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
10281028
else:
1029+
sc = (<object> fsc)**(1/x.ndim)
10291030
return _iter_fftnd(x, s=shape, axes=axes,
1030-
overwrite_arg=overwrite_x, fsc=fsc,
1031+
overwrite_arg=overwrite_x, scale_function=lambda n: sc,
10311032
function=fft if direction == 1 else ifft)
10321033

10331034

1034-
def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False):
1035-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1)
1035+
def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, forward_scale=1.0):
1036+
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=forward_scale)
10361037

10371038

1038-
def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False):
1039-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1)
1039+
def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False, forward_scale=1.0):
1040+
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=forward_scale)
10401041

10411042

1042-
def fftn(x, shape=None, axes=None, overwrite_x=False):
1043-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1)
1043+
def fftn(x, shape=None, axes=None, overwrite_x=False, forward_scale=1.0):
1044+
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=forward_scale)
10441045

10451046

1046-
def ifftn(x, shape=None, axes=None, overwrite_x=False):
1047-
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1)
1047+
def ifftn(x, shape=None, axes=None, overwrite_x=False, forward_scale=1.0):
1048+
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=forward_scale)
10481049

10491050

1050-
def rfft2_numpy(x, s=None, axes=(-2,-1)):
1051-
return rfftn_numpy(x, s=s, axes=axes)
1051+
def rfft2_numpy(x, s=None, axes=(-2,-1), forward_scale=1.0):
1052+
return rfftn_numpy(x, s=s, axes=axes, fsc=forward_scale)
10521053

10531054

1054-
def irfft2_numpy(x, s=None, axes=(-2,-1)):
1055-
return irfftn_numpy(x, s=s, axes=axes)
1055+
def irfft2_numpy(x, s=None, axes=(-2,-1), forward_scale=1.0):
1056+
return irfftn_numpy(x, s=s, axes=axes, fsc=forward_scale)
10561057

10571058

10581059
def _remove_axis(s, axes, axis_to_remove):
@@ -1107,7 +1108,7 @@ def _fix_dimensions(cnp.ndarray arr, object s, object axes):
11071108
return np.pad(arr, tuple(pad_widths), 'constant')
11081109

11091110

1110-
def rfftn_numpy(x, s=None, axes=None):
1111+
def rfftn_numpy(x, s=None, axes=None, forward_scale=1.0):
11111112
a = np.asarray(x)
11121113
no_trim = (s is None) and (axes is None)
11131114
s, axes = _cook_nd_args(a, s, axes)
@@ -1116,7 +1117,7 @@ def rfftn_numpy(x, s=None, axes=None):
11161117
# unnecessary computations
11171118
if not no_trim:
11181119
a = _trim_array(a, s, axes)
1119-
a = rfft_numpy(a, n = s[-1], axis=la)
1120+
a = rfft_numpy(a, n = s[-1], axis=la, forward_scale=forward_scale)
11201121
if len(s) > 1:
11211122
if not no_trim:
11221123
ss = list(s)
@@ -1140,7 +1141,7 @@ def rfftn_numpy(x, s=None, axes=None):
11401141
return a
11411142

11421143

1143-
def irfftn_numpy(x, s=None, axes=None):
1144+
def irfftn_numpy(x, s=None, axes=None, forward_scale=1.0):
11441145
a = np.asarray(x)
11451146
no_trim = (s is None) and (axes is None)
11461147
s, axes = _cook_nd_args(a, s, axes, invreal=True)
@@ -1169,5 +1170,5 @@ def irfftn_numpy(x, s=None, axes=None):
11691170
for ii in range(len(axes)-1):
11701171
a = ifft(a, s[ii], axes[ii], overwrite_x=ovr_x)
11711172
ovr_x = True
1172-
a = irfft_numpy(a, n = s[-1], axis=la)
1173+
a = irfft_numpy(a, n = s[-1], axis=la, forward_scale=forward_scale)
11731174
return a

mkl_fft/tests/test_fftnd.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@
3838

3939
reps_64 = (2**11)*np.finfo(np.float64).eps
4040
reps_32 = (2**11)*np.finfo(np.float32).eps
41-
atol_64 = (2**8)*np.finfo(np.float64).eps
42-
atol_32 = (2**8)*np.finfo(np.float32).eps
41+
atol_64 = (2**9)*np.finfo(np.float64).eps
42+
atol_32 = (2**9)*np.finfo(np.float32).eps
4343

4444
def _get_rtol_atol(x):
4545
dt = x.dtype
46-
if dt == np.double or dt == np.complex128:
46+
if dt == np.float64 or dt == np.complex128:
4747
return reps_64, atol_64
48-
elif dt == np.single or dt == np.complex64:
48+
elif dt == np.float32 or dt == np.complex64:
4949
return reps_32, atol_32
5050
else:
51-
assert (dt == np.double or dt == np.complex128 or dt == np.single or dt == np.complex64), "Unexpected dtype {}".format(dt)
51+
assert (dt == np.float64 or dt == np.complex128 or dt == np.float32 or dt == np.complex64), "Unexpected dtype {}".format(dt)
5252
return reps_64, atol_64
5353

5454

@@ -128,3 +128,54 @@ def test_rfftn_numpy(self):
128128
rfft_tr = mkl_fft.rfftn_numpy(np.transpose(x, a))
129129
tr_rfft = np.transpose(mkl_fft.rfftn_numpy(x, axes=a), a)
130130
assert_allclose(rfft_tr, tr_rfft, rtol=r_tol, atol=a_tol)
131+
132+
class Test_Scales(TestCase):
133+
def setUp(self):
134+
pass
135+
136+
def test_scale_1d_vector(self):
137+
X = np.ones(128, dtype='d')
138+
f1 = mkl_fft.fft(X, forward_scale=0.25)
139+
f2 = mkl_fft.fft(X)
140+
r_tol, a_tol = _get_rtol_atol(X)
141+
assert_allclose(4*f1, f2, rtol=r_tol, atol=a_tol)
142+
143+
X1 = mkl_fft.ifft(f1, forward_scale=0.25)
144+
assert_allclose(X, X1, rtol=r_tol, atol=a_tol)
145+
146+
f3 = mkl_fft.rfft(X, forward_scale=0.5)
147+
X2 = mkl_fft.irfft(f3, forward_scale=0.5)
148+
assert_allclose(X, X2, rtol=r_tol, atol=a_tol)
149+
150+
def test_scale_1d_array(self):
151+
X = np.ones((8, 4, 4,), dtype='d')
152+
f1 = mkl_fft.fft(X, axis=1, forward_scale=0.25)
153+
f2 = mkl_fft.fft(X, axis=1)
154+
r_tol, a_tol = _get_rtol_atol(X)
155+
assert_allclose(4*f1, f2, rtol=r_tol, atol=a_tol)
156+
157+
X1 = mkl_fft.ifft(f1, axis=1, forward_scale=0.25)
158+
assert_allclose(X, X1, rtol=r_tol, atol=a_tol)
159+
160+
f3 = mkl_fft.rfft(X, axis=0, forward_scale=0.5)
161+
X2 = mkl_fft.irfft(f3, axis=0, forward_scale=0.5)
162+
assert_allclose(X, X2, rtol=r_tol, atol=a_tol)
163+
164+
def test_scale_nd(self):
165+
X = np.empty((2, 4, 8, 16), dtype='d')
166+
X.flat[:] = np.cbrt(np.arange(0, X.size, dtype=X.dtype))
167+
f = mkl_fft.fftn(X)
168+
f_scale = mkl_fft.fftn(X, forward_scale=0.2)
169+
170+
r_tol, a_tol = _get_rtol_atol(X)
171+
assert_allclose(f, 5*f_scale, rtol=r_tol, atol=a_tol)
172+
173+
def test_scale_nd_axes(self):
174+
X = np.empty((4, 2, 16, 8), dtype='d')
175+
X.flat[:] = np.cbrt(np.arange(X.size, dtype=X.dtype))
176+
f = mkl_fft.fftn(X, axes=(0, 1, 2, 3))
177+
f_scale = mkl_fft.fftn(X, axes=(0, 1, 2, 3), forward_scale=0.2)
178+
179+
r_tol, a_tol = _get_rtol_atol(X)
180+
assert_allclose(f, 5*f_scale, rtol=r_tol, atol=a_tol)
181+

0 commit comments

Comments
 (0)