Skip to content

Commit 5c62cb8

Browse files
feat: support zero-point decompression for asymmetric quantization (packed) (#463)
* feat: add zero-point decompression support for asymmetric quantization - Fix decompress_weight method in PackedQuantizationCompressor to support unpacking zero-points - Add comprehensive tests for zero-point packing/unpacking with GROUP and CHANNEL strategies - Add end-to-end integration tests for asymmetric quantization workflow - Ensure packed tensors are contiguous for safetensors compatibility Resolves issue referenced in vllm-project/llm-compressor#1704 * nit: assert zero_point exists for asymmetric strategies before unpacking * tests: rely on apply_quantization_config to init scale/zero-point; remove manual creation * tests: rename to test_packed_asym_decompression.py * tests: use in-memory decompress_model; calibrate via fixtures; std-dev similarity; cleanup temp usage * refactor: use in-memory compress/decompress methods * stylefix Signed-off-by: Brian Dellabetta <[email protected]> * style fixes Signed-off-by: Brian Dellabetta <[email protected]> * style fixes Signed-off-by: Brian Dellabetta <[email protected]> * style fixes Signed-off-by: Brian Dellabetta <[email protected]> * style fixes Signed-off-by: Brian Dellabetta <[email protected]> --------- Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent ed002e9 commit 5c62cb8

File tree

3 files changed

+271
-14
lines changed

3 files changed

+271
-14
lines changed

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,14 @@ def compress_weight(
134134
compressed_dict["weight_shape"] = weight_shape
135135
compressed_dict["weight_packed"] = packed_weight
136136

137-
# We typically don't compress zp; apart from when using the packed_compressor
138-
# and when storing group/channel zp
139137
if not quantization_args.symmetric and quantization_args.strategy in [
140138
QuantizationStrategy.GROUP.value,
141139
QuantizationStrategy.CHANNEL.value,
142140
]:
143141
packed_zp = pack_to_int32(
144142
zero_point, quantization_args.num_bits, packed_dim=0
145143
)
146-
compressed_dict["weight_zero_point"] = packed_zp
144+
compressed_dict["weight_zero_point"] = packed_zp.contiguous()
147145
return compressed_dict
148146

149147
def decompress_weight(
@@ -166,16 +164,13 @@ def decompress_weight(
166164
num_bits = quantization_args.num_bits
167165
unpacked = unpack_from_int32(weight, num_bits, original_shape)
168166

169-
# NOTE: this will fail decompression as we don't currently handle packed zp on
170-
# decompression
171167
if not quantization_args.symmetric and quantization_args.strategy in [
172168
QuantizationStrategy.GROUP.value,
173169
QuantizationStrategy.CHANNEL.value,
174170
]:
175-
raise ValueError(
176-
"Decompression of packed zero points is currently not supported"
177-
)
178-
assert zero_point is not None
171+
assert (
172+
zero_point is not None
173+
), "Asymmetric quantization requires zero-point values"
179174
original_zp_shape = (original_shape[0], scale.shape[-1])
180175
zero_point = unpack_from_int32(
181176
zero_point, num_bits, original_zp_shape, packed_dim=0

tests/test_compressors/quantized_compressors/test_pack_quant.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import math
1717
import shutil
18+
import tempfile
1819
from collections import OrderedDict
1920

2021
import pytest
@@ -170,12 +171,13 @@ def test_reload_match(tmp_path, num_bits):
170171
)
171172
save_file(compressed_state_dict, tmp_path / "model.safetensors")
172173

173-
reconstructed_dense_gen = compressor.decompress(
174-
tmp_path, names_to_scheme=quantized_modules_to_scheme
175-
)
176174
reconstructed_dense = {}
177-
for name, value in reconstructed_dense_gen:
178-
reconstructed_dense[name] = value
175+
with tempfile.TemporaryDirectory():
176+
reconstructed_dense_gen = compressor.decompress(
177+
tmp_path, names_to_scheme=quantized_modules_to_scheme
178+
)
179+
for name, value in reconstructed_dense_gen:
180+
reconstructed_dense[name] = value
179181

180182
fake_quant_dummy = fake_quantize(
181183
dense_state_dict["dummy.weight"],
@@ -473,3 +475,91 @@ def test_unpack_from_int32(num_bits, values, expected_tensor):
473475
unpacked_tensor = unpack_from_int32(values, num_bits, expected_tensor.shape)
474476
assert torch.equal(unpacked_tensor, unpacked_tensor)
475477
assert unpacked_tensor.dtype == unpacked_tensor.dtype
478+
479+
480+
@pytest.mark.parametrize(
481+
"strategy,group_size",
482+
[
483+
(QuantizationStrategy.GROUP, 128),
484+
(QuantizationStrategy.CHANNEL, None),
485+
],
486+
)
487+
def test_asymmetric_zero_point_decompression(strategy, group_size, tmp_path):
488+
"""
489+
Test that zero-point packing and unpacking works correctly for asymmetric
490+
quantization with GROUP and CHANNEL strategies.
491+
"""
492+
shape = (512, 1024)
493+
494+
if strategy == QuantizationStrategy.CHANNEL:
495+
expected_zp_shape = (shape[0], 1)
496+
elif strategy == QuantizationStrategy.GROUP:
497+
num_groups = shape[1] // group_size
498+
expected_zp_shape = (shape[0], max(num_groups, 1))
499+
500+
dense_state_dict = {
501+
"dummy.weight": torch.randn(shape),
502+
"dummy.weight_scale": torch.rand(expected_zp_shape).to(torch.float32),
503+
"dummy.weight_zero_point": torch.randint(-8, 8, expected_zp_shape).to(
504+
torch.int8
505+
),
506+
}
507+
508+
quant_config = get_dummy_quant_config(
509+
num_bits=4, strategy=strategy.value, symmetric=False, group_size=group_size
510+
)
511+
512+
compressor = PackedQuantizationCompressor(config=quant_config)
513+
quantized_modules_to_scheme = {"dummy": quant_config.config_groups["group_1"]}
514+
compressed_state_dict = compressor.compress(
515+
dense_state_dict.copy(), names_to_scheme=quantized_modules_to_scheme
516+
)
517+
518+
assert "dummy.weight_zero_point" in compressed_state_dict
519+
assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.int32
520+
521+
save_file(compressed_state_dict, tmp_path / "model.safetensors")
522+
523+
reconstructed_dense_gen = compressor.decompress(
524+
tmp_path, names_to_scheme=quantized_modules_to_scheme
525+
)
526+
reconstructed_dense = {}
527+
for name, value in reconstructed_dense_gen:
528+
reconstructed_dense[name] = value
529+
530+
assert "dummy" in reconstructed_dense
531+
assert "weight" in reconstructed_dense["dummy"]
532+
533+
assert reconstructed_dense["dummy"]["weight"].shape == shape
534+
535+
shutil.rmtree(tmp_path)
536+
537+
538+
@pytest.mark.parametrize(
539+
"num_bits,strategy",
540+
[
541+
(4, QuantizationStrategy.GROUP),
542+
(4, QuantizationStrategy.CHANNEL),
543+
(8, QuantizationStrategy.GROUP),
544+
(8, QuantizationStrategy.CHANNEL),
545+
],
546+
)
547+
def test_zero_point_pack_unpack_consistency(num_bits, strategy):
548+
"""
549+
Test that packing and unpacking zero-points preserves values correctly.
550+
"""
551+
if strategy == QuantizationStrategy.GROUP:
552+
shape = (512, 8)
553+
else:
554+
shape = (512, 1)
555+
556+
max_val = (1 << (num_bits - 1)) - 1
557+
min_val = -(1 << (num_bits - 1))
558+
original_zp = torch.randint(min_val, max_val + 1, shape).to(torch.int8)
559+
560+
packed_zp = pack_to_int32(original_zp, num_bits, packed_dim=0)
561+
562+
unpacked_zp = unpack_from_int32(packed_zp, num_bits, shape, packed_dim=0)
563+
564+
assert torch.equal(original_zp, unpacked_zp)
565+
assert unpacked_zp.dtype == torch.int8
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
End-to-end tests for asymmetric quantization with zero-point decompression.
17+
"""
18+
19+
import pytest
20+
import torch
21+
from compressed_tensors.compressors.model_compressors.model_compressor import (
22+
ModelCompressor,
23+
)
24+
from compressed_tensors.config import CompressionFormat
25+
from compressed_tensors.quantization import (
26+
QuantizationArgs,
27+
QuantizationConfig,
28+
QuantizationScheme,
29+
QuantizationStrategy,
30+
apply_quantization_config,
31+
)
32+
from torch.nn import Linear, Module
33+
34+
35+
class SimpleModel(Module):
36+
"""Simple model for testing"""
37+
38+
def __init__(self, input_dim=512, hidden_dim=256, output_dim=128):
39+
super().__init__()
40+
self.layer1 = Linear(input_dim, hidden_dim, bias=False)
41+
self.layer2 = Linear(hidden_dim, output_dim, bias=False)
42+
43+
def forward(self, x):
44+
x = self.layer1(x)
45+
x = torch.relu(x)
46+
x = self.layer2(x)
47+
return x
48+
49+
50+
def create_asymmetric_quant_config(
51+
num_bits=4, strategy=QuantizationStrategy.GROUP, group_size=128
52+
) -> QuantizationConfig:
53+
"""Create an asymmetric quantization config"""
54+
config_groups = {
55+
"group_1": QuantizationScheme(
56+
targets=["Linear"],
57+
weights=QuantizationArgs(
58+
num_bits=num_bits,
59+
strategy=strategy.value,
60+
group_size=(
61+
group_size if strategy == QuantizationStrategy.GROUP else None
62+
),
63+
symmetric=False,
64+
),
65+
),
66+
}
67+
return QuantizationConfig(config_groups=config_groups)
68+
69+
70+
@pytest.mark.parametrize(
71+
"strategy,group_size",
72+
[
73+
(QuantizationStrategy.GROUP, 128),
74+
(QuantizationStrategy.CHANNEL, None),
75+
],
76+
)
77+
def test_end_to_end_asymmetric_quantization(
78+
strategy,
79+
group_size,
80+
mock_per_group_calibration,
81+
mock_per_channel_calibration,
82+
):
83+
"""
84+
Test end-to-end workflow: quantize -> compress -> decompress in memory
85+
"""
86+
model = SimpleModel()
87+
original_weights = {
88+
"layer1": model.layer1.weight.detach().clone(),
89+
"layer2": model.layer2.weight.detach().clone(),
90+
}
91+
92+
quant_config = create_asymmetric_quant_config(
93+
num_bits=4, strategy=strategy, group_size=group_size
94+
)
95+
# Set pack-quantized format for ModelCompressor usage
96+
quant_config.format = CompressionFormat.pack_quantized.value
97+
apply_quantization_config(model, quant_config)
98+
99+
if strategy == QuantizationStrategy.GROUP:
100+
mock_per_group_calibration(
101+
model.layer1, "weight", model.layer1.weight, group_size
102+
)
103+
mock_per_group_calibration(
104+
model.layer2, "weight", model.layer2.weight, group_size
105+
)
106+
else:
107+
mock_per_channel_calibration(model.layer1, "weight", model.layer1.weight)
108+
mock_per_channel_calibration(model.layer2, "weight", model.layer2.weight)
109+
110+
# Compress and decompress in memory using ModelCompressor
111+
mc = ModelCompressor(quantization_config=quant_config)
112+
mc.compress_model(model)
113+
114+
# Verify compression created zero-point parameters
115+
assert hasattr(model.layer1, "weight_zero_point")
116+
assert hasattr(model.layer2, "weight_zero_point")
117+
assert model.layer1.weight_zero_point.dtype == torch.int32
118+
assert model.layer2.weight_zero_point.dtype == torch.int32
119+
120+
# Decompress in memory
121+
mc.decompress_model(model)
122+
123+
# Verify decompression restored weights correctly
124+
assert model.layer1.weight.shape == original_weights["layer1"].shape
125+
assert model.layer2.weight.shape == original_weights["layer2"].shape
126+
assert model.layer1.weight.dtype.is_floating_point
127+
assert model.layer2.weight.dtype.is_floating_point
128+
assert not torch.isnan(model.layer1.weight).any()
129+
assert not torch.isnan(model.layer2.weight).any()
130+
assert not torch.isinf(model.layer1.weight).any()
131+
assert not torch.isinf(model.layer2.weight).any()
132+
133+
134+
@pytest.mark.parametrize("num_bits", [4, 8])
135+
def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration):
136+
"""
137+
Test that asymmetric quantization with zero-point preserves accuracy better
138+
than symmetric quantization for biased weight distributions.
139+
"""
140+
shape = (256, 512)
141+
biased_weights = torch.randn(shape) + 2.0
142+
143+
quant_config = create_asymmetric_quant_config(
144+
num_bits=num_bits,
145+
strategy=QuantizationStrategy.GROUP,
146+
group_size=128,
147+
)
148+
quant_config.format = CompressionFormat.pack_quantized.value
149+
150+
class SingleLayer(Module):
151+
def __init__(self):
152+
super().__init__()
153+
self.layer = Linear(shape[1], shape[0], bias=False)
154+
155+
model = SingleLayer()
156+
apply_quantization_config(model, quant_config)
157+
158+
with torch.no_grad():
159+
model.layer.weight.copy_(biased_weights)
160+
mock_per_group_calibration(model.layer, "weight", model.layer.weight, 128)
161+
162+
# Compress and decompress in memory using ModelCompressor
163+
mc = ModelCompressor(quantization_config=quant_config)
164+
mc.compress_model(model)
165+
mc.decompress_model(model)
166+
167+
decompressed_weights = model.layer.weight
168+
assert decompressed_weights.shape == shape
169+
assert not torch.isnan(decompressed_weights).any()
170+
assert not torch.isinf(decompressed_weights).any()
171+
threshold = torch.std(torch.rand(shape) - torch.rand(shape))
172+
assert torch.std(biased_weights - decompressed_weights) < threshold

0 commit comments

Comments
 (0)