1515import re
1616import sys
1717import ctypes
18+ import itertools
1819import textwrap
1920from typing import final
2021import warnings
@@ -111,20 +112,19 @@ def __init__(self, *, filepath=None, prefix=None, parent=None):
111112 self .prefix = prefix
112113 self .filepath = filepath
113114 self .dynlib = ctypes .CDLL (filepath , mode = _RTLD_NOLOAD )
115+ self ._symbol_prefix , self ._symbol_suffix = self ._find_affixes ()
114116 self .version = self .get_version ()
115117 self .set_additional_attributes ()
116118
117119 def info (self ):
118120 """Return relevant info wrapped in a dict"""
119- exposed_attrs = {
121+ hidden_attrs = ("dynlib" , "parent" , "_symbol_prefix" , "_symbol_suffix" )
122+ return {
120123 "user_api" : self .user_api ,
121124 "internal_api" : self .internal_api ,
122125 "num_threads" : self .num_threads ,
123- ** vars (self ),
126+ ** { k : v for k , v in vars (self ). items () if k not in hidden_attrs } ,
124127 }
125- exposed_attrs .pop ("dynlib" )
126- exposed_attrs .pop ("parent" )
127- return exposed_attrs
128128
129129 def set_additional_attributes (self ):
130130 """Set additional attributes meant to be exposed in the info dict"""
@@ -149,96 +149,87 @@ def set_num_threads(self, num_threads):
149149 def get_version (self ):
150150 """Return the version of the shared library"""
151151
152+ def _find_affixes (self ):
153+ """Return the affixes for the symbols of the shared library"""
154+ return "" , ""
155+
156+ def _get_symbol (self , name ):
157+ """Return the symbol of the shared library accounding for the affixes"""
158+ return getattr (
159+ self .dynlib , f"{ self ._symbol_prefix } { name } { self ._symbol_suffix } " , None
160+ )
161+
152162
153163class OpenBLASController (LibController ):
154164 """Controller class for OpenBLAS"""
155165
156166 user_api = "blas"
157167 internal_api = "openblas"
158- filename_prefixes = ("libopenblas" , "libblas" )
159- check_symbols = (
160- "openblas_get_num_threads" ,
161- "openblas_get_num_threads64_" ,
162- "openblas_set_num_threads" ,
163- "openblas_set_num_threads64_" ,
164- "openblas_get_config" ,
165- "openblas_get_config64_" ,
166- "openblas_get_parallel" ,
167- "openblas_get_parallel64_" ,
168- "openblas_get_corename" ,
169- "openblas_get_corename64_" ,
168+ filename_prefixes = ("libopenblas" , "libblas" , "libscipy_openblas" )
169+
170+ _symbol_prefixes = ("" , "scipy_" )
171+ _symbol_suffixes = ("" , "64_" , "_64" )
172+
173+ # All variations of "openblas_get_num_threads", accounting for the affixes
174+ check_symbols = tuple (
175+ f"{ prefix } openblas_get_num_threads{ suffix } "
176+ for prefix , suffix in itertools .product (_symbol_prefixes , _symbol_suffixes )
170177 )
171178
179+ def _find_affixes (self ):
180+ for prefix , suffix in itertools .product (
181+ self ._symbol_prefixes , self ._symbol_suffixes
182+ ):
183+ if hasattr (self .dynlib , f"{ prefix } openblas_get_num_threads{ suffix } " ):
184+ return prefix , suffix
185+
172186 def set_additional_attributes (self ):
173187 self .threading_layer = self ._get_threading_layer ()
174188 self .architecture = self ._get_architecture ()
175189
176190 def get_num_threads (self ):
177- get_func = getattr (
178- self .dynlib ,
179- "openblas_get_num_threads" ,
180- # Symbols differ when built for 64bit integers in Fortran
181- getattr (self .dynlib , "openblas_get_num_threads64_" , lambda : None ),
182- )
183-
184- return get_func ()
191+ get_num_threads_func = self ._get_symbol ("openblas_get_num_threads" )
192+ if get_num_threads_func is not None :
193+ return get_num_threads_func ()
194+ return None
185195
186196 def set_num_threads (self , num_threads ):
187- set_func = getattr (
188- self .dynlib ,
189- "openblas_set_num_threads" ,
190- # Symbols differ when built for 64bit integers in Fortran
191- getattr (
192- self .dynlib , "openblas_set_num_threads64_" , lambda num_threads : None
193- ),
194- )
195- return set_func (num_threads )
197+ set_num_threads_func = self ._get_symbol ("openblas_set_num_threads" )
198+ if set_num_threads_func is not None :
199+ return set_num_threads_func (num_threads )
200+ return None
196201
197202 def get_version (self ):
198203 # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
199204 # did not expose its version before that.
200- get_config = getattr (
201- self . dynlib ,
202- "openblas_get_config" ,
203- getattr ( self . dynlib , "openblas_get_config64_" , None ),
204- )
205- if get_config is None :
205+ get_version_func = self . _get_symbol ( "openblas_get_config" )
206+ if get_version_func is not None :
207+ get_version_func . restype = ctypes . c_char_p
208+ config = get_version_func (). split ()
209+ if config [ 0 ] == b"OpenBLAS" :
210+ return config [ 1 ]. decode ( "utf-8" )
206211 return None
207-
208- get_config .restype = ctypes .c_char_p
209- config = get_config ().split ()
210- if config [0 ] == b"OpenBLAS" :
211- return config [1 ].decode ("utf-8" )
212212 return None
213213
214214 def _get_threading_layer (self ):
215215 """Return the threading layer of OpenBLAS"""
216- openblas_get_parallel = getattr (
217- self .dynlib ,
218- "openblas_get_parallel" ,
219- getattr (self .dynlib , "openblas_get_parallel64_" , None ),
220- )
221- if openblas_get_parallel is None :
222- return "unknown"
223- threading_layer = openblas_get_parallel ()
224- if threading_layer == 2 :
225- return "openmp"
226- elif threading_layer == 1 :
227- return "pthreads"
228- return "disabled"
216+ get_threading_layer_func = self ._get_symbol ("openblas_get_parallel" )
217+ if get_threading_layer_func is not None :
218+ threading_layer = get_threading_layer_func ()
219+ if threading_layer == 2 :
220+ return "openmp"
221+ elif threading_layer == 1 :
222+ return "pthreads"
223+ return "disabled"
224+ return "unknown"
229225
230226 def _get_architecture (self ):
231227 """Return the architecture detected by OpenBLAS"""
232- get_corename = getattr (
233- self .dynlib ,
234- "openblas_get_corename" ,
235- getattr (self .dynlib , "openblas_get_corename64_" , None ),
236- )
237- if get_corename is None :
238- return None
239-
240- get_corename .restype = ctypes .c_char_p
241- return get_corename ().decode ("utf-8" )
228+ get_architecture_func = self ._get_symbol ("openblas_get_corename" )
229+ if get_architecture_func is not None :
230+ get_architecture_func .restype = ctypes .c_char_p
231+ return get_architecture_func ().decode ("utf-8" )
232+ return None
242233
243234
244235class BLISController (LibController ):
0 commit comments