Skip to content

Commit

Permalink
Fixed bugs in cuda setup.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 4, 2022
1 parent 758c717 commit 8f84674
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
7 changes: 4 additions & 3 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ def initialize(self):
binary_path = package_dir / binary_name

if not binary_path.exists():
print(f"TODO: compile library for specific version: {binary_name}")
print(f"CUDA_SETUP: TODO: compile library for specific version: {binary_name}")
legacy_binary_name = "libbitsandbytes.so"
print(f"Defaulting to {legacy_binary_name}...")
print(f"CUDA_SETUP: Defaulting to {legacy_binary_name}...")
self.lib = ct.cdll.LoadLibrary(package_dir / legacy_binary_name)
else:
self.lib = ct.cdll.LoadLibrary(package_dir / binary_name)
print(f"CUDA_SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path)

@classmethod
def get_instance(cls):
Expand Down
2 changes: 2 additions & 0 deletions bitsandbytes/cuda_setup/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
from .main import evaluate_cuda_setup
14 changes: 9 additions & 5 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_compute_capabilities():
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('ERROR: libcuda.so not found!')
return None

nGpus = ctypes.c_int()
Expand All @@ -70,7 +71,7 @@ def get_compute_capabilities():
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")

return ccs.sort()
return ccs


# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
Expand All @@ -80,7 +81,8 @@ def get_compute_capability():
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
if ccs := get_compute_capabilities() is not None:
ccs = get_compute_capabilities()
if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
Expand All @@ -92,8 +94,7 @@ def evaluate_cuda_setup():
cc = get_compute_capability()
binary_name = "libbitsandbytes_cpu.so"

# FIXME: has_gpu is still unused
if not (has_gpu := bool(cc)):
if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
Expand All @@ -115,13 +116,16 @@ def evaluate_cuda_setup():
ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")
)
cuda_version_string = f"{major}{minor}"
print(f'CUDA_SETUP: Detected CUDA version {cuda_version_string}')

def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
return f"{bin_base_name}_nocublaslt.so"
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"

binary_name = get_binary_name()

return binary_name
8 changes: 4 additions & 4 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ def test_matmullt(
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.0175
assert (idx == 0).sum().item() <= n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx == 0).sum().item() < n * 0.001
assert (idx == 0).sum().item() <= n * 0.001

if has_fp16_weights:
if any(req_grad):
Expand Down Expand Up @@ -391,9 +391,9 @@ def test_matmullt(
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)

assert (idx == 0).sum().item() < n * 0.1
assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)

0 comments on commit 8f84674

Please sign in to comment.