Skip to content

Commit

Permalink
fix nf4 memory issue by init op_context in forward (#1349)
Browse files Browse the repository at this point in the history
* fix nf4 memory issue by init op_context in forward

* disable repack in init

* fix code style
  • Loading branch information
jiqing-feng authored Sep 13, 2024
1 parent 39097a6 commit 2784653
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
19 changes: 0 additions & 19 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,25 +370,6 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
state.absmax = torch.Tensor()
return torch.empty([1, 0], dtype=torch.uint8), state

return out.unsqueeze(0), state


Expand Down
27 changes: 23 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
if (
getattr(self.weight, "quant_state", None) is not None
and getattr(self.weight.quant_state, "op_context", None) is not None
):
context = self.weight.quant_state.op_context
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias

if getattr(self.weight, "quant_state", None) is not None:
if (
self.weight.quant_state.absmax.shape.numel() == 0
and getattr(self.weight.quant_state, "op_context", None) is not None
):
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
delattr(self.weight.quant_state, "op_context")
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
if getattr(self.weight.quant_state, "op_context", None) is not None:
context = self.weight.quant_state.op_context
destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1)
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if (
x.device.type == "cpu"
and not hasattr(self.weight.quant_state, "op_context")
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
):
enable_ipex_fusion(self.weight, self.weight.quant_state)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
Expand Down
24 changes: 24 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(weight, quant_state):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 3):
import intel_extension_for_pytorch as ipex

lowp_mode = ipex.quantization.WoqLowpMode.BF16
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
ipex.quantization.WoqWeightDtype.NF4,
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
quant_state.blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
)
quant_state.absmax = torch.Tensor()
weight.data = torch.empty([1, 0], dtype=torch.uint8)


class QuantState:
"""container for quantization state components to work with Params4bit and similar classes"""

Expand Down

0 comments on commit 2784653

Please sign in to comment.