Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix]add int8 cache dtype when using attention quantization #128

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Angazenn
Copy link

@Angazenn Angazenn commented Feb 21, 2025

  1. Ascend attention requires int8 kvcache dtype. It is used in initialization of CacheConfig:
if cache_config.cache_dtype == "auto":
    self.dtype = model_config.dtype
else:
    self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

STR_DTYPE_TO_TORCH_DTYPE is defined in vllm.utils:

STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.half,
    "bfloat16": torch.bfloat16,
    "float": torch.float,
    "fp8": torch.uint8,
    "fp8_e4m3": torch.uint8,
    "fp8_e5m2": torch.uint8,
}

Hence we need to update both cache_dtype and STR_DTYPE_TO_TORCH_DTYPE.

  1. In vLLM, attention quant methods are initialized by passing self of layer. When initializing ascend attention quant method, num_heads, head_size, num_kv_heads are required. However, in the codes of Attention.init from vLLM, these member variables are set after initialization of quant method. Hence we have to move these codes ahead of quant method.
# move ahead of quant method.
self.num_heads = num_heads
self.head_size = head_size 
self.num_kv_heads = num_kv_heads

quant_method = quant_config.get_quant_method(
    self, prefix=prefix) if quant_config else None

angazenn added 3 commits February 21, 2025 10:05
# dtype string to torch.dtype. Hence we have to move these codes to here.
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
self.cache_config.cache_dtype = 'int8'
STR_DTYPE_TO_TORCH_DTYPE['int8'] = torch.int8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does vllm originally has entrypoints to resgister custom quant types into STR_DTYPE_TO_TORCH_DTYPE? If not, maybe we need to make this an issue to vllm? @wangxiyuan please check this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we can create a pr to vllm later.

Currenlty. for this kind of change(Monkey patch), please move it to the patch module.

@MengqingCao
Copy link
Contributor

CI failed due to the test path update in vLLM, will fix in #124

Signed-off-by: angazenn <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants