diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 760d557a4..8a1ac2d92 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -21,6 +21,7 @@ import os from pathlib import Path import platform +import re from typing import Set, Union from warnings import warn @@ -112,9 +113,13 @@ def manual_override(self): if not override_value: return - binary_name = self.binary_name.rsplit(".", 1)[0] - # TODO: what's the magic value `-3` here? - self.binary_name = binary_name[:-3] + f'{override_value}{DYNAMIC_LIBRARY_SUFFIX}' + binary_name_stem, _, binary_name_ext = self.binary_name.rpartition(".") + # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda118`; + # let's remove any trailing numbers: + binary_name_stem = re.sub(r"\d+$", "", binary_name_stem) + # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda`; + # let's tack the new version number and the original extension back on. + self.binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}" warn( f'\n\n{"=" * 80}\n'