Skip to content

Commit

Permalink
test improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Dec 3, 2024
1 parent d25ebb4 commit 3d595f1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,7 @@ def double_quant(
threshold=threshold,
)

if threshold > 0.0:
if threshold > 0.0 and outlier_cols is not None:
# Build a COO tensor including all of the outlier columns.
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
outliers = A[:, outlier_cols]
Expand Down
12 changes: 6 additions & 6 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,17 +703,16 @@ def test_coo_double_quant(dim1, dim2):
A = torch.randn(dim1, dim2, device="cuda").half()

idx = torch.abs(A) >= threshold
CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)

if coo_tensor is not None:
if outlier_cols is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
A2 = torch.zeros_like(A) + A1
torch.testing.assert_close(A1, A2)

A1 = A * (idx == 0)
A[:, outlier_cols] = 0
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)


@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
Expand All @@ -728,6 +727,7 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):

if outlier_cols is not None:
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
A[:, outlier_cols] = 0
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def test_linear8bitlt_accumulated_gradient():
l1[0].bias.data.copy_(l2[0].bias.data)
l1[1].bias.data.copy_(l2[1].bias.data)
else:
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04)
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02)
assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1)
assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)


@pytest.mark.parametrize("threshold", [0.0, 2.0])
Expand Down

0 comments on commit 3d595f1

Please sign in to comment.