Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 61 additions & 1 deletion modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .plugins.mcore_custom import CustomModuleMapping, save_safetensors
from .plugins.megatron_importer import GPTModelImporter
from .quant_utils import (
_prefix_wildcard_summarize_exclude_modules,
get_activation_scaling_factor,
get_kv_cache_dtype,
get_quantization_format,
Expand Down Expand Up @@ -320,6 +321,9 @@ def save_pretrained(
pass

if is_last_stage_main_rank and quantization is not None:
# Dynamically detect unquantized modules from state_dict
exclude_modules = self._get_unquantized_layers(state_dict)

hf_quant_config = {
"producer": {
"name": "modelopt",
Expand All @@ -328,7 +332,7 @@ def save_pretrained(
"quantization": {
"quant_algo": quantization,
"kv_cache_quant_algo": kv_cache_quantization,
"exclude_modules": ["lm_head"],
"exclude_modules": exclude_modules,
},
}
with open(save_directory + "/hf_quant_config.json", "w") as f:
Expand Down Expand Up @@ -552,6 +556,62 @@ def _get_quantized_state(
def _get_quantization_format(self, module: torch.nn.Module):
return get_quantization_format(module)

def _get_unquantized_layers(self, state_dict: dict[str, torch.Tensor]) -> list[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit hacky though and we have to maintain the list. @ChenhanYu WDYT?

"""Detect unquantized modules from state_dict by checking for weight_scale.

A module is considered quantized if it has a corresponding weight_scale key.
Only considers linear layer modules (proj, gate, experts, etc.).
Excludes modules that are never quantized by default (embed_tokens, layernorms, lm_head).

Args:
state_dict: The exported state dict with HF-style keys.

Returns:
List of unquantized module patterns like ["model.layers.0.*", "model.layers.47.*"].
"""
# Linear layer patterns that can be quantized
quantizable_patterns = (
"_proj", # q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
"experts", # MoE experts
"lm_head",
)

# Patterns assumed to be always excluded (never quantized by default)
never_quantized_patterns = (
"embed_tokens",
"layernorm",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
)

# Find all linear modules that have weight_scale (quantized)
quantized_modules = set()
for key in state_dict:
if key.endswith(".weight_scale"):
module_name = key[:-13] # len(".weight_scale") = 13
quantized_modules.add(module_name)

# Find all linear modules that have weights and could be quantized
potentially_quantizable = set()
for key in state_dict:
if key.endswith(".weight"):
module_name = key[:-7] # len(".weight") = 7
# Skip modules that are never quantized
if any(p in module_name for p in never_quantized_patterns):
continue
# Only consider quantizable linear layers
if any(p in module_name for p in quantizable_patterns):
potentially_quantizable.add(module_name)

# Find unquantized linear modules (have weight but no weight_scale)
unquantized_modules = potentially_quantizable - quantized_modules

# Use the prefix wildcard summarization to create compact exclude patterns
return sorted(
_prefix_wildcard_summarize_exclude_modules(unquantized_modules, quantized_modules)
)

def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str):
weight_scale = quantized_state.pop("weight_scale", None)
weight_scale_2 = quantized_state.pop("weight_scale_2", None)
Expand Down