Skip to content

[TRTLLM-5831][feat] Add LoRA support for pytorch backend in trtllm-serve #5376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions examples/serve/openai_completion_client_for_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
### OpenAI Completion Client

from openai import OpenAI

client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="tensorrt_llm",
)

response = client.completions.create(
model="llama-7b-hf",
prompt="美国的首都在哪里? \n答案:",
max_tokens=20,
extra_body={
"lora_request": {
"lora_name":
"luotuo-lora-7b-0.1",
"lora_int_id":
0,
"lora_path":
"/home/scratch.trt_llm_data/llm-models/llama-models/luotuo-lora-7b-0.1"
}
},
)

print(response)
4 changes: 4 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict

from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.executor.serialization import register_approved_ipc_class
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams

Expand Down Expand Up @@ -170,6 +172,7 @@ class CompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
lora_request: Optional[LoRARequest] = None

# doc: begin-completion-sampling-params
use_beam_search: bool = False
Expand Down Expand Up @@ -447,6 +450,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
lora_request: Optional[LoRARequest] = None
# doc: end-chat-completion-sampling-params

# doc: begin-chat-completion-extra-params
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ async def create_chat_response(
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=request.stream,
lora_request=request.lora_request,
disaggregated_params=disaggregated_params
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
Expand Down Expand Up @@ -414,6 +415,7 @@ async def create_completion_response(
sampling_params=sampling_params,
_postproc_params=postproc_params,
streaming=request.stream,
lora_request=request.lora_request,
disaggregated_params=disaggregated_params
)
asyncio.create_task(self.await_disconnected(raw_request, promise))
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,18 @@ def test_trtllm_serve_multimodal_example(llm_root, llm_venv):
])


def test_trtllm_serve_lora_example(llm_root, llm_venv):
example_root = Path(os.path.join(llm_root, "examples", "serve"))
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd([
"-m", "pip", "install", "-r",
os.path.join(example_root, "requirements.txt")
])
llm_venv.run_cmd(
["-m", "pytest",
str(test_root / "_test_trtllm_serve_lora.py")])


def test_openai_misc_example(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_misc.py")])
Expand Down Expand Up @@ -1401,6 +1413,11 @@ def test_openai_reasoning(llm_root, llm_venv):
str(test_root / "_test_openai_reasoning.py")])


def test_openai_lora(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])


def test_openai_chat_multimodal_example(llm_root, llm_venv):
example_root = Path(os.path.join(llm_root, "examples", "apps"))
test_root = unittest_path() / "llmapi" / "apps"
Expand Down
104 changes: 104 additions & 0 deletions tests/unittest/llmapi/apps/_test_openai_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import tempfile
from dataclasses import asdict
from typing import List, Optional

import openai
import pytest
import yaml

from tensorrt_llm.executor.request import LoRARequest
from tests.unittest.utils.util import similar

from ..test_llm import get_model_path
from .openai_server import RemoteOpenAIServer

pytestmark = pytest.mark.threadleak(enabled=False)


@pytest.fixture(scope="module", ids=["llama-models/llama-7b-hf"])
def model_name() -> str:
return "llama-models/llama-7b-hf"


@pytest.fixture(scope="module")
def lora_adapter_names() -> List[Optional[str]]:
return [
None, "llama-models/luotuo-lora-7b-0.1",
"llama-models/Japanese-Alpaca-LoRA-7b-v0"
]


@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
"lora_config": {
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
"max_lora_rank": 8
}
}

with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)

yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)


@pytest.fixture(scope="module")
def server(model_name: str,
temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer:
model_path = get_model_path(model_name)
args = []
args.extend(["--backend", "pytorch"])
args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file])
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def client(server: RemoteOpenAIServer) -> openai.OpenAI:
return server.get_client()


def test_lora(client: openai.OpenAI, model_name: str,
lora_adapter_names: List[str]):
prompts = [
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
]
references = [
"沃尔玛\n\n## 新闻\n\n* ",
"美国的首都是华盛顿。\n\n美国的",
"纽约\n\n### カンファレンスの",
"Washington, D.C.\nWashington, D.C. is the capital of the United",
"华盛顿。\n\n英国の首都是什",
"ワシントン\nQ1. アメリカ合衆国",
]

for prompt, reference, lora_adapter_name in zip(prompts, references,
lora_adapter_names * 2):
extra_body = {}
if lora_adapter_name is not None:
lora_req = LoRARequest(lora_adapter_name,
lora_adapter_names.index(lora_adapter_name),
get_model_path(lora_adapter_name))
extra_body["lora_request"] = asdict(lora_req)

response = client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=20,
extra_body=extra_body,
)

assert similar(response.choices[0].text, reference)
69 changes: 69 additions & 0 deletions tests/unittest/llmapi/apps/_test_trtllm_serve_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import subprocess
import sys
import tempfile

import pytest
import yaml

from .openai_server import RemoteOpenAIServer

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from test_llm import get_model_path


@pytest.fixture(scope="module", ids=["llama-models/llama-7b-hf"])
def model_name() -> str:
return "llama-models/llama-7b-hf"


@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
"lora_config": {
"lora_target_modules": ['attn_q', 'attn_k', 'attn_v'],
"max_lora_rank": 8
}
}

with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)

yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)


@pytest.fixture(scope="module")
def server(model_name: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
args = [
"--backend", "pytorch", "--extra_llm_api_options",
temp_extra_llm_api_options_file
]
with RemoteOpenAIServer(model_path, port=8000,
cli_args=args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def example_root():
llm_root = os.getenv("LLM_ROOT")
return os.path.join(llm_root, "examples", "serve")


@pytest.mark.parametrize("exe, script",
[("python3", "openai_completion_client_for_lora.py")])
def test_trtllm_serve_examples(exe: str, script: str,
server: RemoteOpenAIServer, example_root: str):
client_script = os.path.join(example_root, script)
# CalledProcessError will be raised if any errors occur
subprocess.run([exe, client_script],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True)