diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 8974c6400..6679c49cc 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -6,6 +6,7 @@ import torch from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, BNB_BACKEND from bitsandbytes.consts import NONPYTORCH_DOC_URL from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -38,6 +39,9 @@ "nvcuda*.dll", # Windows ) +if HIP_ENVIRONMENT: + CUDA_RUNTIME_LIB_PATTERNS = ("libamdhip64.so*") + logger = logging.getLogger(__name__) @@ -105,37 +109,63 @@ def find_cudart_libraries() -> Iterator[Path]: def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: - print( - f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " - f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", - ) + if not HIP_ENVIRONMENT: + print( + f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " + f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + ) + else: + print( + f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}" + ) + binary_path = get_cuda_bnb_library_path(cuda_specs) if not binary_path.exists(): - print_dedented( - f""" - Library not found: {binary_path}. Maybe you need to compile it from source? - If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, - for example, `make CUDA_VERSION=113`. + if not HIP_ENVIRONMENT: + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + else: + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION + in PyTorch Settings matches your ROCM install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) - The CUDA version for the compile might depend on your conda install, if using conda. - Inspect CUDA version via `conda list | grep cuda`. - """, - ) cuda_major, cuda_minor = cuda_specs.cuda_version_tuple - if cuda_major < 11: - print_dedented( - """ - WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). - You will be only to use 8-bit optimizers and quantization routines! - """, - ) + if not HIP_ENVIRONMENT: + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + else: + if (cuda_major, cuda_minor) < (6, 1): + print_dedented( + """ + WARNING: bitandbytes is fully supported only from ROCm 6.1. + """, + ) - print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + if not cuda_specs.has_cublaslt and not HIP_ENVIRONMENT: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! @@ -152,25 +182,41 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: def print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") + print(f"{BNB_BACKEND} SETUP: WARNING! {BNB_BACKEND} runtime files not found in any environmental path.") elif len(cudart_paths) > 1: + backend_version = torch.version.cuda if not HIP_ENVIRONMENT else torch.version.hip print_dedented( f""" - Found duplicate CUDA runtime files (see below). + Found duplicate {BNB_BACKEND} runtime files (see below). - We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, - but this might mismatch with the CUDA version that is needed for bitsandbytes. - To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. - - For example, if you want to use the CUDA version 122, - BNB_CUDA_VERSION=122 python ... - - OR set the environmental variable in your .bashrc: - export BNB_CUDA_VERSION=122 - - In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + We select the PyTorch default {BNB_BACKEND} runtime, which is {backend_version}, + but this might mismatch with the {BNB_BACKEND} version that is needed for bitsandbytes. """, ) + if not HIP_ENVIRONMENT: + print_dedented( + f""" + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + else: + print_dedented( + f""" + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/rocm-6.1.2, + """, + ) + for pth in cudart_paths: - print(f"* Found CUDA runtime at: {pth}") + print(f"* Found {BNB_BACKEND} runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 1ce096f69..ff4d2fd2a 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -11,17 +11,19 @@ ) from bitsandbytes.diagnostics.utils import print_dedented, print_header +from bitsandbytes.cextension import HIP_ENVIRONMENT, BNB_BACKEND def sanity_check(): from bitsandbytes.cextension import lib if lib is None: + compute_backend = "cuda" if not HIP_ENVIRONMENT else "hip" print_dedented( - """ + f""" Couldn't load the bitsandbytes library, likely due to missing binaries. Please ensure bitsandbytes is properly installed. - For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`. + For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND={compute_backend} -S .`. See the documentation for more details if needed. Trying a simple check anyway, but this will likely fail... @@ -49,19 +51,24 @@ def main(): print_header("OTHER") cuda_specs = get_cuda_specs() - print("CUDA specs:", cuda_specs) + if HIP_ENVIRONMENT: + rocm_specs = f" rocm_version_string=\'{cuda_specs.cuda_version_string}\'," + rocm_specs+= f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + print(f"{BNB_BACKEND} specs:{rocm_specs}") + else: + print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): - print("Torch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") - print("2. CUDA not installed") - print("3. You have multiple conflicting CUDA libraries") + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") + print(f"2. {BNB_BACKEND} not installed") + print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: print_cuda_diagnostics(cuda_specs) print_cuda_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") print_header("") - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() print("SUCCESS!")