Skip to content

Commit 38da871

Browse files
authored
[TRTLLM-6496][feat] Add LoRa Torch tests for the latest NIM model list (#6806)
Signed-off-by: Michal Guzek <[email protected]>
1 parent ca82911 commit 38da871

File tree

8 files changed

+562
-55
lines changed

8 files changed

+562
-55
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,23 @@ def get_bindings_model_config(self,
525525

526526
# For kv cache size calculation: set size_per_head
527527
head_dim_names = ["head_size", "head_dim"]
528+
head_size = None
528529
for head_dim_name in head_dim_names:
529-
if head_dim_name in self.pretrained_config:
530-
head_size = getattr(self.pretrained_config, head_dim_name)
531-
break
532-
else:
530+
if hasattr(self.pretrained_config, head_dim_name):
531+
value = getattr(self.pretrained_config, head_dim_name)
532+
if value is not None:
533+
head_size = value
534+
break
535+
536+
if head_size is None:
537+
assert hidden_size % num_heads == 0, (
538+
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
539+
)
540+
calculated_head_size = hidden_size // num_heads
533541
logger.warning(
534-
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
542+
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
535543
)
536-
head_size = hidden_size // num_heads
544+
head_size = calculated_head_size
537545

538546
model_config_cpp.mlp_hidden_size = mlp_hidden_size
539547
model_config_cpp.size_per_head = head_size

tests/integration/defs/common.py

Lines changed: 183 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222

2323
from packaging import version
2424

25+
from tensorrt_llm import LLM as LLM_torch
26+
from tensorrt_llm.executor.request import LoRARequest
27+
from tensorrt_llm.lora_manager import LoraConfig
28+
from tensorrt_llm.sampling_params import SamplingParams
29+
2530
from .trt_test_alternative import check_call, check_output, exists, is_windows
2631

2732

@@ -739,12 +744,28 @@ def generate_dummy_loras(
739744
from transformers import AutoModelForCausalLM
740745

741746
print("Creating pseudo LoRAs...")
742-
model = AutoModelForCausalLM.from_pretrained(
743-
hf_model_dir,
744-
torch_dtype=torch.float16,
745-
device_map="auto",
746-
trust_remote_code=True,
747-
)
747+
748+
# Avoid meta tensors by loading model to CPU first (ensures all parameters are materialized)
749+
try:
750+
model = AutoModelForCausalLM.from_pretrained(
751+
hf_model_dir,
752+
torch_dtype=torch.float16,
753+
device_map=None, # Load everything to CPU first
754+
trust_remote_code=True,
755+
low_cpu_mem_usage=False,
756+
)
757+
except Exception:
758+
# Fallback to auto device mapping if CPU loading fails
759+
print(
760+
"Warning: Loading model to CPU failed, falling back to auto device mapping"
761+
)
762+
model = AutoModelForCausalLM.from_pretrained(
763+
hf_model_dir,
764+
torch_dtype=torch.float16,
765+
device_map="auto",
766+
trust_remote_code=True,
767+
)
768+
748769
lora_config = LoraConfig(r=lora_rank,
749770
target_modules=target_modules,
750771
bias="none",
@@ -755,12 +776,57 @@ def generate_dummy_loras(
755776
if zero_weights:
756777
for param in lora_model.parameters():
757778
param.data.zero_()
779+
758780
pseudo_lora_dir = f"{lora_output_dir}/pseudo_lora_{lora_idx}"
759781
lora_model.save_pretrained(pseudo_lora_dir)
760782
lora_output_paths.append(pseudo_lora_dir)
761783
return lora_output_paths
762784

763785

786+
def get_test_prompts(use_code_prompts: bool = False) -> list[str]:
787+
"""Get test prompts for LoRA testing.
788+
789+
Args:
790+
use_code_prompts: If True, return code-related prompts. If False, return general prompts.
791+
792+
Returns:
793+
List of test prompts.
794+
"""
795+
if use_code_prompts:
796+
return [
797+
"Write a function that outputs the fibonacci sequence.",
798+
"Convert the following C++ code to Python: x = 0;x++;",
799+
"Find the largest prime factor of 42.",
800+
"write a unit test for this function: $(cat fib.py)",
801+
"# A simple python function to remove whitespace from a string:",
802+
"How to load CodeLlama from HuggingFace?",
803+
]
804+
else:
805+
return [
806+
"Hey how are you doing today?",
807+
"How is the weather in Seattle, WA?",
808+
"Is it ok to fill diesel in a petrol car?",
809+
"Can you check the top 5 trending songs on spotify?",
810+
"What is the capital of France?",
811+
"How to load CodeLlama from HuggingFace?",
812+
]
813+
814+
815+
def get_test_prompts_for_torch() -> list[str]:
816+
"""Get test prompts for LoRA Torch testing.
817+
818+
Returns:
819+
List of test prompts.
820+
"""
821+
return [
822+
"Hey how are you doing today?",
823+
"How is the weather in Seattle, WA?",
824+
"Is it ok to fill diesel in a petrol car?",
825+
"Can you check the top 5 trending songs on spotify?",
826+
"What is the capital of France?",
827+
]
828+
829+
764830
def test_multi_lora_support(
765831
hf_model_dir,
766832
tllm_ckpt_dir,
@@ -815,24 +881,7 @@ def test_multi_lora_support(
815881
print(
816882
f"Build engines completed in {(build_end - build_start):.2f} seconds.")
817883

818-
if use_code_prompts:
819-
input_prompts = [
820-
"Write a function that outputs the fibonacci sequence.",
821-
"Convert the following C++ code to Python: x = 0;x++;",
822-
"Find the largest prime factor of 42.",
823-
"write a unit test for this function: $(cat fib.py)",
824-
"# A simple python function to remove whitespace from a string:",
825-
"How to load CodeLlama from HuggingFace?",
826-
]
827-
else:
828-
input_prompts = [
829-
"Hey how are you doing today?",
830-
"How is the weather in Seattle, WA?",
831-
"Is it ok to fill diesel in a petrol car?",
832-
"Can you check the top 5 trending songs on spotify?",
833-
"What is the capital of France?",
834-
"How to load CodeLlama from HuggingFace?",
835-
]
884+
input_prompts = get_test_prompts(use_code_prompts)
836885

837886
print("Run inference with C++ runtime with pybind...")
838887
inference_start = time.time()
@@ -867,6 +916,116 @@ def test_multi_lora_support(
867916
)
868917

869918

919+
def test_llm_torch_multi_lora_support(
920+
hf_model_dir,
921+
llm_venv,
922+
num_loras=2,
923+
lora_rank=8,
924+
target_hf_modules=["q_proj", "k_proj", "v_proj"],
925+
target_trtllm_modules=["attn_q", "attn_k", "attn_v"],
926+
zero_lora_weights=True,
927+
tensor_parallel_size=1,
928+
pipeline_parallel_size=1,
929+
expected_outputs=None):
930+
"""Test multi-LoRA support with LLM-API Torch backend."""
931+
932+
# if expected_outputs is None:
933+
# raise ValueError("expected_outputs must be provided for exact validation")
934+
935+
start_time = time.time()
936+
print("Creating dummy LoRAs...")
937+
lora_start = time.time()
938+
939+
lora_paths = generate_dummy_loras(
940+
hf_model_dir=hf_model_dir,
941+
lora_output_dir=llm_venv.get_working_directory(),
942+
num_loras=num_loras,
943+
lora_rank=lora_rank,
944+
target_modules=target_hf_modules,
945+
zero_weights=zero_lora_weights)
946+
lora_end = time.time()
947+
print(
948+
f"Creating dummy LoRAs completed in {(lora_end - lora_start):.2f} seconds."
949+
)
950+
951+
print("Initializing LLM_torch with LoRA support...")
952+
init_start = time.time()
953+
954+
lora_config = LoraConfig(lora_dir=lora_paths,
955+
max_lora_rank=lora_rank,
956+
max_loras=num_loras,
957+
max_cpu_loras=num_loras,
958+
lora_target_modules=target_trtllm_modules)
959+
960+
input_prompts = get_test_prompts_for_torch()
961+
962+
with LLM_torch(
963+
model=hf_model_dir,
964+
lora_config=lora_config,
965+
tensor_parallel_size=tensor_parallel_size,
966+
pipeline_parallel_size=pipeline_parallel_size,
967+
dtype="bfloat16",
968+
max_batch_size=8, # From original test
969+
max_input_len=512, # From original test
970+
max_seq_len=562, # From original test
971+
max_beam_width=1 # From original test
972+
) as llm:
973+
974+
init_end = time.time()
975+
print(
976+
f"LLM_torch initialization completed in {(init_end - init_start):.2f} seconds."
977+
)
978+
979+
print("Running inference with LLM-API Torch backend...")
980+
inference_start = time.time()
981+
982+
# Create LoRA requests for different adapters
983+
lora_requests = []
984+
for i in range(len(input_prompts)):
985+
if i % 2 == 1: # Add some requests without LoRA
986+
lora_requests.append(None)
987+
else: # With LoRA
988+
lora_requests.append(
989+
LoRARequest(f"lora-{i}", i,
990+
lora_paths[i % len(lora_paths)]))
991+
992+
sampling_params = SamplingParams(max_tokens=30,
993+
top_p=0.5,
994+
top_k=0,
995+
temperature=0.0)
996+
997+
outputs = llm.generate(input_prompts,
998+
sampling_params=sampling_params,
999+
lora_request=lora_requests)
1000+
1001+
inference_end = time.time()
1002+
print(
1003+
f"Inference completed in {(inference_end - inference_start):.2f} seconds."
1004+
)
1005+
1006+
# Validate exact outputs
1007+
print("Validating exact outputs...")
1008+
assert len(outputs) == len(expected_outputs), \
1009+
f"Expected {len(expected_outputs)} outputs, got {len(outputs)}"
1010+
1011+
for i, (output, expected) in enumerate(zip(outputs, expected_outputs)):
1012+
actual_text = output.outputs[0].text
1013+
print(f"Prompt {i+1}: {input_prompts[i]}")
1014+
print(
1015+
f"LoRA: {lora_requests[i].lora_int_id if lora_requests[i] else 'None'}"
1016+
)
1017+
print(f"Expected: {expected}")
1018+
print(f"Actual: {actual_text}")
1019+
print("-" * 50)
1020+
1021+
# Exact string comparison
1022+
assert actual_text == expected, \
1023+
f"Output {i+1} mismatch:\nExpected: {expected!r}\nActual: {actual_text!r}"
1024+
1025+
total_time = time.time() - start_time
1026+
print(f"Total test execution time: {total_time:.2f} seconds")
1027+
1028+
8701029
def get_dummy_spec_decoding_heads(hf_model_dir,
8711030
save_dir,
8721031
mode='medusa',

tests/integration/defs/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,9 @@ def llama_model_root(request):
10151015
elif request.param == "llama-3.1-8b-instruct-hf-fp8":
10161016
llama_model_root = os.path.join(models_root, "llama-3.1-model",
10171017
"Llama-3.1-8B-Instruct-FP8")
1018+
elif request.param == "llama-3.1-8b-instruct":
1019+
llama_model_root = os.path.join(models_root, "llama-3.1-model",
1020+
"Llama-3.1-8B-Instruct")
10181021
elif request.param == "llama-3.1-8b-hf-nvfp4":
10191022
llama_model_root = os.path.join(models_root, "nvfp4-quantized",
10201023
"Meta-Llama-3.1-8B")
@@ -1024,9 +1027,18 @@ def llama_model_root(request):
10241027
elif request.param == "llama-3.2-1b":
10251028
llama_model_root = os.path.join(models_root, "llama-3.2-models",
10261029
"Llama-3.2-1B")
1030+
elif request.param == "llama-3.2-1b-instruct":
1031+
llama_model_root = os.path.join(models_root, "llama-3.2-models",
1032+
"Llama-3.2-1B-Instruct")
10271033
elif request.param == "llama-3.2-3b":
10281034
llama_model_root = os.path.join(models_root, "llama-3.2-models",
10291035
"Llama-3.2-3B")
1036+
elif request.param == "llama-3.2-3b-instruct":
1037+
llama_model_root = os.path.join(models_root, "llama-3.2-models",
1038+
"Llama-3.2-3B-Instruct")
1039+
elif request.param == "llama-3.3-70b-instruct":
1040+
llama_model_root = os.path.join(models_root, "llama-3.3-models",
1041+
"Llama-3.3-70B-Instruct")
10301042
assert os.path.exists(
10311043
llama_model_root
10321044
), f"{llama_model_root} does not exist under NFS LLM_MODELS_ROOT dir"
@@ -1323,6 +1335,11 @@ def llm_lora_model_root(request):
13231335
elif item == "komt-mistral-7b-v1-lora":
13241336
model_root_list.append(
13251337
os.path.join(models_root, "komt-mistral-7b-v1-lora"))
1338+
elif item == "Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32":
1339+
model_root_list.append(
1340+
os.path.join(
1341+
models_root, "nemotron-nas",
1342+
"Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32"))
13261343

13271344
return ",".join(model_root_list)
13281345

@@ -1363,6 +1380,8 @@ def llm_mistral_model_root(request):
13631380
model_root = os.path.join(models_root, "mistral-7b-v0.1")
13641381
if request.param == "mistral-7b-v0.1":
13651382
model_root = os.path.join(models_root, "mistral-7b-v0.1")
1383+
if request.param == "mistral-nemo-instruct-2407":
1384+
model_root = os.path.join(models_root, "Mistral-Nemo-Instruct-2407")
13661385
if request.param == "komt-mistral-7b-v1":
13671386
model_root = os.path.join(models_root, "komt-mistral-7b-v1")
13681387
if request.param == "mistral-7b-v0.3":

0 commit comments

Comments
 (0)