Skip to content

Commit 4bc92ef

Browse files
committed
fix kv cache serialization, add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0674268 commit 4bc92ef

File tree

4 files changed

+82
-6
lines changed

4 files changed

+82
-6
lines changed

src/compressed_tensors/quantization/quant_config.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def from_pretrained(
171171
:param model: model to calculate quantization scheme of
172172
:return: filled out QuantizationScheme for the input model
173173
"""
174+
from compressed_tensors.modeling import IMPL_ATTR, KV_CACHE_ATTR
174175
from compressed_tensors.quantization.lifecycle.initialize import (
175176
is_attention_module,
176177
)
@@ -196,24 +197,35 @@ def from_pretrained(
196197
for name, submodule in model.named_modules():
197198
layer_type: str = module_type(submodule)
198199

199-
if is_module_quantized(submodule):
200+
# add config group if quantized non-attention or attention quant
201+
has_config_group = is_module_quantized(submodule) and (
202+
not is_attention_module(submodule) or hasattr(submodule, IMPL_ATTR)
203+
)
204+
# only add kvcache if quant attention (which always implies kvcache)
205+
has_kv_cache = is_module_quantized(submodule) and is_attention_module(
206+
submodule
207+
)
208+
209+
if has_config_group:
200210
# add to running set of schemes/layer_type_names
201211
model_status = getattr(submodule, "quantization_status", model_status)
202212
quantization_type_names.add(layer_type)
203213
if submodule.quantization_scheme not in quantization_schemes:
204214
quantization_schemes.append(submodule.quantization_scheme)
205215

206-
# attention quantization implies kv cache quantization
207-
if is_attention_module(submodule):
208-
kv_cache_scheme = submodule.quantization_scheme.input_activations
216+
if has_kv_cache:
217+
model_status = getattr(submodule, "quantization_status", model_status)
218+
kv_cache_scheme = submodule.quantization_scheme.input_activations
209219

210-
else:
220+
if not has_config_group:
211221
# add non-quantized layers to the ignore list
212222
if layer_type not in ignore:
213223
ignore[layer_type] = []
214224
ignore[layer_type].append(name)
215225

216-
if len(quantization_schemes) == 0: # No quantized layers
226+
if (
227+
len(quantization_schemes) == 0 and kv_cache_scheme is None
228+
): # No quantized layers
217229
return None
218230

219231
# create ignore list, only include layers whose class has ever been targeted

tests/test_modeling/test_attention_and_cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def test_attention_cache():
5656
k_called = [False for _ in range(len(layers))]
5757
v_called = [False for _ in range(len(layers))]
5858

59+
# apply attention quantization
5960
_apply_attention(model, layers, q_called, k_called, v_called)
61+
62+
# check attention quantization
6063
outputs = model(**inputs)
6164
assert torch.equal(outputs.logits, true_outputs.logits)
6265
assert all(q_called) and all(k_called) and all(v_called)

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,64 @@ def test_apply_quantization_config_tinyllama():
131131
)
132132

133133

134+
@pytest.mark.parametrize(
135+
"config",
136+
[
137+
QuantizationConfig(
138+
config_groups={
139+
"linear": QuantizationScheme(
140+
targets=["Linear"],
141+
input_activations=QuantizationArgs(
142+
num_bits=8, type="float", strategy="tensor"
143+
),
144+
)
145+
}
146+
),
147+
QuantizationConfig(
148+
config_groups={
149+
"linear": QuantizationScheme(
150+
targets=["Linear"],
151+
input_activations=QuantizationArgs(
152+
num_bits=8, type="float", strategy="tensor"
153+
),
154+
)
155+
},
156+
ignore=[
157+
"model.layers.0.self_attn.q_proj",
158+
"model.layers.1.self_attn.k_proj",
159+
"model.layers.2.self_attn.v_proj",
160+
],
161+
),
162+
QuantizationConfig(
163+
config_groups={},
164+
kv_cache_scheme=QuantizationArgs(
165+
num_bits=8, type="float", strategy="tensor"
166+
),
167+
),
168+
QuantizationConfig(
169+
config_groups={
170+
"attention": QuantizationScheme(
171+
targets=["LlamaAttention"],
172+
input_activations=QuantizationArgs(
173+
num_bits=8, type="float", strategy="tensor"
174+
),
175+
)
176+
},
177+
kv_cache_scheme=QuantizationArgs(
178+
num_bits=8, type="float", strategy="tensor"
179+
),
180+
),
181+
],
182+
)
183+
def test_from_pretrained(config: QuantizationConfig):
184+
model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M")
185+
apply_quantization_config(model, config)
186+
_config = QuantizationConfig.from_pretrained(model)
187+
assert list(_config.config_groups.values()) == list(config.config_groups.values())
188+
assert _config.kv_cache_scheme == config.kv_cache_scheme
189+
assert _config.ignore == config.ignore
190+
191+
134192
def test_serialize_config_tinyllama():
135193
quant_config = get_sample_tinyllama_quant_config()
136194
model = get_tinyllama_model()

tests/test_quantization/test_quant_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from compressed_tensors.quantization import (
1717
DEFAULT_QUANTIZATION_FORMAT,
1818
DEFAULT_QUANTIZATION_METHOD,
19+
QuantizationArgs,
1920
QuantizationConfig,
2021
QuantizationScheme,
2122
QuantizationStatus,
23+
apply_quantization_config,
2224
)
2325
from pydantic import ValidationError
26+
from transformers import AutoModelForCausalLM
2427

2528

2629
def test_basic_config():

0 commit comments

Comments
 (0)