From 510a8808542064b16ab06ef2a7e973c91ed3c9dd Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 14 Oct 2024 17:20:08 -0400 Subject: [PATCH] Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup --- .github/scripts/build-cuda.sh | 30 ++++++++++++++--------------- CMakeLists.txt | 14 +------------- bitsandbytes/autograd/_functions.py | 30 +++++++++++++++-------------- bitsandbytes/cextension.py | 6 +----- bitsandbytes/cuda_specs.py | 2 +- bitsandbytes/diagnostics/cuda.py | 4 ++-- csrc/ops.cu | 15 --------------- tests/conftest.py | 4 ---- tests/test_cuda_setup_evaluator.py | 20 ------------------- tests/test_linear8bitlt.py | 8 +++++--- tests/test_modules.py | 14 +++++++------- 11 files changed, 48 insertions(+), 99 deletions(-) diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh index 0f9b8d726..26a7075b0 100644 --- a/.github/scripts/build-cuda.sh +++ b/.github/scripts/build-cuda.sh @@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90" [[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????} [[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???} [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja -for NO_CUBLASLT in ON OFF; do - if [ "${build_os:0:6}" == ubuntu ]; then - image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 - echo "Using image $image" - docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ - "apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ - && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ - && cmake --build ." - else - pip install cmake==3.28.3 - cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . - cmake --build . --config Release - fi -done + +if [ "${build_os:0:6}" == ubuntu ]; then + image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 + echo "Using image $image" + docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \ + && cmake --build ." +else + pip install cmake==3.28.3 + cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S . + cmake --build . --config Release +fi + output_dir="output/${build_os}/${build_arch}" mkdir -p "${output_dir}" diff --git a/CMakeLists.txt b/CMakeLists.txt index d305e5a3e..ce3962ff7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,6 @@ # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend -# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. @@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") if(APPLE) message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() - option(NO_CUBLASLT "Disable CUBLAS" OFF) set(BUILD_CUDA ON) set(BUILD_MPS OFF) - message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -166,9 +163,6 @@ if(BUILD_CUDA) list(APPEND SRC_FILES ${CUDA_FILES}) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") - if(NO_CUBLASLT) - string(APPEND BNB_OUTPUT_NAME "_nocublaslt") - endif() add_compile_definitions(BUILD_CUDA) elseif(BUILD_MPS) if(NOT APPLE) @@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) - if(NO_CUBLASLT) - target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) - else() - target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) - endif() - + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse) set_target_properties(bitsandbytes PROPERTIES CUDA_SEPARABLE_COMPILATION ON diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index bc7a51113..03e3add4a 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -283,9 +283,9 @@ def forward( B: torch.Tensor, out=None, bias: Optional[torch.Tensor] = None, - state=MatmulLtState, + state: MatmulLtState = None, ): - # state = state or MatmulLtState() + state = state or MatmulLtState() # default of pytorch behavior if inputs are empty ctx.is_empty = False @@ -318,7 +318,7 @@ def forward( if is_transposed: B = B.contiguous() - if (state.is_training and not has_grad) or state.CB is None: + if (state.is_training and not has_grad) or state.SCB is None: state.reset_grads() # 2. Quantize B @@ -347,7 +347,7 @@ def forward( outliers = state.CB[:, state.idx].clone() state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) else: - subA = state.subB = None + subA = None # 3. Int8 Matmul out32, Sout32 = F.igemmlt(CA, state.CB) @@ -377,7 +377,11 @@ def forward( ctx.save_for_backward(None, None) output_shape = (*input_shape[:-1], state.CB.shape[0]) - return output.reshape(output_shape).clone() + + if len(input_shape) == 3: + return output.view(output_shape).clone() + else: + return output @staticmethod def backward(ctx, grad_output): @@ -400,18 +404,16 @@ def backward(ctx, grad_output): grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - if req_gradB: - # grad_output.T @ A - # grad_weight = grad_output.t().mm(A) - grad_B = torch.matmul(grad_output.t(), A) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) # if req_gradB: - # - # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + + # grad_B = torch.matmul(grad_output.t(), A) # if state.threshold > 0.0 and subA is not None: # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + if req_gradB: + gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: # grad_output @ B.T diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index b7522334c..5bed7fba4 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -37,11 +37,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: The library is not guaranteed to exist at the returned path. """ - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" - if not cuda_specs.has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - library_name += "_nocublaslt" - library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index ed19795a0..e72d57590 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -11,7 +11,7 @@ class CUDASpecs: cuda_version_tuple: Tuple[int, int] @property - def has_cublaslt(self) -> bool: + def has_imma(self) -> bool: return self.highest_compute_capability >= (7, 5) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 8974c6400..45dc98dea 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") - # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + # 7.5 is the minimum CC for int8 tensor cores + if not cuda_specs.has_imma: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! diff --git a/csrc/ops.cu b/csrc/ops.cu index 089a30cc1..e2eddc7ab 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -314,8 +314,6 @@ int roundoff(int v, int d) { } -#ifdef NO_CUBLASLT -#else template cublasLtOrder_t get_order() { switch(ORDER) @@ -347,7 +345,6 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); -#endif template int get_leading_dim(int dim1, int dim2) @@ -379,8 +376,6 @@ template int get_leading_dim(int dim1, int dim2) template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { -#ifdef NO_CUBLASLT -#else cublasLtOrder_t orderA = get_order(); cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -419,7 +414,6 @@ template void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -#endif } template int igemmlt( @@ -513,9 +507,6 @@ template int igemmlt( template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { -#ifdef NO_CUBLASLT - return ERR_NOT_IMPLEMENTED; -#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -570,7 +561,6 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; -#endif // NO_CUBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) @@ -681,10 +671,6 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { - -#ifdef NO_CUBLASLT -#else - cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -731,7 +717,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); -#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) diff --git a/tests/conftest.py b/tests/conftest.py index 59146963d..c029c3cb5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,10 +7,6 @@ def pytest_runtest_call(item): try: item.runtest() - except NotImplementedError as nie: - if "NO_CUBLASLT" in str(nie): - pytest.skip("CUBLASLT not available") - raise except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled": pytest.skip("Torch not compiled with CUDA enabled") diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index b13f8b6c6..79406472e 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -13,15 +13,6 @@ def cuda120_spec() -> CUDASpecs: ) -@pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="111", - highest_compute_capability=(7, 2), - cuda_version_tuple=(11, 1), - ) - - def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" @@ -31,14 +22,3 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? - - -def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog): - monkeypatch.setenv("BNB_CUDA_VERSION", "125") - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt" - assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? - - -def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): - monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 149d9a93c..48c3a9ea8 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -69,11 +69,13 @@ def test_linear_no_igemmlt(): fx_ours = linear_custom(x_ours).float() (fx_ours * grad_proj).mean().backward() + + assert linear_custom.state.CB is not None + assert not linear_custom.state.has_fp16_weights assert torch.allclose(fx_ref, fx_ours, atol=0.02) assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) - assert not linear_custom.state.has_fp16_weights - assert linear_custom.state.CB is not None - assert linear_custom.state.CxB is None + + # assert linear_custom.state.CxB is None @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights")) diff --git a/tests/test_modules.py b/tests/test_modules.py index 1f1b17584..c84ffa42a 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -529,16 +529,16 @@ def test_linear_kbit_fp32_bias(module): @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) def test_kbit_backprop(module): b = 16 - dim1 = 32 - dim2 = 48 + dim1 = 36 + dim2 = 56 # dim1 = 37 # dim2 = 83 - ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 16)]) - ref[1].weight.requires_grad = False + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)]) + # ref[1].weight.requires_grad = False torch.nn.init.kaiming_normal_(ref[0].weight) torch.nn.init.kaiming_normal_(ref[1].weight) - kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 16)]) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)]) kbit[0].weight.detach().copy_(ref[0].weight) kbit[1].weight.detach().copy_(ref[1].weight) kbit[0].bias.detach().copy_(ref[0].bias) @@ -572,8 +572,8 @@ def test_kbit_backprop(module): relerrs1.append(relerr1.mean().item()) relerrs2.append(relerr2.mean().item()) - # if isinstance(module, bnb.nn.Linear8bitLt): - if module == bnb.nn.Linear8bitLt: + if isinstance(module, bnb.nn.Linear8bitLt): + # if module == bnb.nn.Linear8bitLt: assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) else: