Skip to content

Commit 9fb2b42

Browse files
committed
fix: Fixed google llm client
1 parent f47e5c0 commit 9fb2b42

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

internal/llm/google_client.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package llm
22

33
import (
44
"context"
5+
"encoding/base64"
56
"encoding/json"
67
"fmt"
78
"strings"
@@ -170,6 +171,12 @@ func convertToolCallsFromContent(content *genai.Content) []map[string]interface{
170171
if part.FunctionCall.ID != "" {
171172
toolCall["id"] = part.FunctionCall.ID
172173
}
174+
if part.Thought {
175+
toolCall["thought"] = true
176+
}
177+
if len(part.ThoughtSignature) > 0 {
178+
toolCall["thought_signature"] = base64.StdEncoding.EncodeToString(part.ThoughtSignature)
179+
}
173180

174181
toolCalls = append(toolCalls, toolCall)
175182
}
@@ -246,6 +253,16 @@ func convertAssistantMessage(msg *Message) (*genai.Content, error) {
246253
if id, _ := tc["id"].(string); id != "" {
247254
part.FunctionCall.ID = id
248255
}
256+
if thought, _ := tc["thought"].(bool); thought {
257+
part.Thought = true
258+
}
259+
if sigStr, _ := tc["thought_signature"].(string); sigStr != "" {
260+
if sig, err := base64.StdEncoding.DecodeString(sigStr); err == nil {
261+
part.ThoughtSignature = sig
262+
}
263+
} else if sigBytes, ok := tc["thought_signature"].([]byte); ok && len(sigBytes) > 0 {
264+
part.ThoughtSignature = sigBytes
265+
}
249266
parts = append(parts, part)
250267
}
251268

internal/llm/google_client_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package llm
2+
3+
import (
4+
"bytes"
5+
"encoding/base64"
6+
"testing"
7+
8+
genai "google.golang.org/genai"
9+
)
10+
11+
func TestGoogleClient_ToolCallThoughtSignatureRoundTrip(t *testing.T) {
12+
signature := []byte{0xde, 0xad, 0xbe, 0xef}
13+
14+
part := genai.NewPartFromFunctionCall("do_stuff", map[string]any{"value": "x"})
15+
part.Thought = true
16+
part.ThoughtSignature = signature
17+
18+
content := genai.NewContentFromParts([]*genai.Part{part}, genai.RoleModel)
19+
20+
toolCalls := convertToolCallsFromContent(content)
21+
if len(toolCalls) != 1 {
22+
t.Fatalf("expected 1 tool call, got %d", len(toolCalls))
23+
}
24+
25+
tc := toolCalls[0]
26+
sig, ok := tc["thought_signature"].(string)
27+
if !ok || sig == "" {
28+
t.Fatalf("expected thought_signature to be captured, got %#v", tc["thought_signature"])
29+
}
30+
if sig != base64.StdEncoding.EncodeToString(signature) {
31+
t.Fatalf("expected signature %q, got %q", base64.StdEncoding.EncodeToString(signature), sig)
32+
}
33+
34+
assistantMsg, err := convertAssistantMessage(&Message{Role: "assistant", ToolCalls: toolCalls})
35+
if err != nil {
36+
t.Fatalf("convertAssistantMessage returned error: %v", err)
37+
}
38+
39+
if len(assistantMsg.Parts) != 1 {
40+
t.Fatalf("expected 1 part after round-trip, got %d", len(assistantMsg.Parts))
41+
}
42+
43+
resultPart := assistantMsg.Parts[0]
44+
if !resultPart.Thought {
45+
t.Fatalf("expected Thought to be preserved")
46+
}
47+
if !bytes.Equal(resultPart.ThoughtSignature, signature) {
48+
t.Fatalf("expected signature %v, got %v", signature, resultPart.ThoughtSignature)
49+
}
50+
if resultPart.FunctionCall == nil || resultPart.FunctionCall.Name != "do_stuff" {
51+
t.Fatalf("expected function call to be preserved, got %+v", resultPart.FunctionCall)
52+
}
53+
}

0 commit comments

Comments
 (0)