From 4337ade2dcf3df779739fe082657e88440c5a993 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Mon, 26 Feb 2024 19:13:15 -0800 Subject: [PATCH] produce a tool call result that works for all known models --- chatlab/messaging.py | 20 ++++++++++++++++++-- chatlab/views/tools.py | 7 ++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/chatlab/messaging.py b/chatlab/messaging.py index 547e658..1833007 100644 --- a/chatlab/messaging.py +++ b/chatlab/messaging.py @@ -10,7 +10,7 @@ """ -from typing import Optional +from typing import Literal, Optional, Required, TypedDict from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam @@ -99,8 +99,23 @@ def function_result(name: str, content: str) -> ChatCompletionMessageParam: "name": name, } +class ChatCompletionToolMessageParamWithName(TypedDict, total=False): + content: Required[str] + """The contents of the tool message.""" -def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessageParam: + role: Required[Literal["tool"]] + """The role of the messages author, in this case `tool`.""" + + tool_call_id: Required[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the tool.""" + + + + +def tool_result(tool_call_id: str, content: str, name: str) -> ChatCompletionToolMessageParamWithName: """Create a tool result message. Args: @@ -112,6 +127,7 @@ def tool_result(tool_call_id: str, content: str) -> ChatCompletionToolMessagePar """ return { "role": "tool", + "name": name, "content": content, "tool_call_id": tool_call_id, } diff --git a/chatlab/views/tools.py b/chatlab/views/tools.py index 2be94d2..dcd3f2e 100644 --- a/chatlab/views/tools.py +++ b/chatlab/views/tools.py @@ -4,7 +4,7 @@ from ..registry import FunctionRegistry, FunctionArgumentError, UnknownFunctionError -from ..messaging import assistant_function_call, function_result +from ..messaging import assistant_function_call, function_result, tool_result class ToolCalled(AutoUpdate): """Once a tool has finished up, this is the view.""" @@ -27,6 +27,11 @@ def get_function_called_message(self): return function_result(self.name, self.result) + def get_tool_called_message(self): + # NOTE: OpenAI has mismatched types where it doesn't include the `name` + # xref: https://github.com/openai/openai-python/issues/1078 + return tool_result(tool_call_id=self.id, content=self.result, name=self.name) + class ToolArguments(AutoUpdate): id: str