Skip to content

Commit

Permalink
Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 14, 2024
1 parent ca372f2 commit 510a880
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 99 deletions.
30 changes: 15 additions & 15 deletions .github/scripts/build-cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done

if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi


output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
Expand Down
14 changes: 1 addition & 13 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
Expand Down Expand Up @@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
Expand Down Expand Up @@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES})

string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS)
if(NOT APPLE)
Expand Down Expand Up @@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)

if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()

target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
Expand Down
30 changes: 16 additions & 14 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def forward(
B: torch.Tensor,
out=None,
bias: Optional[torch.Tensor] = None,
state=MatmulLtState,
state: MatmulLtState = None,
):
# state = state or MatmulLtState()
state = state or MatmulLtState()

# default of pytorch behavior if inputs are empty
ctx.is_empty = False
Expand Down Expand Up @@ -318,7 +318,7 @@ def forward(
if is_transposed:
B = B.contiguous()

if (state.is_training and not has_grad) or state.CB is None:
if (state.is_training and not has_grad) or state.SCB is None:
state.reset_grads()

# 2. Quantize B
Expand Down Expand Up @@ -347,7 +347,7 @@ def forward(
outliers = state.CB[:, state.idx].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
else:
subA = state.subB = None
subA = None

# 3. Int8 Matmul
out32, Sout32 = F.igemmlt(CA, state.CB)
Expand Down Expand Up @@ -377,7 +377,11 @@ def forward(
ctx.save_for_backward(None, None)

output_shape = (*input_shape[:-1], state.CB.shape[0])
return output.reshape(output_shape).clone()

if len(input_shape) == 3:
return output.view(output_shape).clone()
else:
return output

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -400,18 +404,16 @@ def backward(ctx, grad_output):
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
# grad_output.T @ A
# grad_weight = grad_output.t().mm(A)
grad_B = torch.matmul(grad_output.t(), A)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
# if req_gradB:
#
# gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
# grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)

# grad_B = torch.matmul(grad_output.t(), A)
# if state.threshold > 0.0 and subA is not None:
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradB:
gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
# grad_output @ B.T
Expand Down
6 changes: 1 addition & 5 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path.
"""
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class CUDASpecs:
cuda_version_tuple: Tuple[int, int]

@property
def has_cublaslt(self) -> bool:
def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5)


Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:

print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")

# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
# 7.5 is the minimum CC for int8 tensor cores
if not cuda_specs.has_imma:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
Expand Down
15 changes: 0 additions & 15 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,6 @@ int roundoff(int v, int d) {
}


#ifdef NO_CUBLASLT
#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
Expand Down Expand Up @@ -347,7 +345,6 @@ template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
#endif


template<int ORDER> int get_leading_dim(int dim1, int dim2)
Expand Down Expand Up @@ -379,8 +376,6 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)

template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
#ifdef NO_CUBLASLT
#else
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
Expand Down Expand Up @@ -419,7 +414,6 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
#endif
}

template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
Expand Down Expand Up @@ -513,9 +507,6 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(

template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
return ERR_NOT_IMPLEMENTED;
#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
Expand Down Expand Up @@ -570,7 +561,6 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
printf("error detected");

return has_error;
#endif // NO_CUBLASLT
}

int fill_up_to_nearest_multiple(int value, int multiple)
Expand Down Expand Up @@ -681,10 +671,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o

void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{

#ifdef NO_CUBLASLT
#else

cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;

Expand Down Expand Up @@ -731,7 +717,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
#endif
}

template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
Expand Down
4 changes: 0 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
def pytest_runtest_call(item):
try:
item.runtest()
except NotImplementedError as nie:
if "NO_CUBLASLT" in str(nie):
pytest.skip("CUBLASLT not available")
raise
except AssertionError as ae:
if str(ae) == "Torch not compiled with CUDA enabled":
pytest.skip("Torch not compiled with CUDA enabled")
Expand Down
20 changes: 0 additions & 20 deletions tests/test_cuda_setup_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@ def cuda120_spec() -> CUDASpecs:
)


@pytest.fixture
def cuda111_noblas_spec() -> CUDASpecs:
return CUDASpecs(
cuda_version_string="111",
highest_compute_capability=(7, 2),
cuda_version_tuple=(11, 1),
)


def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
Expand All @@ -31,14 +22,3 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "125")
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt"
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?


def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt"
8 changes: 5 additions & 3 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ def test_linear_no_igemmlt():

fx_ours = linear_custom(x_ours).float()
(fx_ours * grad_proj).mean().backward()

assert linear_custom.state.CB is not None
assert not linear_custom.state.has_fp16_weights
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None

# assert linear_custom.state.CxB is None


@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
Expand Down
14 changes: 7 additions & 7 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,16 @@ def test_linear_kbit_fp32_bias(module):
@pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys())
def test_kbit_backprop(module):
b = 16
dim1 = 32
dim2 = 48
dim1 = 36
dim2 = 56
# dim1 = 37
# dim2 = 83

ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 16)])
ref[1].weight.requires_grad = False
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 128)])
# ref[1].weight.requires_grad = False
torch.nn.init.kaiming_normal_(ref[0].weight)
torch.nn.init.kaiming_normal_(ref[1].weight)
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 16)])
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)])
kbit[0].weight.detach().copy_(ref[0].weight)
kbit[1].weight.detach().copy_(ref[1].weight)
kbit[0].bias.detach().copy_(ref[0].bias)
Expand Down Expand Up @@ -572,8 +572,8 @@ def test_kbit_backprop(module):
relerrs1.append(relerr1.mean().item())
relerrs2.append(relerr2.mean().item())

# if isinstance(module, bnb.nn.Linear8bitLt):
if module == bnb.nn.Linear8bitLt:
if isinstance(module, bnb.nn.Linear8bitLt):
# if module == bnb.nn.Linear8bitLt:
assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
Expand Down

0 comments on commit 510a880

Please sign in to comment.