diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index bb6a04892..f915223ca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64): if data is None: raise NotImplementedError(f"Typename {typename} not supported") - data = Tensor(data) - data /= data.abs().max() + data = torch.tensor(data, device=device) + data.div_(data.abs().max()) + assert data.numel() == 16 - return data.to(device) + return data def quantize_fp4( diff --git a/csrc/ops.cu b/csrc/ops.cu index 796211fed..3a6ffdda8 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -58,7 +58,7 @@ template void quantizeBlockwise(floa num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) diff --git a/install_cuda.py b/install_cuda.py index a5d09356d..cf7c8ee71 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -77,9 +77,7 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print( - "Usage: python install_cuda.py [user/system] [download_path]" - ) + print("Usage: python install_cuda.py [user/system] [download_path]") sys.exit(1) version = sys.argv[1] @@ -100,9 +98,7 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print( - f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" - ) + print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1) diff --git a/tests/test_functional.py b/tests/test_functional.py index b9f1a6ead..1cca04511 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1928,7 +1928,9 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) -def test_fp4_quant(dtype): +@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) +@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) +def test_4bit_quant(dtype, quant_type, blocksize): vals = list(product([0, 1], repeat=4)) code = {} @@ -1953,8 +1955,8 @@ def test_fp4_quant(dtype): code[idx] = result A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) - qa, SA = F.quantize_fp4(A1, blocksize=64) - A2 = F.dequantize_fp4(qa, SA) + qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -1962,8 +1964,24 @@ def test_fp4_quant(dtype): err = err.mean() assert A2.dtype == dtype - assert err.item() < 0.1 - assert relerr.item() < 0.28 + + # With larger block sizes, we can expect this to blow up. + # At blocksize>=1024, don't even bother looking at relerr. + if blocksize <= 64: + assert err.item() < 0.1 + assert relerr.item() < 0.28 + elif blocksize <= 256: + assert err.item() < 0.11 + assert relerr.item() < 0.30 + elif blocksize <= 512: + assert err.item() < 0.12 + assert relerr.item() < 0.31 + elif quant_type == "fp4": + # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 + assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 + else: + # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 + assert err.item() < math.log2(blocksize) * 8e-2 @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])