@@ -100,10 +100,34 @@ def _tot_size(x, axes):
100
100
return prod ([s [ai ] for ai in axes ])
101
101
102
102
103
+ def _workers_to_num_threads (w ):
104
+ if w is None :
105
+ return mkl .domain_get_max_threads (domain = 'fft' )
106
+ return int (w )
107
+
108
+
109
+ class Workers :
110
+ def __init__ (self , workers ):
111
+ self .workers = workers
112
+ self .n_threads = _workers_to_num_threads (workers )
113
+
114
+ def __enter__ (self ):
115
+ try :
116
+ mkl .domain_set_num_threads (self .n_threads , domain = 'fft' )
117
+ except :
118
+ raise ValueError ("Class argument {} result in invalid number of threads {}" .format (self .workers , self .n_threads ))
119
+
120
+ def __exit__ (self , * args ):
121
+ # restore default
122
+ max_num_threads = mkl .domain_get_max_threads (domain = 'fft' )
123
+ mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
124
+
125
+
103
126
@_implements (_fft .fft )
104
127
def fft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
105
128
x = _float_utils .__upcast_float16_array (a )
106
- output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
129
+ with Workers (workers ):
130
+ output = _pydfti .fft (x , n = n , axis = axis , overwrite_x = overwrite_x )
107
131
if _unitary (norm ):
108
132
output *= 1 / sqrt (output .shape [axis ])
109
133
return output
@@ -112,7 +136,8 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
112
136
@_implements (_fft .ifft )
113
137
def ifft (a , n = None , axis = - 1 , norm = None , overwrite_x = False , workers = None ):
114
138
x = _float_utils .__upcast_float16_array (a )
115
- output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
139
+ with Workers (workers ):
140
+ output = _pydfti .ifft (x , n = n , axis = axis , overwrite_x = overwrite_x )
116
141
if _unitary (norm ):
117
142
output *= sqrt (output .shape [axis ])
118
143
return output
@@ -121,7 +146,8 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
121
146
@_implements (_fft .fft2 )
122
147
def fft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
123
148
x = _float_utils .__upcast_float16_array (a )
124
- output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
149
+ with Workers (workers ):
150
+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
125
151
if _unitary (norm ):
126
152
factor = 1
127
153
for axis in axes :
@@ -133,7 +159,8 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
133
159
@_implements (_fft .ifft2 )
134
160
def ifft2 (a , s = None , axes = (- 2 ,- 1 ), norm = None , overwrite_x = False , workers = None ):
135
161
x = _float_utils .__upcast_float16_array (a )
136
- output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
162
+ with Workers (workers ):
163
+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
137
164
if _unitary (norm ):
138
165
factor = 1
139
166
_axes = range (output .ndim ) if axes is None else axes
@@ -146,7 +173,8 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
146
173
@_implements (_fft .fftn )
147
174
def fftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
148
175
x = _float_utils .__upcast_float16_array (a )
149
- output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
176
+ with Workers (workers ):
177
+ output = _pydfti .fftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
150
178
if _unitary (norm ):
151
179
factor = 1
152
180
_axes = range (output .ndim ) if axes is None else axes
@@ -159,7 +187,8 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
159
187
@_implements (_fft .ifftn )
160
188
def ifftn (a , s = None , axes = None , norm = None , overwrite_x = False , workers = None ):
161
189
x = _float_utils .__upcast_float16_array (a )
162
- output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
190
+ with Workers (workers ):
191
+ output = _pydfti .ifftn (x , shape = s , axes = axes , overwrite_x = overwrite_x )
163
192
if _unitary (norm ):
164
193
factor = 1
165
194
_axes = range (output .ndim ) if axes is None else axes
@@ -170,64 +199,67 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
170
199
171
200
172
201
@_implements (_fft .rfft )
173
- def rfft (a , n = None , axis = - 1 , norm = None ):
202
+ def rfft (a , n = None , axis = - 1 , norm = None , workers = None ):
174
203
x = _float_utils .__upcast_float16_array (a )
175
204
unitary = _unitary (norm )
176
205
x = _float_utils .__downcast_float128_array (x )
177
206
if unitary and n is None :
178
207
x = asarray (x )
179
208
n = x .shape [axis ]
180
- output = _pydfti .rfft_numpy (x , n = n , axis = axis )
209
+ with Workers (workers ):
210
+ output = _pydfti .rfft_numpy (x , n = n , axis = axis )
181
211
if unitary :
182
212
output *= 1 / sqrt (n )
183
213
return output
184
214
185
215
186
216
@_implements (_fft .irfft )
187
- def irfft (a , n = None , axis = - 1 , norm = None ):
217
+ def irfft (a , n = None , axis = - 1 , norm = None , workers = None ):
188
218
x = _float_utils .__upcast_float16_array (a )
189
219
x = _float_utils .__downcast_float128_array (x )
190
- output = _pydfti .irfft_numpy (x , n = n , axis = axis )
220
+ with Workers (workers ):
221
+ output = _pydfti .irfft_numpy (x , n = n , axis = axis )
191
222
if _unitary (norm ):
192
223
output *= sqrt (output .shape [axis ])
193
224
return output
194
225
195
226
196
227
@_implements (_fft .rfft2 )
197
- def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
228
+ def rfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
198
229
x = _float_utils .__upcast_float16_array (a )
199
230
x = _float_utils .__downcast_float128_array (a )
200
- return rfftn (x , s , axes , norm )
231
+ return rfftn (x , s , axes , norm , workers )
201
232
202
233
203
234
@_implements (_fft .irfft2 )
204
- def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None ):
235
+ def irfft2 (a , s = None , axes = (- 2 , - 1 ), norm = None , workers = None ):
205
236
x = _float_utils .__upcast_float16_array (a )
206
237
x = _float_utils .__downcast_float128_array (x )
207
- return irfftn (x , s , axes , norm )
238
+ return irfftn (x , s , axes , norm , workers )
208
239
209
240
210
241
@_implements (_fft .rfftn )
211
- def rfftn (a , s = None , axes = None , norm = None ):
242
+ def rfftn (a , s = None , axes = None , norm = None , workers = None ):
212
243
unitary = _unitary (norm )
213
244
x = _float_utils .__upcast_float16_array (a )
214
245
x = _float_utils .__downcast_float128_array (x )
215
246
if unitary :
216
247
x = asarray (x )
217
248
s , axes = _cook_nd_args (x , s , axes )
218
-
219
- output = _pydfti .rfftn_numpy (x , s , axes )
249
+ with Workers ( workers ):
250
+ output = _pydfti .rfftn_numpy (x , s , axes )
220
251
if unitary :
221
252
n_tot = prod (asarray (s , dtype = output .dtype ))
222
253
output *= 1 / sqrt (n_tot )
223
254
return output
224
255
225
256
226
257
@_implements (_fft .irfftn )
227
- def irfftn (a , s = None , axes = None , norm = None ):
258
+ def irfftn (a , s = None , axes = None , norm = None , workers = None ):
228
259
x = _float_utils .__upcast_float16_array (a )
229
260
x = _float_utils .__downcast_float128_array (x )
230
- output = _pydfti .irfftn_numpy (x , s , axes )
261
+ with Workers (workers ):
262
+ output = _pydfti .irfftn_numpy (x , s , axes )
231
263
if _unitary (norm ):
232
264
output *= sqrt (_tot_size (output , axes ))
233
265
return output
0 commit comments