From 3d595f138e467f6b9ce935dd58a1fa04ca93531a Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:55:25 -0500 Subject: [PATCH] test improvement --- bitsandbytes/functional.py | 2 +- tests/test_functional.py | 12 ++++++------ tests/test_modules.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 674fee142..d1c5d1d2e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2656,7 +2656,7 @@ def double_quant( threshold=threshold, ) - if threshold > 0.0: + if threshold > 0.0 and outlier_cols is not None: # Build a COO tensor including all of the outlier columns. outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32) outliers = A[:, outlier_cols] diff --git a/tests/test_functional.py b/tests/test_functional.py index 20375a02e..c8ac20896 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -703,17 +703,16 @@ def test_coo_double_quant(dim1, dim2): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - if coo_tensor is not None: + if outlier_cols is not None: A1 = A * idx - A2 = torch.zeros_like(A) - A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + A2 = torch.zeros_like(A) + A1 torch.testing.assert_close(A1, A2) - A1 = A * (idx == 0) + A[:, outlier_cols] = 0 A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @@ -728,6 +727,7 @@ def test_coo_int8_vectorwise_quant(dim1, dim2): if outlier_cols is not None: A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) diff --git a/tests/test_modules.py b/tests/test_modules.py index 278add87f..c2583550d 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -349,8 +349,8 @@ def test_linear8bitlt_accumulated_gradient(): l1[0].bias.data.copy_(l2[0].bias.data) l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04) - torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02) + assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1) + assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1) @pytest.mark.parametrize("threshold", [0.0, 2.0])