Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 54 additions & 70 deletions agent/a2aagent/a2a_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ package a2aagent

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

Expand All @@ -24,20 +21,16 @@ import (
"trpc.group/trpc-go/trpc-a2a-go/server"
"trpc.group/trpc-go/trpc-agent-go/agent"
"trpc.group/trpc-go/trpc-agent-go/event"
ia2a "trpc.group/trpc-go/trpc-agent-go/internal/a2a"
"trpc.group/trpc-go/trpc-agent-go/log"
"trpc.group/trpc-go/trpc-agent-go/model"
"trpc.group/trpc-go/trpc-agent-go/tool"
)

var defaultStreamingChannelSize = 1024
var defaultNonStreamingChannelSize = 10

const (
// AgentCardWellKnownPath is the standard path for agent card discovery
AgentCardWellKnownPath = "/.well-known/agent.json"
// defaultFetchTimeout is the default timeout for fetching agent card
defaultFetchTimeout = 30 * time.Second
// defaultUserIDHeader is the default HTTP header name to send UserID to A2A server
defaultUserIDHeader = "X-User-ID"
defaultStreamingChannelSize = 1024
defaultNonStreamingChannelSize = 10
defaultUserIDHeader = "X-User-ID"
)

// A2AAgent is an agent that communicates with a remote A2A agent via A2A protocol.
Expand All @@ -54,9 +47,9 @@ type A2AAgent struct {
streamingRespHandler StreamingRespHandler // Handler for streaming responses
transferStateKey []string // Keys in session state to transfer to the A2A agent message by metadata
userIDHeader string // HTTP header name to send UserID to A2A server
enableStreaming *bool // Explicitly set streaming mode; nil means use agent card capability

httpClient *http.Client
a2aClient *client.A2AClient
a2aClient *client.A2AClient
}

// New creates a new A2AAgent.
Expand All @@ -71,74 +64,60 @@ func New(opts ...Option) (*A2AAgent, error) {
opt(agent)
}

if agent.agentURL != "" && agent.agentCard == nil {
agentCard, err := agent.resolveAgentCardFromURL()
if err != nil {
return nil, fmt.Errorf("failed to resolve agent card: %w", err)
}
agent.agentCard = agentCard
var agentURL string
if agent.agentCard != nil {
agentURL = agent.agentCard.URL
} else if agent.agentURL != "" {
agentURL = agent.agentURL
} else {
log.Info("agent card or agent card url not set")
}

if agent.agentCard == nil {
return nil, fmt.Errorf("agent card not set")
}
// Normalize the URL to ensure it has a proper scheme
agentURL = ia2a.NormalizeURL(agentURL)

a2aClient, err := client.NewA2AClient(agent.agentCard.URL, agent.extraA2AOptions...)
// Create A2A client first
a2aClient, err := client.NewA2AClient(agentURL, agent.extraA2AOptions...)
if err != nil {
return nil, fmt.Errorf("failed to create A2A client for %s: %w", agent.agentCard.URL, err)
return nil, fmt.Errorf("failed to create A2A client for %s: %w", agentURL, err)
}
agent.a2aClient = a2aClient
return agent, nil
}

// resolveAgentCardFromURL fetches agent card from the well-known path
func (r *A2AAgent) resolveAgentCardFromURL() (*server.AgentCard, error) {
agentURL := r.agentURL

// Construct the agent card discovery URL
agentCardURL := strings.TrimSuffix(agentURL, "/") + AgentCardWellKnownPath

// Create HTTP client if not set
httpClient := r.httpClient
if httpClient == nil {
httpClient = &http.Client{Timeout: defaultFetchTimeout}
}

// Fetch agent card from well-known path
resp, err := httpClient.Get(agentCardURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch agent card from %s: %w", agentCardURL, err)
}
defer resp.Body.Close()
// If agent card is not set, fetch it using A2A client's GetAgentCard method
if agent.agentCard == nil {
agentCard, err := a2aClient.GetAgentCard(context.Background(), "")
if err != nil {
return nil, fmt.Errorf("failed to fetch agent card from %s: %w", agentURL, err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch agent card from %s: HTTP %d", agentCardURL, resp.StatusCode)
}
// Set name and description from agent card if not already set
if agent.name == "" {
agent.name = agentCard.Name
}
if agent.description == "" {
agent.description = agentCard.Description
}

// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read agent card response: %w", err)
}
if agentCard.URL == "" {
agentCard.URL = agentURL
} else {
// Normalize the agent card URL to ensure it has a proper scheme
agentCard.URL = ia2a.NormalizeURL(agentCard.URL)
}

// Parse agent card JSON
var agentCard server.AgentCard
if err := json.Unmarshal(body, &agentCard); err != nil {
return nil, fmt.Errorf("failed to parse agent card JSON: %w", err)
}
// Rebuild a2a client if URL changed
if agentCard.URL != agentURL {
a2aClient, err := client.NewA2AClient(agentCard.URL, agent.extraA2AOptions...)
if err != nil {
return nil, fmt.Errorf("failed to create A2A client for %s: %w", agentCard.URL, err)
}
agent.a2aClient = a2aClient
}

if r.name == "" {
r.name = agentCard.Name
agent.agentCard = agentCard
}

if r.description == "" {
r.description = agentCard.Description
}
// If URL is not set in the agent card, use the provided agent URL.
if agentCard.URL == "" {
agentCard.URL = agentURL
}
return &agentCard, nil
return agent, nil
}

// sendErrorEvent sends an error event to the event channel
Expand Down Expand Up @@ -189,7 +168,12 @@ func (r *A2AAgent) Run(ctx context.Context, invocation *agent.Invocation) (<-cha

// shouldUseStreaming determines whether to use streaming protocol
func (r *A2AAgent) shouldUseStreaming() bool {
// Check if agent card supports streaming
// If explicitly set via option, use that value
if r.enableStreaming != nil {
return *r.enableStreaming
}

// Otherwise check if agent card supports streaming
if r.agentCard != nil && r.agentCard.Capabilities.Streaming != nil {
return *r.agentCard.Capabilities.Streaming
}
Expand Down
9 changes: 9 additions & 0 deletions agent/a2aagent/a2a_agent_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,12 @@ func WithUserIDHeader(header string) Option {
}
}
}

// WithEnableStreaming explicitly controls whether to use streaming protocol.
// If not set (nil), the agent will use the streaming capability from the agent card.
// This option overrides the agent card's capability setting.
func WithEnableStreaming(enable bool) Option {
return func(a *A2AAgent) {
a.enableStreaming = &enable
}
}
144 changes: 4 additions & 140 deletions agent/a2aagent/a2a_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestNew(t *testing.T) {
opts: []Option{},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == AgentCardWellKnownPath {
if r.URL.Path == "/.well-known/agent-card.json" {
agentCard := server.AgentCard{
Name: "test-agent",
Description: "A test agent",
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestNew(t *testing.T) {
},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == AgentCardWellKnownPath {
if r.URL.Path == "/.well-known/agent-card.json" {
agentCard := server.AgentCard{
Name: "test-agent",
Description: "Test agent",
Expand Down Expand Up @@ -531,142 +531,6 @@ func TestA2AAgent_buildA2AMessage(t *testing.T) {
}
}

func TestA2AAgent_resolveAgentCardFromURL(t *testing.T) {
type testCase struct {
name string
agent *A2AAgent
setupFunc func(tc *testCase) *httptest.Server
validateFunc func(t *testing.T, agentCard *server.AgentCard, err error)
}

tests := []testCase{
{
name: "success with valid agent card",
agent: &A2AAgent{},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == AgentCardWellKnownPath {
agentCard := server.AgentCard{
Name: "resolved-agent",
Description: "Resolved from URL",
URL: "http://resolved.com",
}
json.NewEncoder(w).Encode(agentCard)
return
}
w.WriteHeader(http.StatusNotFound)
}))
tc.agent.agentURL = server.URL
return server
},
validateFunc: func(t *testing.T, agentCard *server.AgentCard, err error) {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if agentCard == nil {
t.Fatal("expected agent card, got nil")
}
if agentCard.Name != "resolved-agent" {
t.Errorf("expected name 'resolved-agent', got %s", agentCard.Name)
}
if agentCard.Description != "Resolved from URL" {
t.Errorf("expected description 'Resolved from URL', got %s", agentCard.Description)
}
},
},
{
name: "fills agent name and description when empty",
agent: &A2AAgent{
name: "",
description: "",
},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
agentCard := server.AgentCard{
Name: "auto-filled",
Description: "Auto-filled description",
}
json.NewEncoder(w).Encode(agentCard)
}))
tc.agent.agentURL = server.URL
return server
},
validateFunc: func(t *testing.T, agentCard *server.AgentCard, err error) {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if agentCard.Name != "auto-filled" {
t.Errorf("expected name 'auto-filled', got %s", agentCard.Name)
}
},
},
{
name: "error when HTTP request fails",
agent: &A2AAgent{agentURL: "http://nonexistent.local"},
setupFunc: func(tc *testCase) *httptest.Server {
return nil
},
validateFunc: func(t *testing.T, agentCard *server.AgentCard, err error) {
if err == nil {
t.Error("expected error when HTTP request fails")
}
if agentCard != nil {
t.Error("expected agent card to be nil on error")
}
},
},
{
name: "error when HTTP status not OK",
agent: &A2AAgent{},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
tc.agent.agentURL = server.URL
return server
},
validateFunc: func(t *testing.T, agentCard *server.AgentCard, err error) {
if err == nil {
t.Error("expected error when HTTP status not OK")
}
if agentCard != nil {
t.Error("expected agent card to be nil on error")
}
},
},
{
name: "error when invalid JSON",
agent: &A2AAgent{},
setupFunc: func(tc *testCase) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("invalid json"))
}))
tc.agent.agentURL = server.URL
return server
},
validateFunc: func(t *testing.T, agentCard *server.AgentCard, err error) {
if err == nil {
t.Error("expected error when JSON is invalid")
}
if agentCard != nil {
t.Error("expected agent card to be nil on error")
}
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := tc.setupFunc(&tc)
if server != nil {
defer server.Close()
}
agentCard, err := tc.agent.resolveAgentCardFromURL()
tc.validateFunc(t, agentCard, err)
})
}
}

func TestA2AAgent_Run_ErrorCases(t *testing.T) {
type testCase struct {
name string
Expand Down Expand Up @@ -880,7 +744,7 @@ func TestA2ARequestOptions(t *testing.T) {
t.Run("validates option types and returns error for invalid types", func(t *testing.T) {
// Create test server
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == AgentCardWellKnownPath {
if r.URL.Path == "/.well-known/agent-card.json" {
agentCard := server.AgentCard{
Name: "test-agent",
Description: "A test agent",
Expand Down Expand Up @@ -1023,7 +887,7 @@ func TestUserIDHeaderInRequest(t *testing.T) {

// Create mock A2A server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == AgentCardWellKnownPath {
if r.URL.Path == "/.well-known/agent-card.json" {
// Return agent card with the mock server's URL
agentCard := server.AgentCard{
Name: "test-agent",
Expand Down
2 changes: 1 addition & 1 deletion codeexecutor/container/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ require (
golang.org/x/sys v0.35.0 // indirect
golang.org/x/time v0.12.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1 // indirect
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a // indirect
)
4 changes: 2 additions & 2 deletions codeexecutor/container/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1 h1:P+OyPh+QCNuO8u+M2UPTYZCGKnH9YAcijC8ULokAdTw=
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk=
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a h1:dOon6HF2sPRFnhCLEiAeKPc21JHL2eX7UBWjIR8PLaY=
trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk=
Loading
Loading