Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down Expand Up @@ -506,6 +507,10 @@ def export_quantized(
or args.sparsity_fmt != "dense"
or "int8_sq" in args.qformat
):
if (
args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
) and args.qformat == "nvfp4_svdquant":
raise NotImplementedError("Svdquant does not support mulitple GPUs yet.")
Comment on lines +510 to +513
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in error message.

"mulitple" should be "multiple".

✏️ Proposed fix
             if (
                 args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
             ) and args.qformat == "nvfp4_svdquant":
-                raise NotImplementedError("Svdquant does not support mulitple GPUs yet.")
+                raise NotImplementedError("Svdquant does not support multiple GPUs yet.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (
args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
) and args.qformat == "nvfp4_svdquant":
raise NotImplementedError("Svdquant does not support mulitple GPUs yet.")
if (
args.inference_tensor_parallel != 1 or args.inference_pipeline_parallel != 1
) and args.qformat == "nvfp4_svdquant":
raise NotImplementedError("Svdquant does not support multiple GPUs yet.")
🤖 Prompt for AI Agents
In `@examples/llm_ptq/hf_ptq.py` around lines 510 - 513, Update the
NotImplementedError message string in the conditional that checks
args.inference_tensor_parallel, args.inference_pipeline_parallel and
args.qformat ("nvfp4_svdquant") to fix the typo: change "mulitple" to "multiple"
so the raised message reads "Svdquant does not support multiple GPUs yet." Refer
to the conditional using args.inference_tensor_parallel,
args.inference_pipeline_parallel, args.qformat and the raised
NotImplementedError to locate the change.

warnings.warn(
"Still exporting TensorRT-LLM checkpoints for models not supported by the TensorRT-LLM torch runtime."
)
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_svdquant) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_svdquant]" >&2
exit 1
;;
esac
Expand Down
13 changes: 11 additions & 2 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
QUANTIZATION_INT4_AWQ = "int4_awq"
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
QUANTIZATION_NVFP4 = "nvfp4"
QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant"
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
Expand Down Expand Up @@ -507,12 +508,20 @@ def hidden_size(self):
"""Returns the hidden size of the transformer model."""
if isinstance(self.mlp, MOEConfig):
# fc.weight for MOE is stacked
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if self.mlp.fc.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return self.mlp.fc.weight.shape[-1] * 2
return self.mlp.fc.weight.shape[-1]
else:
k = self.mlp.fc.weight.shape[1]
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if self.mlp.fc.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return k * 2
return k

Expand Down
7 changes: 6 additions & 1 deletion modelopt/torch/export/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LINEAR_ROW,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
ConvConfig,
EmbeddingConfig,
ExpertConfig,
Expand Down Expand Up @@ -398,7 +399,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
group_size=config.awq_block_size,
quantization=config.quantization,
)
if config.quantization == QUANTIZATION_NVFP4_AWQ:
if config.quantization in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
# We have to update weight_scaling_factor and weight_scaling_factor_2
config.weights_scaling_factor, config.weights_scaling_factor_2 = (
NVFP4QTensor.get_weights_scaling_factor(
Expand Down Expand Up @@ -430,6 +434,7 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
if config.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
(
config.weights_scaling_factor,
Expand Down
75 changes: 65 additions & 10 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import torch.nn as nn

from modelopt import __version__
from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection
from modelopt.torch.quantization.model_calib import (
enable_stats_collection,
finish_stats_collection,
svd,
)
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import (
FP8QTensor,
Expand Down Expand Up @@ -57,6 +61,7 @@
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
Expand Down Expand Up @@ -165,7 +170,7 @@ def resmooth_and_get_scale(
)
new_weights.append(weight)
# If NVFP4_AWQ then we view the scales as uint8 to allow for cat later
if quantization == QUANTIZATION_NVFP4_AWQ:
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]:
scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8)
else:
scale = get_scaling_factor_from_weight(weight, group_size)
Expand All @@ -176,7 +181,7 @@ def resmooth_and_get_scale(
return (
torch.cat(new_weights, dim=0),
resmoothed_scales.view(torch.float8_e4m3fn)
if quantization == QUANTIZATION_NVFP4_AWQ
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]
else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat
new_pre_quant_scale,
)
Expand Down Expand Up @@ -243,6 +248,7 @@ def get_activation_scaling_factor(
if get_quantization_format(module) in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_activation_scaling_factor(input_quantizer)
return get_scaling_factor(input_quantizer)
Expand Down Expand Up @@ -270,6 +276,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
Expand Down Expand Up @@ -303,6 +310,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
if get_quantization_format(module) in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
Expand Down Expand Up @@ -487,6 +495,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
block_sizes = getattr(weight_quantizer, "block_sizes")
scale_bits = block_sizes.get("scale_bits")

if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"):
return QUANTIZATION_NVFP4_SVDQUANT
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
Expand Down Expand Up @@ -660,15 +670,18 @@ def process_layer_quant_config(layer_config_dict):
elif v == "w4a8_nvfp4_fp8":
layer_config = {
"quant_algo": "W4A8_NVFP4_FP8",
"group_size": layer_config_dict[prefix + ".awq_block_size"],
"has_zero_point": False,
"pre_quant_scale": True,
"group_size": block_size_value,
}
elif v == "w4a8_mxfp4_fp8":
layer_config = {
"quant_algo": "W4A8_MXFP4_FP8",
"group_size": block_size_value,
}
elif v == "nvfp4_svdquant":
layer_config = {
"quant_algo": "NVFP4_SVD",
"group_size": block_size_value,
}
else:
layer_config = {"quant_algo": v}

Expand Down Expand Up @@ -813,7 +826,12 @@ def to_quantized_weight(
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
return pack_int4_in_uint8(weight, weights_scaling_factor)

if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]:
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format."
assert weights_scaling_factor2 is not None, (
"Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format"
Expand Down Expand Up @@ -1008,6 +1026,40 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
finish_stats_collection(module.weight_quantizer)


def _update_svdquant(modules, new_pre_quant_scale):
"""Updates the pre_quant_scale, svdquant_lora_a and svdquant_lora_b matrices when pre_quant_scale is changed."""
new_pre_quant_scale = new_pre_quant_scale.to(torch.float32)
lora_a = [m.weight_quantizer.svdquant_lora_a.to(torch.float32) for m in modules]
lora_b = [m.weight_quantizer.svdquant_lora_b.to(torch.float32) for m in modules]
weight = [m.weight.to(torch.float32) for m in modules]
old_pre_quant_scale = [m.input_quantizer._pre_quant_scale.to(torch.float32) for m in modules]
weight = [
(w + (lb @ la)) * (s / new_pre_quant_scale)
for w, la, lb, s in zip(weight, lora_a, lora_b, old_pre_quant_scale)
]
weight_concatenated = torch.cat(weight, dim=0)
lb, la = svd(weight_concatenated, rank=lora_a[0].shape[0])
weight_concatenated -= lb @ la
weight_concatenated = weight_concatenated.to(modules[0].weight.dtype)
la = la.to(modules[0].weight_quantizer.svdquant_lora_a.dtype)
lb = lb.to(modules[0].weight_quantizer.svdquant_lora_b.dtype)
new_pre_quant_scale = new_pre_quant_scale.to(modules[0].input_quantizer.pre_quant_scale.dtype)

index = 0
for i, module in enumerate(modules):
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
module.weight_quantizer.svdquant_lora_a = la
assert lora_b[i].shape[0] == module.weight.shape[0]
module.weight_quantizer.svdquant_lora_b = lb[index : index + lora_b[i].shape[0], :]
module.weight = nn.Parameter(weight_concatenated[index : index + lora_b[i].shape[0], :])
index += lora_b[i].shape[0]
# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)


# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
PQS_FUSE_MODULE_MAPPING = [
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
Expand Down Expand Up @@ -1146,9 +1198,12 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
dim=0,
)

for module in modules:
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
_update_pre_quant_scale(module, avg_prequant_scale)
if getattr(modules[0].weight_quantizer, "svdquant_lora_a", None) is not None:
_update_svdquant(modules, avg_prequant_scale)
else:
for module in modules:
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
_update_pre_quant_scale(module, avg_prequant_scale)

if resmooth_only:
return
Expand Down
15 changes: 13 additions & 2 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
Expand Down Expand Up @@ -236,6 +237,10 @@ def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
model_type = type(model).__name__.lower()
module_names = set()

# NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because
# 1) its kernel handles pre-quant scale.
# 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in
# the later gate up fusion.
# Fuse pre_quant_scale to the linear weights if possible
if quantization_format is not None and "nvfp4_awq" in quantization_format.lower():
fuse_prequant_to_linear(model)
Expand All @@ -246,7 +251,8 @@ def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):

# For MoE models update pre_quant_scale to average pre_quant_scale amongst experts
if is_moe(module) and (
quantization_format is not QUANTIZATION_NONE and "awq" in quantization_format
quantization_format is not QUANTIZATION_NONE
and ("awq" in quantization_format or quantization_format == QUANTIZATION_NVFP4_SVDQUANT)
):
# update_experts_avg_prequant_scale(module)
grouped_experts = get_experts_list(module, model_type)
Expand Down Expand Up @@ -417,6 +423,7 @@ def _export_quantized_weight(

if quantization_format in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
Expand All @@ -437,7 +444,11 @@ def _export_quantized_weight(
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
)

if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
weight, _ = maybe_transpose_expert_weight_dimensions(
Expand Down
29 changes: 16 additions & 13 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,18 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz
return blocksize


def svd(weight, rank):
original_device = weight.device
original_dtype = weight.dtype
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need f64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. I kept what @jingyu-ml has originally. This part is just a refactoring so that I can reuse this code during qkv fusion.

u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
us = u[:, :rank] * s[:rank]
vt = vt[:rank]
return us.to(device=original_device, dtype=original_dtype), vt.to(
device=original_device, dtype=original_dtype
)


@torch.no_grad()
def svdquant(
model: nn.Module,
Expand All @@ -1096,25 +1108,16 @@ def svdquant(
def postprocess(module, name):
print_rank_0(f"SVD {name}")
weight = module.weight.data
original_device = weight.device
original_dtype = weight.dtype
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
us, vt = svd(weight, lowrank)
if us.shape[1] < lowrank or vt.shape[0] < lowrank:
warnings.warn(
"The low-rank dimensions do not match the layer dimensions. "
"Please verify your configuration and model settings. "
f"SVD will be skipped for this layer {name}."
)
return
us = u[:, :lowrank] * s[:lowrank]
vt = vt[:lowrank]
module.weight_quantizer.svdquant_lora_a = vt.to(
dtype=original_dtype, device=original_device
)
module.weight_quantizer.svdquant_lora_b = us.to(
dtype=original_dtype, device=original_device
)
module.weight_quantizer.svdquant_lora_a = vt
module.weight_quantizer.svdquant_lora_b = us
module.weight.data.sub_(
module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a
)
Expand Down
Loading