Skip to content

[Bug]: k_scale and v_scale is zero after kv cache fp8 quantizationย #1928

@Sekri0

Description

@Sekri0

โš™๏ธ Your current environment

The output of python collect_env.py
### Environment Information ###
Operating System: `Linux-5.10.134-14.zncgsl6.x86_64-x86_64-with-glibc2.35`
Python Version: `3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:24:10) [GCC 9.4.0]`
llm-compressor Version: `0.8.1`
compressed-tensors Version: `0.12.2`
transformers Version: `4.56.2`
torch Version: `2.8.0`
CUDA Devices: `['NVIDIA H20', 'NVIDIA H20', 'NVIDIA H20', 'NVIDIA H20']`
AMD Devices: `None`

๐Ÿ› Describe the bug

I use the following code to do kv cache fp8 quantization, and wish to use the kv-fp8 model in vllm. But found k_scale and v_scale are all ZERO.

MY CODE

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot

MODEL_ID = "/mnt/home/model/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
MAX_SEQUENCE_LENGTH = 2048

ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def process_and_tokenize(example):
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)

recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))

SAVE_DIR = MODEL_ID.split("/")[-1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

pic of kv scale is zero

Image

๐Ÿ› ๏ธ Steps to reproduce

No response

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions