Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_svdquant) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_svdquant]" >&2
exit 1
;;
esac
Expand Down
13 changes: 11 additions & 2 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
QUANTIZATION_INT4_AWQ = "int4_awq"
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
QUANTIZATION_NVFP4 = "nvfp4"
QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant"
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
Expand Down Expand Up @@ -507,12 +508,20 @@ def hidden_size(self):
"""Returns the hidden size of the transformer model."""
if isinstance(self.mlp, MOEConfig):
# fc.weight for MOE is stacked
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if self.mlp.fc.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return self.mlp.fc.weight.shape[-1] * 2
return self.mlp.fc.weight.shape[-1]
else:
k = self.mlp.fc.weight.shape[1]
if self.mlp.fc.quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if self.mlp.fc.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return k * 2
return k

Expand Down
7 changes: 6 additions & 1 deletion modelopt/torch/export/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LINEAR_ROW,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
ConvConfig,
EmbeddingConfig,
ExpertConfig,
Expand Down Expand Up @@ -398,7 +399,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
group_size=config.awq_block_size,
quantization=config.quantization,
)
if config.quantization == QUANTIZATION_NVFP4_AWQ:
if config.quantization in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
# We have to update weight_scaling_factor and weight_scaling_factor_2
config.weights_scaling_factor, config.weights_scaling_factor_2 = (
NVFP4QTensor.get_weights_scaling_factor(
Expand Down Expand Up @@ -430,6 +434,7 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
if config.quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
(
config.weights_scaling_factor,
Expand Down
75 changes: 65 additions & 10 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import torch.nn as nn

from modelopt import __version__
from modelopt.torch.quantization.model_calib import enable_stats_collection, finish_stats_collection
from modelopt.torch.quantization.model_calib import (
enable_stats_collection,
finish_stats_collection,
svd,
)
from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear
from modelopt.torch.quantization.qtensor import (
FP8QTensor,
Expand Down Expand Up @@ -57,6 +61,7 @@
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
Expand Down Expand Up @@ -165,7 +170,7 @@ def resmooth_and_get_scale(
)
new_weights.append(weight)
# If NVFP4_AWQ then we view the scales as uint8 to allow for cat later
if quantization == QUANTIZATION_NVFP4_AWQ:
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]:
scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, group_size).view(torch.uint8)
else:
scale = get_scaling_factor_from_weight(weight, group_size)
Expand All @@ -176,7 +181,7 @@ def resmooth_and_get_scale(
return (
torch.cat(new_weights, dim=0),
resmoothed_scales.view(torch.float8_e4m3fn)
if quantization == QUANTIZATION_NVFP4_AWQ
if quantization in [QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT]
else resmoothed_scales, # if NVFP4_AWQ we view the scales back as float8_e4m3fn after cat
new_pre_quant_scale,
)
Expand Down Expand Up @@ -243,6 +248,7 @@ def get_activation_scaling_factor(
if get_quantization_format(module) in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_activation_scaling_factor(input_quantizer)
return get_scaling_factor(input_quantizer)
Expand Down Expand Up @@ -270,6 +276,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
Expand Down Expand Up @@ -303,6 +310,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
if get_quantization_format(module) in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
Expand Down Expand Up @@ -487,6 +495,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
block_sizes = getattr(weight_quantizer, "block_sizes")
scale_bits = block_sizes.get("scale_bits")

if input_quantizer is not None and hasattr(weight_quantizer, "svdquant_lora_a"):
return QUANTIZATION_NVFP4_SVDQUANT
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
Expand Down Expand Up @@ -660,15 +670,18 @@ def process_layer_quant_config(layer_config_dict):
elif v == "w4a8_nvfp4_fp8":
layer_config = {
"quant_algo": "W4A8_NVFP4_FP8",
"group_size": layer_config_dict[prefix + ".awq_block_size"],
"has_zero_point": False,
"pre_quant_scale": True,
"group_size": block_size_value,
}
elif v == "w4a8_mxfp4_fp8":
layer_config = {
"quant_algo": "W4A8_MXFP4_FP8",
"group_size": block_size_value,
}
elif v == "nvfp4_svdquant":
layer_config = {
"quant_algo": "NVFP4_SVD",
"group_size": block_size_value,
}
else:
layer_config = {"quant_algo": v}

Expand Down Expand Up @@ -813,7 +826,12 @@ def to_quantized_weight(
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
return pack_int4_in_uint8(weight, weights_scaling_factor)

if quantization in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8]:
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
assert block_size is not None, "Block size not passed. Unable to quantize to NVFP4 format."
assert weights_scaling_factor2 is not None, (
"Weights scaling factor 2 not passed. Unable to quantize to NVFP4 format"
Expand Down Expand Up @@ -1008,6 +1026,40 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
finish_stats_collection(module.weight_quantizer)


def _update_svdquant(modules, new_pre_quant_scale):
"""Updates the pre_quant_scale, svdquant_lora_a and svdquant_lora_b matrices when pre_quant_scale is changed."""
new_pre_quant_scale = new_pre_quant_scale.to(torch.float32)
lora_a = [m.weight_quantizer.svdquant_lora_a.to(torch.float32) for m in modules]
lora_b = [m.weight_quantizer.svdquant_lora_b.to(torch.float32) for m in modules]
weight = [m.weight.to(torch.float32) for m in modules]
old_pre_quant_scale = [m.input_quantizer._pre_quant_scale.to(torch.float32) for m in modules]
weight = [
(w + (lb @ la)) * (s / new_pre_quant_scale)
for w, la, lb, s in zip(weight, lora_a, lora_b, old_pre_quant_scale)
]
weight_concatenated = torch.cat(weight, dim=0)
lb, la = svd(weight_concatenated, rank=lora_a[0].shape[0])
weight_concatenated -= lb @ la
weight_concatenated = weight_concatenated.to(modules[0].weight.dtype)
la = la.to(modules[0].weight_quantizer.svdquant_lora_a.dtype)
lb = lb.to(modules[0].weight_quantizer.svdquant_lora_b.dtype)
new_pre_quant_scale = new_pre_quant_scale.to(modules[0].input_quantizer.pre_quant_scale.dtype)

index = 0
for i, module in enumerate(modules):
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
module.weight_quantizer.svdquant_lora_a = la
assert lora_b[i].shape[0] == module.weight.shape[0]
module.weight_quantizer.svdquant_lora_b = lb[index : index + lora_b[i].shape[0], :]
module.weight = nn.Parameter(weight_concatenated[index : index + lora_b[i].shape[0], :])
index += lora_b[i].shape[0]
# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)


# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
PQS_FUSE_MODULE_MAPPING = [
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
Expand Down Expand Up @@ -1146,9 +1198,12 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
dim=0,
)

for module in modules:
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
_update_pre_quant_scale(module, avg_prequant_scale)
if hasattr(modules[0].weight_quantizer, "svdquant_lora_a"):
_update_svdquant(modules, avg_prequant_scale)
else:
for module in modules:
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
_update_pre_quant_scale(module, avg_prequant_scale)

if resmooth_only:
return
Expand Down
16 changes: 14 additions & 2 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
Expand Down Expand Up @@ -109,6 +110,10 @@ def _output_hook(module, input, output):
fused_linears = {}
module_names = set()

# NVFP4 SVDQuant does not need pre-quant scale fusion (either into previous linear or layernorm) because
# 1) its kernel handles pre-quant scale.
# 2) fusing into previous linear will need to change the lora_up in up_proj which may cause issue in
# the later gate up fusion.
# Fuse pre_quant_scale to the linear weights if possible
if quantization_format is not None and "nvfp4_awq" in quantization_format.lower():
fuse_prequant_to_linear(model)
Expand All @@ -117,7 +122,9 @@ def _output_hook(module, input, output):
module_names.add(name)

# For MoE models update pre_quant_scale to average pre_quant_scale amongst experts
if is_moe(module) and ("awq" in quantization_format):
if is_moe(module) and (
("awq" in quantization_format) or (quantization_format == QUANTIZATION_NVFP4_SVDQUANT)
):
# update_experts_avg_prequant_scale(module)
grouped_experts = get_experts_list(module, model_type)
for modules in grouped_experts:
Expand Down Expand Up @@ -314,6 +321,7 @@ def _export_quantized_weight(

if quantization_format in [
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
Expand All @@ -334,7 +342,11 @@ def _export_quantized_weight(
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
)

if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
if quantization_format in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
]:
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
weight, _ = maybe_transpose_expert_weight_dimensions(
Expand Down
29 changes: 16 additions & 13 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,18 @@ def _get_awq_quantizer_block_size(tensor: torch.Tensor, quantizer: TensorQuantiz
return blocksize


def svd(weight, rank):
original_device = weight.device
original_dtype = weight.dtype
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
us = u[:, :rank] * s[:rank]
vt = vt[:rank]
return us.to(device=original_device, dtype=original_dtype), vt.to(
device=original_device, dtype=original_dtype
)


@torch.no_grad()
def svdquant(
model: nn.Module,
Expand All @@ -1034,25 +1046,16 @@ def svdquant(
def postprocess(module, name):
print_rank_0(f"SVD {name}")
weight = module.weight.data
original_device = weight.device
original_dtype = weight.dtype
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
us, vt = svd(weight, lowrank)
if us.shape[1] < lowrank or vt.shape[0] < lowrank:
warnings.warn(
"The low-rank dimensions do not match the layer dimensions. "
"Please verify your configuration and model settings. "
f"SVD will be skipped for this layer {name}."
)
return
us = u[:, :lowrank] * s[:lowrank]
vt = vt[:lowrank]
module.weight_quantizer.svdquant_lora_a = vt.to(
dtype=original_dtype, device=original_device
)
module.weight_quantizer.svdquant_lora_b = us.to(
dtype=original_dtype, device=original_device
)
module.weight_quantizer.svdquant_lora_a = vt
module.weight_quantizer.svdquant_lora_b = us
module.weight.data.sub_(
module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a
)
Expand Down