From 76b40a5c9ae708db98e8b4a13249b2806601a387 Mon Sep 17 00:00:00 2001 From: Ruslan Svirschevski Date: Wed, 25 Oct 2023 18:00:21 +0300 Subject: [PATCH] save/load via state_dict now --- .gitignore | 1 + bitsandbytes/functional.py | 33 +++++++++++++++++------------- bitsandbytes/nn/modules.py | 41 ++++++++++++++++++++++++++++++-------- tests/test_linear4bit.py | 21 +++++++------------ 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index f8ebf71af..2f929968b 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ dmypy.json dependencies cuda_build +.vscode/* diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d8a542999..fbf004eea 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -568,11 +568,17 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n return out class QuantState: - """container for quantizationstate components to work with Params4bit and similar clases""" + """container for quantization state components to work with Params4bit and similar clases""" + valid_quant_types = ('fp4', 'nf4') + valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = ['absmax', 'code', 'nested_absmax', 'nested_code', 'quant_state', + 'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + + def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): self.absmax = absmax self.shape = shape - self.code = code + self.code = code # TODO consider renaming to `buckets / centroids / scale` self.dtype = dtype self.blocksize = blocksize self.quant_type = quant_type @@ -596,26 +602,26 @@ def __get_item__(self, idx): @classmethod def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState': """ - unpacks dict of tensors into QuantState + unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc. - quant_state_dict may contain item with non-tensor components with key like - `...weight.quant_state.bitsandbytes__[nf4/fp4]` - it is detected with key strored in qs_key, and then unpacked + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. """ # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \ - f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}" + qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)] + if not len(qs_key) and 'quant_type' not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1: + raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items") + + # unpacking minor and non-tensor quant state items if necessary if len(qs_key) == 1: qs_key = qs_key[0] - assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \ - f"invalid qs_key value {qs_key}" qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key)) - qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes - if 'nested_absmax' in qs_dict: offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) state2 = cls( @@ -873,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) - def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index f405cee17..b04ea8d42 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -154,14 +154,25 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, return self @classmethod - def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs): - self = torch.Tensor._make_subclass(cls, data.to(device)) - self.requires_grad = requires_grad - self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) - self.blocksize = self.quant_state.blocksize - self.compress_statistics = self.quant_state.nested - self.quant_type = self.quant_state.quant_type - return self + def from_state_dict(cls, state_dict, prefix="", requires_grad=False): + data = state_dict.pop(prefix.rstrip('.')) + + # extracting components for QuantState from state_dict + qs_dict = {} + for k, v in state_dict.items(): + if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys: + qs_dict[k] = v + state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict} + qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()} + + if data.device.type != "cuda": + raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}") + + cls.requires_grad = requires_grad, + cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device) + + self = torch.Tensor._make_subclass(cls, data=data.to(data.device)) + return self, state_dict def cuda(self, device): w = self.data.contiguous().half().cuda(device) @@ -200,9 +211,11 @@ def to(self, *args, **kwargs): return new_param class Linear4bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None): super().__init__(input_features, output_features, bias, device) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False @@ -233,6 +246,18 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # Note: super()._load_from_state_dict() is not called here intentionally. + if self.bias is not None: + bias_data = state_dict.pop(prefix + "bias", None) + self.bias.data = bias_data.to(self.bias.data.device) + + self.weight, state_dict = bnb.nn.Params4bit.from_state_dict( + state_dict, prefix=prefix + "weight" + ".", requires_grad=False + ) + unexpected_keys.extend(state_dict.keys()) + def forward(self, x: torch.Tensor): # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 6fe037fc5..9f26bbeb2 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -16,7 +16,7 @@ "quant_type, compress_statistics, bias", list(product(["nf4", "fp4"], [False, True], [False, True])), ) -def test_linear4_state_dict(quant_type, compress_statistics, bias): +def test_linear_serialization(quant_type, compress_statistics, bias): original_dtype = torch.float16 compute_dtype = None device = "cuda" @@ -39,16 +39,10 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias): if bias: linear_q.bias.data = linear.bias.data.to(device) + # saving to state_dict: sd = linear_q.state_dict() - # restoring from state_dict: - - sd = linear_q.state_dict() - bias_data2 = sd.pop("bias", None) - weight_data2 = sd.pop("weight") - - weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2) - + # creating new layer with same params: linear_q2 = bnb.nn.Linear4bit( linear.in_features, linear.out_features, @@ -56,13 +50,12 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias): compute_dtype=compute_dtype, compress_statistics=compress_statistics, quant_type=quant_type, - device='meta', + device=device, # TODO create on meta device to save loading time ) - linear_q2.weight = weight2.to(device) - if bias: - linear_q2.bias = torch.nn.Parameter(bias_data2) + # loading weights from state_dict: + linear_q2.load_state_dict(sd) - # matching + # MATCHING a, b = linear_q.weight, linear_q2.weight assert a.device == b.device