Skip to content

Commit 4cc5ace

Browse files
committed
better shapes
Signed-off-by: Kyle Sayers <[email protected]>
1 parent bfa1cd1 commit 4cc5ace

File tree

2 files changed

+64
-7
lines changed

2 files changed

+64
-7
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@
4141
disable_hf_hook,
4242
get_execution_device,
4343
get_head_dim,
44+
get_num_attn_heads,
45+
get_num_kv_heads,
4446
register_offload_parameter,
4547
)
4648
from torch.nn import Module, Parameter
47-
from transformers import PretrainedConfig
4849

4950

5051
__all__ = [
@@ -292,17 +293,23 @@ def initialize_attn_qparams(
292293

293294
_validate_attention_scheme(scheme)
294295

295-
config: PretrainedConfig = getattr(kv_cache, "config")
296+
# extract shapes from config
297+
config = kv_cache.config
298+
num_attn_heads = get_num_attn_heads(config)
299+
num_kv_heads = get_num_kv_heads(config)
296300
head_dim = get_head_dim(config)
297-
observed_shape = (head_dim,) # (batch_size, num_attention_heads, slen, head_dim)
301+
302+
# (batch_size, num_heads, slen, head_dim)
303+
q_observed_shape = (num_attn_heads, None, head_dim)
304+
kv_observed_shape = (num_kv_heads, None, head_dim)
298305
observed_dtype = next(module.parameters()).dtype
299306

300307
if impl is not None:
301308
initialize_qparams(
302309
module,
303310
"q",
304311
scheme.input_activations,
305-
observed_shape=observed_shape,
312+
observed_shape=q_observed_shape,
306313
observed_dtype=observed_dtype,
307314
force_zero_point=force_zero_point,
308315
)
@@ -312,15 +319,15 @@ def initialize_attn_qparams(
312319
module,
313320
"k",
314321
scheme.input_activations,
315-
observed_shape=observed_shape,
322+
observed_shape=kv_observed_shape,
316323
observed_dtype=observed_dtype,
317324
force_zero_point=force_zero_point,
318325
)
319326
initialize_qparams(
320327
module,
321328
"v",
322329
scheme.input_activations,
323-
observed_shape=observed_shape,
330+
observed_shape=kv_observed_shape,
324331
observed_dtype=observed_dtype,
325332
force_zero_point=force_zero_point,
326333
)

src/compressed_tensors/utils/helpers.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
"unpack_bitmasks",
4646
"patch_attr",
4747
"ParameterizedDefaultDict",
48+
"get_num_attn_heads",
49+
"get_num_kv_heads",
4850
"get_head_dim",
4951
]
5052

@@ -399,12 +401,60 @@ def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any:
399401
return self[args]
400402

401403

404+
def get_num_attn_heads(config: PretrainedConfig) -> int:
405+
"""
406+
Get the number of attention heads used by a model
407+
408+
:param config: model config
409+
:return: num_attention_heads of model
410+
"""
411+
if hasattr(config, "num_attention_heads"):
412+
return config.num_attention_heads
413+
414+
elif hasattr(config, "hidden_size") and hasattr(config, "head_dim"):
415+
return config.hidden_size // config.head_dim
416+
417+
else:
418+
raise ValueError(
419+
"Cannot determine num_attention_heads from config. Config must define "
420+
"either `num_attention_heads` or both `hidden_size` and `head_dim`. "
421+
f"{config}"
422+
)
423+
424+
425+
def get_num_kv_heads(config: PretrainedConfig) -> int:
426+
"""
427+
Get the number of key-value attention heads used by a model
428+
429+
:param config: model config
430+
:return: num_key_value_heads of model
431+
"""
432+
if hasattr(config, "num_key_value_heads"):
433+
return config.num_key_value_heads
434+
435+
else:
436+
raise ValueError(
437+
"Cannot determine num_key_value_heads from config. Config must define "
438+
f"`num_key_value_heads`. {config}"
439+
)
440+
441+
402442
def get_head_dim(config: PretrainedConfig) -> int:
443+
"""
444+
Get the number of dimensions used by the attention heads of a model
445+
446+
:param config: model config
447+
:return: head_dim of model
448+
"""
403449
if hasattr(config, "head_dim"):
404450
return config.head_dim
405451

406452
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
407453
return config.hidden_size // config.num_attention_heads
408454

409455
else:
410-
raise ValueError()
456+
raise ValueError(
457+
"Cannot determine head_dim from config. Config must define "
458+
"either `head_dim` or both `hidden_size` and `num_attention_heads`. "
459+
f"{config}"
460+
)

0 commit comments

Comments
 (0)