Skip to content

Commit 8d47349

Browse files
committed
fix zero points initialize
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6d0a140 commit 8d47349

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,21 @@ def create_quantization_config(bits=8, type="int", strategy="tensor"):
174174
],
175175
)
176176
def test_composability(sparsity_config, quantization_config):
177-
model_compressor = ModelCompressor(
178-
sparsity_config=sparsity_config, quantization_config=quantization_config
179-
)
177+
model_compressor = ModelCompressor(sparsity_config, quantization_config)
180178
model: DummyLinearModel = _get_fake_oneshot_sparse_quantized_model(
181-
sparsity_config=sparsity_config, quantization_config=quantization_config
179+
quantization_config,
180+
sparsity_config,
182181
)
183182
model = model.to(torch.float32)
184183

185184
# does both sparse and quantization compression
185+
original_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
186186
model_compressor.compress_model(model)
187-
compressed_state_dict = {key: value.clone() for key, value in model.state_dict()}
188-
189187
model_compressor.decompress_model(model)
190-
decompressed_state_dict = {key: value.clone() for key, value in model.state_dict()}
188+
decompressed_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
191189

192190
# check that the decompressed model is the same as the original model
193-
_check_state_dicts(compressed_state_dict, decompressed_state_dict)
191+
_check_state_dicts(original_state_dict, decompressed_state_dict)
194192

195193

196194
class TwoLayerModel(nn.Module):
@@ -250,6 +248,9 @@ def _get_fake_oneshot_sparse_quantized_model(quantization_config, sparsity_confi
250248
args=quantization_args,
251249
)
252250

251+
if quantization_args.symmetric:
252+
zero_point = None # do not include in model
253+
253254
fake_oneshot_model = DummyLinearModel(quantized_weights, scale, zero_point)
254255
fake_oneshot_model.linear.quantization_scheme = quantization_config.config_groups[
255256
"group_0"
@@ -302,7 +303,7 @@ def test_compress_model_meta(model_stub, q_format, s_config):
302303
)
303304
# Only stores dtype because meta model does not store values
304305
reference_compressor.compress_model(cpu_model)
305-
expected = {k: v.dtype for k, v in cpu_model.state_dict()}
306+
expected = {k: v.dtype for k, v in cpu_model.state_dict().items()}
306307

307308
# Load model on meta device
308309
meta_model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)