Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Mar 6, 2025
1 parent 9435d81 commit a7bf319
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
23 changes: 23 additions & 0 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,31 @@ def update_weight_zp_scale(module: Module):

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it

transform_data = getattr(module, "transform_data", None)
if transform_data is not None:
# order that the transforms were added to match the order they should be applied
untransformed_weight = module.weight.data.clone()
for transform_name, transform_values in transform_data.data.items():
transform = getattr(module, transform_name)
apply = transform_values.get("apply")
call_args = transform_values.get("call_args")
if call_args:
transformed_weight = apply(
input_tensor=module.weight, transform=transform, **call_args
)
else:
transformed_weight = apply(
input_tensor=module.weight, transform=transform
)
module.weight.data.copy_(transformed_weight)

call_observer(module=module, base_name="weight")

# TODO: what do we do here?
if transform_data is not None:
module.weight.data.copy_(untransformed_weight)


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
is_preset_scheme,
preset_name_to_scheme,
)
from compressed_tensors.transforms.transform_config import TransformationConfig
from loguru import logger
from pydantic import Field, field_validator
from torch.nn import Module
Expand Down Expand Up @@ -74,6 +75,7 @@ class QuantizationModifier(Modifier):
"""

config_groups: Optional[Dict[str, QuantizationScheme]] = None
transforms_config: Optional[TransformationConfig] = None
ignore: List[str] = Field(default_factory=list)
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
scheme: Optional[Union[str, Dict[str, Any]]] = None
Expand Down Expand Up @@ -210,7 +212,9 @@ def _check_calibration_data(self, config: QuantizationConfig):
def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
# Add step to attach kv_cache to the model, if present within the config
apply_quantization_config(model, modifier_as_config)
apply_quantization_config(
model, modifier_as_config, transforms_config=self.transforms_config
)
model.apply(set_unset_kv_cache)
return modifier_as_config

Expand Down

0 comments on commit a7bf319

Please sign in to comment.