Skip to content

Commit 2d79309

Browse files
Merge pull request #47 from IntelPython/adjust-workers-behavior-in-fft-backend
Adjusted workes to threads logic to agree with what is in scipy.fft
2 parents 357a0b7 + 8b34758 commit 2d79309

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,35 @@
3939
)
4040

4141
from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
42+
from os import cpu_count as os_cpu_count
43+
import warnings
44+
45+
class _cpu_max_threads_count:
46+
def __init__(self):
47+
self.cpu_count = None
48+
self.max_threads_count = None
49+
50+
def get_cpu_count(self):
51+
max_threads = self.get_max_threads_count()
52+
if self.cpu_count is None:
53+
self.cpu_count = os_cpu_count()
54+
if self.cpu_count > max_threads:
55+
warnings.warn(
56+
("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57+
"Using negative values of worker option may amount to requesting more threads than "
58+
"Intel(R) MKL can acommodate."
59+
).format(self.cpu_count, max_threads))
60+
return self.cpu_count
61+
62+
def get_max_threads_count(self):
63+
if self.max_threads_count is None:
64+
self.max_threads_count = mkl.get_max_threads()
65+
66+
return self.max_threads_count
67+
68+
69+
_hardware_counts = _cpu_max_threads_count()
70+
4271

4372
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
4473
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
@@ -101,9 +130,20 @@ def _tot_size(x, axes):
101130

102131

103132
def _workers_to_num_threads(w):
133+
"""Handle conversion of workers to a positive number of threads in the
134+
same way as scipy.fft.helpers._workers.
135+
"""
104136
if w is None:
105-
return mkl.domain_get_max_threads(domain='fft')
106-
return int(w)
137+
return get_workers()
138+
_w = int(w)
139+
if (_w == 0):
140+
raise ValueError("Number of workers must be nonzero")
141+
if (_w < 0):
142+
_w += _hardware_counts.get_cpu_count() + 1
143+
if _w <= 0:
144+
raise ValueError("workers value out of range; got {}, must not be"
145+
" less than {}".format(w, -_hardware_counts.get_cpu_count()))
146+
return _w
107147

108148

109149
class Workers:
@@ -119,8 +159,8 @@ def __enter__(self):
119159

120160
def __exit__(self, *args):
121161
# restore default
122-
max_num_threads = mkl.domain_get_max_threads(domain='fft')
123-
mkl.domain_set_num_threads(max_num_threads, domain='fft')
162+
n_threads = _hardware_counts.get_max_threads_count()
163+
mkl.domain_set_num_threads(n_threads, domain='fft')
124164

125165

126166
@_implements(_fft.fft)

0 commit comments

Comments
 (0)