Skip to content

[Bug]: KV cache quantization failure in calibrate_kv_cache_output_hook k_scale = kv_cache.k_scales[module.layer_idx]ย #1881

@mratsim

Description

@mratsim

โš™๏ธ Your current environment

The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-6.16.8-arch3-1-x86_64-with-glibc2.42`
Python Version: `3.12.9 (main, Mar 17 2025, 21:01:58) [Clang 20.1.0 ]`
llm-compressor Version: `0.7.2a20250825`
compressed-tensors Version: `0.11.1a20250912`
transformers Version: `4.56.2`
torch Version: `2.8.0+cu129`
CUDA Devices: `['NVIDIA RTX PRO 6000 Blackwell Workstation Edition', 'NVIDIA RTX PRO 6000 Blackwell Workstation Edition']`
AMD Devices: `None`

๐Ÿ› Describe the bug

I am trying to quantized the KV cache of ByteDance-Seed/Seed-OSS-36B-Instruct which can have up to a 512K context (128GB in FP16).

However I get the following error

Image

Full text error

Traceback (most recent call last):
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/pipelines/sequential/helpers.py", line 73, in forward
    outputs = forward_fn(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 19, in forward
  File "[...]/.venv/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/transformers/models/seed_oss/modeling_seed_oss.py", line 259, in forward
    hidden_states, _ = self.self_attn(
                       ^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
           ^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1840, in inner
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/modifiers/utils/hooks.py", line 93, in wrapped_hook
    return hook(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.12/site-packages/llmcompressor/modifiers/quantization/calibration.py", line 260, in calibrate_kv_cache_output_hook
    k_scale = kv_cache.k_scales[module.layer_idx]
              ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
IndexError: list index out of range

Source:

def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
"""
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
"""
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]
update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale)
update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale)

The error is quite similar to #1295.

This is the quantization script I'm using

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationScheme,
    QuantizationStrategy,
    QuantizationType,
)

CALIBRATION_DATASET="HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT="train_sft"
SHUFFLE_SEED=42
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 4096

MODEL_ID = "ByteDance-Seed/Seed-OSS-36B-Instruct"
MODEL_OUT = MODEL_ID.split("/")[1] + "-FP8-KV"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model.generation_config.do_sample=True

# Dataset processing
ds = load_dataset(CALIBRATION_DATASET, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)

def process_and_tokenize(example):
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return tokenizer(text, padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)

recipe = [
    QuantizationModifier(
        ignore=["lm_head"],
        # DeepSeek V3 style block quantization + dynamic per token quantization
        config_groups={
            "group_0": QuantizationScheme(
                targets=["Linear"],
                weights=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    dynamic=False,
                    symmetric=True,
                    strategy=QuantizationStrategy.BLOCK,
                    block_structure=[128, 128],
                ),
                input_activations=QuantizationArgs(
                    num_bits=8,
                    type=QuantizationType.FLOAT,
                    strategy=QuantizationStrategy.GROUP,
                    symmetric=True,
                    dynamic=True,
                    observer=None,
                    group_size=128,
                ),
            ),
        },
        kv_cache_scheme=QuantizationArgs(
            num_bits=8,
            type=QuantizationType.FLOAT,
            dynamic=False,
            symmetric=True,
            strategy=QuantizationStrategy.TENSOR,
        ),
    )
]

oneshot(
    # pipeline="basic",
    model=model,
    recipe=recipe,
    dataset=ds,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Save to disk in compressed-tensors format.
model.save_pretrained(MODEL_OUT, save_compressed=True)
tokenizer.save_pretrained(MODEL_OUT)
print(f'SUCCESS: files saved in {MODEL_OUT}')

# Testing
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
    model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

And my pyproject.toml for uv run

[project]
name = "quantizers"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
    "compressed-tensors>=0.11.0",
    "llmcompressor >= 0.7.1",
    "torch >= 2.8.0",
    "transformers >= 4.56.2",
    "zstandard>=0.23.0",
    "mistral-common >= 1.6.2",
    "huggingface_hub",
    "accelerate",
    "fla-core",
    "flash-linear-attention",
    "causal-conv1d",
]

[tool.uv.sources]
torch = [{ index = "pytorch-cu129"}]
torchaudio = [{ index = "pytorch-cu129"}]
torchvision = [{ index = "pytorch-cu129"}]

# uv pip install -U --prerelease allow torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129
[[tool.uv.index]]
name = "pytorch-cu129"
url = "https://download.pytorch.org/whl/cu129"
explicit = true

๐Ÿ› ๏ธ Steps to reproduce

No response

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions