Skip to content

Commit

Permalink
Merge pull request #503 from TimDettmers/efficient_8bit_serialize
Browse files Browse the repository at this point in the history
Make 8-bit serialization more memory-efficient (v2)
  • Loading branch information
TimDettmers authored Jun 19, 2023
2 parents ac5550a + b599fdb commit 2d321a7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 36 deletions.
25 changes: 14 additions & 11 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return True


def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)


def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)

@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -267,20 +280,10 @@ def reset_grads(self):
self.SBt = None
self.CBt = None

def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)

@property
def tile_indices(self):
if self._tile_indices is None:
device = self.CxB.device
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
return self._tile_indices


Expand Down
64 changes: 39 additions & 25 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import bitsandbytes as bnb
import bitsandbytes.functional
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer, find_outlier_dims

Expand Down Expand Up @@ -306,6 +306,17 @@ def to(self, *args, **kwargs):
return new_param


def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight = state_dict.get(f"{prefix}weight")
if weight is None:
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)


class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
Expand All @@ -322,52 +333,55 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
self.state.use_pool = True

self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)

def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
# reorder weight layout back from ampere/turing to row
reorder_layout = True
weight_clone = self.weight.data.clone()
else:
reorder_layout = False
super()._save_to_state_dict(destination, prefix, keep_vars)

try:
if reorder_layout:
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
scb_name = "SCB"

super()._save_to_state_dict(destination, prefix, keep_vars)
# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None

# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name = "SCB"
key_name = prefix + f"{scb_name}"
format_name = prefix + "weight_format"

# case 1: .cuda was called, SCB is in self.weight
param_from_weight = getattr(self.weight, weight_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, weight_name)

key_name = prefix + f"{weight_name}"
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
elif not self.state.has_fp16_weights and param_from_state is not None:
destination[format_name] = "row"
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
finally:
if reorder_layout:
self.weight.data = weight_clone
destination[format_name] = "row"
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
for key in unexpected_keys:
unexpected_copy = list(unexpected_keys)

for key in unexpected_copy:
input_name = key[len(prefix):]
if input_name == "SCB":
if self.weight.SCB is None:
# buffers not yet initialized, can't call them directly without
# buffers not yet initialized, can't access them directly without quantizing first
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()")

input_param = state_dict[key]
self.weight.SCB.copy_(input_param)

if self.state.SCB is not None:
self.state.SCB = self.weight.SCB

unexpected_keys.remove(key)

def init_8bit_state(self):
Expand Down

0 comments on commit 2d321a7

Please sign in to comment.