Skip to content

Commit 8dd980b

Browse files
authored
FEA Extend OpenBLAS controller to support scipy_openblas (#175)
1 parent 5282c0b commit 8dd980b

File tree

6 files changed

+84
-70
lines changed

6 files changed

+84
-70
lines changed

.azure_pipeline.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ stages:
5858
name: Linux
5959
vmImage: ubuntu-20.04
6060
matrix:
61+
# Linux environment with development versions of numpy and scipy
62+
pylatest_pip_dev:
63+
PACKAGER: 'pip-dev'
64+
PYTHON_VERSION: '*'
65+
CC_OUTER_LOOP: 'gcc'
66+
CC_INNER_LOOP: 'gcc'
6167
# Linux environment to test that packages that comes with Ubuntu 20.04
6268
# are correctly handled.
6369
py38_ubuntu_atlas_gcc_gcc:

CHANGES.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
3.5.0 (TDB)
2+
===========
3+
4+
- Added support for the Scientific Python version of OpenBLAS
5+
(https://github.com/MacPython/openblas-libs), which exposes symbols with different
6+
names than the ones of the original OpenBLAS library.
7+
https://github.com/joblib/threadpoolctl/pull/175
8+
19
3.4.0 (2024-03-20)
210
==================
311

continuous_integration/install.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ elif [[ "$PACKAGER" == "pip" ]]; then
6565
pip install numpy scipy
6666
fi
6767

68+
elif [[ "$PACKAGER" == "pip-dev" ]]; then
69+
# Use conda to build an empty python env and then use pip to install
70+
# numpy and scipy dev versions
71+
TO_INSTALL="python=$PYTHON_VERSION pip"
72+
make_conda $TO_INSTALL
73+
74+
dev_anaconda_url=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
75+
pip install --pre --upgrade --timeout=60 --extra-index $dev_anaconda_url numpy scipy
76+
6877
elif [[ "$PACKAGER" == "ubuntu" ]]; then
6978
# Remove the ubuntu toolchain PPA that seems to be invalid:
7079
# https://github.com/scikit-learn/scikit-learn/pull/13934

continuous_integration/posix.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
steps:
1515
- bash: echo "##vso[task.prependpath]$CONDA/bin"
1616
displayName: Add conda to PATH
17-
condition: or(startsWith(variables['PACKAGER'], 'conda'), eq(variables['PACKAGER'], 'pip'))
17+
condition: or(startsWith(variables['PACKAGER'], 'conda'), startsWith(variables['PACKAGER'], 'pip'))
1818
- bash: sudo chown -R $USER $CONDA
1919
# On Hosted macOS, the agent user doesn't have ownership of Miniconda's installation directory/
2020
# We need to take ownership if we want to update conda or install packages globally

continuous_integration/test_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ set -e
55
if [[ "$PACKAGER" == conda* ]]; then
66
source activate $VIRTUALENV
77
conda list
8-
elif [[ "$PACKAGER" == "pip" ]]; then
8+
elif [[ "$PACKAGER" == pip* ]]; then
99
# we actually use conda to install the base environment:
1010
source activate $VIRTUALENV
1111
pip list

threadpoolctl.py

Lines changed: 59 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import sys
1717
import ctypes
18+
import itertools
1819
import textwrap
1920
from typing import final
2021
import warnings
@@ -111,20 +112,19 @@ def __init__(self, *, filepath=None, prefix=None, parent=None):
111112
self.prefix = prefix
112113
self.filepath = filepath
113114
self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
115+
self._symbol_prefix, self._symbol_suffix = self._find_affixes()
114116
self.version = self.get_version()
115117
self.set_additional_attributes()
116118

117119
def info(self):
118120
"""Return relevant info wrapped in a dict"""
119-
exposed_attrs = {
121+
hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
122+
return {
120123
"user_api": self.user_api,
121124
"internal_api": self.internal_api,
122125
"num_threads": self.num_threads,
123-
**vars(self),
126+
**{k: v for k, v in vars(self).items() if k not in hidden_attrs},
124127
}
125-
exposed_attrs.pop("dynlib")
126-
exposed_attrs.pop("parent")
127-
return exposed_attrs
128128

129129
def set_additional_attributes(self):
130130
"""Set additional attributes meant to be exposed in the info dict"""
@@ -149,96 +149,87 @@ def set_num_threads(self, num_threads):
149149
def get_version(self):
150150
"""Return the version of the shared library"""
151151

152+
def _find_affixes(self):
153+
"""Return the affixes for the symbols of the shared library"""
154+
return "", ""
155+
156+
def _get_symbol(self, name):
157+
"""Return the symbol of the shared library accounding for the affixes"""
158+
return getattr(
159+
self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
160+
)
161+
152162

153163
class OpenBLASController(LibController):
154164
"""Controller class for OpenBLAS"""
155165

156166
user_api = "blas"
157167
internal_api = "openblas"
158-
filename_prefixes = ("libopenblas", "libblas")
159-
check_symbols = (
160-
"openblas_get_num_threads",
161-
"openblas_get_num_threads64_",
162-
"openblas_set_num_threads",
163-
"openblas_set_num_threads64_",
164-
"openblas_get_config",
165-
"openblas_get_config64_",
166-
"openblas_get_parallel",
167-
"openblas_get_parallel64_",
168-
"openblas_get_corename",
169-
"openblas_get_corename64_",
168+
filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")
169+
170+
_symbol_prefixes = ("", "scipy_")
171+
_symbol_suffixes = ("", "64_", "_64")
172+
173+
# All variations of "openblas_get_num_threads", accounting for the affixes
174+
check_symbols = tuple(
175+
f"{prefix}openblas_get_num_threads{suffix}"
176+
for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
170177
)
171178

179+
def _find_affixes(self):
180+
for prefix, suffix in itertools.product(
181+
self._symbol_prefixes, self._symbol_suffixes
182+
):
183+
if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
184+
return prefix, suffix
185+
172186
def set_additional_attributes(self):
173187
self.threading_layer = self._get_threading_layer()
174188
self.architecture = self._get_architecture()
175189

176190
def get_num_threads(self):
177-
get_func = getattr(
178-
self.dynlib,
179-
"openblas_get_num_threads",
180-
# Symbols differ when built for 64bit integers in Fortran
181-
getattr(self.dynlib, "openblas_get_num_threads64_", lambda: None),
182-
)
183-
184-
return get_func()
191+
get_num_threads_func = self._get_symbol("openblas_get_num_threads")
192+
if get_num_threads_func is not None:
193+
return get_num_threads_func()
194+
return None
185195

186196
def set_num_threads(self, num_threads):
187-
set_func = getattr(
188-
self.dynlib,
189-
"openblas_set_num_threads",
190-
# Symbols differ when built for 64bit integers in Fortran
191-
getattr(
192-
self.dynlib, "openblas_set_num_threads64_", lambda num_threads: None
193-
),
194-
)
195-
return set_func(num_threads)
197+
set_num_threads_func = self._get_symbol("openblas_set_num_threads")
198+
if set_num_threads_func is not None:
199+
return set_num_threads_func(num_threads)
200+
return None
196201

197202
def get_version(self):
198203
# None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
199204
# did not expose its version before that.
200-
get_config = getattr(
201-
self.dynlib,
202-
"openblas_get_config",
203-
getattr(self.dynlib, "openblas_get_config64_", None),
204-
)
205-
if get_config is None:
205+
get_version_func = self._get_symbol("openblas_get_config")
206+
if get_version_func is not None:
207+
get_version_func.restype = ctypes.c_char_p
208+
config = get_version_func().split()
209+
if config[0] == b"OpenBLAS":
210+
return config[1].decode("utf-8")
206211
return None
207-
208-
get_config.restype = ctypes.c_char_p
209-
config = get_config().split()
210-
if config[0] == b"OpenBLAS":
211-
return config[1].decode("utf-8")
212212
return None
213213

214214
def _get_threading_layer(self):
215215
"""Return the threading layer of OpenBLAS"""
216-
openblas_get_parallel = getattr(
217-
self.dynlib,
218-
"openblas_get_parallel",
219-
getattr(self.dynlib, "openblas_get_parallel64_", None),
220-
)
221-
if openblas_get_parallel is None:
222-
return "unknown"
223-
threading_layer = openblas_get_parallel()
224-
if threading_layer == 2:
225-
return "openmp"
226-
elif threading_layer == 1:
227-
return "pthreads"
228-
return "disabled"
216+
get_threading_layer_func = self._get_symbol("openblas_get_parallel")
217+
if get_threading_layer_func is not None:
218+
threading_layer = get_threading_layer_func()
219+
if threading_layer == 2:
220+
return "openmp"
221+
elif threading_layer == 1:
222+
return "pthreads"
223+
return "disabled"
224+
return "unknown"
229225

230226
def _get_architecture(self):
231227
"""Return the architecture detected by OpenBLAS"""
232-
get_corename = getattr(
233-
self.dynlib,
234-
"openblas_get_corename",
235-
getattr(self.dynlib, "openblas_get_corename64_", None),
236-
)
237-
if get_corename is None:
238-
return None
239-
240-
get_corename.restype = ctypes.c_char_p
241-
return get_corename().decode("utf-8")
228+
get_architecture_func = self._get_symbol("openblas_get_corename")
229+
if get_architecture_func is not None:
230+
get_architecture_func.restype = ctypes.c_char_p
231+
return get_architecture_func().decode("utf-8")
232+
return None
242233

243234

244235
class BLISController(LibController):

0 commit comments

Comments
 (0)