diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f1bd67327..091a5b875 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -50,6 +50,7 @@ from .plugins.mcore_custom import CustomModuleMapping, save_safetensors from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( + _prefix_wildcard_summarize_exclude_modules, get_activation_scaling_factor, get_kv_cache_dtype, get_quantization_format, @@ -320,6 +321,9 @@ def save_pretrained( pass if is_last_stage_main_rank and quantization is not None: + # Dynamically detect unquantized modules from state_dict + exclude_modules = self._get_unquantized_layers(state_dict) + hf_quant_config = { "producer": { "name": "modelopt", @@ -328,7 +332,7 @@ def save_pretrained( "quantization": { "quant_algo": quantization, "kv_cache_quant_algo": kv_cache_quantization, - "exclude_modules": ["lm_head"], + "exclude_modules": exclude_modules, }, } with open(save_directory + "/hf_quant_config.json", "w") as f: @@ -552,6 +556,62 @@ def _get_quantized_state( def _get_quantization_format(self, module: torch.nn.Module): return get_quantization_format(module) + def _get_unquantized_layers(self, state_dict: dict[str, torch.Tensor]) -> list[str]: + """Detect unquantized modules from state_dict by checking for weight_scale. + + A module is considered quantized if it has a corresponding weight_scale key. + Only considers linear layer modules (proj, gate, experts, etc.). + Excludes modules that are never quantized by default (embed_tokens, layernorms, lm_head). + + Args: + state_dict: The exported state dict with HF-style keys. + + Returns: + List of unquantized module patterns like ["model.layers.0.*", "model.layers.47.*"]. + """ + # Linear layer patterns that can be quantized + quantizable_patterns = ( + "_proj", # q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj + "experts", # MoE experts + "lm_head", + ) + + # Patterns assumed to be always excluded (never quantized by default) + never_quantized_patterns = ( + "embed_tokens", + "layernorm", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ) + + # Find all linear modules that have weight_scale (quantized) + quantized_modules = set() + for key in state_dict: + if key.endswith(".weight_scale"): + module_name = key[:-13] # len(".weight_scale") = 13 + quantized_modules.add(module_name) + + # Find all linear modules that have weights and could be quantized + potentially_quantizable = set() + for key in state_dict: + if key.endswith(".weight"): + module_name = key[:-7] # len(".weight") = 7 + # Skip modules that are never quantized + if any(p in module_name for p in never_quantized_patterns): + continue + # Only consider quantizable linear layers + if any(p in module_name for p in quantizable_patterns): + potentially_quantizable.add(module_name) + + # Find unquantized linear modules (have weight but no weight_scale) + unquantized_modules = potentially_quantizable - quantized_modules + + # Use the prefix wildcard summarization to create compact exclude patterns + return sorted( + _prefix_wildcard_summarize_exclude_modules(unquantized_modules, quantized_modules) + ) + def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str): weight_scale = quantized_state.pop("weight_scale", None) weight_scale_2 = quantized_state.pop("weight_scale_2", None)