Skip to content

Commit 0268d72

Browse files
feat(openai protocol):support logitbias
1 parent 6c3210a commit 0268d72

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field, fields
5-
from typing import List, NamedTuple, Optional, Tuple, Union
5+
from typing import List, NamedTuple, Optional, Tuple, Union,Dict
66

77
import torch
88
from pydantic import BaseModel
@@ -106,7 +106,37 @@ def __call__(
106106
client_ids (List[Optional[int]]): A batch of optional client ids.
107107
"""
108108
pass # noqa
109+
110+
class LogitBiasLogitsProcessor(LogitsProcessor):
111+
def __init__(self, logit_bias: Dict[str, float]) -> None:
112+
super().__init__()
113+
self.logit_bias = logit_bias
114+
self.tokens_to_adjust = {}
115+
for k, v in logit_bias.items():
116+
try:
117+
token_id = int(k)
118+
self.tokens_to_adjust[token_id] = v
119+
except (ValueError, TypeError):
120+
continue
121+
122+
def __call__(self, req_id: int, logits: torch.Tensor,
123+
token_ids: List[List[int]], stream_ptr: Optional[int],
124+
client_id: Optional[int]) -> None:
125+
126+
if self.tokens_to_adjust:
127+
token_ids_list = list(self.tokens_to_adjust.keys())
128+
bias_values = torch.tensor(
129+
[self.tokens_to_adjust[token] for token in token_ids_list],
130+
device=logits.device,
131+
dtype=logits.dtype
132+
)
133+
134+
stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
135+
with torch.cuda.stream(stream):
136+
logits[:, :, token_ids_list] += bias_values
109137

138+
if stream is not None:
139+
stream.synchronize()
110140

111141
@dataclass(slots=True, kw_only=True)
112142
class AdditionalModelOutput:

tensorrt_llm/serve/openai_protocol.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
1616
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
1717

18+
from ..sampling_params import LogitBiasLogitsProcessor
1819

1920
class OpenAIBaseModel(BaseModel):
2021
# OpenAI API does not allow extra fields & allow to initialize by both alias and field name
@@ -242,6 +243,9 @@ def to_sampling_params(self) -> SamplingParams:
242243
guided_decoding=_response_format_to_guided_decoding_params(
243244
self.response_format),
244245

246+
# logits_bias
247+
logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias),
248+
245249
# completion-extra-params
246250
add_special_tokens=self.add_special_tokens,
247251

@@ -532,6 +536,9 @@ def to_sampling_params(self) -> SamplingParams:
532536
guided_decoding=_response_format_to_guided_decoding_params(
533537
self.response_format),
534538

539+
# logits_bias
540+
logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias),
541+
535542
# chat-completion-extra-params
536543
add_special_tokens=self.add_special_tokens,
537544

@@ -567,13 +574,6 @@ def check_logprobs(cls, data):
567574
raise ValueError("top_logprobs is not supported")
568575
return data
569576

570-
@model_validator(mode="before")
571-
@classmethod
572-
def verify_logit_processor(cls, data):
573-
if data.get("logit_bias"):
574-
raise ValueError("logit bias is not supported")
575-
return data
576-
577577
@model_validator(mode="before")
578578
@classmethod
579579
def check_suffix(cls, data):

0 commit comments

Comments
 (0)