diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a25f53f46..8fa7076de 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -122,20 +122,18 @@ jobs: build_os=${{ matrix.os }} build_arch=${{ matrix.arch }} [[ "${{ matrix.os }}" = windows-* ]] && python3 -m pip install ninja - for NO_CUBLASLT in ON OFF; do - if [ ${build_os:0:6} == ubuntu ]; then - image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04 - echo "Using image $image" - docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \ - "apt-get update \ - && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ - && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ - && cmake --build ." - else - cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . - cmake --build . --config Release - fi - done + if [ ${build_os:0:6} == ubuntu ]; then + image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04 + echo "Using image $image" + docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"50;52;60;61;70;75;80;86;89;90\" . \ + && cmake --build ." + else + cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCMAKE_BUILD_TYPE=Release -S . + cmake --build . --config Release + fi mkdir -p output/${{ matrix.os }}/${{ matrix.arch }} ( shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} output/${{ matrix.os }}/${{ matrix.arch }}/ ) - name: Upload build artifact diff --git a/CMakeLists.txt b/CMakeLists.txt index 62ff4e535..06898d08c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,6 @@ # For MSVC: `cmake -B build . && cmake --build build --config Release` # You can also use the following options and variables # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend -# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # is whatever CMake finds on your path. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. @@ -39,10 +38,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") if(APPLE) message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() - option(NO_CUBLASLT "Disable CUBLAS" OFF) set(BUILD_CUDA ON) set(BUILD_MPS OFF) - message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}") elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -145,9 +142,7 @@ if(BUILD_CUDA) list(APPEND SRC_FILES ${CUDA_FILES}) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") - if(NO_CUBLASLT) - string(APPEND BNB_OUTPUT_NAME "_nocublaslt") - endif() + add_compile_definitions(BUILD_CUDA) elseif(BUILD_MPS) if(NOT APPLE) @@ -173,13 +168,11 @@ else() set(GPU_SOURCES) endif() - if(WIN32) # Export all symbols set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() -# Weird MSVC hacks if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") endif() @@ -192,12 +185,11 @@ target_include_directories(bitsandbytes PUBLIC csrc include) if(BUILD_CUDA) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) - if(NO_CUBLASLT) - target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT) - else() - target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt) - endif() + + # Note: As of CUDA 11.0, cublas depends on cublasLt. + # See: https://gitlab.kitware.com/cmake/cmake/-/merge_requests/6857/diffs + # It is listed here for assurance. In CMake > 3.23.0, it's implicit when linking CUDA::cublas. + target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse) set_target_properties(bitsandbytes PROPERTIES @@ -220,4 +212,4 @@ if(MSVC) set_target_properties(bitsandbytes PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG "${PROJECT_SOURCE_DIR}/bitsandbytes") endif() -set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY bitsandbytes) +set_target_properties(bitsandbytes PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/bitsandbytes") diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index b351f7f03..669246c9a 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -6,7 +6,6 @@ - Software: - CPU-only: only CPU quantization functions (no optimizer, no matrix multiply) - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) evaluation: - if paths faulty, return meaningful error @@ -86,11 +85,6 @@ def generate_instructions(self): self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') return - - has_cublaslt = is_cublasLt_compatible(self.cc) - if not has_cublaslt: - make_cmd += '_nomatmul' - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') self.add_log_entry('cd bitsandbytes') @@ -372,10 +366,6 @@ def evaluate_cuda_setup(): "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" ) - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = is_cublasLt_compatible(cc) - # TODO: # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) # (2) Multiple CUDA versions installed @@ -383,11 +373,6 @@ def evaluate_cuda_setup(): # we use ls -l instead of nvcc to determine the cuda version # since most installations will have the libcudart.so installed, but not the compiler - binary_name = f"libbitsandbytes_cuda{cuda_version_string}" - if not has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - binary_name += "_nocublaslt" - - binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}" + binary_name = f"libbitsandbytes_cuda{cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" return binary_name, cudart_path, cc, cuda_version_string diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f0de962e1..3fc8cfbed 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1960,7 +1960,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") + raise NotImplementedError("igemmlt not available (probably CC < 7.5)") if has_error: print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') diff --git a/csrc/ops.cu b/csrc/ops.cu index 796211fed..8741870ae 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -298,9 +298,6 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } - -#ifdef NO_CUBLASLT -#else template cublasLtOrder_t get_order() { switch(ORDER) @@ -332,8 +329,6 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); -#endif - template int get_leading_dim(int dim1, int dim2) { @@ -366,10 +361,33 @@ template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); +// TODO: Check overhead. Maybe not worth it; just check in Python lib once, +// and avoid calling lib functions w/o support for them. +// TODO: Address GTX 1660, any other 7.5 devices maybe not supported. +inline bool igemmlt_supported() { + int device; + int ccMajor; + + CUDA_CHECK_RETURN(cudaGetDevice(&device)); + CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMajor, cudaDevAttrComputeCapabilityMajor, device)); + + if (ccMajor >= 8) + return true; + + if (ccMajor < 7) + return false; + + int ccMinor; + CUDA_CHECK_RETURN(cudaDeviceGetAttribute(&ccMinor, cudaDevAttrComputeCapabilityMinor, device)); + + return ccMinor >= 5; +} + template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { -#ifdef NO_CUBLASLT -#else + if (!igemmlt_supported()) + return; + cublasLtOrder_t orderA = get_order(); cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -408,7 +426,6 @@ template void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -#endif } template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -422,9 +439,9 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { -#ifdef NO_CUBLASLT - return ERR_NOT_IMPLEMENTED; -#else + if (!igemmlt_supported()) + return ERR_NOT_IMPLEMENTED; + int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -479,7 +496,6 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; -#endif // NO_CUBLASLT } int fill_up_to_nearest_multiple(int value, int multiple) @@ -595,8 +611,8 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { -#ifdef NO_CUBLASLT -#else + if (!igemmlt_supported()) + return; cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -644,7 +660,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); -#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index f701f08d0..4b503f177 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -13,7 +13,7 @@ For Linux systems, make sure your hardware meets the following requirements to u | 8-bit optimizers/quantization | NVIDIA Kepler (GTX 780 or newer) | > [!WARNING] -> bitsandbytes >= 0.39.1 no longer includes Kepler binaries in pip installations. This requires manual compilation, and you should follow the general steps and use `cuda11x_nomatmul_kepler` for Kepler-targeted compilation. +> bitsandbytes >= 0.39.1 no longer includes Kepler binaries in pip installations. This requires manual compilation, and you should follow the general steps and use CUDA 11.x for Kepler-targeted compilation. To install from PyPI. diff --git a/tests/conftest.py b/tests/conftest.py index 7aee8c922..17d1fafa0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ def pytest_runtest_call(item): try: item.runtest() except NotImplementedError as nie: - if "NO_CUBLASLT" in str(nie): - pytest.skip("CUBLASLT not available") + if "CC < 7.5" in str(nie): + pytest.skip("INT8 tensor cores not available") raise except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled":