Skip to content

Commit

Permalink
fix: broken tool call after editing the file before saving
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmi committed Jan 4, 2025
1 parent 3c0ba7f commit 871c76e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 8 deletions.
22 changes: 17 additions & 5 deletions gptme/llm/llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ def stream(
def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
for message in message_dicts:
# Format tool result as expected by the model
if message["role"] == "system" and "call_id" in message:
if message["role"] == "user" and "call_id" in message:
modified_message = dict(message)
modified_message["role"] = "user"
modified_message["content"] = [
{
"type": "tool_result",
Expand Down Expand Up @@ -358,22 +357,35 @@ def _transform_system_messages(
# unless a `call_id` is present, indicating the tool_format is 'tool'.
# Tool responses are handled separately by _handle_tool.
for i, message in enumerate(messages):
if message.role == "system" and message.call_id is None:
if message.role == "system":
content = (
f"<system>{message.content}</system>"
if message.call_id is None
else message.content
)

messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
content=content,
files=message.files, # type: ignore
call_id=message.call_id,
)

# find consecutive user role messages and merge them together
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
if messages_new and messages_new[-1].role == "user" and message.role == "user":
if (
messages_new
and messages_new[-1].role == "user"
and message.role == "user"
and message.call_id == messages_new[-1].call_id
):
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n\n{message.content}",
files=messages_new[-1].files + message.files, # type: ignore
call_id=messages_new[-1].call_id,
)
else:
messages_new.append(message)
Expand Down
40 changes: 39 additions & 1 deletion gptme/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
modified_message["content"] = content

if tool_calls:
# Clean content property if empty otherwise the call fails
if not content:
del modified_message["content"]
modified_message["tool_calls"] = tool_calls
Expand All @@ -283,6 +284,41 @@ def _handle_tools(message_dicts: Iterable[dict]) -> Generator[dict, None, None]:
yield message


def _merge_tool_results_with_same_call_id(
messages_dicts: Iterable[dict],
) -> list[dict]: # Generator[dict, None, None]:
"""
When we call a tool, this tool can potentially yield multiple messages. However
the API expect to have only one tool result per tool call. This function tries
to merge subsequent tool results with the same call ID as expected by
the API.
"""

messages_dicts = iter(messages_dicts)

messages_new: list[dict] = []
while message := next(messages_dicts, None):
if messages_new and (
message["role"] == "tool"
and messages_new[-1]["role"] == "tool"
and message["tool_call_id"] == messages_new[-1]["tool_call_id"]
):
prev_msg = messages_new[-1]
content = message["content"]
if not isinstance(content, list):
content = {"type": "text", "text": content}

messages_new[-1] = {
"role": "tool",
"content": prev_msg["content"] + content,
"tool_call_id": prev_msg["tool_call_id"],
}
else:
messages_new.append(message)

return messages_new


def _process_file(msg: dict, model: ModelMeta) -> dict:
message_content = msg["content"]
if model.provider == "deepseek":
Expand Down Expand Up @@ -423,7 +459,9 @@ def _prepare_messages_for_api(
tools_dict = [_spec2tool(tool, model) for tool in tools] if tools else None

if tools_dict is not None:
messages_dicts = _handle_tools(messages_dicts)
messages_dicts = _merge_tool_results_with_same_call_id(
_handle_tools(messages_dicts)
)

messages_dicts = _transform_msgs_for_special_provider(messages_dicts, model)

Expand Down
4 changes: 4 additions & 0 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Generator
from functools import lru_cache

from ..util.interrupt import clear_interruptible

from ..message import Message
from .base import (
ConfirmFunc,
Expand Down Expand Up @@ -124,11 +126,13 @@ def execute_msg(msg: Message, confirm: ConfirmFunc) -> Generator[Message, None,
for tool_response in tooluse.execute(confirm):
yield tool_response.replace(call_id=tooluse.call_id)
except KeyboardInterrupt:
clear_interruptible()
yield Message(
"system",
"User hit Ctrl-c to interrupt the process",
call_id=tooluse.call_id,
)
break


# Called often when checking streaming output for executable blocks,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_llm_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_message_conversion_with_tools():
content='<thinking>\nSomething\n</thinking>\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]

tool_save = get_tool("save")
Expand Down Expand Up @@ -152,7 +153,12 @@ def test_message_conversion_with_tools():
"content": [
{
"type": "tool_result",
"content": [{"type": "text", "text": "Saved to toto.txt"}],
"content": [
{
"type": "text",
"text": "Saved to toto.txt\n\n(Modified by user)",
}
],
"tool_use_id": "tool_call_id",
"cache_control": {"type": "ephemeral"},
}
Expand Down
6 changes: 5 additions & 1 deletion tests/test_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_message_conversion_with_tools():
content='\n@save(tool_call_id): {"path": "path.txt", "content": "file_content"}',
),
Message(role="system", content="Saved to toto.txt", call_id="tool_call_id"),
Message(role="system", content="(Modified by user)", call_id="tool_call_id"),
]

set_default_model("openai/gpt-4o")
Expand Down Expand Up @@ -193,7 +194,10 @@ def test_message_conversion_with_tools():
},
{
"role": "tool",
"content": [{"type": "text", "text": "Saved to toto.txt"}],
"content": [
{"type": "text", "text": "Saved to toto.txt"},
{"type": "text", "text": "(Modified by user)"},
],
"tool_call_id": "tool_call_id",
},
]

0 comments on commit 871c76e

Please sign in to comment.