Skip to content

Commit e031fa0

Browse files
committed
add dask interface
1 parent 6bde33e commit e031fa0

File tree

13 files changed

+289
-10
lines changed

13 files changed

+289
-10
lines changed

.github/workflows/build-with-clang.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,5 @@ jobs:
7373
- name: Run mkl_fft tests
7474
run: |
7575
source ${{ env.ONEAPI_ROOT }}/setvars.sh
76-
pip install scipy mkl-service pytest
76+
pip install pytest mkl-service scipy dask
7777
pytest -s -v --pyargs mkl_fft

.github/workflows/conda-package-cf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
- name: Install mkl_fft
133133
run: |
134134
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
135-
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy $CHANNELS
135+
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python_ver }} ${{ matrix.numpy }} $PACKAGE_NAME pytest scipy dask $CHANNELS
136136
# Test installed packages
137137
conda list -n ${{ env.TEST_ENV_NAME }}
138138
@@ -295,7 +295,7 @@ jobs:
295295
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
296296
SET PACKAGE_VERSION=%%F
297297
)
298-
SET "TEST_DEPENDENCIES=pytest scipy"
298+
SET "TEST_DEPENDENCIES=pytest scipy dask"
299299
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} ${{ matrix.numpy }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
300300
301301
- name: Report content of test environment

.github/workflows/conda-package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ jobs:
132132
run: |
133133
CHANNELS="-c $GITHUB_WORKSPACE/channel ${{ env.CHANNELS }}"
134134
conda create -n ${{ env.TEST_ENV_NAME }} python=${{ matrix.python }} "scipy>=1.10" $CHANNELS
135-
conda install -n ${{ env.TEST_ENV_NAME }} $PACKAGE_NAME pytest $CHANNELS
135+
conda install -n ${{ env.TEST_ENV_NAME }} dask $PACKAGE_NAME pytest $CHANNELS
136136
# Test installed packages
137137
conda list -n ${{ env.TEST_ENV_NAME }}
138138
@@ -296,7 +296,7 @@ jobs:
296296
FOR /F "tokens=* USEBACKQ" %%F IN (`python -c "%SCRIPT%"`) DO (
297297
SET PACKAGE_VERSION=%%F
298298
)
299-
SET "TEST_DEPENDENCIES=pytest scipy"
299+
SET "TEST_DEPENDENCIES=pytest scipy dask"
300300
conda install -n ${{ env.TEST_ENV_NAME }} ${{ env.PACKAGE_NAME }}=%PACKAGE_VERSION% %TEST_DEPENDENCIES% python=${{ matrix.python }} -c ${{ env.workdir }}/channel ${{ env.CHANNELS }}
301301
302302
- name: Report content of test environment

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
* Added Hermitian FFT functions to SciPy interface `mkl_fft.interfaces.scipy_fft`: `hfft`, `ihfft`, `hfftn`, `ihfftn`, `hfft2`, and `ihfft2` [gh-161](https://github.com/IntelPython/mkl_fft/pull/161)
1111
* Added support for `out` kwarg to all FFT functions in `mkl_fft` and `mkl_fft.interfaces.numpy_fft` [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)
1212
* Added `fftfreq`, `fftshift`, `ifftshift`, and `rfftfreq` to both NumPy and SciPy interfaces [gh-179](https://github.com/IntelPython/mkl_fft/pull/179)
13+
* Added a new interface for Dask to the package accessible through `mkl_fft.interfaces.dask_fft` [gh-184](https://github.com/IntelPython/mkl_fft/pull/184)
1314

1415
### Changed
1516
* NumPy interface `mkl_fft.interfaces.numpy_fft` is aligned with numpy-2.x.x [gh-139](https://github.com/IntelPython/mkl_fft/pull/139), [gh-157](https://github.com/IntelPython/mkl_fft/pull/157)

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ More details can be found in [SciPy 2017 conference proceedings](https://github.
4747

4848
---
4949

50-
The `mkl_fft` package offers interfaces that act as drop-in replacements for equivalent functions in NumPy and SciPy. Learn more about these interfaces [here](https://github.com/IntelPython/mkl_fft/blob/master/mkl_fft/interfaces/README.md).
50+
The `mkl_fft` package offers interfaces that act as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask. Learn more about these interfaces [here](https://github.com/IntelPython/mkl_fft/blob/master/mkl_fft/interfaces/README.md).
5151

5252
While using these interfaces is the easiest way to leverage `mk_fft`, one can also use `mkl_fft` directly with the following FFT functions:
5353

conda-recipe-cf/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ test:
3333
requires:
3434
- pytest
3535
- scipy
36+
- dask
3637
imports:
3738
- mkl_fft
3839
- mkl_fft.interfaces
3940
- mkl_fft.interfaces.numpy_fft
4041
- mkl_fft.interfaces.scipy_fft
42+
- mkl_fft.interfaces.dask_fft
4143

4244
about:
4345
home: http://github.com/IntelPython/mkl_fft

conda-recipe/meta.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ test:
3333
requires:
3434
- pytest
3535
- scipy
36+
- dask
3637
imports:
3738
- mkl_fft
3839
- mkl_fft.interfaces
3940
- mkl_fft.interfaces.numpy_fft
4041
- mkl_fft.interfaces.scipy_fft
42+
- mkl_fft.interfaces.dask_fft
4143

4244
about:
4345
home: http://github.com/IntelPython/mkl_fft

mkl_fft/interfaces/README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Interfaces
2-
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy and SciPy.
2+
The `mkl_fft` package provides interfaces that serve as drop-in replacements for equivalent functions in NumPy, SciPy, and Dask.
33

44
---
55

@@ -102,3 +102,43 @@ with scipy.fft.set_backend(mkl_backend, only=True):
102102
print(f"Time with OneMKL FFT backend installed: {t2:.1f} seconds")
103103
# Time with MKL FFT backend installed: 9.1 seconds
104104
```
105+
106+
---
107+
108+
## Dask interface - `mkl_fft.interfaces.dask_fft`
109+
110+
This interface is a drop-in replacement for the [`dask.fft`](https://dask.pydata.org/en/latest/array-api.html#fast-fourier-transforms) module and includes **all** the functions available there:
111+
112+
* complex-to-complex FFTs: `fft`, `ifft`, `fft2`, `ifft2`, `fftn`, `ifftn`.
113+
114+
* real-to-complex and complex-to-real FFTs: `rfft`, `irfft`, `rfft2`, `irfft2`, `rfftn`, `irfftn`.
115+
116+
* Hermitian FFTs: `hfft`, `ihfft`.
117+
118+
* Helper routines: `fft_wrap`, `fftfreq`, `rfftfreq`, `fftshift`, `ifftshift`. These routines serve as a fallback to the Dask implementation and are included for completeness.
119+
120+
The following example shows how to use this interface for calculating a 2D FFT.
121+
122+
```python
123+
import numpy, dask
124+
import mkl_fft.interfaces.dask_fft as dask_fft
125+
126+
a = numpy.random.randn(128, 64) + 1j*numpy.random.randn(128, 64)
127+
x = dask.array.from_array(a, chunks=(64, 64))
128+
lazy_res = dask_fft.fft(x)
129+
mkl_res = lazy_res.compute()
130+
np_res = numpy.fft.fft(a)
131+
numpy.allclose(mkl_res, np_res)
132+
# True
133+
134+
# There are two chunks in this example based on the size of input array (128, 64) and chunk size (64, 64)
135+
# to confirm that MKL FFT is called twice, turn on verbosity
136+
import mkl
137+
mkl.verbose(1)
138+
# True
139+
140+
mkl_res = lazy_res.compute() # MKL_VERBOSE FFT is shown twice below which means MKL FFT is called twice
141+
# MKL_VERBOSE oneMKL 2024.0 Update 2 Patch 2 Product build 20240823 for Intel(R) 64 architecture Intel(R) Advanced Vector Extensions 512 (Intel(R) AVX-512) with support for INT8, BF16, FP16 (limited) instructions, and Intel(R) Advanced Matrix Extensions (Intel(R) AMX) with INT8 and BF16, Lnx 3.80GHz intel_thread
142+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd000010e40) 432.84us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
143+
# MKL_VERBOSE FFT(dcfo64*64,input_strides:{0,1},output_strides:{0,1},input_distance:64,output_distance:64,bScale:0.015625,tLim:32,unaligned_input,desc:0x7fd480011300) 499.00us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:112
144+
```

mkl_fft/interfaces/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@
2323
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

26-
from . import numpy_fft, scipy_fft
26+
from . import dask_fft, numpy_fft, scipy_fft

mkl_fft/interfaces/dask_fft.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
from dask.array.fft import fft_wrap, fftfreq, fftshift, ifftshift, rfftfreq
28+
29+
from . import numpy_fft as _numpy_fft
30+
31+
__all__ = [
32+
"fft",
33+
"ifft",
34+
"fft2",
35+
"ifft2",
36+
"fftn",
37+
"ifftn",
38+
"rfft",
39+
"irfft",
40+
"rfft2",
41+
"irfft2",
42+
"rfftn",
43+
"irfftn",
44+
"hfft",
45+
"ihfft",
46+
"fftshift",
47+
"ifftshift",
48+
"fftfreq",
49+
"rfftfreq",
50+
"fft_wrap",
51+
]
52+
53+
54+
fft = fft_wrap(_numpy_fft.fft)
55+
ifft = fft_wrap(_numpy_fft.ifft)
56+
fft2 = fft_wrap(_numpy_fft.fft2)
57+
ifft2 = fft_wrap(_numpy_fft.ifft2)
58+
fftn = fft_wrap(_numpy_fft.fftn)
59+
ifftn = fft_wrap(_numpy_fft.ifftn)
60+
rfft = fft_wrap(_numpy_fft.rfft)
61+
irfft = fft_wrap(_numpy_fft.irfft)
62+
rfft2 = fft_wrap(_numpy_fft.rfft2)
63+
irfft2 = fft_wrap(_numpy_fft.irfft2)
64+
rfftn = fft_wrap(_numpy_fft.rfftn)
65+
irfftn = fft_wrap(_numpy_fft.irfftn)
66+
hfft = fft_wrap(_numpy_fft.hfft)
67+
ihfft = fft_wrap(_numpy_fft.ihfft)

mkl_fft/tests/test_interfaces.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,16 @@ def test_axes(func):
167167

168168

169169
@pytest.mark.parametrize(
170-
"interface", [mfi.scipy_fft, mfi.numpy_fft], ids=["scipy", "numpy"]
170+
"interface",
171+
[mfi.scipy_fft, mfi.numpy_fft, mfi.dask_fft],
172+
ids=["scipy", "numpy", "dask"],
171173
)
172174
@pytest.mark.parametrize(
173175
"func", ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
174176
)
175177
def test_interface_helper_functions(interface, func):
176178
assert hasattr(interface, func)
179+
180+
181+
def test_dask_fftwrap():
182+
assert hasattr(mfi.dask_fft, "fft_wrap")
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# This file includes tests from dask.fft module:
2+
# https://github.com/dask/dask/blob/main/dask/array/tests/test_fft.py
3+
4+
import contextlib
5+
from itertools import combinations_with_replacement
6+
7+
import dask
8+
import dask.array as da
9+
import numpy as np
10+
import pytest
11+
from dask.array.numpy_compat import NUMPY_GE_200
12+
from dask.array.utils import assert_eq, same_keys
13+
14+
import mkl_fft.interfaces.dask_fft as dask_fft
15+
16+
requires_dask_2024_8_2 = pytest.mark.skipif(
17+
dask.__version__ < "2024.8.2",
18+
reason="norm kwarg requires Dask >= 2024.8.2",
19+
)
20+
21+
all_1d_funcnames = ["fft", "ifft", "rfft", "irfft", "hfft", "ihfft"]
22+
23+
all_nd_funcnames = [
24+
"fft2",
25+
"ifft2",
26+
"fftn",
27+
"ifftn",
28+
"rfft2",
29+
"irfft2",
30+
"rfftn",
31+
"irfftn",
32+
]
33+
34+
if not da._array_expr_enabled():
35+
36+
nparr = np.arange(100).reshape(10, 10)
37+
darr = da.from_array(nparr, chunks=(1, 10))
38+
darr2 = da.from_array(nparr, chunks=(10, 1))
39+
darr3 = da.from_array(nparr, chunks=(10, 10))
40+
41+
42+
@pytest.mark.parametrize("funcname", all_1d_funcnames)
43+
def test_cant_fft_chunked_axis(funcname):
44+
da_fft = getattr(dask_fft, funcname)
45+
46+
bad_darr = da.from_array(nparr, chunks=(5, 5))
47+
for i in range(bad_darr.ndim):
48+
with pytest.raises(ValueError):
49+
da_fft(bad_darr, axis=i)
50+
51+
52+
@pytest.mark.parametrize("funcname", all_1d_funcnames)
53+
def test_fft(funcname):
54+
da_fft = getattr(dask_fft, funcname)
55+
np_fft = getattr(np.fft, funcname)
56+
57+
# pylint: disable=possibly-used-before-assignment
58+
assert_eq(da_fft(darr), np_fft(nparr))
59+
60+
61+
@pytest.mark.parametrize("funcname", all_nd_funcnames)
62+
def test_fft2n_shapes(funcname):
63+
da_fft = getattr(dask_fft, funcname)
64+
np_fft = getattr(np.fft, funcname)
65+
66+
# pylint: disable=possibly-used-before-assignment
67+
assert_eq(da_fft(darr3), np_fft(nparr))
68+
assert_eq(
69+
da_fft(darr3, (8, 9), axes=(1, 0)), np_fft(nparr, (8, 9), axes=(1, 0))
70+
)
71+
assert_eq(
72+
da_fft(darr3, (12, 11), axes=(1, 0)),
73+
np_fft(nparr, (12, 11), axes=(1, 0)),
74+
)
75+
76+
if NUMPY_GE_200 and funcname.endswith("fftn"):
77+
ctx = pytest.warns(
78+
DeprecationWarning,
79+
match="`axes` should not be `None` if `s` is not `None`",
80+
)
81+
else:
82+
ctx = contextlib.nullcontext()
83+
with ctx:
84+
expect = np_fft(nparr, (8, 9))
85+
with ctx:
86+
actual = da_fft(darr3, (8, 9))
87+
assert_eq(expect, actual)
88+
89+
90+
@requires_dask_2024_8_2
91+
@pytest.mark.parametrize("funcname", all_1d_funcnames)
92+
def test_fft_n_kwarg(funcname):
93+
da_fft = getattr(dask_fft, funcname)
94+
np_fft = getattr(np.fft, funcname)
95+
96+
assert_eq(da_fft(darr, 5), np_fft(nparr, 5))
97+
assert_eq(da_fft(darr, 13), np_fft(nparr, 13))
98+
assert_eq(
99+
da_fft(darr, 13, norm="backward"), np_fft(nparr, 13, norm="backward")
100+
)
101+
assert_eq(da_fft(darr, 13, norm="ortho"), np_fft(nparr, 13, norm="ortho"))
102+
assert_eq(
103+
da_fft(darr, 13, norm="forward"), np_fft(nparr, 13, norm="forward")
104+
)
105+
# pylint: disable=possibly-used-before-assignment
106+
assert_eq(da_fft(darr2, axis=0), np_fft(nparr, axis=0))
107+
assert_eq(da_fft(darr2, 5, axis=0), np_fft(nparr, 5, axis=0))
108+
assert_eq(
109+
da_fft(darr2, 13, axis=0, norm="backward"),
110+
np_fft(nparr, 13, axis=0, norm="backward"),
111+
)
112+
assert_eq(
113+
da_fft(darr2, 12, axis=0, norm="ortho"),
114+
np_fft(nparr, 12, axis=0, norm="ortho"),
115+
)
116+
assert_eq(
117+
da_fft(darr2, 12, axis=0, norm="forward"),
118+
np_fft(nparr, 12, axis=0, norm="forward"),
119+
)
120+
121+
122+
@pytest.mark.parametrize("funcname", all_1d_funcnames)
123+
def test_fft_consistent_names(funcname):
124+
da_fft = getattr(dask_fft, funcname)
125+
126+
assert same_keys(da_fft(darr, 5), da_fft(darr, 5))
127+
assert same_keys(da_fft(darr2, 5, axis=0), da_fft(darr2, 5, axis=0))
128+
assert not same_keys(da_fft(darr, 5), da_fft(darr, 13))
129+
130+
131+
@pytest.mark.parametrize("funcname", all_nd_funcnames)
132+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
133+
def test_nd_ffts_axes(funcname, dtype):
134+
np_fft = getattr(np.fft, funcname)
135+
da_fft = getattr(dask_fft, funcname)
136+
137+
shape = (7, 8, 9)
138+
chunk_size = (3, 3, 3)
139+
a = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
140+
d = da.from_array(a, chunks=chunk_size)
141+
142+
for num_axes in range(1, d.ndim):
143+
for axes in combinations_with_replacement(range(d.ndim), num_axes):
144+
cs = list(chunk_size)
145+
for i in axes:
146+
cs[i] = shape[i]
147+
d2 = d.rechunk(cs)
148+
if len(set(axes)) < len(axes):
149+
with pytest.raises(ValueError):
150+
da_fft(d2, axes=axes)
151+
else:
152+
r = da_fft(d2, axes=axes)
153+
er = np_fft(a, axes=axes)
154+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
155+
check_dtype = True
156+
assert r.dtype == er.dtype
157+
else:
158+
check_dtype = False
159+
assert r.shape == er.shape
160+
161+
assert_eq(r, er, check_dtype=check_dtype, rtol=1e-6, atol=1e-4)

0 commit comments

Comments
 (0)