2222
2323from packaging import version
2424
25+ from tensorrt_llm import LLM as LLM_torch
26+ from tensorrt_llm .executor .request import LoRARequest
27+ from tensorrt_llm .lora_manager import LoraConfig
28+ from tensorrt_llm .sampling_params import SamplingParams
29+
2530from .trt_test_alternative import check_call , check_output , exists , is_windows
2631
2732
@@ -739,12 +744,28 @@ def generate_dummy_loras(
739744 from transformers import AutoModelForCausalLM
740745
741746 print ("Creating pseudo LoRAs..." )
742- model = AutoModelForCausalLM .from_pretrained (
743- hf_model_dir ,
744- torch_dtype = torch .float16 ,
745- device_map = "auto" ,
746- trust_remote_code = True ,
747- )
747+
748+ # Avoid meta tensors by loading model to CPU first (ensures all parameters are materialized)
749+ try :
750+ model = AutoModelForCausalLM .from_pretrained (
751+ hf_model_dir ,
752+ torch_dtype = torch .float16 ,
753+ device_map = None , # Load everything to CPU first
754+ trust_remote_code = True ,
755+ low_cpu_mem_usage = False ,
756+ )
757+ except Exception :
758+ # Fallback to auto device mapping if CPU loading fails
759+ print (
760+ "Warning: Loading model to CPU failed, falling back to auto device mapping"
761+ )
762+ model = AutoModelForCausalLM .from_pretrained (
763+ hf_model_dir ,
764+ torch_dtype = torch .float16 ,
765+ device_map = "auto" ,
766+ trust_remote_code = True ,
767+ )
768+
748769 lora_config = LoraConfig (r = lora_rank ,
749770 target_modules = target_modules ,
750771 bias = "none" ,
@@ -755,12 +776,57 @@ def generate_dummy_loras(
755776 if zero_weights :
756777 for param in lora_model .parameters ():
757778 param .data .zero_ ()
779+
758780 pseudo_lora_dir = f"{ lora_output_dir } /pseudo_lora_{ lora_idx } "
759781 lora_model .save_pretrained (pseudo_lora_dir )
760782 lora_output_paths .append (pseudo_lora_dir )
761783 return lora_output_paths
762784
763785
786+ def get_test_prompts (use_code_prompts : bool = False ) -> list [str ]:
787+ """Get test prompts for LoRA testing.
788+
789+ Args:
790+ use_code_prompts: If True, return code-related prompts. If False, return general prompts.
791+
792+ Returns:
793+ List of test prompts.
794+ """
795+ if use_code_prompts :
796+ return [
797+ "Write a function that outputs the fibonacci sequence." ,
798+ "Convert the following C++ code to Python: x = 0;x++;" ,
799+ "Find the largest prime factor of 42." ,
800+ "write a unit test for this function: $(cat fib.py)" ,
801+ "# A simple python function to remove whitespace from a string:" ,
802+ "How to load CodeLlama from HuggingFace?" ,
803+ ]
804+ else :
805+ return [
806+ "Hey how are you doing today?" ,
807+ "How is the weather in Seattle, WA?" ,
808+ "Is it ok to fill diesel in a petrol car?" ,
809+ "Can you check the top 5 trending songs on spotify?" ,
810+ "What is the capital of France?" ,
811+ "How to load CodeLlama from HuggingFace?" ,
812+ ]
813+
814+
815+ def get_test_prompts_for_torch () -> list [str ]:
816+ """Get test prompts for LoRA Torch testing.
817+
818+ Returns:
819+ List of test prompts.
820+ """
821+ return [
822+ "Hey how are you doing today?" ,
823+ "How is the weather in Seattle, WA?" ,
824+ "Is it ok to fill diesel in a petrol car?" ,
825+ "Can you check the top 5 trending songs on spotify?" ,
826+ "What is the capital of France?" ,
827+ ]
828+
829+
764830def test_multi_lora_support (
765831 hf_model_dir ,
766832 tllm_ckpt_dir ,
@@ -815,24 +881,7 @@ def test_multi_lora_support(
815881 print (
816882 f"Build engines completed in { (build_end - build_start ):.2f} seconds." )
817883
818- if use_code_prompts :
819- input_prompts = [
820- "Write a function that outputs the fibonacci sequence." ,
821- "Convert the following C++ code to Python: x = 0;x++;" ,
822- "Find the largest prime factor of 42." ,
823- "write a unit test for this function: $(cat fib.py)" ,
824- "# A simple python function to remove whitespace from a string:" ,
825- "How to load CodeLlama from HuggingFace?" ,
826- ]
827- else :
828- input_prompts = [
829- "Hey how are you doing today?" ,
830- "How is the weather in Seattle, WA?" ,
831- "Is it ok to fill diesel in a petrol car?" ,
832- "Can you check the top 5 trending songs on spotify?" ,
833- "What is the capital of France?" ,
834- "How to load CodeLlama from HuggingFace?" ,
835- ]
884+ input_prompts = get_test_prompts (use_code_prompts )
836885
837886 print ("Run inference with C++ runtime with pybind..." )
838887 inference_start = time .time ()
@@ -867,6 +916,116 @@ def test_multi_lora_support(
867916 )
868917
869918
919+ def test_llm_torch_multi_lora_support (
920+ hf_model_dir ,
921+ llm_venv ,
922+ num_loras = 2 ,
923+ lora_rank = 8 ,
924+ target_hf_modules = ["q_proj" , "k_proj" , "v_proj" ],
925+ target_trtllm_modules = ["attn_q" , "attn_k" , "attn_v" ],
926+ zero_lora_weights = True ,
927+ tensor_parallel_size = 1 ,
928+ pipeline_parallel_size = 1 ,
929+ expected_outputs = None ):
930+ """Test multi-LoRA support with LLM-API Torch backend."""
931+
932+ # if expected_outputs is None:
933+ # raise ValueError("expected_outputs must be provided for exact validation")
934+
935+ start_time = time .time ()
936+ print ("Creating dummy LoRAs..." )
937+ lora_start = time .time ()
938+
939+ lora_paths = generate_dummy_loras (
940+ hf_model_dir = hf_model_dir ,
941+ lora_output_dir = llm_venv .get_working_directory (),
942+ num_loras = num_loras ,
943+ lora_rank = lora_rank ,
944+ target_modules = target_hf_modules ,
945+ zero_weights = zero_lora_weights )
946+ lora_end = time .time ()
947+ print (
948+ f"Creating dummy LoRAs completed in { (lora_end - lora_start ):.2f} seconds."
949+ )
950+
951+ print ("Initializing LLM_torch with LoRA support..." )
952+ init_start = time .time ()
953+
954+ lora_config = LoraConfig (lora_dir = lora_paths ,
955+ max_lora_rank = lora_rank ,
956+ max_loras = num_loras ,
957+ max_cpu_loras = num_loras ,
958+ lora_target_modules = target_trtllm_modules )
959+
960+ input_prompts = get_test_prompts_for_torch ()
961+
962+ with LLM_torch (
963+ model = hf_model_dir ,
964+ lora_config = lora_config ,
965+ tensor_parallel_size = tensor_parallel_size ,
966+ pipeline_parallel_size = pipeline_parallel_size ,
967+ dtype = "bfloat16" ,
968+ max_batch_size = 8 , # From original test
969+ max_input_len = 512 , # From original test
970+ max_seq_len = 562 , # From original test
971+ max_beam_width = 1 # From original test
972+ ) as llm :
973+
974+ init_end = time .time ()
975+ print (
976+ f"LLM_torch initialization completed in { (init_end - init_start ):.2f} seconds."
977+ )
978+
979+ print ("Running inference with LLM-API Torch backend..." )
980+ inference_start = time .time ()
981+
982+ # Create LoRA requests for different adapters
983+ lora_requests = []
984+ for i in range (len (input_prompts )):
985+ if i % 2 == 1 : # Add some requests without LoRA
986+ lora_requests .append (None )
987+ else : # With LoRA
988+ lora_requests .append (
989+ LoRARequest (f"lora-{ i } " , i ,
990+ lora_paths [i % len (lora_paths )]))
991+
992+ sampling_params = SamplingParams (max_tokens = 30 ,
993+ top_p = 0.5 ,
994+ top_k = 0 ,
995+ temperature = 0.0 )
996+
997+ outputs = llm .generate (input_prompts ,
998+ sampling_params = sampling_params ,
999+ lora_request = lora_requests )
1000+
1001+ inference_end = time .time ()
1002+ print (
1003+ f"Inference completed in { (inference_end - inference_start ):.2f} seconds."
1004+ )
1005+
1006+ # Validate exact outputs
1007+ print ("Validating exact outputs..." )
1008+ assert len (outputs ) == len (expected_outputs ), \
1009+ f"Expected { len (expected_outputs )} outputs, got { len (outputs )} "
1010+
1011+ for i , (output , expected ) in enumerate (zip (outputs , expected_outputs )):
1012+ actual_text = output .outputs [0 ].text
1013+ print (f"Prompt { i + 1 } : { input_prompts [i ]} " )
1014+ print (
1015+ f"LoRA: { lora_requests [i ].lora_int_id if lora_requests [i ] else 'None' } "
1016+ )
1017+ print (f"Expected: { expected } " )
1018+ print (f"Actual: { actual_text } " )
1019+ print ("-" * 50 )
1020+
1021+ # Exact string comparison
1022+ assert actual_text == expected , \
1023+ f"Output { i + 1 } mismatch:\n Expected: { expected !r} \n Actual: { actual_text !r} "
1024+
1025+ total_time = time .time () - start_time
1026+ print (f"Total test execution time: { total_time :.2f} seconds" )
1027+
1028+
8701029def get_dummy_spec_decoding_heads (hf_model_dir ,
8711030 save_dir ,
8721031 mode = 'medusa' ,
0 commit comments