Skip to content

Commit bb6c338

Browse files
TracinQiJune
andauthored
AWQ support Modelopt ckpts. (NVIDIA#3258)
Signed-off-by: Tracin <[email protected]> Co-authored-by: QI JUN <[email protected]>
1 parent b763051 commit bb6c338

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

tensorrt_llm/models/modeling_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,14 +1705,16 @@ def preprocess_perlayer_weights(weights,
17051705
dtype = torch.float16
17061706
if model_config.dtype == "bfloat16":
17071707
dtype = torch.bfloat16
1708-
weights[name] = preprocessor(param.T, torch.quint4x2,
1708+
weights[name] = preprocessor(param.transpose(-1, -2),
1709+
torch.quint4x2,
17091710
activation_type).view(dtype)
1710-
if name.endswith('weights_scaling_factor'
1711-
) and param.shape[0] > param.shape[1]:
1712-
# TODO: refine on supporting ModelOpt HF-AWQ
1713-
weights[name] = param.T.contiguous().to(
1711+
if name.endswith('weights_scaling_factor'):
1712+
weights[name] = param.transpose(-1, -2).contiguous().to(
17141713
str_dtype_to_torch(model_config.dtype))
17151714
if name.endswith('prequant_scaling_factor'):
1715+
if len(weights[name].shape) == 2:
1716+
# MoE experts share the same scaling factor.
1717+
param = param[0, :]
17161718
weights[name] = param.reshape(1, -1)
17171719
if model_config.mapping.tp_rank > 0:
17181720
if name.endswith('attention.dense.bias') or name.endswith(

tests/integration/defs/examples/test_mixtral.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -888,24 +888,39 @@ def test_llm_mixtral_1gpu_fp4_llmapi(
888888
venv_check_call(llm_venv, mmlu_cmd)
889889

890890

891-
@pytest.mark.parametrize("model_name", ['mixtral-8x7b-v0.1-AWQ'])
891+
@pytest.mark.parametrize(
892+
"model_name", ['mixtral-8x7b-v0.1-AWQ', 'Mixtral-8x7B-Instruct-v0.1'])
892893
def test_llm_mixtral_int4_awq_1gpu_summary(llama_example_root,
893894
llm_datasets_root, model_name,
894895
llm_rouge_root, llm_venv, cmodel_dir,
895-
engine_dir):
896+
engine_dir,
897+
qcache_dir_without_install_package):
896898
models_root = llm_models_root()
897899
model_dir = os.path.join(models_root, model_name)
898900
ckpt_dir = os.path.join(cmodel_dir, model_name)
899901

900-
print("Convert checkpoint...")
901-
convert_cmd = [
902-
f"{llama_example_root}/convert_checkpoint.py",
903-
"--model_dir",
904-
model_dir,
905-
"--output_dir",
906-
ckpt_dir,
907-
]
908-
venv_check_call(llm_venv, convert_cmd)
902+
if 'AWQ' in model_name:
903+
print("Convert checkpoint...")
904+
convert_cmd = [
905+
f"{llama_example_root}/convert_checkpoint.py",
906+
"--model_dir",
907+
model_dir,
908+
"--output_dir",
909+
ckpt_dir,
910+
]
911+
venv_check_call(llm_venv, convert_cmd)
912+
else:
913+
print("Quantizing model...")
914+
ckpt_dir = quantize_data(
915+
llm_venv,
916+
llama_example_root,
917+
model_dir=model_dir,
918+
calib_dataset=f"{llm_datasets_root}/cnn_dailymail",
919+
dtype="float16",
920+
qformat="int4_awq",
921+
quantize_dir=qcache_dir_without_install_package,
922+
tp_size=1,
923+
calib_size=32)
909924

910925
print("Build engines...")
911926
build_cmd = [

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ examples/test_mixtral.py::test_llm_mixtral_wo_2gpus_summary[Mixtral-8x7B-v0.1-in
180180
examples/test_mixtral.py::test_llm_mixtral_wo_2gpus_summary[Mixtral-8x7B-v0.1-int8-nb:4]
181181
examples/test_mixtral.py::test_llm_mixtral_1gpu_fp4_llmapi[Mixtral-8x7B-Instruct-v0.1]
182182
examples/test_mixtral.py::test_llm_mixtral_int4_awq_1gpu_summary[mixtral-8x7b-v0.1-AWQ]
183+
examples/test_mixtral.py::test_llm_mixtral_int4_awq_1gpu_summary[Mixtral-8x7B-Instruct-v0.1]
183184
examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
184185
examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1]
185186
examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]

0 commit comments

Comments
 (0)