Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions backend/internal/service/gateway_tool_rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ func buildToolNameRewriteFromBody(body []byte) *ToolNameRewrite {
}

// applyToolNameRewriteToBody 把已构造的 ToolNameRewrite 应用到 body 上:
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点(Parrot 行为对齐,
// ttl 客户端已有则透传,否则默认 claude.DefaultCacheControlTTL)
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
//
// 历史 $.messages[*].content[*].name(tool_use)不在请求侧改写——这与 Parrot 一致;
// 响应侧 bytes.Replace 会连带还原它们。
// - 改写 $.tools[*].name(仅对 shouldMimicToolName 通过的 tool)
// - 改写 $.tool_choice.name(仅当 $.tool_choice.type == "tool")
// - 改写 $.messages[*].content[*].name(仅当 type == "tool_use")
// - 在 $.tools[last].cache_control 上打 ephemeral 缓存断点
//
// 响应侧 bytes.Replace 会连带还原假名 → 真名。
func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
if rw == nil || len(rw.Forward) == 0 {
body = applyToolsLastCacheBreakpoint(body)
Expand Down Expand Up @@ -213,6 +213,38 @@ func applyToolNameRewriteToBody(body []byte, rw *ToolNameRewrite) []byte {
}
}

// Rewrite tool_use names in messages to match the renamed tools.
// Without this, Anthropic rejects requests where messages reference tools
// by their original name but tools[] declares the renamed (fake) name.
messages := gjson.GetBytes(body, "messages")
if messages.IsArray() {
messages.ForEach(func(msgKey, msg gjson.Result) bool {
msgIdx := int(msgKey.Num)
content := msg.Get("content")
if !content.IsArray() {
return true
}
content.ForEach(func(blkKey, blk gjson.Result) bool {
blkIdx := int(blkKey.Num)
if blk.Get("type").String() != "tool_use" {
return true
}
name := blk.Get("name").String()
if name == "" {
return true
}
if fake, ok := rw.Forward[name]; ok {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIdx, blkIdx)
if next, err := sjson.SetBytes(body, path, fake); err == nil {
body = next
}
}
return true
})
return true
})
}

body = applyToolsLastCacheBreakpoint(body)
return body
}
Expand Down
24 changes: 24 additions & 0 deletions backend/internal/service/gateway_tool_rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ func TestApplyToolNameRewriteToBody_RenamesToolsAndToolChoice(t *testing.T) {
require.Equal(t, "tool", gjson.GetBytes(out, "tool_choice.type").String())
}


func TestApplyToolNameRewriteToBody_RenamesToolUseInMessages(t *testing.T) {
// sessions_list -> cc_sess_list (static prefix: sessions_ -> sessions_)
// web_search is a server tool (type != ""), not rewritten
// messages tool_use names must be rewritten to match tools[]
body := []byte(`{"tools":[{"name":"sessions_list","input_schema":{}},{"name":"web_search","type":"web_search_20250305"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"assistant","content":[{"type":"tool_use","id":"tu_01","name":"sessions_list","input":{}},{"type":"text","text":"thinking"}]},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu_01","content":"ok"}]}]}`)
rw := buildToolNameRewriteFromBody(body)
require.NotNil(t, rw)
require.Equal(t, "cc_sess_list", rw.Forward["sessions_list"])

out := applyToolNameRewriteToBody(body, rw)

// tools[0].name rewritten
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "tools.0.name").String())
// tools[1].name untouched (server tool)
require.Equal(t, "web_search", gjson.GetBytes(out, "tools.1.name").String())
// messages[1].content[0].name (tool_use) also rewritten to match tools
require.Equal(t, "cc_sess_list", gjson.GetBytes(out, "messages.1.content.0.name").String())
// messages[1].content[1] (text) untouched
require.Equal(t, "thinking", gjson.GetBytes(out, "messages.1.content.1.text").String())
// messages[2].content[0] (tool_result) untouched — no name field in tool_result
require.Equal(t, "ok", gjson.GetBytes(out, "messages.2.content.0.content").String())
}

func TestApplyToolsLastCacheBreakpoint_InjectsDefault(t *testing.T) {
body := []byte(`{"tools":[{"name":"a","input_schema":{}},{"name":"b","input_schema":{}}]}`)
out := applyToolsLastCacheBreakpoint(body)
Expand Down
Loading