39
39
)
40
40
41
41
from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42
+ from os import cpu_count as os_cpu_count
43
+ import warnings
42
44
43
- _max_threads_count = mkl .get_max_threads ()
45
+ class _cpu_max_threads_count :
46
+ def __init__ (self ):
47
+ self .cpu_count = None
48
+ self .max_threads_count = None
44
49
50
+ def get_cpu_count (self ):
51
+ max_threads = self .get_max_threads_count ()
52
+ if self .cpu_count is None :
53
+ self .cpu_count = os_cpu_count ()
54
+ if self .cpu_count > max_threads :
55
+ warnings .warn (
56
+ ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57
+ "Using negative values of worker option may amount to requesting more threads than "
58
+ "Intel(R) MKL can acommodate."
59
+ ).format (self .cpu_count , max_threads ))
60
+ return self .cpu_count
61
+
62
+ def get_max_threads_count (self ):
63
+ if self .max_threads_count is None :
64
+ self .max_threads_count = mkl .get_max_threads ()
65
+
66
+ return self .max_threads_count
67
+
68
+
69
+ _hardware_counts = _cpu_max_threads_count ()
70
+
45
71
46
72
__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
47
73
'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
@@ -113,10 +139,10 @@ def _workers_to_num_threads(w):
113
139
if (_w == 0 ):
114
140
raise ValueError ("Number of workers must be nonzero" )
115
141
if (_w < 0 ):
116
- _w += _max_threads_count + 1
142
+ _w += _hardware_counts . get_cpu_count () + 1
117
143
if _w <= 0 :
118
144
raise ValueError ("workers value out of range; got {}, must not be"
119
- " less than {}" .format (w , - _max_threads_count ))
145
+ " less than {}" .format (w , - _hardware_counts . get_cpu_count () ))
120
146
return _w
121
147
122
148
@@ -133,7 +159,8 @@ def __enter__(self):
133
159
134
160
def __exit__ (self , * args ):
135
161
# restore default
136
- mkl .domain_set_num_threads (_max_threads_count , domain = 'fft' )
162
+ n_threads = _hardware_counts .get_max_threads_count ()
163
+ mkl .domain_set_num_threads (n_threads , domain = 'fft' )
137
164
138
165
139
166
@_implements (_fft .fft )
0 commit comments