Skip to content

feat(openai protocol):support logitbias #5354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from typing import List, NamedTuple, Optional, Tuple, Union
from typing import List, NamedTuple, Optional, Tuple, Union,Dict

import torch
from pydantic import BaseModel
Expand Down Expand Up @@ -106,7 +106,37 @@ def __call__(
client_ids (List[Optional[int]]): A batch of optional client ids.
"""
pass # noqa

class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, logit_bias: Dict[str, float]) -> None:
super().__init__()
self.logit_bias = logit_bias
self.tokens_to_adjust = {}
for k, v in logit_bias.items():
try:
token_id = int(k)
self.tokens_to_adjust[token_id] = v
except (ValueError, TypeError):
continue

def __call__(self, req_id: int, logits: torch.Tensor,
token_ids: List[List[int]], stream_ptr: Optional[int],
client_id: Optional[int]) -> None:

if self.tokens_to_adjust:
token_ids_list = list(self.tokens_to_adjust.keys())
bias_values = torch.tensor(
[self.tokens_to_adjust[token] for token in token_ids_list],
device=logits.device,
dtype=logits.dtype
)

stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr)
with torch.cuda.stream(stream):
logits[:, :, token_ids_list] += bias_values

if stream is not None:
stream.synchronize()

@dataclass(slots=True, kw_only=True)
class AdditionalModelOutput:
Expand Down
14 changes: 7 additions & 7 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams

from ..sampling_params import LogitBiasLogitsProcessor

class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields & allow to initialize by both alias and field name
Expand Down Expand Up @@ -242,6 +243,9 @@ def to_sampling_params(self) -> SamplingParams:
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# logits_bias
logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias),

# completion-extra-params
add_special_tokens=self.add_special_tokens,

Expand Down Expand Up @@ -532,6 +536,9 @@ def to_sampling_params(self) -> SamplingParams:
guided_decoding=_response_format_to_guided_decoding_params(
self.response_format),

# logits_bias
logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias),

# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,

Expand Down Expand Up @@ -567,13 +574,6 @@ def check_logprobs(cls, data):
raise ValueError("top_logprobs is not supported")
return data

@model_validator(mode="before")
@classmethod
def verify_logit_processor(cls, data):
if data.get("logit_bias"):
raise ValueError("logit bias is not supported")
return data

@model_validator(mode="before")
@classmethod
def check_suffix(cls, data):
Expand Down