Skip to content

Commit 0acdecb

Browse files
authored
[https://nvbugs/5569713][fix] Disable fp8 deep gemm for EXAONE-4.0-32B-FP8 (#8429)
Signed-off-by: Junyi Xu <[email protected]>
1 parent f256eb9 commit 0acdecb

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

tensorrt_llm/_torch/models/modeling_exaone4.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
77
from tensorrt_llm.functional import PositionEmbeddingType
8+
from tensorrt_llm.quantization import QuantAlgo
89

910
from ..attention_backend import AttentionMetadata
1011
from ..attention_backend.interface import (PositionalEmbeddingParams,
@@ -54,7 +55,8 @@ class Exaone4Attention(QKNormRoPEAttention):
5455
def __init__(self,
5556
model_config: ModelConfig[Exaone4Config],
5657
layer_idx: Optional[int] = None,
57-
fuse_qk_norm_rope: bool = False):
58+
fuse_qk_norm_rope: bool = False,
59+
disable_deep_gemm: bool = False):
5860
config = model_config.pretrained_config
5961

6062
self.attention_window_size = None
@@ -88,6 +90,7 @@ def __init__(self,
8890
layer_idx=layer_idx,
8991
dtype=config.torch_dtype,
9092
config=model_config,
93+
disable_deep_gemm=disable_deep_gemm,
9194
)
9295

9396
def forward(
@@ -128,9 +131,17 @@ def __init__(
128131
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
129132
)
130133

134+
disable_deep_gemm = False
135+
quant_config = getattr(model_config, "quant_config", None)
136+
if quant_config is not None:
137+
# EXAONE4 fp8 has an illegal memory access issue with deep_gemm.
138+
disable_deep_gemm = getattr(quant_config, "quant_algo",
139+
None) == QuantAlgo.FP8_BLOCK_SCALES
140+
131141
self.self_attn = Exaone4Attention(
132142
model_config,
133143
layer_idx=layer_idx,
144+
disable_deep_gemm=disable_deep_gemm,
134145
)
135146

136147
self.mlp = GatedMLP(
@@ -140,6 +151,7 @@ def __init__(
140151
dtype=config.torch_dtype,
141152
config=model_config,
142153
layer_idx=layer_idx,
154+
disable_deep_gemm=disable_deep_gemm,
143155
)
144156

145157
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ l0_b200:
7373
- unittest/_torch/modeling -k "modeling_llama"
7474
- unittest/_torch/modeling -k "modeling_mixtral"
7575
- unittest/_torch/modeling -k "modeling_gpt_oss"
76+
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
7677
- unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison"
7778
- condition:
7879
ranges:

tests/unittest/_torch/modeling/test_modeling_exaone4.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import json
2+
import os
3+
import shutil
14
import unittest
25
from copy import deepcopy
36
from dataclasses import dataclass
@@ -51,8 +54,9 @@ class Exaone4Config(PretrainedConfig):
5154
"max_position_embeddings": 131072,
5255
"model_type": "exaone4",
5356
"num_attention_heads": 40,
54-
"num_hidden_layers":
55-
4, #NOTE: For testing, we use 4 instead of 64(all layers)
57+
# NOTE: For testing, we use 32 instead of 64(all layers)
58+
# Increase from 4 to 32 to trigger the deep_gemm kernel issue
59+
"num_hidden_layers": 32,
5660
"num_key_value_heads": 8,
5761
"pad_token_id": 0,
5862
"rms_norm_eps": 1e-05,
@@ -74,6 +78,15 @@ class Exaone4Config(PretrainedConfig):
7478
"attn_implementation": "flash_attention_2"
7579
}
7680

81+
EXAONE4_FP8_QUANT_CONFIG = {
82+
"quantization_config": {
83+
"activation_scheme": "dynamic",
84+
"modules_to_not_convert": None,
85+
"quant_method": "fp8",
86+
"weight_block_size": [128, 128]
87+
},
88+
}
89+
7790

7891
@dataclass(repr=False)
7992
class Scenario:
@@ -390,3 +403,30 @@ def run_forward(input_ids, position_ids, attn_metadata):
390403
if graph_runner is not None:
391404
graph_runner.clear()
392405
kv_cache_manager.shutdown()
406+
407+
@parameterized.expand([None, "FP8"])
408+
def test_llm_load(self, quant_algo):
409+
410+
def dump_config_json(dst_dir, config):
411+
if os.path.exists(dst_dir):
412+
shutil.rmtree(dst_dir)
413+
os.makedirs(dst_dir)
414+
415+
dst_path = os.path.join(dst_dir, 'config.json')
416+
with open(dst_path, 'w', encoding='utf-8') as f:
417+
json.dump(config, f, indent=2, ensure_ascii=False)
418+
419+
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
420+
if quant_algo == "FP8":
421+
if getSMVersion() < 89:
422+
self.skipTest(
423+
"This test is not supported in pre-Ada architecture")
424+
425+
config_dict.update(EXAONE4_FP8_QUANT_CONFIG)
426+
427+
tmp_model_dir = f"/tmp/exaone4_llm_load_test_model"
428+
dump_config_json(tmp_model_dir, config_dict)
429+
try:
430+
tensorrt_llm.LLM(model=tmp_model_dir, load_format="dummy")
431+
except Exception:
432+
raise RuntimeError("Failed to load model.")

0 commit comments

Comments
 (0)