From 3d595f138e467f6b9ce935dd58a1fa04ca93531a Mon Sep 17 00:00:00 2001
From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Date: Tue, 3 Dec 2024 16:55:25 -0500
Subject: [PATCH] test improvement

---
 bitsandbytes/functional.py |  2 +-
 tests/test_functional.py   | 12 ++++++------
 tests/test_modules.py      |  4 ++--
 3 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 674fee142..d1c5d1d2e 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -2656,7 +2656,7 @@ def double_quant(
         threshold=threshold,
     )
 
-    if threshold > 0.0:
+    if threshold > 0.0 and outlier_cols is not None:
         # Build a COO tensor including all of the outlier columns.
         outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
         outliers = A[:, outlier_cols]
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 20375a02e..c8ac20896 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -703,17 +703,16 @@ def test_coo_double_quant(dim1, dim2):
         A = torch.randn(dim1, dim2, device="cuda").half()
 
         idx = torch.abs(A) >= threshold
-        CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
+        CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
 
-        if coo_tensor is not None:
+        if outlier_cols is not None:
             A1 = A * idx
-            A2 = torch.zeros_like(A)
-            A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+            A2 = torch.zeros_like(A) + A1
             torch.testing.assert_close(A1, A2)
 
-            A1 = A * (idx == 0)
+            A[:, outlier_cols] = 0
             A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
-            torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
+            torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
 
 
 @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
@@ -728,6 +727,7 @@ def test_coo_int8_vectorwise_quant(dim1, dim2):
 
         if outlier_cols is not None:
             A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
+            A[:, outlier_cols] = 0
             torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
 
 
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 278add87f..c2583550d 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -349,8 +349,8 @@ def test_linear8bitlt_accumulated_gradient():
             l1[0].bias.data.copy_(l2[0].bias.data)
             l1[1].bias.data.copy_(l2[1].bias.data)
         else:
-            torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04)
-            torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.00, atol=0.02)
+            assert_all_approx_close(l1[0].weight.grad, l2[0].weight.grad, rtol=1.05, atol=0.04, count=1)
+            assert_all_approx_close(l1[1].weight.grad, l2[1].weight.grad, rtol=1.05, atol=0.04, count=1)
 
 
 @pytest.mark.parametrize("threshold", [0.0, 2.0])