diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..d463b651c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -83,6 +83,7 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, } KV_QUANT_CFG_CHOICES = { diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..cb0a9a192 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_svdquant) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_svdquant]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 306348f2c..73ed85e3a 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -33,6 +33,7 @@ QUANTIZATION_INT4_AWQ = "int4_awq" QUANTIZATION_W4A8_AWQ = "w4a8_awq" QUANTIZATION_NVFP4 = "nvfp4" +QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant" QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8" QUANTIZATION_MXFP4 = "mxfp4" QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8" @@ -507,12 +508,20 @@ def hidden_size(self): """Returns the hidden size of the transformer model.""" if isinstance(self.mlp, MOEConfig): # fc.weight for MOE is stacked - if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if self.mlp.fc.quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: return self.mlp.fc.weight.shape[-1] * 2 return self.mlp.fc.weight.shape[-1] else: k = self.mlp.fc.weight.shape[1] - if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if self.mlp.fc.quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: return k * 2 return k diff --git a/modelopt/torch/export/postprocess.py b/modelopt/torch/export/postprocess.py index 5c3d0fcf3..376a52a41 100644 --- a/modelopt/torch/export/postprocess.py +++ b/modelopt/torch/export/postprocess.py @@ -35,6 +35,7 @@ LINEAR_ROW, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ConvConfig, EmbeddingConfig, ExpertConfig, @@ -398,7 +399,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None): group_size=config.awq_block_size, quantization=config.quantization, ) - if config.quantization == QUANTIZATION_NVFP4_AWQ: + if config.quantization in [ + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: # We have to update weight_scaling_factor and weight_scaling_factor_2 config.weights_scaling_factor, config.weights_scaling_factor_2 = ( NVFP4QTensor.get_weights_scaling_factor( @@ -430,6 +434,7 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None): if config.quantization in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: ( config.weights_scaling_factor, diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..e0ef00156 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -25,7 +25,11 @@ import torch.nn as nn from modelopt import __version__ -from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection +from modelopt.torch.quantization.model_calib import ( + enable_stats_collection, + finish_stats_collection, + svd, +) from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear from modelopt.torch.quantization.qtensor import ( FP8QTensor, @@ -57,6 +61,7 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_W4A8_NVFP4_FP8, @@ -165,7 +170,7 @@ def resmooth_and_get_scale( ) new_weights.append(weight) # If NVFP4_AWQ then we view the scales as uint8 to allow for cat later - if quantization == QUANTIZATION_NVFP4_AWQ: + if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]: scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8) else: scale = get_scaling_factor_from_weight(weight, group_size) @@ -176,7 +181,7 @@ def resmooth_and_get_scale( return ( torch.cat(new_weights, dim=0), resmoothed_scales.view(torch.float8_e4m3fn) - if quantization == QUANTIZATION_NVFP4_AWQ + if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT] else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat new_pre_quant_scale, ) @@ -243,6 +248,7 @@ def get_activation_scaling_factor( if get_quantization_format(module) in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: return NVFP4QTensor.get_activation_scaling_factor(input_quantizer) return get_scaling_factor(input_quantizer) @@ -270,6 +276,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: @@ -303,6 +310,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") if get_quantization_format(module) in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, ]: return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: @@ -487,6 +495,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames block_sizes = getattr(weight_quantizer, "block_sizes") scale_bits = block_sizes.get("scale_bits") + if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"): + return QUANTIZATION_NVFP4_SVDQUANT if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ if getattr(layer, "fused_with_prequant", False): @@ -660,15 +670,18 @@ def process_layer_quant_config(layer_config_dict): elif v == "w4a8_nvfp4_fp8": layer_config = { "quant_algo": "W4A8_NVFP4_FP8", - "group_size": layer_config_dict[prefix + ".awq_block_size"], - "has_zero_point": False, - "pre_quant_scale": True, + "group_size": block_size_value, } elif v == "w4a8_mxfp4_fp8": layer_config = { "quant_algo": "W4A8_MXFP4_FP8", "group_size": block_size_value, } + elif v == "nvfp4_svdquant": + layer_config = { + "quant_algo": "NVFP4_SVD", + "group_size": block_size_value, + } else: layer_config = {"quant_algo": v} @@ -813,7 +826,12 @@ def to_quantized_weight( if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]: return pack_int4_in_uint8(weight, weights_scaling_factor) - if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]: + if quantization in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_W4A8_NVFP4_FP8, + QUANTIZATION_NVFP4_SVDQUANT, + ]: assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format." assert weights_scaling_factor2 is not None, ( "Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format" @@ -1008,6 +1026,40 @@ def _update_pre_quant_scale(module, new_pre_quant_scale): finish_stats_collection(module.weight_quantizer) +def _update_svdquant(modules, new_pre_quant_scale): + """Updates the pre_quant_scale, svdquant_lora_a and svdquant_lora_b matrices when pre_quant_scale is changed.""" + new_pre_quant_scale = new_pre_quant_scale.to(torch.float32) + lora_a = [m.weight_quantizer.svdquant_lora_a.to(torch.float32) for m in modules] + lora_b = [m.weight_quantizer.svdquant_lora_b.to(torch.float32) for m in modules] + weight = [m.weight.to(torch.float32) for m in modules] + old_pre_quant_scale = [m.input_quantizer._pre_quant_scale.to(torch.float32) for m in modules] + weight = [ + (w + (lb @ la)) * (s / new_pre_quant_scale) + for w, la, lb, s in zip(weight, lora_a, lora_b, old_pre_quant_scale) + ] + weight_concatenated = torch.cat(weight, dim=0) + lb, la = svd(weight_concatenated, rank=lora_a[0].shape[0]) + weight_concatenated -= lb @ la + weight_concatenated = weight_concatenated.to(modules[0].weight.dtype) + la = la.to(modules[0].weight_quantizer.svdquant_lora_a.dtype) + lb = lb.to(modules[0].weight_quantizer.svdquant_lora_b.dtype) + new_pre_quant_scale = new_pre_quant_scale.to(modules[0].input_quantizer.pre_quant_scale.dtype) + + index = 0 + for i, module in enumerate(modules): + module.input_quantizer.pre_quant_scale = new_pre_quant_scale + module.weight_quantizer.svdquant_lora_a = la + assert lora_b[i].shape[0] == module.weight.shape[0] + module.weight_quantizer.svdquant_lora_b = lb[index : index + lora_b[i].shape[0], :] + module.weight = nn.Parameter(weight_concatenated[index : index + lora_b[i].shape[0], :]) + index += lora_b[i].shape[0] + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) + + # Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) PQS_FUSE_MODULE_MAPPING = [ # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension @@ -1146,9 +1198,12 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False dim=0, ) - for module in modules: - if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale): - _update_pre_quant_scale(module, avg_prequant_scale) + if hasattr(modules[0].weight_quantizer, "svdquant_lora_a"): + _update_svdquant(modules, avg_prequant_scale) + else: + for module in modules: + if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale): + _update_pre_quant_scale(module, avg_prequant_scale) if resmooth_only: return diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ccfc01200..bed35c3ca 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -54,6 +54,7 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ) @@ -109,6 +110,10 @@ def _output_hook(module, input, output): fused_linears = {} module_names = set() + # NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because + # 1) its kernel handles pre-quant scale. + # 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in + # the later gate up fusion. # Fuse pre_quant_scale to the linear weights if possible if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): fuse_prequant_to_linear(model) @@ -117,7 +122,9 @@ def _output_hook(module, input, output): module_names.add(name) # For MoE models update pre_quant_scale to average pre_quant_scale amongst experts - if is_moe(module) and ("awq" in quantization_format): + if is_moe(module) and ( + ("awq" in quantization_format) or (quantization_format == QUANTIZATION_NVFP4_SVDQUANT) + ): # update_experts_avg_prequant_scale(module) grouped_experts = get_experts_list(module, model_type) for modules in grouped_experts: @@ -314,6 +321,7 @@ def _export_quantized_weight( if quantization_format in [ QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_NVFP4, QUANTIZATION_W4A8_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, @@ -334,7 +342,11 @@ def _export_quantized_weight( for expert_type in ["Llama4TextExperts", "GptOssExperts"] ) - if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]: + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + ]: # Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim) # for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization weight, _ = maybe_transpose_expert_weight_dimensions( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d875d9c5b..33ada0c8c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1013,6 +1013,18 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz return blocksize +def svd(weight, rank): + original_device = weight.device + original_dtype = weight.dtype + weight_f64 = weight.to(dtype=torch.float64, device=original_device) + u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) + us = u[:, :rank] * s[:rank] + vt = vt[:rank] + return us.to(device=original_device, dtype=original_dtype), vt.to( + device=original_device, dtype=original_dtype + ) + + @torch.no_grad() def svdquant( model: nn.Module, @@ -1034,25 +1046,16 @@ def svdquant( def postprocess(module, name): print_rank_0(f"SVD {name}") weight = module.weight.data - original_device = weight.device - original_dtype = weight.dtype - weight_f64 = weight.to(dtype=torch.float64, device=original_device) - u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) - if u.shape[1] < lowrank or vt.shape[0] < lowrank: + us, vt = svd(weight, lowrank) + if us.shape[1] < lowrank or vt.shape[0] < lowrank: warnings.warn( "The low-rank dimensions do not match the layer dimensions. " "Please verify your configuration and model settings. " f"SVD will be skipped for this layer {name}." ) return - us = u[:, :lowrank] * s[:lowrank] - vt = vt[:lowrank] - module.weight_quantizer.svdquant_lora_a = vt.to( - dtype=original_dtype, device=original_device - ) - module.weight_quantizer.svdquant_lora_b = us.to( - dtype=original_dtype, device=original_device - ) + module.weight_quantizer.svdquant_lora_a = vt + module.weight_quantizer.svdquant_lora_b = us module.weight.data.sub_( module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a )