Skip to content

Commit

Permalink
begin setting up tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
rgbkrk committed Feb 27, 2024
1 parent f700287 commit bd26ad7
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ async def __call__(self, *messages: Union[ChatCompletionMessageParam, str], stre
"""Send messages to the chat model and display the response."""
return await self.submit(*messages, stream=stream, **kwargs)

async def __process_stream(
self, resp: AsyncStream[ChatCompletionChunk]
) -> Tuple[str, Optional[ToolArguments]]:
async def __process_stream(self, resp: AsyncStream[ChatCompletionChunk]) -> Tuple[str, Optional[ToolArguments], List[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
function_view: Optional[ToolArguments] = None
finish_reason = None

tool_calls: list[ToolArguments] = []

async for result in resp: # Go through the results of the stream
choices = result.choices

Expand All @@ -161,6 +161,22 @@ async def __process_stream(
if choice.delta.content is not None:
assistant_view.display_once()
assistant_view.append(choice.delta.content)
elif choice.delta.tool_calls is not None:
for tool_call in choice.delta.tool_calls:
if (
tool_call.function is not None
and tool_call.function.name is not None
and tool_call.function.arguments is not None
and tool_call.id is not None
):
# Must build up
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)
tool_calls.append(tool_argument)

# TODO: self.append the big tools payload

elif choice.delta.function_call is not None:
function_call = choice.delta.function_call
if function_call.name is not None:
Expand Down Expand Up @@ -189,7 +205,7 @@ async def __process_stream(
if finish_reason is None:
raise ValueError("No finish reason provided by OpenAI")

return (finish_reason, function_view)
return (finish_reason, function_view, tool_calls)

async def __process_full_completion(self, resp: ChatCompletion) -> Tuple[str, Optional[ToolArguments]]:
assistant_view: AssistantMessageView = AssistantMessageView()
Expand Down Expand Up @@ -231,6 +247,10 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
"""
full_messages: List[ChatCompletionMessageParam] = []
full_messages.extend(self.messages)

# TODO: Just keeping this aside while working on both stream and non-stream
tool_arguments: List[ToolArguments] = []

for message in messages:
if isinstance(message, str):
full_messages.append(human(message))
Expand All @@ -243,27 +263,26 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
base_url=self.base_url,
)

api_manifest = self.function_registry.api_manifest()

# Due to the strict response typing based on `Literal` typing on `stream`, we have to process these
# two cases separately
if stream:
streaming_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=True,
temperature=kwargs.get("temperature", 0),
)

self.append(*messages)

finish_reason, function_call_request = await self.__process_stream(streaming_response)
finish_reason, function_call_request, tool_arguments = await self.__process_stream(streaming_response)
else:
# TODO: Process tools for non stream
full_response = await client.chat.completions.create(
model=self.model,
messages=full_messages,
**api_manifest,
tools=self.function_registry.tools,
stream=False,
temperature=kwargs.get("temperature", 0),
)
Expand Down Expand Up @@ -298,6 +317,15 @@ async def submit(self, *messages: Union[ChatCompletionMessageParam, str], stream
# Reply back to the LLM with the result of the function call, allow it to continue
await self.submit(stream=stream, **kwargs)
return

if finish_reason == "tool_calls":
for tool_argument in tool_arguments:
# Oh crap I need to append the big assistant call of it too. May have to assume we've done it by here.
function_called = await tool_argument.call(self.function_registry)
# TODO: Format the tool message
self.append(function_called.get_tool_called_message())

await self.submit(stream=stream, **kwargs)

# All other finish reasons are valid for regular assistant messages
if finish_reason == "stop":
Expand Down

0 comments on commit bd26ad7

Please sign in to comment.