Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ eval/
*_ckpt*/
output/
outputs/
output*/
logs*/
outs/
wandb/
tools/results/
Expand Down
138 changes: 122 additions & 16 deletions angelslim/compressor/qat/modules/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch.nn as nn
import torch.nn.functional as F

from ....utils import is_deepspeed_zero3_enabled, is_zero3_param

FP8_E4M3_QMIN = -448
FP8_E4M3_QMAX = 448

Expand Down Expand Up @@ -49,10 +51,30 @@ def _parse_bits_and_dtype(qtype_str):


class Quantizer(nn.Module):
def __init__(self, config, quant_info, x=None, is_act=False, resume=False, num_heads=-1):
def __init__(
self,
config,
quant_info,
x=None,
is_act=False,
resume=False,
num_heads=-1,
weight_shape=None,
):
super().__init__()
self.is_act = is_act
self.num_heads = num_heads
# ``weight_shape`` lets the caller pre-declare the (out_features,
# in_features) of the parent Linear so we can size weight-side
# quantizer Parameters without ever touching the (possibly ZeRO-3
# sharded) weight tensor.
self.weight_shape = (
(int(weight_shape[0]), int(weight_shape[1])) if weight_shape is not None else None
)
# Configurable initial values used when ZeRO-3 is active and we
# cannot depend on the weight data.
self.weight_scale_init_value = float(config.get("weight_scale_init_value", 1.0))
self.activation_scale_init_value = float(config.get("activation_scale_init_value", 1.0))
info = quant_info.quant_algo_info["w"]
self.group_size = quant_info.quant_algo_info.get("w_group_size", -1)
rewrite_conf = config.get("weight", {})
Expand Down Expand Up @@ -117,8 +139,21 @@ def _init_quant_params(self, x):
self.scale = self.zero_point = None
if self.resume:
self.init = True
zp = torch.empty(1) if not self.is_sym else None
self._set_quant_parameters(torch.empty(1), zp)
init_val = self.activation_scale_init_value
scale = torch.full((1,), init_val, dtype=torch.float32)
zp = torch.zeros(1, dtype=torch.float32) if not self.is_sym else None
self._set_quant_parameters(scale, zp)
return

# Weight-side path. If we cannot use ``x`` (ZeRO-3 sharded,
# meta, or simply not provided), allocate Parameters by shape
# and ``weight_scale_init_value``.
if self._needs_external_weight_init(x):
shape = self._weight_scale_shape_from_meta()
init_val = self.weight_scale_init_value
scale = torch.full(shape, init_val, dtype=torch.float32)
zp = torch.zeros(shape, dtype=torch.float32) if not self.is_sym else None
self._set_quant_parameters(scale, zp)
return

if self.is_sym:
Expand All @@ -131,6 +166,52 @@ def _init_quant_params(self, x):
)
self._set_quant_parameters(scale, zp.round())

def _needs_external_weight_init(self, x):
"""True when weight-side init must skip data-dependent computation
and instead allocate Parameters from shape + init_value.

Triggered by:
* DeepSpeed ZeRO-3 active (HF integration registered)
* ``x`` is a ZeRO-3 sharded Parameter
* ``x`` is None / on meta device / empty
"""
if is_deepspeed_zero3_enabled():
return True
if x is None:
return True
if is_zero3_param(x):
return True
if hasattr(x, "device") and x.device.type == "meta":
return True
if hasattr(x, "numel") and x.numel() == 0:
return True
return False

def _weight_2d_shape(self):
"""Resolve (out_features, in_features) for the underlying Linear.
Callers must have passed ``weight_shape`` via ``QuantLinear``."""
if self.weight_shape is not None:
return self.weight_shape
raise RuntimeError(
"Quantizer needs ``weight_shape`` to size weight scale without a "
"concrete tensor (set in QuantLinear.__init__)."
)

def _weight_scale_shape_from_meta(self):
out_dim, in_dim = self._weight_2d_shape()
if self.granularity == "per-channel":
return (out_dim, 1)
if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
if in_dim % self.group_size != 0:
raise ValueError(
f"dim 1 ({in_dim}) not divisible by group_size ({self.group_size})"
)
return (out_dim, in_dim // self.group_size)
# per-tensor and any reduce-to-scalar variant
return (1,)

def _init_lwc_params(self, x, config):
lwc_cfg = config.get("lwc", {})
if isinstance(lwc_cfg, dict):
Expand All @@ -141,11 +222,18 @@ def _init_lwc_params(self, x, config):
self.lwc_init_value = 4.0

if self.lwc:
if x.dim() != 2:
x_for_shape = x.flatten(1)
# Resolve (out_dim, in_dim) without depending on ``x`` data.
if self._needs_external_weight_init(x):
out_dim, in_dim = self._weight_2d_shape()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
x_for_shape = x
out_dim, in_dim = x_for_shape.shape
if x.dim() != 2:
x_for_shape = x.flatten(1)
else:
x_for_shape = x
out_dim, in_dim = x_for_shape.shape
device = x.device

if self.granularity == "per-group":
if not self.group_size or self.group_size <= 0:
raise ValueError("per-group quantization requires positive group_size.")
Expand All @@ -157,9 +245,7 @@ def _init_lwc_params(self, x, config):
else:
dim1 = 1

init = (
torch.ones((dim1, 1), device=x.device, dtype=torch.float32) * self.lwc_init_value
)
init = torch.ones((dim1, 1), device=device, dtype=torch.float32) * self.lwc_init_value
self.clip_factor_w_max = nn.Parameter(init.clone(), requires_grad=True)
self.clip_factor_w_min = nn.Parameter(init.clone(), requires_grad=True)
self.sigmoid = nn.Sigmoid()
Expand Down Expand Up @@ -473,7 +559,14 @@ def fake_quant(self, x):
None if self.is_sym else clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
)
scale, round_zero_point = self._expand_scale_zp(scale, round_zero_point, x)
return self._fake_quant_with_params(x, scale, round_zero_point)
out = self._fake_quant_with_params(x, scale, round_zero_point)
# Scale is kept in fp32 for numerical stability, but multiplying by
# a bf16/fp16 activation upcasts the result. Cast back to the input
# dtype so downstream F.linear / DeepSpeed autocast wrappers see a
# consistent dtype.
if out.dtype != x.dtype:
out = out.to(x.dtype)
return out

def forward(self, x: torch.Tensor):
if self.bits >= 16:
Expand Down Expand Up @@ -516,8 +609,18 @@ def __init__(
self.register_parameter("bias", org_module.bias)
self.use_weight_quant = use_weight_quant
self.use_act_quant = use_act_quant
# Under ZeRO-3 the weight Parameter ``org_module.weight`` may be a
# zero-numel shard. Pass an explicit (out, in) shape so the weight
# quantizer can size its scale Parameter from the Linear shape
# rather than inspecting the (possibly sharded) tensor.
weight_shape = (org_module.out_features, org_module.in_features)
if self.use_weight_quant:
self.weight_quantizer = Quantizer(config, quant_info, x=org_module.weight)
self.weight_quantizer = Quantizer(
config,
quant_info,
x=org_module.weight,
weight_shape=weight_shape,
)
if self.use_act_quant:
self.act_quantizer = Quantizer(config, quant_info, is_act=True, resume=resume)

Expand All @@ -531,13 +634,16 @@ def __init__(
)

def forward(self, input: torch.Tensor):
if input.shape[0] == 0:
return self.fwd_func(input, self.weight, self.bias)

weight = self.weight_quantizer(self.weight) if self.use_weight_quant else self.weight
if self.use_act_quant:
input = self.act_quantizer(input)
output = self.fwd_func(input, weight, self.bias)
# Defensive dtype alignment: upstream (DeepSpeed ZeRO-3 / HF
# autocast) may have cast ``input`` to fp16 even though we run in
# bf16. Align to the (fake-quantised) weight dtype so F.linear
# stays consistent.
output = self.fwd_func(
input.to(self.weight.dtype), weight.to(self.weight.dtype), self.bias
)
if self.use_qkv_quant:
output = self.qkv_quantizer(output)
return output
Expand Down
55 changes: 51 additions & 4 deletions angelslim/compressor/qat/plugins/learnable_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import torch
from tqdm import tqdm

from ....utils import print_info, set_op_by_name
from ....utils import (
gathered_params_if_zero3,
is_deepspeed_zero3_enabled,
print_info,
set_op_by_name,
stream_load_scales,
)
from ..modules.quantizer import QuantLinear
from .base_plugin import BasePlugin
from .plugin_manager import PluginManager
Expand All @@ -29,11 +35,22 @@

@PluginManager.plugin("learnable_scale")
class LearnableScalePlugin(BasePlugin):
def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **kwargs):
def __init__(
self,
quant_info=None,
ignore_layers=None,
resume_ckpt_dir=None,
from_ptq_ckpt_dir=None,
**kwargs,
):
super().__init__(**kwargs)
self.quant_info = quant_info
self.ignore_layers = ignore_layers
self.resume_ckpt_dir = resume_ckpt_dir
# Optional warm-start from a PTQ "real" checkpoint (only scales are
# read; base weights stay as loaded by from_pretrained). Required
# under DeepSpeed ZeRO-3.
self.from_ptq_ckpt_dir = from_ptq_ckpt_dir
self.use_weight_quant = self.config.get("use_weight_quant", False)
self.use_activation_quant = self.config.get("use_activation_quant", False)
self.fp8_attn = self.config.get("fp8_attn", False)
Expand All @@ -47,9 +64,23 @@ def __init__(self, quant_info=None, ignore_layers=None, resume_ckpt_dir=None, **
self.learn_norm = learnable_cfg.get("norm", False)

def before_train(self, **kwargs):
zero3 = is_deepspeed_zero3_enabled()
if zero3 and not self.from_ptq_ckpt_dir:
raise ValueError(
"DeepSpeed ZeRO-3 QAT requires `compression.QAT.from_ptq_ckpt` "
"to warm-start scales (lazy_init via forward is impossible "
"on sharded weights)."
)

# Retrieve KV head count from model config for per-head quantization
model_config = getattr(self.quant_model.model, "config", None)
num_kv_heads = getattr(model_config, "num_key_value_heads", -1)
# Pre-allocate ``act_quantizer.scale`` as a Parameter whenever we
# plan to fill it from a checkpoint (full resume OR PTQ warm-start
# OR ZeRO-3 — where lazy_init is impossible).
act_preallocate = (
self.resume_ckpt_dir is not None or self.from_ptq_ckpt_dir is not None or zero3
)
for name, module in self.quant_model.model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(ig in name for ig in self.ignore_layers):
Expand All @@ -67,7 +98,7 @@ def before_train(self, **kwargs):
self.quant_info,
self.use_weight_quant,
self.use_activation_quant,
resume=self.resume_ckpt_dir is not None,
resume=act_preallocate,
qkv_config=qkv_cfg,
)
set_op_by_name(self.quant_model.model, name, q_linear)
Expand All @@ -78,10 +109,18 @@ def before_train(self, **kwargs):

print_info(self.quant_model.model)

# Warm-start scales from a previous PTQ "real" checkpoint. Only
# quantizer Parameters are touched; base Linear weights are NOT
# overwritten.
if self.from_ptq_ckpt_dir is not None:
stream_load_scales(self.quant_model.model, self.from_ptq_ckpt_dir)

if (
self.use_activation_quant
and not q_linear.act_quantizer.dynamic
and self.resume_ckpt_dir is None
and not zero3
and self.from_ptq_ckpt_dir is None
):
self._lazy_init(**kwargs)

Expand Down Expand Up @@ -284,5 +323,13 @@ def _get_qkv_config_for_layer(name, quant_config):
@torch.no_grad()
def quant_inplace(model):
for _, module in model.named_modules():
if isinstance(module, QuantLinear):
if not isinstance(module, QuantLinear):
continue
# Gather the weight together with all weight_quantizer Parameters
# (scale / zero_point / optional LWC clip factors) so the
# fake-quant runs on the full materialised tensor under ZeRO-3.
params = [module.weight]
if hasattr(module, "weight_quantizer"):
params.extend(module.weight_quantizer.parameters(recurse=True))
with gathered_params_if_zero3(params, modifier_rank=None):
module.weight.data = module.weight_quantizer(module.weight.data)
Loading
Loading