-
Notifications
You must be signed in to change notification settings - Fork 671
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding CLI tool for CUDA install debugging - intermediate commit
- Loading branch information
1 parent
bd51532
commit 5d90b38
Showing
7 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from bitsandbytes.debug_cli import cli | ||
|
||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
build is dependent on | ||
- compute capability | ||
- dependent on GPU family | ||
- CUDA version | ||
- Software: | ||
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) | ||
- CuBLAS-LT: full-build 8-bit optimizer | ||
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) | ||
alle Binaries packagen | ||
evaluation: | ||
- if paths faulty, return meaningful error | ||
- else: | ||
- determine CUDA version | ||
- determine capabilities | ||
- based on that set the default path | ||
""" | ||
|
||
from os import environ as env | ||
from pathlib import Path | ||
from typing import Set, Union | ||
from .utils import warn_of_missing_prerequisite, print_err | ||
|
||
|
||
CUDA_RUNTIME_LIB: str = "libcudart.so" | ||
|
||
def tokenize_paths(paths: str) -> Set[Path]: | ||
return { | ||
Path(ld_path) for ld_path in paths.split(':') | ||
if ld_path | ||
} | ||
|
||
def get_cuda_runtime_lib_path( | ||
# TODO: replace this with logic for all paths in env vars | ||
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH") | ||
) -> Union[Path, None]: | ||
""" # TODO: add doc-string | ||
""" | ||
|
||
if not LD_LIBRARY_PATH: | ||
warn_of_missing_prerequisite( | ||
'LD_LIBRARY_PATH is completely missing from environment!' | ||
) | ||
return None | ||
|
||
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH) | ||
|
||
non_existent_directories: Set[Path] = { | ||
path for path in ld_library_paths | ||
if not path.exists() | ||
} | ||
|
||
if non_existent_directories: | ||
print_err( | ||
"WARNING: The following directories listed your path were found to " | ||
f"be non-existent: {non_existent_directories}" | ||
) | ||
|
||
cuda_runtime_libs: Set[Path] = { | ||
path / CUDA_RUNTIME_LIB for path in ld_library_paths | ||
if (path / CUDA_RUNTIME_LIB).is_file() | ||
} - non_existent_directories | ||
|
||
if len(cuda_runtime_libs) > 1: | ||
err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." | ||
raise FileNotFoundError(err_msg) | ||
|
||
elif len(cuda_runtime_libs) < 1: | ||
err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." | ||
raise FileNotFoundError(err_msg) | ||
|
||
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) | ||
return ld_library_paths | ||
|
||
def evaluate_cuda_setup(): | ||
# - if paths faulty, return meaningful error | ||
# - else: | ||
# - determine CUDA version | ||
# - determine capabilities | ||
# - based on that set the default path | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import typer | ||
|
||
|
||
cli = typer.Typer() | ||
|
||
|
||
@cli.callback() | ||
def callback(): | ||
""" | ||
Awesome Portal Gun | ||
""" | ||
|
||
|
||
@cli.command() | ||
def shoot(): | ||
""" | ||
Shoot the portal gun | ||
""" | ||
typer.echo("Shooting portal gun") | ||
|
||
|
||
@cli.command() | ||
def load(): | ||
""" | ||
Load the portal gun | ||
""" | ||
typer.echo("Loading portal gun") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import sys | ||
|
||
def print_err(s: str) -> None: | ||
print(s, file=sys.stderr) | ||
|
||
def warn_of_missing_prerequisite(s: str) -> None: | ||
print_err('WARNING, missing pre-requisite: ' + s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: 8-bit | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.9 | ||
- pytest | ||
- pytorch | ||
- torchaudio | ||
- torchvision | ||
- cudatoolkit=11.1 | ||
- typer | ||
- ca-certificates | ||
- certifi | ||
- openssl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
|
||
from typing import List | ||
|
||
from bitsandbytes.cuda_setup import ( | ||
CUDA_RUNTIME_LIB, | ||
get_cuda_runtime_lib_path, | ||
evaluate_cuda_setup, | ||
tokenize_paths, | ||
) | ||
|
||
|
||
HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ | ||
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), | ||
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), | ||
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), | ||
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), | ||
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_input, expected", | ||
HAPPY_PATH__LD_LIB_TEST_PATHS | ||
) | ||
def test_get_cuda_runtime_lib_path__happy_path( | ||
tmp_path, test_input: str, expected: str | ||
): | ||
for path in tokenize_paths(test_input): | ||
assert False == tmp_path / test_input | ||
test_dir.mkdir() | ||
(test_input / CUDA_RUNTIME_LIB).touch() | ||
assert get_cuda_runtime_lib_path(test_input) == expected | ||
|
||
|
||
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [ | ||
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}", | ||
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}", | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("test_input", UNHAPPY_PATH__LD_LIB_TEST_PATHS) | ||
def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str): | ||
test_input = tmp_path / test_input | ||
(test_input / CUDA_RUNTIME_LIB).touch() | ||
with pytest.raises(FileNotFoundError) as err_info: | ||
get_cuda_runtime_lib_path(test_input) | ||
assert all( | ||
match in err_info | ||
for match in {"duplicate", CUDA_RUNTIME_LIB} | ||
) | ||
|
||
|
||
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): | ||
existent_dir = tmp_path / 'a/b' | ||
existent_dir.mkdir() | ||
non_existent_dir = tmp_path / 'c/d' # non-existent dir | ||
test_input = ":".join([str(existent_dir), str(non_existent_dir)]) | ||
|
||
get_cuda_runtime_lib_path(test_input) | ||
std_err = capsys.readouterr().err | ||
|
||
assert all( | ||
match in std_err | ||
for match in {"WARNING", "non-existent"} | ||
) |