-
Notifications
You must be signed in to change notification settings - Fork 269
feat: extend the num head terms for fp8 calibration #972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @blueswhen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly improves the flexibility of FP8 KV cache calibration by allowing calibration configurations with different, but proportionally related, numbers of attention heads to be used. It introduces logic to adapt calibration scales on-the-fly and relaxes validation constraints, thereby removing previous operational limitations related to Tensor Parallelism.
Highlights
- Flexible FP8 Calibration: The FP8 KV cache calibration mechanism has been enhanced to support more flexible configurations regarding the number of attention heads. Calibration scales can now be adapted (repeated or subsampled) when the configured number of heads (
cfg["num_head"]
) is a multiple or divisor of the model's total heads (self.total_head_num
). - Scale Adaptation Logic: New logic was introduced in
OfflineFP8QuantKVManager
to automatically adjust the calibration scales. This ensures that scales from a calibration config with a different, but proportionally related, number of heads can be correctly applied to the model. - Relaxed Configuration Validation: The strict equality check for
num_head
during FP8 configuration loading has been relaxed. The system now validates ifcfg["num_head"]
is a multiple or divisor ofself.total_head_num
, providing greater flexibility. - Removed Tensor Parallelism Restriction: A previous documentation note in
api_cli.py
that restricted Tensor Parallelism (TP) size to be no more than the head number during calibration has been removed, indicating that this limitation is no longer applicable due to the improved calibration flexibility. - Updated Test Configurations: Existing test configurations for both per-head and per-tensor FP8 calibration have been updated to reflect the new capability, specifically by doubling the
num_head
in the calibration JSONs and updating the corresponding scale data.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request extends the fp8 calibration to support a different number of attention heads between the calibration data and the model configuration. The changes are mostly in lightllm/common/offline_fp8_quant_mem_manager.py
to handle up-sampling and down-sampling of quantization scales. While the intent is clear and the new scaling logic seems correct, I've found a critical issue in the implementation that could lead to incorrect behavior. The conditional logic for applying scale adjustments improperly conflates the enable_fa3
flag with the quantization type, which could cause incorrect scales to be used under certain configurations. I've provided a detailed comment and a suggested fix for this.
elif cfg["num_head"] > self.total_head_num: | ||
factor = cfg["num_head"] // self.total_head_num | ||
self.scales = self.scales[..., ::factor].contiguous() | ||
elif cfg["num_head"] < self.total_head_num: | ||
factor = self.total_head_num // cfg["num_head"] | ||
self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current if/elif
structure seems to conflate the enable_fa3
flag with the quantization type, which can lead to incorrect behavior. For instance, if enable_fa3
is False
but the quantization type from the config is per_head
, the code will incorrectly execute torch.repeat_interleave
, which is likely intended only for per_tensor
quantization. This would erroneously expand the per-head scales.
elif cfg["num_head"] > self.total_head_num: | |
factor = cfg["num_head"] // self.total_head_num | |
self.scales = self.scales[..., ::factor].contiguous() | |
elif cfg["num_head"] < self.total_head_num: | |
factor = self.total_head_num // cfg["num_head"] | |
self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() | |
quant_type = cfg.get("quant_type", "per_tensor") | |
if quant_type == "per_tensor": | |
if not get_env_start_args().enable_fa3: | |
self.scales = torch.repeat_interleave(self.scales, head_num, dim=-1) | |
elif quant_type == "per_head": | |
if get_env_start_args().enable_fa3: | |
if cfg["num_head"] > self.total_head_num: | |
factor = cfg["num_head"] // self.total_head_num | |
self.scales = self.scales[..., ::factor].contiguous() | |
elif cfg["num_head"] < self.total_head_num: | |
factor = self.total_head_num // cfg["num_head"] | |
self.scales = torch.repeat_interleave(self.scales, factor, dim=-1).contiguous() |
@@ -25,7 +25,6 @@ def __init__( | |||
|
|||
self.qmax = torch.finfo(torch.float8_e4m3fn).max | |||
self.qmin = torch.finfo(torch.float8_e4m3fn).min | |||
self.layer_num = layer_num | |||
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -175,7 +175,6 @@ def make_argument_parser() -> argparse.ArgumentParser: | |||
export_fp8kv_calibration record and export kv cache quant calibration results to a json file. | |||
It can be used for llama and qwen model. | |||
Calibration need to disable cudagraph and use fa3 or flashinfer backend. | |||
Tp size must no more than head num when calibration. | |||
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.