diff --git a/core/transport/mcp/base.go b/core/transport/mcp/base.go new file mode 100644 index 00000000..6e2d3543 --- /dev/null +++ b/core/transport/mcp/base.go @@ -0,0 +1,200 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" +) + +// BaseMcpTransport holds the common state and logic for MCP HTTP transports. +type BaseMcpTransport struct { + baseURL string + HTTPClient *http.Client + ServerVersion string + initOnce sync.Once + initErr error + + // HandshakeHook is the abstract method _initialize_session. + // The specific version implementation will assign this function. + HandshakeHook func(context.Context) error +} + +// BaseURL returns the base URL for the transport. +func (b *BaseMcpTransport) BaseURL() string { + return b.baseURL +} + +// NewBaseTransport creates a new base transport. +func NewBaseTransport(baseURL string, client *http.Client) (*BaseMcpTransport, error) { + if client == nil { + client = &http.Client{} + } + var fullURL string + var err error + // Normalize by removing trailing slash first + cleanBaseURL := strings.TrimRight(baseURL, "/") + + // Only append "/mcp/" if it is not already present + if strings.HasSuffix(cleanBaseURL, "/mcp") { + // It's already correct, just use it + fullURL = cleanBaseURL + } else { + // It's missing, so join it safely + // url.JoinPath handles the slash insertion automatically + fullURL, err = url.JoinPath(cleanBaseURL, "mcp") + if err != nil { + return nil, err + } + } + + // Ensure trailing slash + fullURL += "/" + + return &BaseMcpTransport{ + baseURL: fullURL, + HTTPClient: client, + }, nil +} + +// EnsureInitialized guarantees the session is ready before making requests. +func (b *BaseMcpTransport) EnsureInitialized(ctx context.Context) error { + b.initOnce.Do(func() { + if b.HandshakeHook != nil { + b.initErr = b.HandshakeHook(ctx) + } else { + b.initErr = fmt.Errorf("transport initialization logic (HandshakeHook) not defined") + } + }) + return b.initErr +} + +// ConvertToolDefinition converts the raw tool dictionary into a transport.ToolSchema. +func (b *BaseMcpTransport) ConvertToolDefinition(toolData map[string]any) (transport.ToolSchema, error) { + var paramAuth map[string]any + var invokeAuth []string + + if meta, ok := toolData["_meta"].(map[string]any); ok { + if pa, ok := meta["toolbox/authParam"].(map[string]any); ok { + paramAuth = pa + } + if ia, ok := meta["toolbox/authInvoke"].([]any); ok { + invokeAuth = make([]string, 0, len(ia)) + for _, v := range ia { + if s, ok := v.(string); ok { + invokeAuth = append(invokeAuth, s) + } + } + } + } + + description, _ := toolData["description"].(string) + inputSchema, _ := toolData["inputSchema"].(map[string]any) + properties, _ := inputSchema["properties"].(map[string]any) + + // Create lookup set for required fields + requiredSet := make(map[string]bool) + if reqList, ok := inputSchema["required"].([]any); ok { + for _, r := range reqList { + if s, ok := r.(string); ok { + requiredSet[s] = true + } + } + } + + // Build Parameter List + parameters := make([]transport.ParameterSchema, 0, len(properties)) + + for propertyName, definition := range properties { + definitionMap, ok := definition.(map[string]any) + if !ok { + continue + } + + // Handle Auth Sources for this specific parameter + var authSources []string + if paramAuth != nil { + if sourcesRaw, ok := paramAuth[propertyName]; ok { + if sourcesList, ok := sourcesRaw.([]any); ok { + authSources = make([]string, 0, len(sourcesList)) + for _, s := range sourcesList { + if str, ok := s.(string); ok { + authSources = append(authSources, str) + } + } + } + } + } + + // Recursively parse the property + param := parseProperty(propertyName, definitionMap, requiredSet[propertyName]) + param.AuthSources = authSources + + parameters = append(parameters, param) + } + + return transport.ToolSchema{ + Description: description, + Parameters: parameters, + AuthRequired: invokeAuth, + }, nil +} + +// parseProperty is the recursive helper to create ParameterSchema +func parseProperty(name string, definitionMap map[string]any, isRequired bool) transport.ParameterSchema { + param := transport.ParameterSchema{ + Name: name, + Type: getString(definitionMap, "type"), + Description: getString(definitionMap, "description"), + Required: isRequired, + } + + switch param.Type { + case "object": + if ap, ok := definitionMap["additionalProperties"]; ok { + switch v := ap.(type) { + case bool: + param.AdditionalProperties = v + case map[string]any: + schema := parseProperty("", v, false) + param.AdditionalProperties = &schema + } + } + + case "array": + if itemsMap, ok := definitionMap["items"].(map[string]any); ok { + itemSchema := parseProperty("", itemsMap, false) + param.Items = &itemSchema + } + } + + return param +} + +// Helper to safely extract string values from map +func getString(m map[string]any, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} diff --git a/core/transport/mcp/base_test.go b/core/transport/mcp/base_test.go new file mode 100644 index 00000000..77552bf5 --- /dev/null +++ b/core/transport/mcp/base_test.go @@ -0,0 +1,217 @@ +//go:build unit + +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "context" + "errors" + "testing" +) + +func TestNewBaseTransport(t *testing.T) { + tests := []struct { + name string + baseURL string + expected string + }{ + { + name: "Clean URL", + baseURL: "http://example.com", + expected: "http://example.com/mcp/", + }, + { + name: "Trailing Slash", + baseURL: "http://example.com/", + expected: "http://example.com/mcp/", + }, + { + name: "Already Has MCP Suffix", + baseURL: "http://example.com/mcp", + expected: "http://example.com/mcp/", + }, + { + name: "Already Has MCP Suffix with Slash", + baseURL: "http://example.com/mcp/", + expected: "http://example.com/mcp/", + }, + { + name: "Deep Path", + baseURL: "http://example.com/api/v1", + expected: "http://example.com/api/v1/mcp/", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tr, _ := NewBaseTransport(tc.baseURL, nil) + if tr.BaseURL() != tc.expected { + t.Errorf("Expected URL %s, got %s", tc.expected, tr.BaseURL()) + } + if tr.HTTPClient == nil { + t.Error("Expected HTTPClient to be initialized, got nil") + } + }) + } +} + +func TestEnsureInitialized(t *testing.T) { + t.Run("Success", func(t *testing.T) { + tr, _ := NewBaseTransport("http://example.com", nil) + called := 0 + tr.HandshakeHook = func(ctx context.Context) error { + called++ + return nil + } + + // First call should trigger hook + if err := tr.EnsureInitialized(context.Background()); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Second call should NOT trigger hook + if err := tr.EnsureInitialized(context.Background()); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if called != 1 { + t.Errorf("Expected hook to be called once, got %d", called) + } + }) + + t.Run("Failure", func(t *testing.T) { + tr, _ := NewBaseTransport("http://example.com", nil) + expectedErr := errors.New("handshake failed") + tr.HandshakeHook = func(ctx context.Context) error { + return expectedErr + } + + if err := tr.EnsureInitialized(context.Background()); err != expectedErr { + t.Errorf("Expected error %v, got %v", expectedErr, err) + } + + // verify error is cached + if err := tr.EnsureInitialized(context.Background()); err != expectedErr { + t.Errorf("Expected cached error %v, got %v", expectedErr, err) + } + }) + + t.Run("MissingHook", func(t *testing.T) { + tr, _ := NewBaseTransport("http://example.com", nil) + // No hook defined + err := tr.EnsureInitialized(context.Background()) + if err == nil { + t.Error("Expected error when HandshakeHook is missing, got nil") + } + }) +} + +func TestConvertToolDefinition(t *testing.T) { + tr, _ := NewBaseTransport("http://example.com", nil) + + rawTool := map[string]any{ + "name": "complex_tool", + "description": "A test tool", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "simple_str": map[string]any{ + "type": "string", + "description": "Simple string param", + }, + "nested_obj": map[string]any{ + "type": "object", + "properties": map[string]any{ + "inner_int": map[string]any{"type": "integer"}, + }, + "additionalProperties": map[string]any{ + "type": "string", + }, + }, + "str_array": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "string", + }, + }, + }, + "required": []any{"simple_str"}, + }, + "_meta": map[string]any{ + "toolbox/authParam": map[string]any{ + "simple_str": []any{"header:x-api-key"}, + }, + "toolbox/authInvoke": []any{"oauth2"}, + }, + } + + schema, err := tr.ConvertToolDefinition(rawTool) + if err != nil { + t.Fatalf("ConvertToolDefinition failed: %v", err) + } + + // Check Top-Level Metadata + if schema.Description != "A test tool" { + t.Errorf("Expected description 'A test tool', got '%s'", schema.Description) + } + + // Check Auth Requirements + if len(schema.AuthRequired) != 1 || schema.AuthRequired[0] != "oauth2" { + t.Errorf("Expected AuthRequired=['oauth2'], got %v", schema.AuthRequired) + } + + // Check Parameters + if len(schema.Parameters) != 3 { + t.Fatalf("Expected 3 parameters, got %d", len(schema.Parameters)) + } + + // Helper map to find params by name easily + params := make(map[string]any) + for _, p := range schema.Parameters { + params[p.Name] = p + } + + foundSimple := false + for _, p := range schema.Parameters { + if p.Name == "simple_str" { + foundSimple = true + if !p.Required { + t.Error("Expected simple_str to be required") + } + if len(p.AuthSources) != 1 || p.AuthSources[0] != "header:x-api-key" { + t.Errorf("Expected AuthSources=['header:x-api-key'], got %v", p.AuthSources) + } + } else if p.Name == "nested_obj" { + if p.Type != "object" { + t.Errorf("Expected nested_obj type object, got %s", p.Type) + } + if p.AdditionalProperties == nil { + t.Error("Expected nested_obj to have AdditionalProperties schema") + } + } else if p.Name == "str_array" { + if p.Type != "array" { + t.Errorf("Expected str_array type array, got %s", p.Type) + } + if p.Items == nil || p.Items.Type != "string" { + t.Error("Expected str_array items to be type string") + } + } + } + + if !foundSimple { + t.Error("Parameter 'simple_str' not found in converted schema") + } +} diff --git a/core/transport/mcp/v20241105/mcp.go b/core/transport/mcp/v20241105/mcp.go new file mode 100644 index 00000000..971f94db --- /dev/null +++ b/core/transport/mcp/v20241105/mcp.go @@ -0,0 +1,285 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v20241105 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/google/uuid" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/mcp" +) + +const ( + ProtocolVersion = "2024-11-05" + ClientName = "toolbox-go-sdk" + ClientVersion = mcp.SDKVersion +) + +// Ensure that McpTransport implements the Transport interface. +var _ transport.Transport = &McpTransport{} + +// McpTransport implements the MCP v2024-11-05 protocol. +type McpTransport struct { + *mcp.BaseMcpTransport + protocolVersion string +} + +// New creates a new version-specific transport instance. +func New(baseURL string, client *http.Client) (*McpTransport, error) { + baseTransport, err := mcp.NewBaseTransport(baseURL, client) + if err != nil { + return nil, err + } + + t := &McpTransport{ + BaseMcpTransport: baseTransport, + protocolVersion: ProtocolVersion, + } + t.BaseMcpTransport.HandshakeHook = t.initializeSession + + return t, nil +} + +// ListTools fetches available tools +func (t *McpTransport) ListTools(ctx context.Context, toolsetName string, headers map[string]string) (*transport.ManifestSchema, error) { + if err := t.EnsureInitialized(ctx); err != nil { + return nil, err + } + + requestURL := t.BaseURL() + if toolsetName != "" { + var err error + requestURL, err = url.JoinPath(requestURL, toolsetName) + if err != nil { + return nil, fmt.Errorf("failed to construct toolset URL: %w", err) + } + } + + var result listToolsResult + if err := t.sendRequest(ctx, requestURL, "tools/list", map[string]any{}, headers, &result); err != nil { + return nil, fmt.Errorf("failed to list tools: %w", err) + } + + manifest := &transport.ManifestSchema{ + ServerVersion: t.ServerVersion, + Tools: make(map[string]transport.ToolSchema), + } + + for i, tool := range result.Tools { + if tool.Name == "" { + return nil, fmt.Errorf("received invalid tool definition at index %d: missing 'name' field", i) + } + + rawTool := map[string]any{ + "name": tool.Name, + "description": tool.Description, + "inputSchema": tool.InputSchema, + } + + if tool.Meta != nil { + rawTool["_meta"] = tool.Meta + } + + toolSchema, err := t.ConvertToolDefinition(rawTool) + if err != nil { + return nil, fmt.Errorf("failed to convert schema for tool %s: %w", tool.Name, err) + } + + manifest.Tools[tool.Name] = toolSchema + } + + return manifest, nil +} + +// GetTool fetches a single tool +func (t *McpTransport) GetTool(ctx context.Context, toolName string, headers map[string]string) (*transport.ManifestSchema, error) { + manifest, err := t.ListTools(ctx, "", headers) + if err != nil { + return nil, err + } + + tool, exists := manifest.Tools[toolName] + if !exists { + return nil, fmt.Errorf("tool '%s' not found", toolName) + } + + return &transport.ManifestSchema{ + ServerVersion: manifest.ServerVersion, + Tools: map[string]transport.ToolSchema{toolName: tool}, + }, nil +} + +// InvokeTool executes a tool +func (t *McpTransport) InvokeTool(ctx context.Context, toolName string, payload map[string]any, headers map[string]string) (any, error) { + if err := t.EnsureInitialized(ctx); err != nil { + return "", err + } + + params := callToolRequestParams{ + Name: toolName, + Arguments: payload, + } + + var result callToolResult + if err := t.sendRequest(ctx, t.BaseURL(), "tools/call", params, headers, &result); err != nil { + return "", fmt.Errorf("failed to invoke tool '%s': %w", toolName, err) + } + + if result.IsError { + return "", fmt.Errorf("tool execution resulted in error") + } + + // Concatenate all text content blocks + var sb strings.Builder + for _, content := range result.Content { + if content.Type == "text" { + sb.WriteString(content.Text) + } + } + + output := sb.String() + if output == "" { + // Return null if no text content found but not an error + return "null", nil + } + return output, nil +} + +// initializeSession performs the initial handshake with the server. +func (t *McpTransport) initializeSession(ctx context.Context) error { + params := initializeRequestParams{ + ProtocolVersion: t.protocolVersion, + Capabilities: clientCapabilities{}, + ClientInfo: implementation{ + Name: ClientName, + Version: ClientVersion, + }, + } + + var result initializeResult + if err := t.sendRequest(ctx, t.BaseURL(), "initialize", params, nil, &result); err != nil { + return err + } + + // Protocol Version Check + if result.ProtocolVersion != t.protocolVersion { + return fmt.Errorf("MCP version mismatch: client (%s) != server (%s)", t.protocolVersion, result.ProtocolVersion) + } + + // Capabilities Check + if result.Capabilities.Tools == nil { + return fmt.Errorf("server does not support the 'tools' capability") + } + + t.ServerVersion = result.ServerInfo.Version + + // Confirm Handshake + return t.sendNotification(ctx, "notifications/initialized", map[string]any{}) +} + +// sendRequest sends a standard JSON-RPC request to the server. +func (t *McpTransport) sendRequest(ctx context.Context, url string, method string, params any, headers map[string]string, dest any) error { + req := jsonRPCRequest{ + JSONRPC: "2.0", + Method: method, + ID: uuid.New().String(), + Params: params, + } + return t.doRPC(ctx, url, req, headers, dest) +} + +// sendNotification sends a standard JSON-RPC notification (no response expected). +func (t *McpTransport) sendNotification(ctx context.Context, method string, params any) error { + req := jsonRPCNotification{ + JSONRPC: "2.0", + Method: method, + Params: params, + } + return t.doRPC(ctx, t.BaseURL(), req, nil, nil) +} + +// doRPC performs the low-level HTTP POST and handles JSON-RPC wrapping/unwrapping. +func (t *McpTransport) doRPC(ctx context.Context, url string, reqBody any, headers map[string]string, dest any) error { + payload, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal failed: %w", err) + } + + // Create Request + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(payload)) + if err != nil { + return fmt.Errorf("create request failed: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + + // Apply resolved headers + for k, v := range headers { + httpReq.Header.Set(k, v) + } + + resp, err := t.HTTPClient.Do(httpReq) + if err != nil { + return fmt.Errorf("http request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + // Continue to body parsing + } else if (resp.StatusCode == http.StatusAccepted || resp.StatusCode == http.StatusNoContent) && dest == nil { + return nil // Valid notification success + } else { + // Any other code, OR a 202/204 when we expected a result, is a failure. + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + if dest == nil { + return nil + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read body failed: %w", err) + } + + // Decode RPC Envelope + var rpcResp jsonRPCResponse + if err := json.Unmarshal(bodyBytes, &rpcResp); err != nil { + return fmt.Errorf("response unmarshal failed: %w", err) + } + + // Check RPC Error + if rpcResp.Error != nil { + return fmt.Errorf("MCP request failed with code %d: %s", rpcResp.Error.Code, rpcResp.Error.Message) + } + + // Decode Result into specific struct + // We marshal the 'result' field back to bytes to unmarshal it into the specific 'dest' struct + resultBytes, _ := json.Marshal(rpcResp.Result) + if err := json.Unmarshal(resultBytes, dest); err != nil { + return fmt.Errorf("failed to parse result data: %w", err) + } + + return nil +} diff --git a/core/transport/mcp/v20241105/mcp_test.go b/core/transport/mcp/v20241105/mcp_test.go new file mode 100644 index 00000000..2a8c8269 --- /dev/null +++ b/core/transport/mcp/v20241105/mcp_test.go @@ -0,0 +1,473 @@ +//go:build unit + +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v20241105 + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMCPServer is a helper to mock MCP JSON-RPC responses +type mockMCPServer struct { + *httptest.Server + handlers map[string]func(params json.RawMessage) (any, error) + requests []jsonRPCRequest // Log of received requests for verification +} + +func newMockMCPServer(t *testing.T) *mockMCPServer { + m := &mockMCPServer{ + handlers: make(map[string]func(json.RawMessage) (any, error)), + } + + m.Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + + var req jsonRPCRequest + err = json.Unmarshal(body, &req) + require.NoError(t, err) + + m.requests = append(m.requests, req) + + // Handle Notifications (no ID) - return 204 or 200 OK immediately + if req.ID == nil { + if handler, ok := m.handlers[req.Method]; ok { + _, _ = handler(asRawMessage(req.Params)) + } + w.WriteHeader(http.StatusOK) + return + } + + // Handle Requests + handler, ok := m.handlers[req.Method] + if !ok { + http.Error(w, "method not found", http.StatusNotFound) + return + } + + result, err := handler(asRawMessage(req.Params)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + resBytes, err := json.Marshal(result) + require.NoError(t, err) + + resp := jsonRPCResponse{ + JSONRPC: "2.0", + ID: req.ID, + Result: resBytes, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + })) + + // Register default handshake handlers + m.handlers["initialize"] = func(params json.RawMessage) (any, error) { + return initializeResult{ + ProtocolVersion: "2024-11-05", + Capabilities: serverCapabilities{ + Tools: map[string]any{"listChanged": true}, + }, + ServerInfo: implementation{ + Name: "mock-server", + Version: "1.0.0", + }, + }, nil + } + m.handlers["notifications/initialized"] = func(params json.RawMessage) (any, error) { + return nil, nil + } + + return m +} + +func asRawMessage(v any) json.RawMessage { + b, _ := json.Marshal(v) + return b +} + +func TestListTools(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + // Mock tools/list response using strict mcpTool struct + server.handlers["tools/list"] = func(params json.RawMessage) (any, error) { + return listToolsResult{ + Tools: []mcpTool{ + { + Name: "get_weather", + Description: "Get weather for a location", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + "required": []string{"location"}, + }, + }, + }, + }, nil + } + + client, _ := New(server.URL, server.Client()) + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + manifest, err := client.ListTools(ctx, "", nil) + require.NoError(t, err) + require.NotNil(t, manifest) + + assert.Equal(t, "1.0.0", manifest.ServerVersion) + assert.Contains(t, manifest.Tools, "get_weather") + tool := manifest.Tools["get_weather"] + assert.Equal(t, "Get weather for a location", tool.Description) + assert.Len(t, tool.Parameters, 1) + assert.Equal(t, "location", tool.Parameters[0].Name) + }) + + t.Run("Verify Handshake Sequence", func(t *testing.T) { + require.GreaterOrEqual(t, len(server.requests), 3) + assert.Equal(t, "initialize", server.requests[0].Method) + assert.Equal(t, "notifications/initialized", server.requests[1].Method) + assert.Equal(t, "tools/list", server.requests[2].Method) + }) +} + +func TestListTools_ErrorOnEmptyName(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/list"] = func(params json.RawMessage) (any, error) { + return listToolsResult{ + Tools: []mcpTool{ + {Name: "valid", InputSchema: map[string]any{}}, + {Name: "", InputSchema: map[string]any{}}, // Invalid tool + }, + }, nil + } + + client, _ := New(server.URL, server.Client()) + _, err := client.ListTools(context.Background(), "", nil) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing 'name' field") +} + +func TestGetTool_Success(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/list"] = func(params json.RawMessage) (any, error) { + return listToolsResult{ + Tools: []mcpTool{ + {Name: "tool_a", InputSchema: map[string]any{"type": "object"}}, + {Name: "tool_b", InputSchema: map[string]any{"type": "object"}}, + }, + }, nil + } + + client, _ := New(server.URL, server.Client()) + manifest, err := client.GetTool(context.Background(), "tool_a", nil) + require.NoError(t, err) + assert.Contains(t, manifest.Tools, "tool_a") + assert.NotContains(t, manifest.Tools, "tool_b") +} + +func TestGetTool_NotFound(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/list"] = func(params json.RawMessage) (any, error) { + return listToolsResult{Tools: []mcpTool{}}, nil + } + + client, _ := New(server.URL, server.Client()) + _, err := client.GetTool(context.Background(), "missing_tool", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestInvokeTool(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/call"] = func(params json.RawMessage) (any, error) { + // Verify arguments + var callParams callToolRequestParams + _ = json.Unmarshal(params, &callParams) + if callParams.Name != "echo" { + return nil, nil + } + + msg, _ := callParams.Arguments["message"].(string) + return callToolResult{ + Content: []textContent{ + {Type: "text", Text: "Echo: " + msg}, + }, + IsError: false, + }, nil + } + + client, _ := New(server.URL, server.Client()) + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + args := map[string]any{"message": "Hello MCP"} + result, err := client.InvokeTool(ctx, "echo", args, nil) + require.NoError(t, err) + + resStr, ok := result.(string) + require.True(t, ok) + assert.Equal(t, "Echo: Hello MCP", resStr) + }) +} + +func TestProtocolMismatch(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + // Override initialize to return wrong version + server.handlers["initialize"] = func(params json.RawMessage) (any, error) { + return initializeResult{ + ProtocolVersion: "2099-01-01", // Future version + Capabilities: serverCapabilities{Tools: map[string]any{}}, + ServerInfo: implementation{Name: "mock", Version: "1.0"}, + }, nil + } + + client, _ := New(server.URL, server.Client()) + + _, err := client.ListTools(context.Background(), "", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "MCP version mismatch") +} + +func TestInitialize_MissingCapabilities(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["initialize"] = func(params json.RawMessage) (any, error) { + return initializeResult{ + ProtocolVersion: "2024-11-05", + Capabilities: serverCapabilities{Tools: nil}, + ServerInfo: implementation{Name: "srv", Version: "1"}, + }, nil + } + + client, _ := New(server.URL, server.Client()) + _, err := client.ListTools(context.Background(), "", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "does not support the 'tools' capability") +} + +func TestConvertToolSchema(t *testing.T) { + // Use the transport's ConvertToolDefinition which delegates to the base/helper logic + tr, _ := New("http://example.com", nil) + + rawTool := map[string]any{ + "name": "complex_tool", + "description": "Complex tool", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "tag": map[string]any{ + "type": "string", + "description": "A tag", + }, + "count": map[string]any{ + "type": "integer", + }, + }, + "required": []any{"tag"}, + }, + "_meta": map[string]any{ + "toolbox/authParam": map[string]any{ + "tag": []any{"serviceA"}, + }, + "toolbox/authInvoke": []any{"serviceB"}, + }, + } + + schema, err := tr.ConvertToolDefinition(rawTool) + require.NoError(t, err) + + assert.Equal(t, "Complex tool", schema.Description) + assert.Len(t, schema.Parameters, 2) + assert.Equal(t, []string{"serviceB"}, schema.AuthRequired) + + for _, p := range schema.Parameters { + if p.Name == "tag" { + assert.True(t, p.Required) + assert.Equal(t, []string{"serviceA"}, p.AuthSources) + } + } +} + +func TestListTools_WithToolset(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + // We verify that the toolset name was appended to the URL in the POST request + server.handlers["tools/list"] = func(params json.RawMessage) (any, error) { + return listToolsResult{Tools: []mcpTool{}}, nil + } + + client, _ := New(server.URL, server.Client()) + toolsetName := "my-toolset" + + _, err := client.ListTools(context.Background(), toolsetName, nil) + require.NoError(t, err) +} + +func TestRequest_NetworkError(t *testing.T) { + // Close server immediately to simulate connection refused + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + url := server.URL + server.Close() + + client, _ := New(url, server.Client()) + _, err := client.ListTools(context.Background(), "", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "http request failed") +} + +func TestRequest_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Error")) + })) + defer server.Close() + + client, _ := New(server.URL, server.Client()) + _, err := client.ListTools(context.Background(), "", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "API request failed with status 500") +} + +func TestRequest_BadJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ broken json `)) + })) + defer server.Close() + + client, _ := New(server.URL, server.Client()) + _, err := client.ListTools(context.Background(), "", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "response unmarshal failed") +} + +func TestRequest_NewRequestError(t *testing.T) { + // Bad URL triggers http.NewRequest error + _, err := New("http://bad\nurl.com", http.DefaultClient) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "invalid control character in URL") +} + +func TestRequest_MarshalError(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + client, _ := New(server.URL, server.Client()) + + // Force initialization first + _ = client.EnsureInitialized(context.Background()) + + // Pass a type that cannot be marshaled to JSON (e.g. channel) + badPayload := map[string]any{"bad": make(chan int)} + _, err := client.InvokeTool(context.Background(), "tool", badPayload, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "marshal failed") +} + +func TestInvokeTool_ErrorResult(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/call"] = func(params json.RawMessage) (any, error) { + return callToolResult{ + Content: []textContent{{Type: "text", Text: "Something went wrong"}}, + IsError: true, + }, nil + } + + client, _ := New(server.URL, server.Client()) + _, err := client.InvokeTool(context.Background(), "tool", nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "tool execution resulted in error") +} + +func TestInvokeTool_RPCError(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/call"] = func(params json.RawMessage) (any, error) { + return nil, errors.New("internal server error") + } + + client, _ := New(server.URL, server.Client()) + _, err := client.InvokeTool(context.Background(), "tool", nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "internal server error") +} + +func TestInvokeTool_ComplexContent(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/call"] = func(params json.RawMessage) (any, error) { + return callToolResult{ + Content: []textContent{ + {Type: "text", Text: "Part 1 "}, + {Type: "image", Text: "base64data"}, // Should be ignored + {Type: "text", Text: "Part 2"}, + }, + }, nil + } + + client, _ := New(server.URL, server.Client()) + res, err := client.InvokeTool(context.Background(), "t", nil, nil) + require.NoError(t, err) + assert.Equal(t, "Part 1 Part 2", res) +} + +func TestInvokeTool_EmptyResult(t *testing.T) { + server := newMockMCPServer(t) + defer server.Close() + + server.handlers["tools/call"] = func(params json.RawMessage) (any, error) { + return callToolResult{ + Content: []textContent{}, + }, nil + } + + client, _ := New(server.URL, server.Client()) + res, err := client.InvokeTool(context.Background(), "t", nil, nil) + require.NoError(t, err) + assert.Equal(t, "null", res) +} diff --git a/core/transport/mcp/v20241105/types.go b/core/transport/mcp/v20241105/types.go new file mode 100644 index 00000000..929517b7 --- /dev/null +++ b/core/transport/mcp/v20241105/types.go @@ -0,0 +1,108 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v20241105 + +import "encoding/json" + +// jsonRPCRequest represents a standard JSON-RPC 2.0 request. +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID any `json:"id,omitempty"` // string or int + Params any `json:"params,omitempty"` // map or struct +} + +// jsonRPCNotification represents a standard JSON-RPC 2.0 notification (no ID). +type jsonRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +// jsonRPCResponse represents a standard JSON-RPC 2.0 response. +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID any `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +// jsonRPCError represents the error object inside a JSON-RPC response. +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +// implementation describes the name and version of the client. +type implementation struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// clientCapabilities describes the features supported by the client. +type clientCapabilities map[string]any + +// serverCapabilities describes the features supported by the server. +type serverCapabilities struct { + Prompts map[string]any `json:"prompts,omitempty"` + Tools map[string]any `json:"tools,omitempty"` +} + +// initializeRequestParams holds the parameters for the 'initialize' handshake. +type initializeRequestParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities clientCapabilities `json:"capabilities"` + ClientInfo implementation `json:"clientInfo"` +} + +// initializeResult holds the response from the 'initialize' handshake. +type initializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities serverCapabilities `json:"capabilities"` + ServerInfo implementation `json:"serverInfo"` + Instructions string `json:"instructions,omitempty"` +} + +// mcpTool represents a single tool definition from the server. +type mcpTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]any `json:"inputSchema"` + Meta map[string]any `json:"_meta,omitempty"` +} + +// listToolsResult holds the response from the 'tools/list' method. +type listToolsResult struct { + Tools []mcpTool `json:"tools"` +} + +// callToolRequestParams holds the parameters for the 'tools/call' method. +type callToolRequestParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +// textContent represents a single text block in a tool's output. +type textContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// callToolResult holds the response from the 'tools/call' method. +type callToolResult struct { + Content []textContent `json:"content"` + IsError bool `json:"isError"` +} diff --git a/core/transport/mcp/version.go b/core/transport/mcp/version.go new file mode 100644 index 00000000..0a6fc427 --- /dev/null +++ b/core/transport/mcp/version.go @@ -0,0 +1,19 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +// SDKVersion is the current version of the library. +// This is updated automatically by release-please. +const SDKVersion = "0.4.0" // x-release-please-version diff --git a/go.mod b/go.mod index bf1d226e..226b67ce 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go/secretmanager v1.16.0 cloud.google.com/go/storage v1.58.0 github.com/firebase/genkit/go v1.2.0 + github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.11.1 golang.org/x/oauth2 v0.34.0 google.golang.org/adk v0.3.0 @@ -39,7 +40,6 @@ require ( github.com/google/dotprompt/go v0.0.0-20251014011017-8d056e027254 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/gax-go/v2 v2.15.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect