diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 04442a2e1f..b335f3091b 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import List, NamedTuple, Optional, Tuple, Union,Dict import torch from pydantic import BaseModel @@ -106,7 +106,37 @@ def __call__( client_ids (List[Optional[int]]): A batch of optional client ids. """ pass # noqa + +class LogitBiasLogitsProcessor(LogitsProcessor): + def __init__(self, logit_bias: Dict[str, float]) -> None: + super().__init__() + self.logit_bias = logit_bias + self.tokens_to_adjust = {} + for k, v in logit_bias.items(): + try: + token_id = int(k) + self.tokens_to_adjust[token_id] = v + except (ValueError, TypeError): + continue + + def __call__(self, req_id: int, logits: torch.Tensor, + token_ids: List[List[int]], stream_ptr: Optional[int], + client_id: Optional[int]) -> None: + + if self.tokens_to_adjust: + token_ids_list = list(self.tokens_to_adjust.keys()) + bias_values = torch.tensor( + [self.tokens_to_adjust[token] for token in token_ids_list], + device=logits.device, + dtype=logits.dtype + ) + + stream = None if stream_ptr is None else torch.cuda.ExternalStream(stream_ptr) + with torch.cuda.stream(stream): + logits[:, :, token_ids_list] += bias_values + if stream is not None: + stream.synchronize() @dataclass(slots=True, kw_only=True) class AdditionalModelOutput: diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 03ad7df27b..7be76acefd 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -15,6 +15,7 @@ from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams +from ..sampling_params import LogitBiasLogitsProcessor class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields & allow to initialize by both alias and field name @@ -242,6 +243,9 @@ def to_sampling_params(self) -> SamplingParams: guided_decoding=_response_format_to_guided_decoding_params( self.response_format), + # logits_bias + logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias), + # completion-extra-params add_special_tokens=self.add_special_tokens, @@ -532,6 +536,9 @@ def to_sampling_params(self) -> SamplingParams: guided_decoding=_response_format_to_guided_decoding_params( self.response_format), + # logits_bias + logits_processor = None if not self.logit_bias else LogitBiasLogitsProcessor(self.logit_bias), + # chat-completion-extra-params add_special_tokens=self.add_special_tokens, @@ -567,13 +574,6 @@ def check_logprobs(cls, data): raise ValueError("top_logprobs is not supported") return data - @model_validator(mode="before") - @classmethod - def verify_logit_processor(cls, data): - if data.get("logit_bias"): - raise ValueError("logit bias is not supported") - return data - @model_validator(mode="before") @classmethod def check_suffix(cls, data):