39
39
)
40
40
41
41
from numpy .core import (array , asarray , shape , conjugate , take , sqrt , prod )
42
+ from os import cpu_count as os_cpu_count
43
+ import warnings
44
+
45
+ class _cpu_max_threads_count :
46
+ def __init__ (self ):
47
+ self .cpu_count = None
48
+ self .max_threads_count = None
49
+
50
+ def get_cpu_count (self ):
51
+ max_threads = self .get_max_threads_count ()
52
+ if self .cpu_count is None :
53
+ self .cpu_count = os_cpu_count ()
54
+ if self .cpu_count > max_threads :
55
+ warnings .warn (
56
+ ("os.cpu_count() returned value of {} greater than mkl.get_max_threads()'s value of {}. "
57
+ "Using negative values of worker option may amount to requesting more threads than "
58
+ "Intel(R) MKL can acommodate."
59
+ ).format (self .cpu_count , max_threads ))
60
+ return self .cpu_count
61
+
62
+ def get_max_threads_count (self ):
63
+ if self .max_threads_count is None :
64
+ self .max_threads_count = mkl .get_max_threads ()
65
+
66
+ return self .max_threads_count
67
+
68
+
69
+ _hardware_counts = _cpu_max_threads_count ()
70
+
42
71
43
72
__all__ = ['fft' , 'ifft' , 'fft2' , 'ifft2' , 'fftn' , 'ifftn' ,
44
73
'rfft' , 'irfft' , 'rfft2' , 'irfft2' , 'rfftn' , 'irfftn' ,
@@ -101,9 +130,20 @@ def _tot_size(x, axes):
101
130
102
131
103
132
def _workers_to_num_threads (w ):
133
+ """Handle conversion of workers to a positive number of threads in the
134
+ same way as scipy.fft.helpers._workers.
135
+ """
104
136
if w is None :
105
- return mkl .domain_get_max_threads (domain = 'fft' )
106
- return int (w )
137
+ return get_workers ()
138
+ _w = int (w )
139
+ if (_w == 0 ):
140
+ raise ValueError ("Number of workers must be nonzero" )
141
+ if (_w < 0 ):
142
+ _w += _hardware_counts .get_cpu_count () + 1
143
+ if _w <= 0 :
144
+ raise ValueError ("workers value out of range; got {}, must not be"
145
+ " less than {}" .format (w , - _hardware_counts .get_cpu_count ()))
146
+ return _w
107
147
108
148
109
149
class Workers :
@@ -119,8 +159,8 @@ def __enter__(self):
119
159
120
160
def __exit__ (self , * args ):
121
161
# restore default
122
- max_num_threads = mkl . domain_get_max_threads ( domain = 'fft' )
123
- mkl .domain_set_num_threads (max_num_threads , domain = 'fft' )
162
+ n_threads = _hardware_counts . get_max_threads_count ( )
163
+ mkl .domain_set_num_threads (n_threads , domain = 'fft' )
124
164
125
165
126
166
@_implements (_fft .fft )
0 commit comments