Skip to content

Commit 588c9a3

Browse files
implemented support for workers keyword
1 parent c0f0d05 commit 588c9a3

File tree

1 file changed

+51
-19
lines changed

1 file changed

+51
-19
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,34 @@ def _tot_size(x, axes):
100100
return prod([s[ai] for ai in axes])
101101

102102

103+
def _workers_to_num_threads(w):
104+
if w is None:
105+
return mkl.domain_get_max_threads(domain='fft')
106+
return int(w)
107+
108+
109+
class Workers:
110+
def __init__(self, workers):
111+
self.workers = workers
112+
self.n_threads = _workers_to_num_threads(workers)
113+
114+
def __enter__(self):
115+
try:
116+
mkl.domain_set_num_threads(self.n_threads, domain='fft')
117+
except:
118+
raise ValueError("Class argument {} result in invalid number of threads {}".format(self.workers, self.n_threads))
119+
120+
def __exit__(self, *args):
121+
# restore default
122+
max_num_threads = mkl.domain_get_max_threads(domain='fft')
123+
mkl.domain_set_num_threads(max_num_threads, domain='fft')
124+
125+
103126
@_implements(_fft.fft)
104127
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
105128
x = _float_utils.__upcast_float16_array(a)
106-
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
129+
with Workers(workers):
130+
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
107131
if _unitary(norm):
108132
output *= 1 / sqrt(output.shape[axis])
109133
return output
@@ -112,7 +136,8 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
112136
@_implements(_fft.ifft)
113137
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
114138
x = _float_utils.__upcast_float16_array(a)
115-
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
139+
with Workers(workers):
140+
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
116141
if _unitary(norm):
117142
output *= sqrt(output.shape[axis])
118143
return output
@@ -121,7 +146,8 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
121146
@_implements(_fft.fft2)
122147
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
123148
x = _float_utils.__upcast_float16_array(a)
124-
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
149+
with Workers(workers):
150+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
125151
if _unitary(norm):
126152
factor = 1
127153
for axis in axes:
@@ -133,7 +159,8 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
133159
@_implements(_fft.ifft2)
134160
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
135161
x = _float_utils.__upcast_float16_array(a)
136-
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
162+
with Workers(workers):
163+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
137164
if _unitary(norm):
138165
factor = 1
139166
_axes = range(output.ndim) if axes is None else axes
@@ -146,7 +173,8 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
146173
@_implements(_fft.fftn)
147174
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
148175
x = _float_utils.__upcast_float16_array(a)
149-
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
176+
with Workers(workers):
177+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
150178
if _unitary(norm):
151179
factor = 1
152180
_axes = range(output.ndim) if axes is None else axes
@@ -159,7 +187,8 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
159187
@_implements(_fft.ifftn)
160188
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
161189
x = _float_utils.__upcast_float16_array(a)
162-
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
190+
with Workers(workers):
191+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
163192
if _unitary(norm):
164193
factor = 1
165194
_axes = range(output.ndim) if axes is None else axes
@@ -170,64 +199,67 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
170199

171200

172201
@_implements(_fft.rfft)
173-
def rfft(a, n=None, axis=-1, norm=None):
202+
def rfft(a, n=None, axis=-1, norm=None, workers=None):
174203
x = _float_utils.__upcast_float16_array(a)
175204
unitary = _unitary(norm)
176205
x = _float_utils.__downcast_float128_array(x)
177206
if unitary and n is None:
178207
x = asarray(x)
179208
n = x.shape[axis]
180-
output = _pydfti.rfft_numpy(x, n=n, axis=axis)
209+
with Workers(workers):
210+
output = _pydfti.rfft_numpy(x, n=n, axis=axis)
181211
if unitary:
182212
output *= 1 / sqrt(n)
183213
return output
184214

185215

186216
@_implements(_fft.irfft)
187-
def irfft(a, n=None, axis=-1, norm=None):
217+
def irfft(a, n=None, axis=-1, norm=None, workers=None):
188218
x = _float_utils.__upcast_float16_array(a)
189219
x = _float_utils.__downcast_float128_array(x)
190-
output = _pydfti.irfft_numpy(x, n=n, axis=axis)
220+
with Workers(workers):
221+
output = _pydfti.irfft_numpy(x, n=n, axis=axis)
191222
if _unitary(norm):
192223
output *= sqrt(output.shape[axis])
193224
return output
194225

195226

196227
@_implements(_fft.rfft2)
197-
def rfft2(a, s=None, axes=(-2, -1), norm=None):
228+
def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
198229
x = _float_utils.__upcast_float16_array(a)
199230
x = _float_utils.__downcast_float128_array(a)
200-
return rfftn(x, s, axes, norm)
231+
return rfftn(x, s, axes, norm, workers)
201232

202233

203234
@_implements(_fft.irfft2)
204-
def irfft2(a, s=None, axes=(-2, -1), norm=None):
235+
def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
205236
x = _float_utils.__upcast_float16_array(a)
206237
x = _float_utils.__downcast_float128_array(x)
207-
return irfftn(x, s, axes, norm)
238+
return irfftn(x, s, axes, norm, workers)
208239

209240

210241
@_implements(_fft.rfftn)
211-
def rfftn(a, s=None, axes=None, norm=None):
242+
def rfftn(a, s=None, axes=None, norm=None, workers=None):
212243
unitary = _unitary(norm)
213244
x = _float_utils.__upcast_float16_array(a)
214245
x = _float_utils.__downcast_float128_array(x)
215246
if unitary:
216247
x = asarray(x)
217248
s, axes = _cook_nd_args(x, s, axes)
218-
219-
output = _pydfti.rfftn_numpy(x, s, axes)
249+
with Workers(workers):
250+
output = _pydfti.rfftn_numpy(x, s, axes)
220251
if unitary:
221252
n_tot = prod(asarray(s, dtype=output.dtype))
222253
output *= 1 / sqrt(n_tot)
223254
return output
224255

225256

226257
@_implements(_fft.irfftn)
227-
def irfftn(a, s=None, axes=None, norm=None):
258+
def irfftn(a, s=None, axes=None, norm=None, workers=None):
228259
x = _float_utils.__upcast_float16_array(a)
229260
x = _float_utils.__downcast_float128_array(x)
230-
output = _pydfti.irfftn_numpy(x, s, axes)
261+
with Workers(workers):
262+
output = _pydfti.irfftn_numpy(x, s, axes)
231263
if _unitary(norm):
232264
output *= sqrt(_tot_size(output, axes))
233265
return output

0 commit comments

Comments
 (0)