Skip to content

Commit

Permalink
Update diagnostic functions for ROCm
Browse files Browse the repository at this point in the history
  • Loading branch information
pnunna93 committed Aug 24, 2024
1 parent 52ba52e commit 755dfbe
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 45 deletions.
120 changes: 83 additions & 37 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from bitsandbytes.cextension import get_cuda_bnb_library_path
from bitsandbytes.cextension import HIP_ENVIRONMENT, BNB_BACKEND
from bitsandbytes.consts import NONPYTORCH_DOC_URL
from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.diagnostics.utils import print_dedented
Expand Down Expand Up @@ -38,6 +39,9 @@
"nvcuda*.dll", # Windows
)

if HIP_ENVIRONMENT:
CUDA_RUNTIME_LIB_PATTERNS = ("libamdhip64.so*")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -105,37 +109,63 @@ def find_cudart_libraries() -> Iterator[Path]:


def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
)
if not HIP_ENVIRONMENT:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
)
else:
print(
f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}"
)


binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.
if not HIP_ENVIRONMENT:
print_dedented(
f"""
Library not found: {binary_path}. Maybe you need to compile it from source?
If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`,
for example, `make CUDA_VERSION=113`.
The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
""",
)
else:
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION
in PyTorch Settings matches your ROCM install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
)

The CUDA version for the compile might depend on your conda install, if using conda.
Inspect CUDA version via `conda list | grep cuda`.
""",
)

cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
if cuda_major < 11:
print_dedented(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
""",
)
if not HIP_ENVIRONMENT:
if cuda_major < 11:
print_dedented(
"""
WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8().
You will be only to use 8-bit optimizers and quantization routines!
""",
)

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
else:
if (cuda_major, cuda_minor) < (6, 1):
print_dedented(
"""
WARNING: bitandbytes is fully supported only from ROCm 6.1.
""",
)

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")

# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
if not cuda_specs.has_cublaslt and not HIP_ENVIRONMENT:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
Expand All @@ -152,25 +182,41 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
def print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
print(f"{BNB_BACKEND} SETUP: WARNING! {BNB_BACKEND} runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
backend_version = torch.version.cuda if not HIP_ENVIRONMENT else torch.version.hip
print_dedented(
f"""
Found duplicate CUDA runtime files (see below).
Found duplicate {BNB_BACKEND} runtime files (see below).
We select the PyTorch default CUDA runtime, which is {torch.version.cuda},
but this might mismatch with the CUDA version that is needed for bitsandbytes.
To override this behavior set the `BNB_CUDA_VERSION=<version string, e.g. 122>` environmental variable.
For example, if you want to use the CUDA version 122,
BNB_CUDA_VERSION=122 python ...
OR set the environmental variable in your .bashrc:
export BNB_CUDA_VERSION=122
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
We select the PyTorch default {BNB_BACKEND} runtime, which is {backend_version},
but this might mismatch with the {BNB_BACKEND} version that is needed for bitsandbytes.
""",
)
if not HIP_ENVIRONMENT:
print_dedented(
f"""
To override this behavior set the `BNB_CUDA_VERSION=<version string, e.g. 122>` environmental variable.
For example, if you want to use the CUDA version 122,
BNB_CUDA_VERSION=122 python ...
OR set the environmental variable in your .bashrc:
export BNB_CUDA_VERSION=122
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""",
)
else:
print_dedented(
f"""
To resolve it, install PyTorch built for the ROCm version you want to use
and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/rocm-6.1.2,
""",
)

for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}")
print(f"* Found {BNB_BACKEND} runtime at: {pth}")
23 changes: 15 additions & 8 deletions bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
)
from bitsandbytes.diagnostics.utils import print_dedented, print_header

from bitsandbytes.cextension import HIP_ENVIRONMENT, BNB_BACKEND

def sanity_check():
from bitsandbytes.cextension import lib

if lib is None:
compute_backend = "cuda" if not HIP_ENVIRONMENT else "hip"
print_dedented(
"""
f"""
Couldn't load the bitsandbytes library, likely due to missing binaries.
Please ensure bitsandbytes is properly installed.
For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND=cuda -S .`.
For source installations, compile the binaries with `cmake -DCOMPUTE_BACKEND={compute_backend} -S .`.
See the documentation for more details if needed.
Trying a simple check anyway, but this will likely fail...
Expand Down Expand Up @@ -49,19 +51,24 @@ def main():

print_header("OTHER")
cuda_specs = get_cuda_specs()
print("CUDA specs:", cuda_specs)
if HIP_ENVIRONMENT:
rocm_specs = f" rocm_version_string=\'{cuda_specs.cuda_version_string}\',"
rocm_specs+= f" rocm_version_tuple={cuda_specs.cuda_version_tuple}"
print(f"{BNB_BACKEND} specs:{rocm_specs}")
else:
print(f"{BNB_BACKEND} specs:{cuda_specs}")
if not torch.cuda.is_available():
print("Torch says CUDA is not available. Possible reasons:")
print("1. CUDA driver not installed")
print("2. CUDA not installed")
print("3. You have multiple conflicting CUDA libraries")
print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:")
print(f"1. {BNB_BACKEND} driver not installed")
print(f"2. {BNB_BACKEND} not installed")
print(f"3. You have multiple conflicting {BNB_BACKEND} libraries")
if cuda_specs:
print_cuda_diagnostics(cuda_specs)
print_cuda_runtime_diagnostics()
print_header("")
print_header("DEBUG INFO END")
print_header("")
print("Checking that the library is importable and CUDA is callable...")
print(f"Checking that the library is importable and {BNB_BACKEND} is callable...")
try:
sanity_check()
print("SUCCESS!")
Expand Down

0 comments on commit 755dfbe

Please sign in to comment.