Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion session/vertexai/vertexai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package vertexai

import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
Expand Down Expand Up @@ -271,6 +272,16 @@ func (c *vertexAiClient) appendEvent(ctx context.Context, appName, sessionID str
return fmt.Errorf("error creating metadata: %w", err)
}

// The legacy column-backed fields are still written below as a fallback
// for readers that ignore raw_event.
var rawEvent *structpb.Struct
if eventNeedsRawEvent(event) {
rawEvent, err = eventToRawEvent(event)
if err != nil {
return fmt.Errorf("error creating raw event: %w", err)
}
}

_, err = c.rpcClient.AppendEvent(ctx, &aiplatformpb.AppendEventRequest{
Name: sessionNameByID(sessionID, c, reasoningEngine),
Event: &aiplatformpb.SessionEvent{
Expand All @@ -285,6 +296,7 @@ func (c *vertexAiClient) appendEvent(ctx context.Context, appName, sessionID str
EventMetadata: metadata,
ErrorCode: event.ErrorCode,
ErrorMessage: event.ErrorMessage,
RawEvent: rawEvent,
},
})
if err != nil {
Expand All @@ -294,6 +306,59 @@ func (c *vertexAiClient) appendEvent(ctx context.Context, appName, sessionID str
return nil
}

// eventNeedsRawEvent reports whether the event carries state that has no
// dedicated SessionEvent column and would be lost without raw_event.
// Gating raw_event on this keeps plain events on their legacy wire format,
// so the recorded replay fixtures stay valid.
func eventNeedsRawEvent(event *session.Event) bool {
return event.Output != nil ||
event.NodeInfo != nil ||
event.IsolationScope != "" ||
event.RequestedInput != nil ||
len(event.Routes) > 0
}

// eventToRawEvent serializes a session.Event into a structpb.Struct for
// the SessionEvent.raw_event field. Uses Go's JSON encoding; not yet
// byte-compatible with adk-python's camelCase dump (cross-runtime parity
// is tracked separately).
//
// Integers in the any-typed Output and StateDelta come back as float64
// (structpb numbers and json.Unmarshal into any are both float64). This
// matches the SQL backend, so the lossiness is framework-wide; store
// values needing exact integer fidelity as strings.
func eventToRawEvent(event *session.Event) (*structpb.Struct, error) {
b, err := json.Marshal(event)
if err != nil {
return nil, fmt.Errorf("marshaling event: %w", err)
}
var m map[string]any
if err := json.Unmarshal(b, &m); err != nil {
return nil, fmt.Errorf("unmarshaling event to map: %w", err)
}
s, err := structpb.NewStruct(m)
if err != nil {
return nil, fmt.Errorf("converting event to structpb: %w", err)
}
return s, nil
}

// eventFromRawEvent reconstructs a session.Event from a raw_event struct
// written by eventToRawEvent. Identity fields (ID, Timestamp,
// InvocationID, Author) are authoritative on the SessionEvent envelope,
// so callers overwrite them after this returns.
func eventFromRawEvent(raw *structpb.Struct) (*session.Event, error) {
b, err := json.Marshal(raw.AsMap())
if err != nil {
return nil, fmt.Errorf("marshaling raw event map: %w", err)
}
event := &session.Event{}
if err := json.Unmarshal(b, event); err != nil {
return nil, fmt.Errorf("unmarshaling raw event: %w", err)
}
return event, nil
}

func (c *vertexAiClient) listSessionEvents(ctx context.Context, appName, sessionID string, after time.Time, numRecentEvents int) ([]*session.Event, error) {
reasoningEngine, err := c.getReasoningEngineID(appName)
if err != nil {
Expand All @@ -316,12 +381,28 @@ func (c *vertexAiClient) listSessionEvents(ctx context.Context, appName, session
return nil, fmt.Errorf("error fetching session events: %w", err)
}

content := aiplatformToGenaiContent(rpcResp)
id, err := sessionIdBySessionName(rpcResp.Name)
if err != nil {
return nil, fmt.Errorf("error fetching session events: %w", err)
}

// Prefer raw_event; fall back to legacy field reconstruction for
// events written before raw_event support.
if rpcResp.RawEvent != nil {
event, err := eventFromRawEvent(rpcResp.RawEvent)
if err != nil {
return nil, fmt.Errorf("error fetching session events: %w", err)
}
// Identity fields are authoritative on the envelope.
event.ID = id
event.Timestamp = rpcResp.Timestamp.AsTime()
event.InvocationID = rpcResp.InvocationId
event.Author = rpcResp.Author
events = append(events, event)
continue
}

content := aiplatformToGenaiContent(rpcResp)
event := &session.Event{
ID: id,
Timestamp: rpcResp.Timestamp.AsTime(),
Expand Down
144 changes: 144 additions & 0 deletions session/vertexai/vertexai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ package vertexai
import (
"testing"

"github.com/google/go-cmp/cmp"

"google.golang.org/adk/model"
"google.golang.org/adk/session"
"google.golang.org/adk/util/vertexai"

"google.golang.org/genai"
Expand Down Expand Up @@ -110,6 +114,146 @@ func TestGetReasoningEngineID(t *testing.T) {
}
}

// TestRawEventRoundTrip pins that fields lacking a dedicated SessionEvent
// column survive a raw_event write/read round-trip — NodeInfo in
// particular, which the legacy field-based path dropped — and that the
// fields already persisted via dedicated columns are not degraded.
func TestRawEventRoundTrip(t *testing.T) {
tests := []struct {
name string
event *session.Event
}{
{
name: "workflow fields",
event: &session.Event{
InvocationID: "inv-1",
Author: "agent-x",
Branch: "a.b",
IsolationScope: "scope-1",
Routes: []string{"approve"},
Output: "the-output",
NodeInfo: &session.NodeInfo{
Path: "wf/child@1",
MessageAsOutput: true,
OutputFor: []string{"wf/child@1", "wf"},
},
},
},
{
name: "content with text and function call",
event: &session.Event{
Author: "agent-x",
LLMResponse: model.LLMResponse{
Content: &genai.Content{
Role: string(genai.RoleModel),
Parts: []*genai.Part{
{Text: "hello"},
{FunctionCall: &genai.FunctionCall{
ID: "call-1",
Name: "get_weather",
Args: map[string]any{"city": "Stockholm"},
}},
},
},
},
},
},
{
name: "structured output",
event: &session.Event{
Author: "agent-x",
Output: map[string]any{"score": float64(42), "label": "ok"},
},
},
{
// Typed map[string]int64 keeps its int type (unlike any-typed
// Output/StateDelta; see TestRawEventNumericContract).
name: "typed int64 artifact delta preserved",
event: &session.Event{
Author: "agent-x",
Output: "x",
Actions: session.EventActions{
ArtifactDelta: map[string]int64{"file.png": 7},
},
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
raw, err := eventToRawEvent(tc.event)
if err != nil {
t.Fatalf("eventToRawEvent() error = %v", err)
}
got, err := eventFromRawEvent(raw)
if err != nil {
t.Fatalf("eventFromRawEvent() error = %v", err)
}
if diff := cmp.Diff(tc.event, got); diff != "" {
t.Errorf("round-trip mismatch (-want +got):\n%s", diff)
}
})
}
}

// TestRawEventNumericContract pins the documented contract (see
// eventToRawEvent): integers in the any-typed Output and StateDelta come
// back as float64.
func TestRawEventNumericContract(t *testing.T) {
event := &session.Event{
Author: "agent-x",
Output: int64(9007199254740993), // 2^53 + 1
Actions: session.EventActions{
StateDelta: map[string]any{"count": 3},
},
}
raw, err := eventToRawEvent(event)
if err != nil {
t.Fatalf("eventToRawEvent() error = %v", err)
}
got, err := eventFromRawEvent(raw)
if err != nil {
t.Fatalf("eventFromRawEvent() error = %v", err)
}
if _, ok := got.Output.(float64); !ok {
t.Errorf("Output type = %T, want float64", got.Output)
}
if v, ok := got.Actions.StateDelta["count"].(float64); !ok || v != 3 {
t.Errorf("StateDelta[count] = %#v, want float64(3)", got.Actions.StateDelta["count"])
}
}

// TestEventNeedsRawEvent guards the invariant that plain events keep
// their legacy wire format (no raw_event) while events carrying state
// without a dedicated SessionEvent column opt into raw_event. Changing
// this for plain events would invalidate the recorded replay fixtures.
func TestEventNeedsRawEvent(t *testing.T) {
tests := []struct {
name string
event *session.Event
want bool
}{
{name: "plain event", event: &session.Event{Author: "user"}, want: false},
{name: "with content only", event: &session.Event{
LLMResponse: model.LLMResponse{Content: genai.NewContentFromText("hi", genai.RoleUser)},
}, want: false},
{name: "with output", event: &session.Event{Output: "x"}, want: true},
{name: "with node info", event: &session.Event{NodeInfo: &session.NodeInfo{Path: "wf"}}, want: true},
{name: "with isolation scope", event: &session.Event{IsolationScope: "s"}, want: true},
{name: "with routes", event: &session.Event{Routes: []string{"approve"}}, want: true},
{name: "with requested input", event: &session.Event{
RequestedInput: &session.RequestInput{InterruptID: "i"},
}, want: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := eventNeedsRawEvent(tc.event); got != tc.want {
t.Errorf("eventNeedsRawEvent() = %v, want %v", got, tc.want)
}
})
}
}

func TestAiplatformToGenaiContentPreservesFunctionIDs(t *testing.T) {
args, err := structpb.NewStruct(map[string]any{"city": "Stockholm"})
if err != nil {
Expand Down
Loading