diff --git a/.github/release-please.yml b/.github/release-please.yml index a197a37..5f09123 100644 --- a/.github/release-please.yml +++ b/.github/release-please.yml @@ -15,3 +15,6 @@ handleGHRelease: true packageName: mcp-toolbox-sdk-go releaseType: go +extraFiles: [ + "core/transport/mcp/version.go", +] \ No newline at end of file diff --git a/core/client.go b/core/client.go index 985efdd..907751e 100644 --- a/core/client.go +++ b/core/client.go @@ -17,10 +17,14 @@ package core import ( "context" "fmt" - "log" "net/http" "strings" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" + mcp20241105 "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/mcp/v20241105" + mcp20250326 "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/mcp/v20250326" + mcp20250618 "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/mcp/v20250618" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/toolboxtransport" "golang.org/x/oauth2" ) @@ -28,6 +32,9 @@ import ( type ToolboxClient struct { baseURL string httpClient *http.Client + protocol Protocol + protocolSet bool + transport transport.Transport clientHeaderSources map[string]oauth2.TokenSource defaultToolOptions []ToolOption defaultOptionsSet bool @@ -39,7 +46,7 @@ type ToolboxClient struct { // Inputs: // - url: The base URL of the Toolbox server. // - opts: A variadic list of ClientOption functions to configure the client, -// such as setting a custom http.Client or default headers. +// such as setting a custom http.Client, default headers, or the underlying protocol. // // Returns: // @@ -47,9 +54,11 @@ type ToolboxClient struct { // and an error if configuration fails. func NewToolboxClient(url string, opts ...ClientOption) (*ToolboxClient, error) { // Initialize the client with default values. + // We default to ProtocolMCP (the newest version alias) if not overridden. tc := &ToolboxClient{ baseURL: url, httpClient: &http.Client{}, + protocol: MCP, // Default clientHeaderSources: make(map[string]oauth2.TokenSource), defaultToolOptions: []ToolOption{}, } @@ -64,7 +73,22 @@ func NewToolboxClient(url string, opts ...ClientOption) (*ToolboxClient, error) } } - return tc, nil + // Initialize the Transport based on the selected Protocol. + var transportErr error = nil + switch tc.protocol { + case MCPv20250618: + tc.transport, transportErr = mcp20250618.New(tc.baseURL, tc.httpClient) + case MCPv20250326: + tc.transport, transportErr = mcp20250326.New(tc.baseURL, tc.httpClient) + case MCPv20241105: + tc.transport, transportErr = mcp20241105.New(tc.baseURL, tc.httpClient) + case Toolbox: + tc.transport = toolboxtransport.New(tc.baseURL, tc.httpClient) + default: + return nil, fmt.Errorf("unsupported protocol version: %s", tc.protocol) + } + + return tc, transportErr } // newToolboxTool is an internal factory method that constructs a @@ -87,6 +111,7 @@ func (tc *ToolboxClient) newToolboxTool( schema ToolSchema, finalConfig *ToolConfig, isStrict bool, + tr transport.Transport, ) (*ToolboxTool, []string, []string, error) { // These will be the parameters that the end-user must provide at invocation time. @@ -151,17 +176,12 @@ func (tc *ToolboxClient) newToolboxTool( finalConfig.AuthTokenSources, ) - if (len(remainingAuthnParams) > 0 || len(remainingAuthzTokens) > 0 || len(tc.clientHeaderSources) > 0) && !strings.HasPrefix(tc.baseURL, "https://") { - log.Println("WARNING: Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.") - } - // Construct the final tool object. tt := &ToolboxTool{ name: name, description: schema.Description, parameters: finalParameters, - invocationURL: fmt.Sprintf("%s/api/tool/%s%s", tc.baseURL, name, toolInvokeSuffix), - httpClient: tc.httpClient, + transport: tr, authTokenSources: finalConfig.AuthTokenSources, boundParams: localBoundParams, requiredAuthnParams: remainingAuthnParams, @@ -204,9 +224,14 @@ func (tc *ToolboxClient) LoadTool(name string, ctx context.Context, opts ...Tool } } + resolvedHeaders, err := resolveClientHeaders(tc.clientHeaderSources) + if err != nil { + return nil, err + } + // Fetch the manifest for the specified tool. - url := fmt.Sprintf("%s/api/tool/%s", tc.baseURL, name) - manifest, err := loadManifest(ctx, url, tc.httpClient, tc.clientHeaderSources) + manifest, err := tc.transport.GetTool(ctx, name, resolvedHeaders) + if err != nil { return nil, fmt.Errorf("failed to load tool manifest for '%s': %w", name, err) } @@ -219,7 +244,7 @@ func (tc *ToolboxClient) LoadTool(name string, ctx context.Context, opts ...Tool } // Construct the tool from its schema and the final configuration. - tool, usedAuthKeys, usedBoundKeys, err := tc.newToolboxTool(name, schema, finalConfig, true) + tool, usedAuthKeys, usedBoundKeys, err := tc.newToolboxTool(name, schema, finalConfig, true, tc.transport) if err != nil { return nil, fmt.Errorf("failed to create toolbox tool from schema for '%s': %w", name, err) } @@ -291,15 +316,14 @@ func (tc *ToolboxClient) LoadToolset(name string, ctx context.Context, opts ...T } } - // Determine the manifest URL based on whether a specific toolset name was provided. - var url string - if name == "" { - url = fmt.Sprintf("%s/api/toolset/", tc.baseURL) - } else { - url = fmt.Sprintf("%s/api/toolset/%s", tc.baseURL, name) - } // Fetch the manifest for the toolset. - manifest, err := loadManifest(ctx, url, tc.httpClient, tc.clientHeaderSources) + resolvedHeaders, err := resolveClientHeaders(tc.clientHeaderSources) + if err != nil { + return nil, err + } + + // Fetch Manifest via Transport + manifest, err := tc.transport.ListTools(ctx, name, resolvedHeaders) if err != nil { return nil, fmt.Errorf("failed to load toolset manifest for '%s': %w", name, err) } @@ -322,7 +346,7 @@ func (tc *ToolboxClient) LoadToolset(name string, ctx context.Context, opts ...T for toolName, schema := range manifest.Tools { // Construct each tool from its schema and the shared configuration. - tool, usedAuthKeys, usedBoundKeys, err := tc.newToolboxTool(toolName, schema, finalConfig, finalConfig.Strict) + tool, usedAuthKeys, usedBoundKeys, err := tc.newToolboxTool(toolName, schema, finalConfig, finalConfig.Strict, tc.transport) if err != nil { return nil, fmt.Errorf("failed to create tool '%s': %w", toolName, err) } diff --git a/core/client_test.go b/core/client_test.go index f291dc1..e002dca 100644 --- a/core/client_test.go +++ b/core/client_test.go @@ -60,6 +60,11 @@ func TestNewToolboxClient(t *testing.T) { if client.httpClient.Timeout != 0 { t.Errorf("expected no timeout, got %v", client.httpClient.Timeout) } + + if client.protocol != ProtocolMCP { + t.Errorf("expected default protocol to be ProtocolMCP, got %v", client.protocol) + } + }) t.Run("Returns error when a nil option is provided", func(t *testing.T) { @@ -259,7 +264,7 @@ func TestLoadToolAndLoadToolset(t *testing.T) { defer server.Close() t.Run("LoadTool - Success", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) tool, err := client.LoadTool("toolA", context.Background(), WithBindParamString("param1", "value1"), @@ -274,7 +279,7 @@ func TestLoadToolAndLoadToolset(t *testing.T) { }) t.Run("LoadTool - Negative Test - Unused bound parameter", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadTool("toolA", context.Background(), WithBindParamString("param1", "value1"), @@ -289,7 +294,7 @@ func TestLoadToolAndLoadToolset(t *testing.T) { }) t.Run("LoadToolset - Success with non-strict mode", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) tools, err := client.LoadToolset( "", context.Background(), @@ -306,7 +311,7 @@ func TestLoadToolAndLoadToolset(t *testing.T) { }) t.Run("LoadToolset - Negative Test - Unused parameter in non-strict mode", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadToolset( "", context.Background(), @@ -322,7 +327,7 @@ func TestLoadToolAndLoadToolset(t *testing.T) { }) t.Run("LoadToolset - Negative Test - Unused parameter in strict mode", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadToolset( "", context.Background(), @@ -434,7 +439,7 @@ func TestNegativeAndEdgeCases(t *testing.T) { t.Run("LoadTool fails when a nil ToolOption is provided", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL) + client, _ := NewToolboxClient(server.URL, WithProtocol(Toolbox)) _, err := client.LoadTool("any-tool", context.Background(), nil) if err == nil { t.Fatal("Expected an error when a nil option is passed to LoadTool, but got nil") @@ -474,7 +479,7 @@ func TestNegativeAndEdgeCases(t *testing.T) { })) defer serverWithNoTools.Close() - client, _ := NewToolboxClient(serverWithNoTools.URL, WithHTTPClient(serverWithNoTools.Client())) + client, _ := NewToolboxClient(serverWithNoTools.URL, WithHTTPClient(serverWithNoTools.Client()), WithProtocol(Toolbox)) // This call would panic if the code doesn't check for a nil map. _, err := client.LoadTool("any-tool", context.Background()) @@ -567,25 +572,11 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { log.SetOutput(&buf) defer log.SetOutput(originalOutput) - t.Run("logs warning for HTTP with headers", func(t *testing.T) { - buf.Reset() - - client, _ := NewToolboxClient(server.URL, - WithHTTPClient(server.Client()), - ) - - _, _ = client.LoadTool("toolA", context.Background()) - - expectedLog := "WARNING: Sending ID token over HTTP" - if !strings.Contains(buf.String(), expectedLog) { - t.Errorf("expected log message '%s' not found in output: '%s'", expectedLog, buf.String()) - } - }) - t.Run("LoadTool fails when a default option is invalid", func(t *testing.T) { // Setup client with duplicate default options client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), + WithProtocol(Toolbox), WithDefaultToolOptions( WithStrict(true), WithStrict(false), @@ -605,7 +596,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { }) t.Run("LoadTool fails when tool is not in the manifest", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadTool("tool-that-does-not-exist", context.Background()) if err == nil { @@ -621,7 +612,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) errorServer.Close() - client, _ := NewToolboxClient(errorServer.URL, WithHTTPClient(errorServer.Client())) + client, _ := NewToolboxClient(errorServer.URL, WithHTTPClient(errorServer.Client()), WithProtocol(Toolbox)) _, err := client.LoadTool("any-tool", context.Background()) if err == nil { @@ -633,7 +624,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { }) t.Run("LoadTool fails with unused auth tokens", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadTool("toolA", context.Background(), WithAuthTokenString("unused-auth", "token"), // This auth is not needed by toolA ) @@ -646,7 +637,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { }) t.Run("LoadTool fails with unused bound parameters", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadTool("toolA", context.Background(), WithBindParamString("unused-param", "value"), // This param is not defined on toolA ) @@ -661,7 +652,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { }) t.Run("LoadToolset fails with unused parameters in strict mode", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadToolset( "", context.Background(), @@ -679,7 +670,7 @@ func TestLoadToolAndLoadToolset_ErrorPaths(t *testing.T) { }) t.Run("LoadToolset fails with unused parameters in non-strict mode", func(t *testing.T) { - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client()), WithProtocol(Toolbox)) _, err := client.LoadToolset( "", context.Background(), diff --git a/core/e2e_mcp_test.go b/core/e2e_mcp_test.go new file mode 100644 index 0000000..01ab06d --- /dev/null +++ b/core/e2e_mcp_test.go @@ -0,0 +1,756 @@ +//go:build e2e + +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/googleapis/mcp-toolbox-sdk-go/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +type protocolTestCase struct { + name string + protocol core.Protocol + isDefault bool // If true, we do NOT pass WithProtocol() to verify default behavior +} + +// protocolsToTest defines the matrix of MCP protocols we want to verify. +var protocolsToTest = []protocolTestCase{ + // The Default Case (User passes nothing, expects latest) + {name: "Default (Latest)", isDefault: true}, + + // Explicit Versions + {name: "v20241105", protocol: core.MCPv20241105}, + {name: "v20250326", protocol: core.MCPv20250326}, + {name: "v20250618", protocol: core.MCPv20250618}, + {name: "MCP Alias (Latest)", protocol: core.MCP}, +} + +// CapturingTransport wraps http.RoundTripper to capture headers from the latest request. +type CapturingTransport struct { + base http.RoundTripper + lastHeaders http.Header + mu sync.Mutex +} + +func (c *CapturingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.mu.Lock() + c.lastHeaders = req.Header.Clone() + c.mu.Unlock() + + // Delegate to the actual network transport + base := c.base + if base == nil { + base = http.DefaultTransport + } + return base.RoundTrip(req) +} + +func (c *CapturingTransport) CapturedHeaders() http.Header { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastHeaders +} + +// helper factory to create a client with a specific protocol +func getNewMCPToolboxClient(t *testing.T, tc protocolTestCase) *core.ToolboxClient { + opts := []core.ClientOption{} + + // Only add WithProtocol if it's NOT the default test case + if !tc.isDefault { + opts = append(opts, core.WithProtocol(tc.protocol)) + } + + client, err := core.NewToolboxClient("http://localhost:5000", opts...) + require.NoError(t, err, "Failed to create MCP ToolboxClient for %s", tc.name) + return client +} + +func TestMCP_Basic(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + // Helper to create a new client for each sub-test + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + + // Helper to load the get-n-rows tool + getNRowsTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool { + tool, err := client.LoadTool("get-n-rows", context.Background()) + require.NoError(t, err, "Failed to load tool 'get-n-rows'") + require.Equal(t, "get-n-rows", tool.Name()) + return tool + } + + t.Run("test_mcp_client_headers", func(t *testing.T) { + // Setup the Transport to capture headers + capturer := &CapturingTransport{} + httpClient := &http.Client{ + Transport: capturer, + Timeout: 30 * time.Second, + } + + // Build options manually to inject HTTP client + opts := []core.ClientOption{ + core.WithHTTPClient(httpClient), + } + // Logic to handle Default vs Explicit protocol + if !proto.isDefault { + opts = append(opts, core.WithProtocol(proto.protocol)) + } + + // Inject Transport into Client + client, err := core.NewToolboxClient("http://localhost:5000", opts...) + require.NoError(t, err) + + // Trigger a request + _, err = client.LoadTool("get-n-rows", context.Background()) + require.NoError(t, err) + + // Verify Transport Compliance + headers := capturer.CapturedHeaders() + + // Determine which protocol to check against + protocolToCheck := proto.protocol + if proto.isDefault { + protocolToCheck = core.MCPv20250618 // Default should match latest + } + + switch protocolToCheck { + case core.MCPv20241105: + // Should NOT have new headers + assert.Empty(t, headers.Get("MCP-Protocol-Version"), "v20241105 should not send protocol version header") + assert.Empty(t, headers.Get("Mcp-Session-Id"), "v20241105 should not include Mcp-Session-Id") + + case core.MCPv20250326: + // v2025-03-26: Must send Accept: application/json + assert.Equal(t, "application/json", headers.Get("Accept"), "v20250326 must request JSON only") + assert.NotEmpty(t, headers.Get("Mcp-Session-Id"), "v20250326 should include Mcp-Session-Id") + assert.Empty(t, headers.Get("MCP-Protocol-Version"), "v20250326 should not send protocol version header") + + case core.MCPv20250618: + // v2025-06-18: Must send Accept AND Protocol Version + assert.Equal(t, "application/json", headers.Get("Accept"), "v20250618 must request JSON only") + assert.Empty(t, headers.Get("Mcp-Session-Id"), "v20250618 should not include Mcp-Session-Id") + assert.Equal(t, "2025-06-18", headers.Get("MCP-Protocol-Version"), "v20250618 must send correct protocol version header") + } + }) + + t.Run("test_load_toolset_specific", func(t *testing.T) { + testCases := []struct { + name string + toolsetName string + expectedLength int + expectedTools []string + }{ + {"my-toolset", "my-toolset", 1, []string{"get-row-by-id"}}, + {"my-toolset-2", "my-toolset-2", 2, []string{"get-n-rows", "get-row-by-id"}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := newClient(t) + toolset, err := client.LoadToolset(tc.toolsetName, context.Background()) + + require.NoError(t, err) + assert.Len(t, toolset, tc.expectedLength) + + toolNames := make(map[string]struct{}) + for _, tool := range toolset { + toolNames[tool.Name()] = struct{}{} + } + expectedToolsSet := make(map[string]struct{}) + for _, name := range tc.expectedTools { + expectedToolsSet[name] = struct{}{} + } + assert.Equal(t, expectedToolsSet, toolNames) + }) + } + }) + + t.Run("test_load_toolset_default", func(t *testing.T) { + client := newClient(t) + toolset, err := client.LoadToolset("", context.Background()) + require.NoError(t, err) + + assert.Len(t, toolset, 7) + toolNames := make(map[string]struct{}) + for _, tool := range toolset { + toolNames[tool.Name()] = struct{}{} + } + expectedTools := map[string]struct{}{ + "get-row-by-content-auth": {}, + "get-row-by-email-auth": {}, + "get-row-by-id-auth": {}, + "get-row-by-id": {}, + "get-n-rows": {}, + "search-rows": {}, + "process-data": {}, + } + assert.Equal(t, expectedTools, toolNames) + }) + + t.Run("test_run_tool", func(t *testing.T) { + client := newClient(t) + tool := getNRowsTool(t, client) + + response, err := tool.Invoke(context.Background(), map[string]any{"num_rows": "2"}) + require.NoError(t, err) + + respStr, ok := response.(string) + require.True(t, ok, "Response should be a string") + assert.Contains(t, respStr, "row1") + assert.Contains(t, respStr, "row2") + assert.NotContains(t, respStr, "row3") + }) + + t.Run("test_run_tool_missing_params", func(t *testing.T) { + client := newClient(t) + tool := getNRowsTool(t, client) + + _, err := tool.Invoke(context.Background(), map[string]any{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required parameter 'num_rows'") + }) + + t.Run("test_run_tool_wrong_param_type", func(t *testing.T) { + client := newClient(t) + tool := getNRowsTool(t, client) + + _, err := tool.Invoke(context.Background(), map[string]any{"num_rows": 2}) + require.Error(t, err) + assert.Contains(t, err.Error(), "parameter 'num_rows' expects a string, but got int") + }) + }) + } +} + +func TestMCP_LoadErrors(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + + t.Run("test_load_non_existent_tool", func(t *testing.T) { + client := newClient(t) + _, err := client.LoadTool("non-existent-tool", context.Background()) + require.Error(t, err) + assert.True(t, err != nil) + }) + + t.Run("test_load_non_existent_toolset", func(t *testing.T) { + client := newClient(t) + _, err := client.LoadToolset("non-existent-toolset", context.Background()) + require.Error(t, err) + }) + }) + } + + t.Run("test_new_client_with_nil_option", func(t *testing.T) { + _, err := core.NewToolboxClient("http://localhost:5000", nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "received a nil ClientOption") + }) + + t.Run("test_load_tool_with_nil_option", func(t *testing.T) { + client := getNewMCPToolboxClient(t, protocolsToTest[0]) + _, err := client.LoadTool("get-n-rows", context.Background(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "received a nil ToolOption") + }) +} + +func TestMCP_BindParams(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + getNRowsTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool { + tool, err := client.LoadTool("get-n-rows", context.Background()) + require.NoError(t, err) + return tool + } + + t.Run("test_bind_params", func(t *testing.T) { + client := newClient(t) + tool := getNRowsTool(t, client) + + newTool, err := tool.ToolFrom(core.WithBindParamString("num_rows", "3")) + require.NoError(t, err) + + response, err := newTool.Invoke(context.Background(), map[string]any{}) + require.NoError(t, err) + + respStr, ok := response.(string) + require.True(t, ok) + assert.Contains(t, respStr, "row1") + assert.Contains(t, respStr, "row2") + assert.Contains(t, respStr, "row3") + assert.NotContains(t, respStr, "row4") + }) + + t.Run("test_bind_params_callable", func(t *testing.T) { + client := newClient(t) + tool := getNRowsTool(t, client) + + callable := func() (string, error) { + return "3", nil + } + + newTool, err := tool.ToolFrom(core.WithBindParamStringFunc("num_rows", callable)) + require.NoError(t, err) + + response, err := newTool.Invoke(context.Background(), map[string]any{}) + require.NoError(t, err) + + respStr, ok := response.(string) + require.True(t, ok) + assert.Contains(t, respStr, "row1") + assert.Contains(t, respStr, "row2") + assert.Contains(t, respStr, "row3") + assert.NotContains(t, respStr, "row4") + }) + }) + } +} + +func TestMCP_BindParamErrors(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + client := getNewMCPToolboxClient(t, proto) + tool, err := client.LoadTool("get-n-rows", context.Background()) + require.NoError(t, err) + + t.Run("test_bind_non_existent_param", func(t *testing.T) { + _, err := tool.ToolFrom(core.WithBindParamString("non-existent-param", "3")) + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to bind parameter: no parameter named 'non-existent-param' on the tool") + }) + + t.Run("test_override_bound_param", func(t *testing.T) { + newTool, err := tool.ToolFrom(core.WithBindParamString("num_rows", "2")) + require.NoError(t, err) + + _, err = newTool.ToolFrom(core.WithBindParamString("num_rows", "3")) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot override existing bound parameter: 'num_rows'") + }) + }) + } +} + +func TestMCP_Auth(t *testing.T) { + // Helper to create a static token source from a string token + staticTokenSource := func(token string) oauth2.TokenSource { + return oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) + } + + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + + t.Run("test_run_tool_unauth_with_auth", func(t *testing.T) { + client := newClient(t) + _, err := client.LoadTool("get-row-by-id", context.Background(), + core.WithAuthTokenSource("my-test-auth", staticTokenSource(authToken2)), + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth") + }) + + t.Run("test_run_tool_no_auth", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-id-auth", context.Background()) + require.NoError(t, err) + + _, err = tool.Invoke(context.Background(), map[string]any{"id": "2"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission error: auth service 'my-test-auth' is required") + }) + + t.Run("test_run_tool_wrong_auth", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-id-auth", context.Background()) + require.NoError(t, err) + + authedTool, err := tool.ToolFrom( + core.WithAuthTokenSource("my-test-auth", staticTokenSource(authToken2)), + ) + require.NoError(t, err) + + _, err = authedTool.Invoke(context.Background(), map[string]any{"id": "2"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unauthorized Tool call") + }) + + t.Run("test_run_tool_auth", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-id-auth", context.Background(), + core.WithAuthTokenSource("my-test-auth", staticTokenSource(authToken1)), + ) + require.NoError(t, err) + + response, err := tool.Invoke(context.Background(), map[string]any{"id": "2"}) + require.NoError(t, err) + + respStr, ok := response.(string) + require.True(t, ok) + assert.Contains(t, respStr, "row2") + }) + + t.Run("test_run_tool_param_auth_no_auth", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-email-auth", context.Background()) + require.NoError(t, err) + + _, err = tool.Invoke(context.Background(), map[string]any{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "permission error: auth service 'my-test-auth' is required") + }) + + t.Run("test_run_tool_param_auth", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-email-auth", context.Background(), + core.WithAuthTokenSource("my-test-auth", staticTokenSource(authToken1)), + ) + require.NoError(t, err) + + response, err := tool.Invoke(context.Background(), map[string]any{}) + require.NoError(t, err) + + respStr, ok := response.(string) + require.True(t, ok) + assert.Contains(t, respStr, "row4") + assert.Contains(t, respStr, "row5") + assert.Contains(t, respStr, "row6") + }) + + t.Run("test_run_tool_param_auth_no_field", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-content-auth", context.Background(), + core.WithAuthTokenSource("my-test-auth", staticTokenSource(authToken1)), + ) + require.NoError(t, err) + + _, err = tool.Invoke(context.Background(), map[string]any{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no field named row_data in claims") + }) + + t.Run("test_run_tool_with_failing_token_source", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-row-by-id-auth", context.Background(), + core.WithAuthTokenSource("my-test-auth", &failingTokenSource{}), + ) + require.NoError(t, err) + + _, err = tool.Invoke(context.Background(), map[string]any{"id": "2"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get token for header my-test-auth_token") + assert.Contains(t, err.Error(), "token source failed as designed") + }) + }) + } +} + +func TestMCP_OptionalParams(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + searchRowsTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool { + tool, err := client.LoadTool("search-rows", context.Background()) + require.NoError(t, err, "Failed to load tool 'search-rows'") + return tool + } + + t.Run("test_tool_schema_is_correct", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + params := tool.Parameters() + + // Convert slice to map for easy lookup + paramMap := make(map[string]core.ParameterSchema) + for _, p := range params { + paramMap[p.Name] = p + } + + emailParam, ok := paramMap["email"] + require.True(t, ok) + assert.True(t, emailParam.Required) + assert.Equal(t, "string", emailParam.Type) + + dataParam, ok := paramMap["data"] + require.True(t, ok) + assert.False(t, dataParam.Required) + assert.Equal(t, "string", dataParam.Type) + + idParam, ok := paramMap["id"] + require.True(t, ok) + assert.False(t, idParam.Required) + assert.Equal(t, "integer", idParam.Type) + }) + + t.Run("test_run_tool_omitting_optionals", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + + // Test case 1: Optional params are completely omitted + response1, err1 := tool.Invoke(context.Background(), map[string]any{ + "email": "twishabansal@google.com", + }) + require.NoError(t, err1) + respStr1, ok1 := response1.(string) + require.True(t, ok1) + assert.Contains(t, respStr1, `"email":"twishabansal@google.com"`) + assert.Contains(t, respStr1, "row2") + assert.NotContains(t, respStr1, "row3") + + // Test case 2: Optional params are explicitly nil + response2, err2 := tool.Invoke(context.Background(), map[string]any{ + "email": "twishabansal@google.com", + "data": nil, + "id": nil, + }) + require.NoError(t, err2) + respStr2, ok2 := response2.(string) + require.True(t, ok2) + assert.Equal(t, respStr1, respStr2) + }) + + t.Run("test_run_tool_with_all_params_provided", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + response, err := tool.Invoke(context.Background(), map[string]any{ + "email": "twishabansal@google.com", + "data": "row3", + "id": 3, + }) + require.NoError(t, err) + respStr, ok := response.(string) + require.True(t, ok) + assert.Contains(t, respStr, `"email":"twishabansal@google.com"`) + assert.Contains(t, respStr, `"id":3`) + assert.Contains(t, respStr, "row3") + assert.NotContains(t, respStr, "row2") + }) + + t.Run("test_run_tool_missing_required_param", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + _, err := tool.Invoke(context.Background(), map[string]any{ + "data": "row5", + "id": 5, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required parameter 'email'") + }) + + t.Run("test_run_tool_required_param_is_nil", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + _, err := tool.Invoke(context.Background(), map[string]any{ + "email": nil, + "id": 5, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "parameter 'email' is required but received a nil value") + }) + + t.Run("test_run_tool_with_non_matching_data", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + + // Test with a different email + response, err := tool.Invoke(context.Background(), map[string]any{ + "email": "anubhavdhawan@google.com", + "id": 3, + "data": "row3", + }) + require.NoError(t, err) + assert.Equal(t, "null", response, "Response should be null for non-matching email") + + // Test with different data + response, err = tool.Invoke(context.Background(), map[string]any{ + "email": "twishabansal@google.com", + "id": 3, + "data": "row4", + }) + require.NoError(t, err) + assert.Equal(t, "null", response, "Response should be null for non-matching data") + }) + + t.Run("test_run_tool_wrong_type_for_integer", func(t *testing.T) { + client := newClient(t) + tool := searchRowsTool(t, client) + + _, err := tool.Invoke(context.Background(), map[string]any{ + "email": "twishabansal@google.com", + "id": "not-an-integer", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "parameter 'id' expects an integer, but got string") + }) + }) + } +} + +func TestMCP_MapParams(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + processDataTool := func(t *testing.T, client *core.ToolboxClient) *core.ToolboxTool { + tool, err := client.LoadTool("process-data", context.Background()) + require.NoError(t, err, "Failed to load tool 'process-data'") + return tool + } + + t.Run("test_tool_schema_is_correct", func(t *testing.T) { + client := newClient(t) + tool := processDataTool(t, client) + params := tool.Parameters() + + paramMap := make(map[string]core.ParameterSchema) + for _, p := range params { + paramMap[p.Name] = p + } + + execCtxParam, ok := paramMap["execution_context"] + require.True(t, ok) + assert.True(t, execCtxParam.Required) + assert.Equal(t, "object", execCtxParam.Type) + + userScoresParam, ok := paramMap["user_scores"] + require.True(t, ok) + assert.True(t, userScoresParam.Required) + assert.Equal(t, "object", userScoresParam.Type) + + featureFlagsParam, ok := paramMap["feature_flags"] + require.True(t, ok) + assert.False(t, featureFlagsParam.Required) + assert.Equal(t, "object", featureFlagsParam.Type) + }) + + t.Run("test_run_tool_with_all_map_params", func(t *testing.T) { + client := newClient(t) + tool := processDataTool(t, client) + + response, err := tool.Invoke(context.Background(), map[string]any{ + "execution_context": map[string]any{ + "env": "prod", + "id": 1234, + "user": 1234.5, + }, + "user_scores": map[string]any{ + "user1": 100, + "user2": 200, + }, + "feature_flags": map[string]any{ + "new_feature": true, + }, + }) + require.NoError(t, err) + respStr, ok := response.(string) + require.True(t, ok) + + assert.Contains(t, respStr, `"execution_context":{"env":"prod","id":1234,"user":1234.5}`) + assert.Contains(t, respStr, `"user_scores":{"user1":100,"user2":200}`) + assert.Contains(t, respStr, `"feature_flags":{"new_feature":true}`) + }) + + t.Run("test_run_tool_omitting_optional_map", func(t *testing.T) { + client := newClient(t) + tool := processDataTool(t, client) + + response, err := tool.Invoke(context.Background(), map[string]any{ + "execution_context": map[string]any{"env": "dev"}, + "user_scores": map[string]any{"user3": 300}, + }) + require.NoError(t, err) + respStr, ok := response.(string) + require.True(t, ok) + + assert.Contains(t, respStr, `"execution_context":{"env":"dev"}`) + assert.Contains(t, respStr, `"user_scores":{"user3":300}`) + assert.Contains(t, respStr, `"feature_flags":null`) + }) + + t.Run("test_run_tool_with_wrong_map_value_type", func(t *testing.T) { + client := newClient(t) + tool := processDataTool(t, client) + + _, err := tool.Invoke(context.Background(), map[string]any{ + "execution_context": map[string]any{"env": "staging"}, + "user_scores": map[string]any{ + "user4": "not-an-integer", + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "expects an integer, but got string") + }) + }) + } +} + +func TestMCP_ContextHandling(t *testing.T) { + for _, proto := range protocolsToTest { + t.Run(proto.name, func(t *testing.T) { + newClient := func(t *testing.T) *core.ToolboxClient { + return getNewMCPToolboxClient(t, proto) + } + + t.Run("test_load_tool_with_cancelled_context", func(t *testing.T) { + client := newClient(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.LoadTool("get-n-rows", ctx) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("test_invoke_tool_with_timed_out_context", func(t *testing.T) { + client := newClient(t) + tool, err := client.LoadTool("get-n-rows", context.Background()) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + time.Sleep(1 * time.Millisecond) + + _, err = tool.Invoke(ctx, map[string]any{"num_rows": "1"}) + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + }) + } +} diff --git a/core/e2e_test.go b/core/e2e_test.go index f82e6d1..245502f 100644 --- a/core/e2e_test.go +++ b/core/e2e_test.go @@ -67,11 +67,11 @@ func TestMain(m *testing.M) { // Download and start the toolbox server cmd := setupAndStartToolboxServer(ctx, toolboxVersion, toolsFilePath) - // --- 2. Run Tests --- + // Run Tests log.Println("Setup complete. Running tests...") exitCode := m.Run() - // --- 3. Teardown Phase --- + // 3. Teardown Phase log.Println("Tearing down toolbox server...") if err := cmd.Process.Kill(); err != nil { log.Printf("Failed to kill toolbox server process: %v", err) @@ -81,10 +81,17 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } +// helper factory to create a client with a specific protocol +func getNewToolboxClient() (*core.ToolboxClient, error) { + client, err := core.NewToolboxClient("http://localhost:5000", + core.WithProtocol(core.Toolbox)) + return client, err +} + func TestE2E_Basic(t *testing.T) { // Helper to create a new client for each sub-test, like a function-scoped fixture newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err, "Failed to create ToolboxClient") return client } @@ -187,7 +194,7 @@ func TestE2E_Basic(t *testing.T) { func TestE2E_LoadErrors(t *testing.T) { newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err, "Failed to create ToolboxClient") return client } @@ -222,7 +229,7 @@ func TestE2E_LoadErrors(t *testing.T) { func TestE2E_BindParams(t *testing.T) { newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err) return client } @@ -274,7 +281,7 @@ func TestE2E_BindParams(t *testing.T) { } func TestE2E_BindParamErrors(t *testing.T) { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err) tool, err := client.LoadTool("get-n-rows", context.Background()) require.NoError(t, err) @@ -297,7 +304,7 @@ func TestE2E_BindParamErrors(t *testing.T) { func TestE2E_Auth(t *testing.T) { newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err) return client } @@ -404,7 +411,7 @@ func TestE2E_Auth(t *testing.T) { _, err = tool.Invoke(context.Background(), map[string]any{"id": "2"}) require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get token for service 'my-test-auth'") + assert.Contains(t, err.Error(), "failed to resolve token for header 'my-test-auth_token'") assert.Contains(t, err.Error(), "token source failed as designed") }) } @@ -412,7 +419,7 @@ func TestE2E_Auth(t *testing.T) { func TestE2E_OptionalParams(t *testing.T) { // Helper to create a new client newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err, "Failed to create ToolboxClient") return client } @@ -561,7 +568,7 @@ func TestE2E_OptionalParams(t *testing.T) { func TestE2E_MapParams(t *testing.T) { // Helper to create a new client newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err, "Failed to create ToolboxClient") return client } @@ -669,7 +676,7 @@ func TestE2E_MapParams(t *testing.T) { func TestE2E_ContextHandling(t *testing.T) { newClient := func(t *testing.T) *core.ToolboxClient { - client, err := core.NewToolboxClient("http://localhost:5000") + client, err := getNewToolboxClient() require.NoError(t, err, "Failed to create ToolboxClient") return client } diff --git a/core/options.go b/core/options.go index f6023b6..726f828 100644 --- a/core/options.go +++ b/core/options.go @@ -34,6 +34,18 @@ func newToolConfig() *ToolConfig { } } +// WithProtocol provides a the underlying transport protocol to the ToolboxClient.. +func WithProtocol(p Protocol) ClientOption { + return func(tc *ToolboxClient) error { + if tc.protocolSet { + return fmt.Errorf("protocol is already set and cannot be overridden") + } + tc.protocol = p + tc.protocolSet = true + return nil + } +} + // WithHTTPClient provides a custom http.Client to the ToolboxClient. func WithHTTPClient(client *http.Client) ClientOption { return func(tc *ToolboxClient) error { diff --git a/core/options_test.go b/core/options_test.go index 4c15343..c660d55 100644 --- a/core/options_test.go +++ b/core/options_test.go @@ -70,6 +70,62 @@ func TestWithHTTPClient(t *testing.T) { }) } +func TestWithProtocol(t *testing.T) { + // Verify all protocols can be set individually + tests := []struct { + name string + protocol Protocol + }{ + {"Sets Toolbox Protocol", Toolbox}, + {"Sets MCP v2025-06-18", MCPv20250618}, + {"Sets MCP v2025-03-26", MCPv20250326}, + {"Sets MCP v2024-11-05", MCPv20241105}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := newTestClient() + opt := WithProtocol(tc.protocol) + + err := opt(client) + + if err != nil { + t.Errorf("Expected no error, but got: %v", err) + } + if client.protocol != tc.protocol { + t.Errorf("Expected protocol to be %s, got %s", tc.protocol, client.protocol) + } + if !client.protocolSet { + t.Error("Expected protocolSet flag to be true") + } + }) + } + + // Verify error on duplicate setting + t.Run("Error when setting protocol twice", func(t *testing.T) { + client := newTestClient() + + // First call (Should succeed) + firstOpt := WithProtocol(MCPv20241105) + if err := firstOpt(client); err != nil { + t.Fatalf("Unexpected error on first set: %v", err) + } + + // Second call (Should fail) + secondOpt := WithProtocol(MCPv20250326) + err := secondOpt(client) + + if err == nil { + t.Error("Expected error when setting protocol twice, but got nil") + } + + // Verify the protocol wasn't overwritten + if client.protocol != MCPv20241105 { + t.Errorf("Expected protocol to remain %s, but changed to %s", MCPv20241105, client.protocol) + } + }) +} + func TestWithClientHeaderString(t *testing.T) { t.Run("Success case", func(t *testing.T) { client := newTestClient() diff --git a/core/protocol.go b/core/protocol.go index debfc85..f2ec8aa 100644 --- a/core/protocol.go +++ b/core/protocol.go @@ -14,154 +14,37 @@ package core -import ( - "fmt" - "reflect" -) - -// Schema for a tool parameter. -type ParameterSchema struct { - Name string `json:"name"` - Type string `json:"type"` - Required bool `json:"required,omitempty"` - Description string `json:"description"` - AuthSources []string `json:"authSources,omitempty"` - Items *ParameterSchema `json:"items,omitempty"` - AdditionalProperties any `json:"additionalProperties,omitempty"` -} - -// validateType is a helper for manual type checking. -func (p *ParameterSchema) validateType(value any) error { - if value == nil { - if p.Required { - return fmt.Errorf("parameter '%s' is required but received a nil value", p.Name) - } - return nil - } +import "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" - switch p.Type { - case "string": - if _, ok := value.(string); !ok { - return fmt.Errorf("parameter '%s' expects a string, but got %T", p.Name, value) - } - case "integer": - switch value.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - default: - return fmt.Errorf("parameter '%s' expects an integer, but got %T", p.Name, value) - } - case "float": - switch value.(type) { - case float32, float64: - default: - return fmt.Errorf("parameter '%s' expects an float, but got %T", p.Name, value) - } - case "boolean": - if _, ok := value.(bool); !ok { - return fmt.Errorf("parameter '%s' expects a boolean, but got %T", p.Name, value) - } - case "array": - v := reflect.ValueOf(value) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return fmt.Errorf("parameter '%s' expects an array/slice, but got %T", p.Name, value) - } - for i := range v.Len() { - item := v.Index(i).Interface() +// Protocol defines underlying transport protocols. +type Protocol string - if err := p.Items.validateType(item); err != nil { - return fmt.Errorf("error in array '%s' at index %d: %w", p.Name, i, err) - } - } - case "object": - // First, check that the value is a map with string keys. - valMap, ok := value.(map[string]any) - if !ok { - return fmt.Errorf("parameter '%s' expects a map, but got %T", p.Name, value) - } +const ( + // Toolbox represents the Native Toolbox protocol. + Toolbox Protocol = "toolbox" - switch ap := p.AdditionalProperties.(type) { - // No validation required, allows any type - case bool: + // MCP Version Constants + MCPv20250618 Protocol = "2025-06-18" + MCPv20250326 Protocol = "2025-03-26" + MCPv20241105 Protocol = "2024-11-05" - // Validate type for each value in map - case *ParameterSchema: - // Raise error if the input is a nested map / array - if ap.Type == "object" || ap.Type == "array" { - return fmt.Errorf("invalid schema for object '%s': values cannot be of type '%s'", p.Name, ap.Type) - } - for key, val := range valMap { - if err := ap.validateType(val); err != nil { - return fmt.Errorf("error in object '%s' for key '%s': %w", p.Name, key, err) - } - } + // MCP is the default alias pointing to the newest supported version. + MCP = MCPv20250618 +) - default: - // This is a schema / manifest error. - return fmt.Errorf( - "invalid schema for parameter '%s': AdditionalProperties must be a boolean or a map[string]any, but got %T", - p.Name, - ap, - ) - } - default: - return fmt.Errorf("unknown type '%s' in schema for parameter '%s'", p.Type, p.Name) +// GetSupportedMcpVersions returns a list of supported MCP protocol versions. +func GetSupportedMcpVersions() []string { + return []string{ + string(MCPv20250618), + string(MCPv20250326), + string(MCPv20241105), } - return nil } -// ValidateDefinition checks if the schema itself is well-formed. -func (p *ParameterSchema) ValidateDefinition() error { - if p.Type == "" { - return fmt.Errorf("schema validation failed for '%s': type is missing", p.Name) - } - - switch p.Type { - case "array": - if p.Items == nil { - return fmt.Errorf("parameter '%s' is an array but is missing item type definition", p.Name) - } - // Recursively validate the nested schema's definition. - if err := p.Items.ValidateDefinition(); err != nil { - return err - } +type ManifestSchema = transport.ManifestSchema - case "object": - switch ap := p.AdditionalProperties.(type) { - case bool: - // Valid scenario - case *ParameterSchema: - if err := ap.ValidateDefinition(); err != nil { - return err - } - default: - // Any other type is an invalid schema definition. - return fmt.Errorf( - "invalid schema for parameter '%s': AdditionalProperties must be a boolean or a schema, but got %T", - p.Name, - ap, - ) - } +// ToolSchema defines a single tool in the manifest. +type ToolSchema = transport.ToolSchema - case "string", "integer", "float", "boolean": - // No type-specific rules for these. - break - - default: - return fmt.Errorf("unknown schema type '%s' for parameter '%s'", p.Type, p.Name) - } - - return nil -} - -// Schema for a tool. -type ToolSchema struct { - Description string `json:"description"` - Parameters []ParameterSchema `json:"parameters"` - AuthRequired []string `json:"authRequired,omitempty"` -} - -// Schema for the Toolbox manifest. -type ManifestSchema struct { - ServerVersion string `json:"serverVersion"` - Tools map[string]ToolSchema `json:"tools"` -} +// ParameterSchema defines the structure and validation logic for tool parameters. +type ParameterSchema = transport.ParameterSchema diff --git a/core/protocol_test.go b/core/protocol_test.go index 9d8bc6b..a3e2756 100644 --- a/core/protocol_test.go +++ b/core/protocol_test.go @@ -6,7 +6,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -16,566 +16,26 @@ package core -import ( - "fmt" - "strings" - "testing" - "time" -) +import "testing" -// Tests ParameterSchema with type 'int'. -func TestParameterSchemaInteger(t *testing.T) { +func TestGetSupportedMcpVersions(t *testing.T) { + versions := GetSupportedMcpVersions() - schema := ParameterSchema{ - Name: "param_name", - Type: "integer", - Description: "integer parameter", + // Verify we get exactly 3 versions + if len(versions) != 3 { + t.Errorf("Expected 3 supported versions, got %d", len(versions)) } - t.Run("Test int param", func(t *testing.T) { - value := 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test int8 param", func(t *testing.T) { - var value int8 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test int16 param", func(t *testing.T) { - var value int16 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test int32 param", func(t *testing.T) { - var value int32 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test int64 param", func(t *testing.T) { - var value int64 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test uint param", func(t *testing.T) { - var value uint = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test uint8 param", func(t *testing.T) { - var value uint8 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test uint16 param", func(t *testing.T) { - var value uint16 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test uint32 param", func(t *testing.T) { - var value uint32 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test uint64 param", func(t *testing.T) { - var value uint64 = 1 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - -} - -// Tests ParameterSchema with type 'string'. -func TestParameterSchemaString(t *testing.T) { - - schema := ParameterSchema{ - Name: "param_name", - Type: "string", - Description: "string parameter", - } - - value := "abc" - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - -} - -// Tests ParameterSchema with type 'boolean'. -func TestParameterSchemaBoolean(t *testing.T) { - - schema := ParameterSchema{ - Name: "param_name", - Type: "boolean", - Description: "boolean parameter", - } - - value := true - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - -} - -// Tests ParameterSchema with type 'float'. -func TestParameterSchemaFloat(t *testing.T) { - - schema := ParameterSchema{ - Name: "param_name", - Type: "float", - Description: "float parameter", - } - - t.Run("Test float32 param", func(t *testing.T) { - var value float32 = 3.14 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - t.Run("Test float64 param", func(t *testing.T) { - value := 3.14 - - err := schema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - }) - -} - -// Tests ParameterSchema with type 'array'. -func TestParameterSchemaStringArray(t *testing.T) { - - itemSchema := ParameterSchema{ - Name: "item", - Type: "string", - Description: "item of the array", - } - - paramSchema := ParameterSchema{ - Name: "param_name", - Type: "array", - Description: "array parameter", - Items: &itemSchema, - } - - value := []string{"abc", "def"} - - err := paramSchema.validateType(value) - - if err != nil { - t.Fatal(err.Error()) - } - -} - -// Tests ParameterSchema with an undefined type. -func TestParameterSchemaUndefinedType(t *testing.T) { - - paramSchema := ParameterSchema{ - Name: "param_name", - Type: "time", - Description: "time parameter", - } - - value := time.Now() - - err := paramSchema.validateType(value) - - if err == nil { - t.Fatal("Expected an error, but got nil") - } - -} - -func TestOptionalStringParameter(t *testing.T) { - schema := ParameterSchema{ - Name: "nickname", - Type: "string", - Description: "An optional nickname", - Required: false, // Explicitly optional - } - - t.Run("allows nil value for optional parameter", func(t *testing.T) { - err := schema.validateType(nil) - if err != nil { - t.Errorf("validateType() with nil should not return an error for an optional parameter, but got: %v", err) - } - }) - - t.Run("allows valid string value", func(t *testing.T) { - err := schema.validateType("my-name") - if err != nil { - t.Errorf("validateType() should not return an error for a valid string, but got: %v", err) - } - }) -} - -func TestRequiredParameter(t *testing.T) { - schema := ParameterSchema{ - Name: "id", - Type: "integer", - Description: "A required ID", - Required: true, // Explicitly required + // Verify the content matches our constants + expected := []string{ + string(MCPv20250618), + string(MCPv20250326), + string(MCPv20241105), } - t.Run("rejects nil value for required parameter", func(t *testing.T) { - err := schema.validateType(nil) - if err == nil { - t.Errorf("validateType() with nil should return an error for a required parameter, but it didn't") + for i, v := range versions { + if v != expected[i] { + t.Errorf("Index %d: expected version %s, got %s", i, expected[i], v) } - }) - - t.Run("allows valid integer value", func(t *testing.T) { - err := schema.validateType(12345) - if err != nil { - t.Errorf("validateType() should not return an error for a valid integer, but got: %v", err) - } - }) -} - -func TestOptionalArrayParameter(t *testing.T) { - schema := ParameterSchema{ - Name: "optional_scores", - Type: "array", - Description: "An optional list of scores", - Required: false, - Items: &ParameterSchema{ - Type: "integer", - }, } - - t.Run("allows nil value for optional array", func(t *testing.T) { - err := schema.validateType(nil) - if err != nil { - t.Errorf("validateType() with nil should not return an error for an optional array, but got: %v", err) - } - }) - - t.Run("allows valid integer slice", func(t *testing.T) { - err := schema.validateType([]int{95, 100}) - if err != nil { - t.Errorf("validateType() should not return an error for a valid slice, but got: %v", err) - } - }) - - t.Run("rejects slice with wrong item type", func(t *testing.T) { - err := schema.validateType([]string{"not", "an", "int"}) - if err == nil { - t.Errorf("validateType() should have returned an error for a slice with incorrect item types, but it didn't") - } - }) -} - -func TestValidateTypeObject(t *testing.T) { - t.Run("generic object allows any value types", func(t *testing.T) { - schema := ParameterSchema{ - Name: "metadata", - Type: "object", - AdditionalProperties: true, // or nil - } - - // A map with mixed value types should be valid. - validInput := map[string]any{ - "key_string": "a string", - "key_int": 123, - "key_bool": true, - "key_map": map[string]any{"id": 1}, - "key_array": []string{"id", "number"}, - } - if err := schema.validateType(validInput); err != nil { - t.Errorf("Expected no error for generic object, but got: %v", err) - } - - // A value that is not a map should be invalid. - invalidInput := "I am a string, not an object" - if err := schema.validateType(invalidInput); err == nil { - t.Errorf("Expected an error for non-map input, but got nil") - } - }) - - t.Run("typed object validation", func(t *testing.T) { - testCases := []struct { - name string - valueType string - validInput map[string]any - invalidInput map[string]any - }{ - { - name: "string values", - valueType: "string", - validInput: map[string]any{"header": "application/json"}, - invalidInput: map[string]any{"bad_header": 123}, - }, - { - name: "integer values", - valueType: "integer", - validInput: map[string]any{"user_score": 100}, - invalidInput: map[string]any{"bad_score": "100"}, - }, - { - name: "float values", - valueType: "float", - validInput: map[string]any{"item_price": 99.99}, - invalidInput: map[string]any{"bad_price": 99}, - }, - { - name: "boolean values", - valueType: "boolean", - validInput: map[string]any{"feature_flag": true}, - invalidInput: map[string]any{"bad_flag": "true"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - schema := ParameterSchema{ - Name: "test_map", - Type: "object", - AdditionalProperties: &ParameterSchema{Type: tc.valueType}, - } - - // Test that valid input passes - if err := schema.validateType(tc.validInput); err != nil { - t.Errorf("Expected no error for valid input, got: %v", err) - } - - // Test that invalid input fails - if err := schema.validateType(tc.invalidInput); err == nil { - t.Errorf("Expected an error for invalid input, but got nil") - } - }) - } - }) - - t.Run("Fail for object valueType maps", func(t *testing.T) { - - // This schema itself is invalid so there is no valid test case - schema := ParameterSchema{ - Name: "test_map", - Type: "object", - AdditionalProperties: &ParameterSchema{Type: "object"}, - } - - invalidInput := map[string]any{"feature_flag": map[string]any{"id": "123"},} - // Test that invalid input fails - if err := schema.validateType(invalidInput); err == nil { - t.Errorf("Expected an error for invalid input, but got nil") - } - }) - - t.Run("Fail for array valueType maps", func(t *testing.T) { - - // This schema itself is invalid so there is no valid test case - schema := ParameterSchema{ - Name: "test_map", - Type: "object", - AdditionalProperties: &ParameterSchema{Type: "array"}, - } - - invalidInput := map[string]any{"feature_flag": []string{"id", "number"},} - // Test that invalid input fails - if err := schema.validateType(invalidInput); err == nil { - t.Errorf("Expected an error for invalid input, but got nil") - } - }) - - t.Run("optional and required objects", func(t *testing.T) { - // An optional object can be nil - optionalSchema := ParameterSchema{Name: "optional_metadata", Type: "object", Required: false} - if err := optionalSchema.validateType(nil); err != nil { - t.Errorf("Expected no error for nil on optional object, but got: %v", err) - } - - // A required object cannot be nil - requiredSchema := ParameterSchema{Name: "required_metadata", Type: "object", Required: true} - if err := requiredSchema.validateType(nil); err == nil { - t.Error("Expected an error for nil on required object, but got nil") - } - }) - - t.Run("object with unsupported value type in schema", func(t *testing.T) { - unsupportedType := "custom_object" - schema := ParameterSchema{ - Name: "custom_data", - Type: "object", - AdditionalProperties: &ParameterSchema{Type: unsupportedType}, - } - - input := map[string]any{"key": "some value"} - err := schema.validateType(input) - - if err == nil { - t.Fatal("Expected an error for unsupported sub-schema type, but got nil") - } - - // Check if the error message contains the expected text. - expectedErrorMsg := fmt.Sprintf("unknown type '%s'", unsupportedType) - if !strings.Contains(err.Error(), expectedErrorMsg) { - t.Errorf("Expected error to contain '%s', but got '%v'", expectedErrorMsg, err) - } - }) -} - -func TestParameterSchema_ValidateDefinition(t *testing.T) { - t.Run("should succeed for simple valid types", func(t *testing.T) { - testCases := []struct { - name string - schema *ParameterSchema - }{ - {"String", &ParameterSchema{Name: "p_string", Type: "string"}}, - {"Integer", &ParameterSchema{Name: "p_int", Type: "integer"}}, - {"Float", &ParameterSchema{Name: "p_float", Type: "float"}}, - {"Boolean", &ParameterSchema{Name: "p_bool", Type: "boolean"}}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if err := tc.schema.ValidateDefinition(); err != nil { - t.Errorf("expected no error, but got: %v", err) - } - }) - } - }) - - t.Run("should succeed for a valid array schema", func(t *testing.T) { - schema := &ParameterSchema{ - Name: "p_array", - Type: "array", - Items: &ParameterSchema{Type: "string"}, - } - if err := schema.ValidateDefinition(); err != nil { - t.Errorf("expected no error for valid array, but got: %v", err) - } - }) - - t.Run("should succeed for valid object schemas", func(t *testing.T) { - testCases := []struct { - name string - schema *ParameterSchema - }{ - { - "Typed Object", - &ParameterSchema{ - Name: "p_obj_typed", - Type: "object", - AdditionalProperties: &ParameterSchema{Type: "integer"}, - }, - }, - { - "Generic Object (bool)", - &ParameterSchema{ - Name: "p_obj_bool", - Type: "object", - AdditionalProperties: true, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if err := tc.schema.ValidateDefinition(); err != nil { - t.Errorf("expected no error, but got: %v", err) - } - }) - } - }) - - t.Run("should fail when type is missing", func(t *testing.T) { - schema := &ParameterSchema{Name: "p_missing_type", Type: "object", AdditionalProperties: &ParameterSchema{Type: ""}} - err := schema.ValidateDefinition() - if err == nil { - t.Fatal("expected an error for missing type, but got nil") - } - if !strings.Contains(err.Error(), "type is missing") { - t.Errorf("error message should mention 'type is missing', but was: %s", err) - } - }) - - t.Run("should fail when type is unknown", func(t *testing.T) { - schema := &ParameterSchema{Name: "p_unknown", Type: "object", AdditionalProperties: &ParameterSchema{Type: "some-custom-type"}} - err := schema.ValidateDefinition() - if err == nil { - t.Fatal("expected an error for unknown type, but got nil") - } - if !strings.Contains(err.Error(), "unknown schema type") { - t.Errorf("error message should mention 'unknown schema type', but was: %s", err) - } - }) - - t.Run("should fail for array with missing items property", func(t *testing.T) { - schema := &ParameterSchema{Name: "p_bad_array", Type: "array", Items: nil} - err := schema.ValidateDefinition() - if err == nil { - t.Fatal("expected an error for array with nil items, but got nil") - } - if !strings.Contains(err.Error(), "missing item type definition") { - t.Errorf("error message should mention 'missing item type definition', but was: %s", err) - } - }) - - t.Run("should fail for object with invalid AdditionalProperties type", func(t *testing.T) { - schema := &ParameterSchema{ - Name: "p_bad_object", - Type: "object", - AdditionalProperties: "a-string-is-not-valid", - } - err := schema.ValidateDefinition() - if err == nil { - t.Fatal("expected an error for invalid AdditionalProperties, but got nil") - } - if !strings.Contains(err.Error(), "must be a boolean or a schema") { - t.Errorf("error message should mention 'must be a boolean or a schema', but was: %s", err) - } - }) } diff --git a/core/tool.go b/core/tool.go index 63cd668..4c70917 100644 --- a/core/tool.go +++ b/core/tool.go @@ -15,17 +15,15 @@ package core import ( - "bytes" "context" "encoding/json" "fmt" - "io" - "net/http" "reflect" "strings" "maps" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" "golang.org/x/oauth2" ) @@ -34,8 +32,7 @@ type ToolboxTool struct { name string description string parameters []ParameterSchema - invocationURL string - httpClient *http.Client + transport transport.Transport authTokenSources map[string]oauth2.TokenSource boundParams map[string]any requiredAuthnParams map[string][]string @@ -43,8 +40,6 @@ type ToolboxTool struct { clientHeaderSources map[string]oauth2.TokenSource } -const toolInvokeSuffix = "/invoke" - // Name returns the tool's name. func (tt *ToolboxTool) Name() string { return tt.name @@ -189,8 +184,7 @@ func (tt *ToolboxTool) cloneToolboxTool() *ToolboxTool { newTt := &ToolboxTool{ name: tt.name, description: tt.description, - invocationURL: tt.invocationURL, - httpClient: tt.httpClient, + transport: tt.transport, parameters: make([]ParameterSchema, len(tt.parameters)), authTokenSources: make(map[string]oauth2.TokenSource, len(tt.authTokenSources)), boundParams: make(map[string]any, len(tt.boundParams)), @@ -244,11 +238,8 @@ func (tt *ToolboxTool) cloneToolboxTool() *ToolboxTool { // 'result' field) or a raw string. Returns an error if any step of the // process fails. func (tt *ToolboxTool) Invoke(ctx context.Context, input map[string]any) (any, error) { - if tt.httpClient == nil { - return nil, fmt.Errorf("http client is not set for toolbox tool '%s'", tt.name) - } - // Before proceeding, ensure all authentication tokens required by the tool are available. + // Ensure all authentication tokens required by the tool are available. if len(tt.requiredAuthnParams) > 0 || len(tt.requiredAuthzTokens) > 0 { reqAuthServices := make(map[string]struct{}) for _, services := range tt.requiredAuthnParams { @@ -274,67 +265,34 @@ func (tt *ToolboxTool) Invoke(ctx context.Context, input map[string]any) (any, e return nil, fmt.Errorf("tool payload processing failed: %w", err) } - payloadBytes, err := json.Marshal(finalPayload) - if err != nil { - return nil, fmt.Errorf("failed to marshal tool payload for API call: %w", err) - } + resolvedHeaders := make(map[string]string) - // Assemble the API request - req, err := http.NewRequestWithContext(ctx, "POST", tt.invocationURL, bytes.NewBuffer(payloadBytes)) - if err != nil { - return nil, fmt.Errorf("failed to create API request for tool '%s': %w", tt.name, err) - } - req.Header.Set("Content-Type", "application/json") - - // Apply client-wide headers. - for name, source := range tt.clientHeaderSources { - token, tokenErr := source.Token() - if tokenErr != nil { - return nil, fmt.Errorf("failed to resolve client header '%s': %w", name, tokenErr) - } - req.Header.Set(name, token.AccessToken) - } - // Apply tool-specific authentication headers. - for authService, source := range tt.authTokenSources { - token, tokenErr := source.Token() - if tokenErr != nil { - return nil, fmt.Errorf("failed to get token for service '%s' for tool '%s': %w", authService, tt.name, tokenErr) + // Resolve Client Headers + for k, source := range tt.clientHeaderSources { + token, err := source.Token() + if err != nil { + return nil, fmt.Errorf("failed to resolve client header %s: %w", k, err) } - headerName := fmt.Sprintf("%s_token", authService) - req.Header.Set(headerName, token.AccessToken) + resolvedHeaders[k] = token.AccessToken } - // API call execution - resp, err := tt.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("API call to tool '%s' failed: %w", tt.name, err) + // Resolve Auth Headers + for name, source := range tt.authTokenSources { + token, err := source.Token() + if err != nil { + return nil, fmt.Errorf("failed to resolve auth token %s: %w", name, err) + } + // Toolbox HTTP protocol expects the suffix "_token" + headerName := fmt.Sprintf("%s_token", name) + resolvedHeaders[headerName] = token.AccessToken } - defer resp.Body.Close() - responseBody, err := io.ReadAll(resp.Body) + response, err := tt.transport.InvokeTool(ctx, tt.name, finalPayload, resolvedHeaders) if err != nil { - return nil, fmt.Errorf("failed to read API response body for tool '%s': %w", tt.name, err) - } - - // Handle non-successful status codes - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var errorResponse map[string]any - if jsonErr := json.Unmarshal(responseBody, &errorResponse); jsonErr == nil { - if errMsg, ok := errorResponse["error"].(string); ok { - return nil, fmt.Errorf("tool '%s' API returned error status %d: %s", tt.name, resp.StatusCode, errMsg) - } - } - return nil, fmt.Errorf("tool '%s' API returned unexpected status: %d %s, body: %s", tt.name, resp.StatusCode, resp.Status, string(responseBody)) + return nil, err } - // For successful responses, attempt to extract the 'result' field. - var apiResult map[string]any - if err := json.Unmarshal(responseBody, &apiResult); err == nil { - if result, ok := apiResult["result"]; ok { - return result, nil - } - } - return string(responseBody), nil + return response, nil } // validateAndBuildPayload performs manual type validation and applies bound parameters. @@ -366,7 +324,7 @@ func (tt *ToolboxTool) validateAndBuildPayload(input map[string]any) (map[string // If the parameter is a valid unbound parameter, validate its type. if isUnbound { - if err := param.validateType(value); err != nil { + if err := param.ValidateType(value); err != nil { return nil, err } } diff --git a/core/tool_test.go b/core/tool_test.go index dc9ebbb..b222ec0 100644 --- a/core/tool_test.go +++ b/core/tool_test.go @@ -27,9 +27,25 @@ import ( "strings" "testing" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/toolboxtransport" "golang.org/x/oauth2" ) +// Dummy transport for tests +type dummyTransport struct{} + +func (d *dummyTransport) BaseURL() string { return "" } +func (d *dummyTransport) GetTool(ctx context.Context, name string, h map[string]string) (*transport.ManifestSchema, error) { + return nil, nil +} +func (d *dummyTransport) ListTools(ctx context.Context, set string, h map[string]string) (*transport.ManifestSchema, error) { + return nil, nil +} +func (d *dummyTransport) InvokeTool(ctx context.Context, name string, p map[string]any, h map[string]string) (any, error) { + return nil, nil +} + func TestToolboxTool_Getters(t *testing.T) { sampleParams := []ParameterSchema{ {Name: "param_one", Type: "string"}, @@ -40,6 +56,7 @@ func TestToolboxTool_Getters(t *testing.T) { name: "my-test-tool", description: "A tool specifically for testing purposes.", parameters: sampleParams, + transport: &dummyTransport{}, } t.Run("Name Method Returns Correct Value", func(t *testing.T) { @@ -78,6 +95,7 @@ func TestToolboxTool_Getters(t *testing.T) { t.Run("Handles Case With No Parameters", func(t *testing.T) { emptyTool := &ToolboxTool{ parameters: []ParameterSchema{}, + transport: &dummyTransport{}, } params := emptyTool.Parameters() @@ -160,6 +178,7 @@ func TestToolFrom(t *testing.T) { authTokenSources: map[string]oauth2.TokenSource{ "google": &mockTokenSource{}, // Auth source already set on parent }, + transport: &dummyTransport{}, } getTestTool := func() *ToolboxTool { @@ -241,9 +260,11 @@ func TestToolFrom(t *testing.T) { func TestCloneToolboxTool(t *testing.T) { // 1. Setup an original tool with populated maps and slices to test deep copying. + originalTransport := &dummyTransport{} originalTool := &ToolboxTool{ name: "original_tool", description: "An original tool to be cloned.", + transport: originalTransport, parameters: []ParameterSchema{ {Name: "p1", Type: "string"}, }, @@ -272,6 +293,10 @@ func TestCloneToolboxTool(t *testing.T) { t.Fatal("Initial clone is not deeply equal to the original") } + if clone.transport != originalTool.transport { + t.Error("Clone should share the same transport reference") + } + t.Run("Negative Test - modifying clone's boundParams map", func(t *testing.T) { clone.boundParams["b2"] = "newValue" delete(clone.boundParams, "b1") @@ -469,12 +494,13 @@ func (ft *failingTransport) RoundTrip(req *http.Request) (*http.Response, error) func TestToolboxTool_Invoke(t *testing.T) { // A base tool for successful invocations - createBaseTool := func(httpClient *http.Client, invocationURL string) *ToolboxTool { + createBaseTool := func(httpClient *http.Client, baseURL string) *ToolboxTool { + tr := toolboxtransport.New(baseURL, httpClient) + return &ToolboxTool{ - name: "weather", - description: "Get the weather", - invocationURL: invocationURL, - httpClient: httpClient, + name: "weather", + description: "Get the weather", + transport: tr, parameters: []ParameterSchema{ {Name: "city", Type: "string"}, }, @@ -717,7 +743,7 @@ func TestToolboxTool_Invoke(t *testing.T) { if err == nil { t.Fatal("Expected an error from a failed API call, but got nil") } - if !strings.Contains(err.Error(), "API call to tool 'weather' failed") { + if !strings.Contains(err.Error(), "HTTP call to tool 'weather' failed") { t.Errorf("Incorrect error message for failed API call. Got: %v", err) } }) @@ -733,7 +759,7 @@ func TestToolboxTool_Invoke(t *testing.T) { if err == nil { t.Fatal("Expected an error from a failed API call, but got nil") } - if !strings.Contains(err.Error(), "API call to tool 'weather' failed") { + if !strings.Contains(err.Error(), "HTTP call to tool 'weather' failed") { t.Errorf("Incorrect error message for failed API call. Got: %v", err) } }) @@ -751,7 +777,7 @@ func TestToolboxTool_Invoke(t *testing.T) { if err == nil { t.Fatal("Expected an error from a failing response body read, but got nil") } - if !strings.Contains(err.Error(), "failed to read API response body") { + if !strings.Contains(err.Error(), "failed to read response body") { t.Errorf("Incorrect error message for failed body read. Got: %v", err) } }) diff --git a/core/utils.go b/core/utils.go index c0d9bf2..fecdae3 100644 --- a/core/utils.go +++ b/core/utils.go @@ -15,11 +15,8 @@ package core import ( - "context" "encoding/json" "fmt" - "io" - "net/http" "golang.org/x/oauth2" ) @@ -147,86 +144,17 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { }, nil } -// resolveAndApplyHeaders iterates through a map of token sources, retrieves a -// token from each, and applies it as a header to the given HTTP request. -// -// Inputs: -// - clientHeaderSources: A map where the key is the HTTP header name and the -// value is the TokenSource that provides the header's value. -// - req: The HTTP request to which the headers will be added. This request is -// modified in place. -// -// Returns: -// -// An error if retrieving a token from any source fails, otherwise nil. -func resolveAndApplyHeaders( - clientHeaderSources map[string]oauth2.TokenSource, - req *http.Request, -) error { - for name, source := range clientHeaderSources { - // Retrieve the token +// Helper to resolve client-level headers +func resolveClientHeaders(clientHeaderSources map[string]oauth2.TokenSource) (map[string]string, error) { + resolved := make(map[string]string) + for k, source := range clientHeaderSources { token, err := source.Token() if err != nil { - return fmt.Errorf("failed to resolve header '%s': %w", name, err) + return nil, fmt.Errorf("failed to resolve client header '%s': %w", k, err) } - // Set the header on the request object. - req.Header.Set(name, token.AccessToken) - } - return nil -} - -// loadManifest is an internal helper for fetching manifests from the Toolbox server. -// Inputs: -// - ctx: The context to control the lifecycle of the HTTP request, including -// cancellation. -// - url: The specific URL from which to fetch the manifest. -// - httpClient: The http.Client used to execute the request. -// - clientHeaderSources: A map of token sources to be resolved and applied as -// headers to the request. -// -// Returns: -// -// A pointer to the successfully parsed ManifestSchema and a nil error, or a -// nil ManifestSchema and a descriptive error if any part of the process fails. -func loadManifest(ctx context.Context, url string, httpClient *http.Client, - clientHeaderSources map[string]oauth2.TokenSource) (*ManifestSchema, error) { - // Create a new GET request with a context for cancellation. - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request to %s: %w", url, err) - } - - // Add all client-level headers to the request - if err := resolveAndApplyHeaders(clientHeaderSources, req); err != nil { - return nil, fmt.Errorf("failed to apply client headers: %w", err) - } - - // Execute the HTTP request. - resp, err := httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to make HTTP request to %s: %w", url, err) - } - defer resp.Body.Close() - - // Check for non-successful status codes and include the response body - // for better debugging. - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("server returned non-OK status: %d %s, body: %s", resp.StatusCode, resp.Status, string(bodyBytes)) - } - - // Read the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - // Unmarshal the JSON body into the ManifestSchema struct. - var manifest ManifestSchema - if err = json.Unmarshal(body, &manifest); err != nil { - return nil, fmt.Errorf("unable to parse manifest correctly: %w", err) + resolved[k] = token.AccessToken } - return &manifest, nil + return resolved, nil } // schemaToMap recursively converts a ParameterSchema to a map with it's type and description. diff --git a/core/utils_test.go b/core/utils_test.go index 6bb38fb..5371dbd 100644 --- a/core/utils_test.go +++ b/core/utils_test.go @@ -17,17 +17,13 @@ package core import ( - "context" - "encoding/json" "errors" - "net/http" - "net/http/httptest" "reflect" "sort" - "strings" "testing" - "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) @@ -128,170 +124,64 @@ func TestIdentifyAuthRequirements(t *testing.T) { }) } -func TestResolveAndApplyHeaders(t *testing.T) { - t.Run("Successfully applies headers", func(t *testing.T) { - // Setup - client, _ := NewToolboxClient("test-url") - client.clientHeaderSources["Authorization"] = &mockTokenSource{token: &oauth2.Token{AccessToken: "token123"}} - client.clientHeaderSources["X-Api-Key"] = &mockTokenSource{token: &oauth2.Token{AccessToken: "key456"}} - - req, _ := http.NewRequest("GET", "https://toolbox.example.com", nil) - - // Action - err := resolveAndApplyHeaders(client.clientHeaderSources, req) - - // Assert - if err != nil { - t.Fatalf("Expected no error, but got: %v", err) - } - if auth := req.Header.Get("Authorization"); auth != "token123" { - t.Errorf("Expected Authorization header 'token123', got %q", auth) - } - if key := req.Header.Get("X-Api-Key"); key != "key456" { - t.Errorf("Expected X-Api-Key header 'key456', got %q", key) - } - }) - - t.Run("Returns error when a token source fails", func(t *testing.T) { - client, _ := NewToolboxClient("test-url") - client.clientHeaderSources["Authorization"] = &failingTokenSource{} - - req, _ := http.NewRequest("GET", "https://toolbox.example.com", nil) - - err := resolveAndApplyHeaders(client.clientHeaderSources, req) - - if err == nil { - t.Fatal("Expected an error, but got nil") - } - if !strings.Contains(err.Error(), "failed to resolve header 'Authorization'") { - t.Errorf("Error message missing expected text. Got: %s", err.Error()) - } - if !strings.Contains(err.Error(), "token source failed as designed") { - t.Errorf("Error message did not wrap the underlying error. Got: %s", err.Error()) - } - }) +// mockTokenSource is a helper to simulate token generation behavior. +type mockTokenSource struct { + token *oauth2.Token + err error } -func TestLoadManifest(t *testing.T) { - validManifest := ManifestSchema{ - ServerVersion: "v1", - Tools: map[string]ToolSchema{ - "toolA": {Description: "Does a thing"}, - }, +func (m *mockTokenSource) Token() (*oauth2.Token, error) { + if m.err != nil { + return nil, m.err } - validManifestJSON, _ := json.Marshal(validManifest) - - t.Run("Successfully loads and unmarshals manifest", func(t *testing.T) { - // Setup mock server - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - t.Errorf("Server did not receive expected Authorization header") - w.WriteHeader(http.StatusUnauthorized) - return - } - w.WriteHeader(http.StatusOK) - if _, err := w.Write(validManifestJSON); err != nil { - t.Fatalf("Mock server failed to write response: %v", err) - } - })) - defer server.Close() - - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) - client.clientHeaderSources["Authorization"] = oauth2.StaticTokenSource(&oauth2.Token{ - AccessToken: "Bearer test-token", - }) - - manifest, err := loadManifest(context.Background(), server.URL, client.httpClient, client.clientHeaderSources) + return m.token, nil +} - if err != nil { - t.Fatalf("Expected no error, but got: %v", err) - } - if !reflect.DeepEqual(*manifest, validManifest) { - t.Errorf("Returned manifest does not match expected value") +func TestResolveClientHeaders(t *testing.T) { + t.Run("Success_MultipleHeaders", func(t *testing.T) { + // Setup input map directly + sources := map[string]oauth2.TokenSource{ + "Authorization": &mockTokenSource{token: &oauth2.Token{AccessToken: "bearer-token"}}, + "X-Custom-Header": &mockTokenSource{token: &oauth2.Token{AccessToken: "custom-value"}}, } - }) - - t.Run("Fails when header resolution fails", func(t *testing.T) { - // Setup client with a failing token source - client, _ := NewToolboxClient("any-url") - client.clientHeaderSources["Authorization"] = &failingTokenSource{} // Use the failing mock - // Action - _, err := loadManifest(context.Background(), "http://example.com", client.httpClient, client.clientHeaderSources) + // Execute function directly + headers, err := resolveClientHeaders(sources) - // Assert - if err == nil { - t.Fatal("Expected an error due to header resolution failure, but got nil") - } - if !strings.Contains(err.Error(), "failed to apply client headers") { - t.Errorf("Error message missing expected text. Got: %s", err.Error()) - } + // Verify + require.NoError(t, err) + assert.Len(t, headers, 2) + assert.Equal(t, "bearer-token", headers["Authorization"]) + assert.Equal(t, "custom-value", headers["X-Custom-Header"]) }) - t.Run("Fails when server returns non-200 status", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte("internal server error")); err != nil { - t.Fatalf("Mock server failed to write response: %v", err) - } - })) - defer server.Close() - - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) + t.Run("Success_Empty", func(t *testing.T) { + sources := make(map[string]oauth2.TokenSource) - _, err := loadManifest(context.Background(), server.URL, client.httpClient, client.clientHeaderSources) + headers, err := resolveClientHeaders(sources) - if err == nil { - t.Fatal("Expected an error due to non-OK status, but got nil") - } - if !strings.Contains(err.Error(), "server returned non-OK status: 500") { - t.Errorf("Error message missing expected status code. Got: %s", err.Error()) - } + require.NoError(t, err) + assert.Empty(t, headers) + assert.NotNil(t, headers) // Ensure we get a map, not nil }) - t.Run("Fails when response body is invalid JSON", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(`{"serverVersion": "bad-json",`)); err != nil { - t.Fatalf("Mock server failed to write response: %v", err) - } - })) - defer server.Close() - - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) - - _, err := loadManifest(context.Background(), server.URL, client.httpClient, client.clientHeaderSources) - - if err == nil { - t.Fatal("Expected an error due to JSON unmarshal failure, but got nil") - } - if !strings.Contains(err.Error(), "unable to parse manifest correctly") { - t.Errorf("Error message missing expected text. Got: %s", err.Error()) + t.Run("Failure_SingleSourceError", func(t *testing.T) { + // Setup: One valid source, one failing source + sources := map[string]oauth2.TokenSource{ + "Valid-Header": &mockTokenSource{token: &oauth2.Token{AccessToken: "ok"}}, + "Broken-Header": &mockTokenSource{err: errors.New("network timeout")}, } - }) - - t.Run("Fails when context is canceled", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client, _ := NewToolboxClient(server.URL, WithHTTPClient(server.Client())) - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) - defer cancel() + // Execute + headers, err := resolveClientHeaders(sources) - // Action - _, err := loadManifest(ctx, server.URL, client.httpClient, client.clientHeaderSources) + // Verify + require.Error(t, err) + assert.Nil(t, headers, "Should return nil map on error") - // Assert - if err == nil { - t.Fatal("Expected an error due to context cancellation, but got nil") - } - if !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("Expected context.DeadlineExceeded error, but got a different error: %v", err) - } + // Check error wrapping + assert.Contains(t, err.Error(), "failed to resolve client header 'Broken-Header'") + assert.Contains(t, err.Error(), "network timeout") }) } diff --git a/tbadk/e2e_test.go b/tbadk/e2e_test.go index 5fc44ef..42c0c84 100644 --- a/tbadk/e2e_test.go +++ b/tbadk/e2e_test.go @@ -247,14 +247,14 @@ func TestE2E_LoadErrors(t *testing.T) { client := newClient(t) _, err := client.LoadTool("non-existent-tool", context.Background()) require.Error(t, err) - assert.Contains(t, err.Error(), "server returned non-OK status: 404") + assert.Contains(t, err.Error(), "tool 'non-existent-tool' not found") }) t.Run("test_load_non_existent_toolset", func(t *testing.T) { client := newClient(t) _, err := client.LoadToolset("non-existent-toolset", context.Background()) require.Error(t, err) - assert.Contains(t, err.Error(), "server returned non-OK status: 404") + assert.Contains(t, err.Error(), "toolset does not exist") }) t.Run("test_new_client_with_nil_option", func(t *testing.T) { @@ -406,7 +406,7 @@ func TestE2E_Auth(t *testing.T) { _, err = authedTool.Run(testToolCtx, map[string]any{"id": "2"}) require.Error(t, err) - assert.Contains(t, err.Error(), "tool invocation not authorized") + assert.Contains(t, err.Error(), "unauthorized Tool call") }) t.Run("test_run_tool_auth", func(t *testing.T) { @@ -472,7 +472,7 @@ func TestE2E_Auth(t *testing.T) { _, err = tool.Run(testToolCtx, map[string]any{"id": "2"}) require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get token for service 'my-test-auth'") + assert.Contains(t, err.Error(), "failed to get token for header my-test-auth_token") assert.Contains(t, err.Error(), "token source failed as designed") }) } diff --git a/tbadk/utils_test.go b/tbadk/utils_test.go index c71a30a..d6418c2 100644 --- a/tbadk/utils_test.go +++ b/tbadk/utils_test.go @@ -20,47 +20,102 @@ import ( "net/http" "net/http/httptest" "reflect" - "strings" "testing" "github.com/googleapis/mcp-toolbox-sdk-go/core" ) +// convertParamsToJSONSchema reconstructs a raw JSON schema from the SDK's internal ParameterSchema. +// This is needed because the Mock Server must send "raw" JSON, which the Client then parses back into structs. +func convertParamsToJSONSchema(params []core.ParameterSchema) map[string]any { + properties := make(map[string]any) + required := []string{} + + for _, p := range params { + prop := map[string]any{ + "type": p.Type, + "description": p.Description, + } + properties[p.Name] = prop + if p.Required { + required = append(required, p.Name) + } + } + + return map[string]any{ + "type": "object", + "properties": properties, + "required": required, + } +} + func createCoreTool(t *testing.T, toolName string, schema core.ToolSchema) (*core.ToolboxTool, *httptest.Server) { t.Helper() - // Create a mock manifest - manifest := core.ManifestSchema{ - ServerVersion: "v1", - Tools: map[string]core.ToolSchema{ - toolName: schema, - }, - } - manifestJSON, err := json.Marshal(manifest) - if err != nil { - t.Fatalf("Failed to marshal mock manifest: %v", err) + // Prepare the Tool definition in MCP JSON format + mcpToolDef := map[string]any{ + "name": toolName, + "description": schema.Description, + "inputSchema": convertParamsToJSONSchema(schema.Parameters), } - // Setup a mock server to serve this manifest. + // Setup a Mock MCP Server (JSON-RPC 2.0) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle the specific tool manifest request from LoadTool - if strings.HasSuffix(r.URL.Path, "/api/tool/"+toolName) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(manifestJSON) + var req struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + ID any `json:"id"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + var result any + + // Handle MCP Protocol Lifecycle + switch req.Method { + case "initialize": + // Handshake + result = map[string]any{ + "protocolVersion": "2025-06-18", // Matches latest default + "capabilities": map[string]any{"tools": map[string]any{}}, + "serverInfo": map[string]any{ + "name": "mock-server", + "version": "1.0.0", + }, + } + case "notifications/initialized": + // Confirmation (No response needed) + return + case "tools/list": + // List available tools + result = map[string]any{ + "tools": []any{mcpToolDef}, + } + default: + // Ignore other methods for this test return } - http.NotFound(w, r) + + // Send JSON-RPC Response + resp := map[string]any{ + "jsonrpc": "2.0", + "id": req.ID, + "result": result, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) })) - // Create a real client pointing to the mock server. + // Create Client, defaults to Latest MCP (v2025-06-18) client, err := core.NewToolboxClient(server.URL, core.WithHTTPClient(server.Client())) if err != nil { server.Close() t.Fatalf("Failed to create ToolboxClient: %v", err) } - // Load the tool, which returns the real *core.ToolboxTool instance. + // 4. Load the tool (Triggers initialize -> tools/list) tool, err := client.LoadTool(toolName, context.Background()) if err != nil { server.Close() @@ -69,6 +124,7 @@ func createCoreTool(t *testing.T, toolName string, schema core.ToolSchema) (*cor return tool, server } + func TestToADKTool(t *testing.T) { t.Run("Success - Happy Path with parameters", func(t *testing.T) { @@ -80,11 +136,11 @@ func TestToADKTool(t *testing.T) { }, } - // Create Core Tool + // Create Core Tool via MCP Mock coreTool, server := createCoreTool(t, "getWeather", toolSchema) - defer server.Close() // Ensure server is closed after the test + defer server.Close() - // Convert the Core tool to ADK Tool + // Convert to ADK Tool adkTool, err := toADKTool(coreTool) if err != nil { @@ -93,11 +149,8 @@ func TestToADKTool(t *testing.T) { if adkTool.funcDeclaration == nil { t.Fatal("adkTool.funcDeclaration is nil") } - if adkTool.ToolboxTool != coreTool { - t.Error("adkTool.ToolboxTool does not point to the original tool") - } - // Verify the FunctionDeclaration fields + // Verify Basic Fields if got, want := adkTool.funcDeclaration.Name, "getWeather"; got != want { t.Errorf("funcDeclaration.Name = %q, want %q", got, want) } @@ -105,17 +158,17 @@ func TestToADKTool(t *testing.T) { t.Errorf("funcDeclaration.Description = %q, want %q", got, want) } - // Verify the parameters schema + // Verify Schema Conversion var params map[string]any schema, err := adkTool.InputSchema() if err != nil { t.Error("Failed to fetch input schema", err) } - err = json.Unmarshal(schema, ¶ms) - if err != nil { + if err := json.Unmarshal(schema, ¶ms); err != nil { t.Fatalf("Failed to unmarshal generated parameters schema: %v", err) } + // Expected JSON Structure expectedParamsJSON := ` { "type": "object", @@ -134,10 +187,9 @@ func TestToADKTool(t *testing.T) { }) t.Run("Success - No Parameters", func(t *testing.T) { - // Define schema with no parameters toolSchema := core.ToolSchema{ Description: "A tool with no params", - Parameters: nil, // Test nil slice + Parameters: nil, } coreTool, server := createCoreTool(t, "noParams", toolSchema) @@ -157,12 +209,8 @@ func TestToADKTool(t *testing.T) { if err != nil { t.Error("Failed to fetch input schema", err) } - err = json.Unmarshal(schema, ¶ms) - if err != nil { - t.Fatalf("Failed to unmarshal generated parameters schema: %v", err) - } + _ = json.Unmarshal(schema, ¶ms) - // core.ToolboxTool.InputSchema() correctly returns an empty properties map expectedParamsJSON := `{"type": "object", "properties": {}}` var expectedParams map[string]any _ = json.Unmarshal([]byte(expectedParamsJSON), &expectedParams) diff --git a/tbgenkit/tbgenkit_test.go b/tbgenkit/tbgenkit_test.go index d14dac9..fd35d80 100644 --- a/tbgenkit/tbgenkit_test.go +++ b/tbgenkit/tbgenkit_test.go @@ -324,7 +324,7 @@ func TestToGenkitTool_Auth(t *testing.T) { _, err = genkitTool.RunRaw(ctx, map[string]any{"id": "2"}) require.Error(t, err) - assert.Contains(t, err.Error(), "tool invocation not authorized") + assert.Contains(t, err.Error(), "unauthorized Tool call") }) t.Run("test_run_tool_auth", func(t *testing.T) { @@ -422,7 +422,7 @@ func TestToGenkitTool_Auth(t *testing.T) { _, err = genkitTool.RunRaw(ctx, map[string]any{"id": "2"}) require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get token for service 'my-test-auth'") + assert.Contains(t, err.Error(), "failed to get token for header my-test-auth_token") assert.Contains(t, err.Error(), "token source failed as designed") }) }