Skip to content

Commit

Permalink
Removed rowscale (segfaults on ampere).
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Jul 26, 2022
1 parent 8b1fd32 commit 1e88edd
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta

# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
Expand Down
13 changes: 4 additions & 9 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out

def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
dimsA = len(shapeA)
Expand All @@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')

if row_scale is not None: assert row_scale.numel() == out.shape[0]
assert dimsB != 3, 'len(B.shape)==3 not supported'
assert A.device.type == 'cuda'
assert B.device.type == 'cuda'
Expand All @@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
ptrRowScale = get_ptr(row_scale)

k = shapeA[-1]
lda = ct.c_int32(m*32)
Expand All @@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
k = ct.c_int32(k)

has_error = 0
ptrRowScale = get_ptr(None)
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif row_scale is None:
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == 'col_ampere':
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif row_scale is None:
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)

if has_error == 1:
raise Exception('cublasLt ran into an error!')
Expand Down
2 changes: 2 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ def test_integrated_igemmlt(dim1, dim4, inner):
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
Expand Down Expand Up @@ -1064,6 +1065,7 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
Expand Down

0 comments on commit 1e88edd

Please sign in to comment.