Skip to content

Commit 9a3e631

Browse files
authored
Remove static token quantization (#487)
* remove static per token quantization Signed-off-by: Kyle Sayers <[email protected]> * earlier validation Signed-off-by: Kyle Sayers <[email protected]> * fix tests Signed-off-by: Kyle Sayers <[email protected]> * dummy commit Signed-off-by: Kyle Sayers <[email protected]> * dummy commit 2 Signed-off-by: Kyle Sayers <[email protected]> * dummy commit 3 Signed-off-by: Kyle Sayers <[email protected]> * dummy commit 4 Signed-off-by: Kyle Sayers <[email protected]> * fix test Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent ef898c4 commit 9a3e631

File tree

6 files changed

+8
-61
lines changed

6 files changed

+8
-61
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def initialize_qparams(
199199
expected_shape = (1,)
200200

201201
elif strategy == QuantizationStrategy.TOKEN:
202-
expected_shape = (1, 1)
202+
raise ValueError("Cannot perform static token quantization")
203203

204204
elif strategy == QuantizationStrategy.CHANNEL:
205205
if len(observed_shape) < 2:

src/compressed_tensors/quantization/quant_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
264264
actorder = model.actorder
265265
dynamic = model.dynamic
266266
observer = model.observer
267+
dynamic = model.dynamic
267268

268269
# infer strategy
269270
if strategy is None:
@@ -279,6 +280,12 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
279280
"strategy='group' and group_size = -1 for 'channel'"
280281
)
281282

283+
# validate token strategy
284+
if strategy == QuantizationStrategy.TOKEN and not dynamic:
285+
raise ValueError(
286+
"Cannot perform static token quantization, please use `dynamic=True`"
287+
)
288+
282289
# validate group strategy
283290
if strategy == QuantizationStrategy.GROUP:
284291
if group_size is None or group_size <= 0:

tests/conftest.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,6 @@ def _get_dim(dim: int, value: torch.Tensor):
2929
return reduce_dims
3030

3131

32-
@pytest.fixture
33-
def mock_per_token_calibration():
34-
def update_scale_zp(module: torch.nn.Module, base_name: str, value: torch.Tensor):
35-
quantization_scheme = getattr(module, "quantization_scheme", None)
36-
if not quantization_scheme:
37-
# no quantization scheme nothing to do
38-
return
39-
40-
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
41-
args = getattr(quantization_scheme, arg_name, None)
42-
43-
dim = _get_dim({0, 1}, value)
44-
min_val = torch.amin(value, dim=dim, keepdims=True)
45-
max_val = torch.amax(value, dim=dim, keepdims=True)
46-
scale, zp = calculate_qparams(min_val, max_val, args)
47-
update_parameter_data(module, scale, f"{base_name}_scale")
48-
update_parameter_data(module, zp, f"{base_name}_zero_point")
49-
50-
return update_scale_zp
51-
52-
5332
@pytest.fixture
5433
def mock_per_group_calibration():
5534
def update_scale_zp(

tests/test_quantization/lifecycle/test_initialize.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,6 @@ def test_initialize_module_for_quantization_offloaded(
176176
QuantizationArgs(strategy="block", block_structure=[2, 4]),
177177
None,
178178
),
179-
(
180-
QuantizationArgs(strategy="token"),
181-
QuantizationArgs(strategy="token"),
182-
),
183179
],
184180
)
185181
def test_initialize_quantization_parameters(weights, input_activations):
@@ -238,9 +234,6 @@ def test_initialize_quantization_parameters(weights, input_activations):
238234
# For activations or when block_structure is None
239235
expected_shape = (1,)
240236

241-
elif args.strategy == QuantizationStrategy.TOKEN:
242-
expected_shape = (1, 1)
243-
244237
if not args.dynamic:
245238
assert getattr(layer, f"{q_param_name}_scale").shape == expected_shape
246239
assert getattr(layer, f"{q_param_name}_zero_point").shape == expected_shape

tests/test_quantization/test_configs/test_strategies.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,34 +105,3 @@ def test_group(
105105
model_shape[1],
106106
int(model_shape[0] / group_size),
107107
)
108-
109-
110-
@torch.no_grad
111-
@pytest.mark.parametrize("input_symmetry", [True, False])
112-
@pytest.mark.parametrize("weight_symmetry", [True, False])
113-
@pytest.mark.parametrize("input_shape", [(32, 256), (300, 200), (400, 400)])
114-
def test_token(
115-
mock_per_channel_calibration,
116-
mock_per_token_calibration,
117-
input_symmetry,
118-
weight_symmetry,
119-
input_shape,
120-
):
121-
model = Linear(input_shape[1], 256)
122-
quant_config = create_config(
123-
input_symmetry,
124-
weight_symmetry,
125-
w_strategy=QuantizationStrategy.CHANNEL,
126-
i_strategy=QuantizationStrategy.TOKEN,
127-
)
128-
apply_quantization_config(model, quant_config)
129-
130-
inputs = torch.randn(input_shape)
131-
mock_per_channel_calibration(model, base_name="weight", value=model.weight)
132-
mock_per_token_calibration(model, base_name="input", value=inputs)
133-
134-
assert model.input_scale.shape == (1, 1)
135-
assert model.input_zero_point.shape == (1, 1)
136-
137-
assert model.weight_scale.shape == (256, 1)
138-
assert model.weight_zero_point.shape == (256, 1)

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
]
5151
),
5252
),
53-
(True, "token", torch.Size([1, 1])),
5453
],
5554
)
5655
def test_calculate_qparams(keepdims, strategy, exp_shape):

0 commit comments

Comments
 (0)