|
28 | 28 | from . import _float_utils
|
29 | 29 | import mkl
|
30 | 30 |
|
31 |
| -from numpy.core import (array, asarray, shape, conjugate, take, sqrt, prod) |
32 |
| -from os import cpu_count as os_cpu_count |
33 |
| -import warnings |
| 31 | +from numpy.core import (take, sqrt, prod) |
| 32 | +import contextvars |
| 33 | +import operator |
34 | 34 |
|
35 | 35 |
|
36 | 36 | __doc__ = """
|
@@ -64,15 +64,44 @@ def get_max_threads_count(self):
|
64 | 64 | return self.max_threads_count
|
65 | 65 |
|
66 | 66 |
|
67 |
| -_hardware_counts = _cpu_max_threads_count() |
| 67 | +class _workers_data: |
| 68 | + def __init__(self, workers=None): |
| 69 | + if workers: |
| 70 | + self.workers_ = workers |
| 71 | + else: |
| 72 | + self.workers_ = _cpu_max_threads_count().get_cpu_count() |
| 73 | + self.workers_ = operator.index(self.workers_) |
| 74 | + |
| 75 | + @property |
| 76 | + def workers(self): |
| 77 | + return self.workers_ |
| 78 | + |
| 79 | + @workers.setter |
| 80 | + def workers(self, workers_val): |
| 81 | + self.workerks_ = operator.index(workers_val) |
| 82 | + |
| 83 | + |
| 84 | +_workers_global_settings = contextvars.ContextVar('scipy_backend_workers', default=_workers_data()) |
| 85 | + |
| 86 | + |
| 87 | +def get_workers(): |
| 88 | + "Gets the number of workers used by mkl_fft by default" |
| 89 | + return _workers_global_settings.get().workers |
| 90 | + |
| 91 | + |
| 92 | +def set_workers(n_workers): |
| 93 | + "Set the value of workers used by default, returns the previous value" |
| 94 | + nw = operator.index(n_workers) |
| 95 | + wd = _workers_global_settings.get() |
| 96 | + saved_nw = wd.workers |
| 97 | + wd.workers = nw |
| 98 | + _workers_global_settings.set(wd) |
| 99 | + return saved_nw |
68 | 100 |
|
69 | 101 |
|
70 | 102 | __all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
|
71 | 103 | 'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
|
72 |
| - 'hfft', 'ihfft', 'hfft2', 'ihfft2', 'hfftn', 'ihfftn', |
73 |
| - 'dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn', |
74 |
| - 'fftshift', 'ifftshift', 'fftfreq', 'rfftfreq', 'get_workers', |
75 |
| - 'set_workers', 'next_fast_len', 'DftiBackend'] |
| 104 | + 'get_workers', 'set_workers', 'DftiBackend'] |
76 | 105 |
|
77 | 106 | __ua_domain__ = "numpy.scipy.fft"
|
78 | 107 |
|
@@ -114,27 +143,21 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
|
114 | 143 | return s, axes
|
115 | 144 |
|
116 | 145 |
|
117 |
| -def _tot_size(x, axes): |
118 |
| - s = x.shape |
119 |
| - if axes is None: |
120 |
| - return x.size |
121 |
| - return prod([s[ai] for ai in axes]) |
122 |
| - |
123 |
| - |
124 | 146 | def _workers_to_num_threads(w):
|
125 | 147 | """Handle conversion of workers to a positive number of threads in the
|
126 | 148 | same way as scipy.fft.helpers._workers.
|
127 | 149 | """
|
128 | 150 | if w is None:
|
129 |
| - return _hardware_counts.get_cpu_count() |
130 |
| - _w = int(w) |
| 151 | + return _workers_global_settings.get().workers |
| 152 | + _w = operator.index(w) |
131 | 153 | if (_w == 0):
|
132 | 154 | raise ValueError("Number of workers must be nonzero")
|
133 | 155 | if (_w < 0):
|
134 |
| - _w += _hardware_counts.get_cpu_count() + 1 |
| 156 | + ub = _cpu_max_threads_count().get_cpu_count() |
| 157 | + _w += ub + 1 |
135 | 158 | if _w <= 0:
|
136 | 159 | raise ValueError("workers value out of range; got {}, must not be"
|
137 |
| - " less than {}".format(w, -_hardware_counts.get_cpu_count())) |
| 160 | + " less than {}".format(w, -ub)) |
138 | 161 | return _w
|
139 | 162 |
|
140 | 163 |
|
|
0 commit comments