Skip to content

Commit

Permalink
fix: Azure AI Inference API tool response is slightly different than OAI
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jan 15, 2025
1 parent dbf77c9 commit 7fe3cdf
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 28 deletions.
32 changes: 23 additions & 9 deletions app/helpers/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,24 +492,38 @@ async def _content_callback(buffer: str) -> None:
# logger.debug("Tools: %s", tools)

# Execute LLM inference
maximum_tokens_reached = False
content_buffer_pointer = 0
last_buffered_tool_id = None
maximum_tokens_reached = False
tool_calls_buffer: dict[str, MessageToolModel] = {}
try:
# Consume the completion stream
async for delta in completion_stream(
max_tokens=160, # Lowest possible value for 90% of the cases, if not sufficient, retry will be triggered, 100 tokens ~= 75 words, 20 words ~= 1 sentence, 6 sentences ~= 160 tokens
messages=call.messages,
system=system,
tools=tools,
):
if not delta.content:
for piece in delta.tool_calls or []:
tool_calls_buffer[piece.id] = tool_calls_buffer.get(
piece.id, MessageToolModel()
)
tool_calls_buffer[piece.id] += piece
else:
# Store whole content
# Complete tools
if delta.tool_calls:
for piece in delta.tool_calls:
# Azure AI Inference sometimes returns empty tool IDs, in that case, use the last one
if piece.id:
last_buffered_tool_id = piece.id
# No tool ID, alert and skip
if not last_buffered_tool_id:
logger.warning(
"Empty tool ID, cannot buffer tool call: %s", piece
)
continue
# New, init buffer
if last_buffered_tool_id not in tool_calls_buffer:
tool_calls_buffer[last_buffered_tool_id] = MessageToolModel()
# Append
tool_calls_buffer[last_buffered_tool_id].add_delta(piece)

# Complete content
if delta.content:
content_full += delta.content
for sentence, length in tts_sentence_split(
content_full[content_buffer_pointer:], False
Expand Down
63 changes: 44 additions & 19 deletions app/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,33 @@ class ToolModel(BaseModel):
function_name: str = ""
tool_id: str = ""

def __add__(self, other: object) -> "ToolModel":
if not isinstance(other, StreamingChatResponseToolCallUpdate):
return NotImplemented
if other.id:
self.tool_id = other.id
if other.function:
if other.function.name:
self.function_name = other.function.name
if other.function.arguments:
self.function_arguments += other.function.arguments
return self
@property
def is_openai_valid(self) -> bool:
"""
Check if the tool model is valid for OpenAI.
def __hash__(self) -> int:
return self.tool_id.__hash__()
The model is valid if it has a tool ID and a function name.
"""
return bool(self.tool_id and self.function_name)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ToolModel):
return False
return self.tool_id == other.tool_id
def add_delta(self, delta: StreamingChatResponseToolCallUpdate) -> "ToolModel":
"""
Update the tool model with a delta.
This model will be updated with the delta's values from the streaming API.
"""
if delta.id:
self.tool_id = delta.id
if delta.function.name:
self.function_name = delta.function.name
if delta.function.arguments:
self.function_arguments += delta.function.arguments
return self

def to_openai(self) -> ChatCompletionsToolCall:
"""
Convert the tool model to an OpenAI tool call.
"""
return ChatCompletionsToolCall(
id=self.tool_id,
function=FunctionCall(
Expand All @@ -94,6 +100,14 @@ def to_openai(self) -> ChatCompletionsToolCall:
),
)

def __hash__(self) -> int:
return self.tool_id.__hash__()

def __eq__(self, other: object) -> bool:
if not isinstance(other, ToolModel):
return False
return self.tool_id == other.tool_id


class MessageModel(BaseModel):
# Immutable fields
Expand All @@ -120,16 +134,23 @@ def _validate_created_at(cls, created_at: datetime) -> datetime:
def to_openai(
self,
) -> list[ChatRequestMessage]:
"""
Convert the message model to an OpenAI message.
Tools are validated before being added to the message, invalid ones are discarded.
"""
# Removing newlines from the content to avoid hallucinations issues with GPT-4 Turbo
content = " ".join([line.strip() for line in self.content.splitlines()])

# Init content for human persona
if self.persona == PersonaEnum.HUMAN:
return [
UserMessage(
content=f"action={self.action.value} {content}",
)
]

# Init content for assistant persona
if self.persona == PersonaEnum.ASSISTANT:
if not self.tool_calls:
return [
Expand All @@ -138,19 +159,23 @@ def to_openai(
)
]

# Add tools
valid_tools = [
tool_call for tool_call in self.tool_calls if tool_call.is_openai_valid
]
res = []
res.append(
AssistantMessage(
content=f"action={self.action.value} style={self.style.value} {content}",
tool_calls=[tool_call.to_openai() for tool_call in self.tool_calls],
tool_calls=[tool_call.to_openai() for tool_call in valid_tools],
)
)
res.extend(
ToolMessage(
content=tool_call.content,
tool_call_id=tool_call.tool_id,
)
for tool_call in self.tool_calls
for tool_call in valid_tools
if tool_call.content
)
return res
Expand Down

0 comments on commit 7fe3cdf

Please sign in to comment.