diff --git a/examples/serve/openai_completion_client_for_lora.py b/examples/serve/openai_completion_client_for_lora.py new file mode 100644 index 0000000000..be4857e5cb --- /dev/null +++ b/examples/serve/openai_completion_client_for_lora.py @@ -0,0 +1,26 @@ +### OpenAI Completion Client + +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="tensorrt_llm", +) + +response = client.completions.create( + model="llama-7b-hf", + prompt="美国的首都在哪里? \n答案:", + max_tokens=20, + extra_body={ + "lora_request": { + "lora_name": + "luotuo-lora-7b-0.1", + "lora_int_id": + 0, + "lora_path": + "/home/scratch.trt_llm_data/llm-models/llama-models/luotuo-lora-7b-0.1" + } + }, +) + +print(response) diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 03ad7df27b..052ba62b21 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -12,6 +12,8 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated, Required, TypedDict +from tensorrt_llm.executor.request import LoRARequest +from tensorrt_llm.executor.serialization import register_approved_ipc_class from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams @@ -170,6 +172,7 @@ class CompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 user: Optional[str] = None + lora_request: Optional[LoRARequest] = None # doc: begin-completion-sampling-params use_beam_search: bool = False @@ -447,6 +450,7 @@ class ChatCompletionRequest(OpenAIBaseModel): skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + lora_request: Optional[LoRARequest] = None # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index edc5b5f6f6..63f7e82c73 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -294,6 +294,7 @@ async def create_chat_response( sampling_params=sampling_params, _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=request.stream, + lora_request=request.lora_request, disaggregated_params=disaggregated_params ) asyncio.create_task(self.await_disconnected(raw_request, promise)) @@ -414,6 +415,7 @@ async def create_completion_response( sampling_params=sampling_params, _postproc_params=postproc_params, streaming=request.stream, + lora_request=request.lora_request, disaggregated_params=disaggregated_params ) asyncio.create_task(self.await_disconnected(raw_request, promise)) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 22f6035d95..24e6e4c393 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1371,6 +1371,18 @@ def test_trtllm_serve_multimodal_example(llm_root, llm_venv): ]) +def test_trtllm_serve_lora_example(llm_root, llm_venv): + example_root = Path(os.path.join(llm_root, "examples", "serve")) + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd([ + "-m", "pip", "install", "-r", + os.path.join(example_root, "requirements.txt") + ]) + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_trtllm_serve_lora.py")]) + + def test_openai_misc_example(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_misc.py")]) @@ -1401,6 +1413,11 @@ def test_openai_reasoning(llm_root, llm_venv): str(test_root / "_test_openai_reasoning.py")]) +def test_openai_lora(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")]) + + def test_openai_chat_multimodal_example(llm_root, llm_venv): example_root = Path(os.path.join(llm_root, "examples", "apps")) test_root = unittest_path() / "llmapi" / "apps" diff --git a/tests/unittest/llmapi/apps/_test_openai_lora.py b/tests/unittest/llmapi/apps/_test_openai_lora.py new file mode 100644 index 0000000000..ba65547d0b --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_lora.py @@ -0,0 +1,104 @@ +import os +import tempfile +from dataclasses import asdict +from typing import List, Optional + +import openai +import pytest +import yaml + +from tensorrt_llm.executor.request import LoRARequest +from tests.unittest.utils.util import similar + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["llama-models/llama-7b-hf"]) +def model_name() -> str: + return "llama-models/llama-7b-hf" + + +@pytest.fixture(scope="module") +def lora_adapter_names() -> List[Optional[str]]: + return [ + None, "llama-models/luotuo-lora-7b-0.1", + "llama-models/Japanese-Alpaca-LoRA-7b-v0" + ] + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = { + "lora_config": { + "lora_target_modules": ['attn_q', 'attn_k', 'attn_v'], + "max_lora_rank": 8 + } + } + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, + temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer: + model_path = get_model_path(model_name) + args = [] + args.extend(["--backend", "pytorch"]) + args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file]) + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer) -> openai.OpenAI: + return server.get_client() + + +def test_lora(client: openai.OpenAI, model_name: str, + lora_adapter_names: List[str]): + prompts = [ + "美国的首都在哪里? \n答案:", + "美国的首都在哪里? \n答案:", + "美国的首都在哪里? \n答案:", + "アメリカ合衆国の首都はどこですか? \n答え:", + "アメリカ合衆国の首都はどこですか? \n答え:", + "アメリカ合衆国の首都はどこですか? \n答え:", + ] + references = [ + "沃尔玛\n\n## 新闻\n\n* ", + "美国的首都是华盛顿。\n\n美国的", + "纽约\n\n### カンファレンスの", + "Washington, D.C.\nWashington, D.C. is the capital of the United", + "华盛顿。\n\n英国の首都是什", + "ワシントン\nQ1. アメリカ合衆国", + ] + + for prompt, reference, lora_adapter_name in zip(prompts, references, + lora_adapter_names * 2): + extra_body = {} + if lora_adapter_name is not None: + lora_req = LoRARequest(lora_adapter_name, + lora_adapter_names.index(lora_adapter_name), + get_model_path(lora_adapter_name)) + extra_body["lora_request"] = asdict(lora_req) + + response = client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=20, + extra_body=extra_body, + ) + + assert similar(response.choices[0].text, reference) diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py new file mode 100644 index 0000000000..2248250b83 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py @@ -0,0 +1,69 @@ +import os +import subprocess +import sys +import tempfile + +import pytest +import yaml + +from .openai_server import RemoteOpenAIServer + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from test_llm import get_model_path + + +@pytest.fixture(scope="module", ids=["llama-models/llama-7b-hf"]) +def model_name() -> str: + return "llama-models/llama-7b-hf" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(): + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") + try: + extra_llm_api_options_dict = { + "lora_config": { + "lora_target_modules": ['attn_q', 'attn_k', 'attn_v'], + "max_lora_rank": 8 + } + } + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, temp_extra_llm_api_options_file: str): + model_path = get_model_path(model_name) + args = [ + "--backend", "pytorch", "--extra_llm_api_options", + temp_extra_llm_api_options_file + ] + with RemoteOpenAIServer(model_path, port=8000, + cli_args=args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def example_root(): + llm_root = os.getenv("LLM_ROOT") + return os.path.join(llm_root, "examples", "serve") + + +@pytest.mark.parametrize("exe, script", + [("python3", "openai_completion_client_for_lora.py")]) +def test_trtllm_serve_examples(exe: str, script: str, + server: RemoteOpenAIServer, example_root: str): + client_script = os.path.join(example_root, script) + # CalledProcessError will be raised if any errors occur + subprocess.run([exe, client_script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True)