diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 03d2cbd61..16f0c4445 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -100,11 +100,18 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path else: logger.warning("Could not find the bitsandbytes CUDA binary at %r", cuda_binary_path) + elif torch.backends.mps.is_built(): + binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}" logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) if hasattr(dll, "get_context"): # only a CUDA-built library exposes this return CudaBNBNativeLibrary(dll) + + if "_mps" in str(binary_path): + logger.warning("The installed version of bitsandbytes was compiled with alpha MPS support. " + "This version may become unstable unexpectedly.") + return BNBNativeLibrary(dll) logger.warning( "The installed version of bitsandbytes was compiled without GPU support. "