Skip to content

Commit

Permalink
save/load via state_dict now
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Nov 2, 2023
1 parent 965fd5d commit 76b40a5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ dmypy.json

dependencies
cuda_build
.vscode/*
33 changes: 19 additions & 14 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,17 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return out

class QuantState:
"""container for quantizationstate components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types = ('fp4', 'nf4')
valid_qs_type_keys = [f"quant_state.bitsandbytes__{x}" for x in valid_quant_types]
valid_qs_keys = ['absmax', 'code', 'nested_absmax', 'nested_code', 'quant_state',
'quant_type', 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset']


def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None):
self.absmax = absmax
self.shape = shape
self.code = code
self.code = code # TODO consider renaming to `buckets / centroids / scale`
self.dtype = dtype
self.blocksize = blocksize
self.quant_type = quant_type
Expand All @@ -596,26 +602,26 @@ def __get_item__(self, idx):
@classmethod
def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState':
"""
unpacks dict of tensors into QuantState
unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
quant_state_dict may contain item with non-tensor components with key like
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
it is detected with key strored in qs_key, and then unpacked
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""

# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
assert len(qs_key) == 1 or not qs_key and 'quant_type' in qs_dict, \
f"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys: {tuple(qs_dict.keys())}"
qs_key = [k for k, v in qs_dict.items() if k in cls.valid_qs_type_keys and isinstance(v, torch.Tensor)]
if not len(qs_key) and 'quant_type' not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1:
raise ValueError(f"There should be exaclly one quant_state item with key from {self.valid_qs_type_keys}. Detected {len(qs_ley)} such items")

# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
qs_key = qs_key[0]
assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \
f"invalid qs_key value {qs_key}"
qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key))

qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes

if 'nested_absmax' in qs_dict:
offset = torch.tensor(float(qs_dict['nested_offset'])).to(device)
state2 = cls(
Expand Down Expand Up @@ -873,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device)



def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')

Expand Down
41 changes: 33 additions & 8 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,25 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,
return self

@classmethod
def from_prequantized(cls, data, quantized_stats, requires_grad=False, device='cuda', **kwargs):
self = torch.Tensor._make_subclass(cls, data.to(device))
self.requires_grad = requires_grad
self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device)
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
return self
def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
data = state_dict.pop(prefix.rstrip('.'))

# extracting components for QuantState from state_dict
qs_dict = {}
for k, v in state_dict.items():
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
qs_dict[k] = v
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}

if data.device.type != "cuda":
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")

cls.requires_grad = requires_grad,
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)

self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
return self, state_dict

def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
Expand Down Expand Up @@ -200,9 +211,11 @@ def to(self, *args, **kwargs):
return new_param

class Linear4bit(nn.Linear):

def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False

Expand Down Expand Up @@ -233,6 +246,18 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
if self.bias is not None:
bias_data = state_dict.pop(prefix + "bias", None)
self.bias.data = bias_data.to(self.bias.data.device)

self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
unexpected_keys.extend(state_dict.keys())

def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
Expand Down
21 changes: 7 additions & 14 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"quant_type, compress_statistics, bias",
list(product(["nf4", "fp4"], [False, True], [False, True])),
)
def test_linear4_state_dict(quant_type, compress_statistics, bias):
def test_linear_serialization(quant_type, compress_statistics, bias):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
Expand All @@ -39,30 +39,23 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
if bias:
linear_q.bias.data = linear.bias.data.to(device)

# saving to state_dict:
sd = linear_q.state_dict()

# restoring from state_dict:

sd = linear_q.state_dict()
bias_data2 = sd.pop("bias", None)
weight_data2 = sd.pop("weight")

weight2 = bnb.nn.Params4bit.from_prequantized(quantized_stats=sd, data=weight_data2)

# creating new layer with same params:
linear_q2 = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
device='meta',
device=device, # TODO create on meta device to save loading time
)
linear_q2.weight = weight2.to(device)
if bias:
linear_q2.bias = torch.nn.Parameter(bias_data2)
# loading weights from state_dict:
linear_q2.load_state_dict(sd)

# matching
# MATCHING
a, b = linear_q.weight, linear_q2.weight

assert a.device == b.device
Expand Down

0 comments on commit 76b40a5

Please sign in to comment.