From a17eb071f4d77ccd14be23e621d153e1bc814530 Mon Sep 17 00:00:00 2001 From: Nathan Bombaci Date: Mon, 28 Apr 2025 17:51:52 -0700 Subject: [PATCH] Allow scipy_fft module to be used as a scipy fft backend --- README.md | 17 +++++++++++++++++ mkl_fft/_scipy_fft.py | 3 +++ mkl_fft/tests/test_interfaces.py | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/README.md b/README.md index c7356ab..0a8e066 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,23 @@ and similar inverse c2r FFT (`irfft*`) functions. The package also provides `mkl_fft.interfaces.numpy_fft` and `mkl_fft.interfaces.scipy_fft` interfaces which provide drop-in replacements for equivalent functions in NumPy and SciPy, respectively. +`mkl_fft.interfaces.scipy_fft` can also be used as a backend for `scipy.fft.set_backend()` + +```python +>>> import numpy as np, mkl_fft, mkl_fft.interfaces.scipy_fft as mkl_be, scipy, scipy.fft, mkl + +>>> mkl.verbose(1) +# True + +>>> x = np.random.randn(8*7).reshape((7, 8)) +>>> with scipy.fft.set_backend(mkl_be, only=True): +>>> ff = scipy.fft.fft2(x, workers=4) +>>> ff2 = scipy.fft.fft2(x) +# MKL_VERBOSE oneMKL 2025 Update 1 Product build 20250306 for Intel(R) 64 architecture Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2) enabled processors, Lnx 2.70GHz intel_thread +# MKL_VERBOSE FFT(drfo7:8:8x8:1:1,input_strides:{0,8,1},output_strides:{0,8,1},bScale:0.0178571,tLim:1,unaligned_input,unaligned_output,desc:0x561750094440) 15.56us CNR:OFF Dyn:1 FastMM:1 TID:0 NThr:4 + +>>> np.allclose(ff, ff2) +``` --- To build `mkl_fft` from sources on Linux with IntelĀ® OneMKL: diff --git a/mkl_fft/_scipy_fft.py b/mkl_fft/_scipy_fft.py index 7a5f658..da98df7 100644 --- a/mkl_fft/_scipy_fft.py +++ b/mkl_fft/_scipy_fft.py @@ -139,6 +139,9 @@ def set_workers(n_workers): "get_workers", "set_workers", "DftiBackend", + # Following needed for module to be used as scipy fft backend + "__ua_domain__", + "__ua_function__", ] __ua_domain__ = "numpy.scipy.fft" diff --git a/mkl_fft/tests/test_interfaces.py b/mkl_fft/tests/test_interfaces.py index c91affa..917bde7 100644 --- a/mkl_fft/tests/test_interfaces.py +++ b/mkl_fft/tests/test_interfaces.py @@ -166,3 +166,9 @@ def test_axes(func): exp = np.fft.rfft2(x, axes=(1, 2)) tol = 64 * np.finfo(np.float64).eps assert np.allclose(res, exp, atol=tol, rtol=tol) + + +def test_scipy_fft_backend(): + """scipy_fft exposes properties necessary for use as a scipy fft backend""" + assert hasattr(mfi.scipy_fft, "__ua_domain__") + assert hasattr(mfi.scipy_fft, "__ua_function__")