From cfd6ac75eff48e8c06b03cd8e721302a713c77a8 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:46:02 -0500 Subject: [PATCH] add deepcopy and copy for Param4bit (#1060) * fix deepcopy and copy * add tests * remove line * ruff fix * ruff * Update tests/test_linear4bit.py Co-authored-by: Aarni Koskela * add missing state * ruff format * ignore formatting commit for git blame * Params4bit should be initialized as frozen by default * add test for serialization round-tripping * add comparison capability for QuantSate * add back accidentally remove line --------- Co-authored-by: Aarni Koskela Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> --- .git-blame-ignore-revs | 3 ++ bitsandbytes/functional.py | 15 ++++++++ bitsandbytes/nn/modules.py | 43 +++++++++++++++++++-- tests/test_linear4bit.py | 77 +++++++++++++++++++++++++++++++------- 4 files changed, 121 insertions(+), 17 deletions(-) diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f7dd01bdf..c0386dc9f 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -6,3 +6,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 # Remove f-prefix from strings that don't use formatting 7727fa4c8c6c1ef2b109120aff4196a0a6bf3ed6 + +# format tests/linear_4bit.py +34735ba89de8235ea9da6ef409f814dcea9e2038 \ No newline at end of file diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9fc5e08f0..f0de962e1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -706,6 +706,21 @@ def to(self, device): self.state2.absmax = self.state2.absmax.to(device) self.state2.code = self.state2.code.to(device) + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) and + self.shape == other.shape and + torch.allclose(self.code, other.code, atol=1e-6) and + self.dtype == other.dtype and + self.blocksize == other.blocksize and + self.quant_type == other.quant_type and + (self.offset == other.offset if self.offset is not None and other.offset is not None else self.offset is other.offset) and + (self.state2 == other.state2 if self.state2 is not None and other.state2 is not None else self.state2 is other.state2) + ) + def quantize_blockwise( A: Tensor, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2b7e1f067..bd2bd5832 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Optional, TypeVar, Union, overload import warnings @@ -191,7 +192,7 @@ class Params4bit(torch.nn.Parameter): def __new__( cls, data: Optional[torch.Tensor] = None, - requires_grad=True, + requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, blocksize: int = 64, compress_statistics: bool = True, @@ -214,6 +215,37 @@ def __new__( self.module = module return self + def __getstate__(self): + state = self.__dict__ + state["data"] = self.data + state["requires_grad"] = self.requires_grad + return state + + def __setstate__(self, state): + self.requires_grad = state["requires_grad"] + self.blocksize = state["blocksize"] + self.compress_statistics = state["compress_statistics"] + self.quant_type = state["quant_type"] + self.quant_state = state["quant_state"] + self.data = state["data"] + self.quant_storage = state["quant_storage"] + self.bnb_quantized = state["bnb_quantized"] + self.module = state["module"] + + def __deepcopy__(self,memo): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + new_instance.quant_state = copy.deepcopy(state["quant_state"]) + new_instance.data = copy.deepcopy(state["data"]) + return new_instance + + def __copy__(self): + new_instance = type(self).__new__(type(self)) + state = self.__getstate__() + new_instance.__setstate__(state) + return new_instance + @classmethod def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], requires_grad: bool = False, device='cuda', **kwargs) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) @@ -227,8 +259,13 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], def _quantize(self, device): w = self.data.contiguous().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, - quant_type=self.quant_type, quant_storage=self.quant_storage) + w_4bit, quant_state = bnb.functional.quantize_4bit( + w, + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + ) self.data = w_4bit self.quant_state = quant_state if self.module is not None: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 13db28ed4..3e62bdf3b 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,4 +1,6 @@ +import copy import os +import pickle from tempfile import TemporaryDirectory import pytest @@ -8,13 +10,14 @@ from tests.helpers import TRUE_FALSE storage = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'bfloat16': torch.bfloat16, - 'float32': torch.float32 + "uint8": torch.uint8, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, } -@pytest.mark.parametrize("quant_storage", ['uint8', 'float16', 'bfloat16', 'float32']) + +@pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @@ -24,7 +27,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora device = "cuda" layer_shape = (300, 400) - linear = torch.nn.Linear(*layer_shape, dtype=original_dtype, device="cpu") # original layer + linear = torch.nn.Linear( + *layer_shape, dtype=original_dtype, device="cpu" + ) # original layer # Quantizing original layer linear_q = bnb.nn.Linear4bit( @@ -36,7 +41,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_type=quant_type, device="meta", ) - new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False) + new_weight = bnb.nn.Params4bit( + data=linear.weight, quant_type=quant_type, requires_grad=False + ) linear_q.weight = new_weight if bias: linear_q.bias = torch.nn.Parameter(linear.bias) @@ -80,7 +87,12 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora quant_storage=storage[quant_storage], device="meta", ) - linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage]) + linear_qs.weight = bnb.nn.Params4bit( + data=linear.weight, + requires_grad=False, + quant_type=quant_type, + quant_storage=storage[quant_storage], + ) if bias: linear_qs.bias = torch.nn.Parameter(linear.bias) linear_qs = linear_qs.to(device) @@ -91,7 +103,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora q0 = a.quant_state q1 = b.quant_state - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0, attr), getattr(q1, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -99,7 +111,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert c == d, f"{c} != {d}" if q0.state2 is not None: - for attr in ('code', 'dtype', 'blocksize', 'absmax'): + for attr in ("code", "dtype", "blocksize", "absmax"): c, d = getattr(q0.state2, attr), getattr(q1.state2, attr) if isinstance(c, torch.Tensor): assert torch.equal(c, d) @@ -125,7 +137,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora assert torch.equal(a, c) # Test moving to CPU and back to GPU - linear_q2.to('cpu') + linear_q2.to("cpu") linear_q2.to(device) d = linear_qs(x) assert c.dtype == d.dtype @@ -139,10 +151,47 @@ def test_linear_serialization(quant_type, compress_statistics, bias, quant_stora torch.save(linear.state_dict(), state_path) torch.save(linear_q.state_dict(), state_path_4bit) - size_orig, size_4 = os.path.getsize(state_path), os.path.getsize( - state_path_4bit + size_orig, size_4 = ( + os.path.getsize(state_path), + os.path.getsize(state_path_4bit), ) size_ratio = size_4 / size_orig - target_compression = 0.143 if original_dtype == torch.float32 else 0.29 # these numbers get lower as weight shape increases + target_compression = ( + 0.143 if original_dtype == torch.float32 else 0.29 + ) # these numbers get lower as weight shape increases ratio_error_msg = f"quantized_size {size_4:,} is larger on disk than {target_compression:.2%} of original size {size_orig:,}" assert size_ratio < target_compression, ratio_error_msg + + +def test_copy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + + shallow_copy_param = copy.copy(param) + assert param.quant_state is shallow_copy_param.quant_state + assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() + + +def test_deepcopy_param(): + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + copy_param = copy.deepcopy(param) + assert param.quant_state is not copy_param.quant_state + assert param.data.data_ptr() != copy_param.data.data_ptr() + + +def test_params4bit_real_serialization(): + original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + + original_param.cuda(0) # move to CUDA to trigger quantization + + serialized_param = pickle.dumps(original_param) + deserialized_param = pickle.loads(serialized_param) + + assert torch.equal(original_param.data, deserialized_param.data) + assert original_param.requires_grad == deserialized_param.requires_grad == False + assert original_param.quant_type == deserialized_param.quant_type + assert original_param.blocksize == deserialized_param.blocksize + assert original_param.compress_statistics == deserialized_param.compress_statistics + assert original_param.quant_state == deserialized_param.quant_state