Skip to content

Commit 8b34758

Browse files
use os.cpu_count to translate negative value of worker keyword, but issue a warning if cpu_count(0 ends up being higher than MKL's max_threads. The warning is only issued once
1 parent 26670a4 commit 8b34758

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,35 @@
3939
)
4040

4141
from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod)
42+
from os import cpu_count as os_cpu_count
43+
import warnings
4244

43-
_max_threads_count = mkl.get_max_threads()
45+
class _cpu_max_threads_count:
46+
def __init__(self):
47+
self.cpu_count = None
48+
self.max_threads_count = None
4449

50+
def get_cpu_count(self):
51+
max_threads = self.get_max_threads_count()
52+
if self.cpu_count is None:
53+
self.cpu_count = os_cpu_count()
54+
if self.cpu_count > max_threads:
55+
warnings.warn(
56+
("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57+
"Using negative values of worker option may amount to requesting more threads than "
58+
"Intel(R) MKL can acommodate."
59+
).format(self.cpu_count, max_threads))
60+
return self.cpu_count
61+
62+
def get_max_threads_count(self):
63+
if self.max_threads_count is None:
64+
self.max_threads_count = mkl.get_max_threads()
65+
66+
return self.max_threads_count
67+
68+
69+
_hardware_counts = _cpu_max_threads_count()
70+
4571

4672
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
4773
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
@@ -113,10 +139,10 @@ def _workers_to_num_threads(w):
113139
if (_w == 0):
114140
raise ValueError("Number of workers must be nonzero")
115141
if (_w < 0):
116-
_w += _max_threads_count + 1
142+
_w += _hardware_counts.get_cpu_count() + 1
117143
if _w <= 0:
118144
raise ValueError("workers value out of range; got {}, must not be"
119-
" less than {}".format(w, -_max_threads_count))
145+
" less than {}".format(w, -_hardware_counts.get_cpu_count()))
120146
return _w
121147

122148

@@ -133,7 +159,8 @@ def __enter__(self):
133159

134160
def __exit__(self, *args):
135161
# restore default
136-
mkl.domain_set_num_threads(_max_threads_count, domain='fft')
162+
n_threads = _hardware_counts.get_max_threads_count()
163+
mkl.domain_set_num_threads(n_threads, domain='fft')
137164

138165

139166
@_implements(_fft.fft)

0 commit comments

Comments
 (0)