diff --git a/core/transport/interface.go b/core/transport/interface.go new file mode 100644 index 00000000..55106011 --- /dev/null +++ b/core/transport/interface.go @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "context" +) + +type Transport interface { + BaseURL() string + + // GetTool fetches a single tool manifest. + GetTool(ctx context.Context, toolName string, headers map[string]string) (*ManifestSchema, error) + + // ListTools fetches available tools. + ListTools(ctx context.Context, toolsetName string, headers map[string]string) (*ManifestSchema, error) + + // InvokeTool executes a tool. + InvokeTool(ctx context.Context, toolName string, payload map[string]any, headers map[string]string) (any, error) +} diff --git a/core/transport/toolboxtransport/http.go b/core/transport/toolboxtransport/http.go new file mode 100644 index 00000000..2bf3ab66 --- /dev/null +++ b/core/transport/toolboxtransport/http.go @@ -0,0 +1,177 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolboxtransport + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" +) + +type ToolboxTransport struct { + baseURL string + httpClient *http.Client +} + +// Ensure that ToolboxTransport implements the Transport interface. +var _ transport.Transport = &ToolboxTransport{} + +func New(baseURL string, client *http.Client) transport.Transport { + return &ToolboxTransport{baseURL: baseURL, httpClient: client} +} + +func (t *ToolboxTransport) BaseURL() string { return t.baseURL } + +func (t *ToolboxTransport) GetTool(ctx context.Context, toolName string, headers map[string]string) (*transport.ManifestSchema, error) { + fullURL, err := url.JoinPath(t.baseURL, "api", "tool", toolName) + if err != nil { + return nil, fmt.Errorf("failed to construct URL: %w", err) + } + return t.LoadManifest(ctx, fullURL, headers) +} + +func (t *ToolboxTransport) ListTools(ctx context.Context, toolsetName string, headers map[string]string) (*transport.ManifestSchema, error) { + fullURL, err := url.JoinPath(t.baseURL, "api", "toolset", toolsetName) + if err != nil { + return nil, fmt.Errorf("failed to construct URL: %w", err) + } + if toolsetName == "" && !strings.HasSuffix(fullURL, "/") { + fullURL += "/" + } + return t.LoadManifest(ctx, fullURL, headers) +} + +// LoadManifest is an internal helper for fetching manifests from the Toolbox server. +// Inputs: +// - ctx: The context to control the lifecycle of the HTTP request, including +// cancellation. +// - url: The specific URL from which to fetch the manifest. +// - headers: A map of token sources to be resolved and applied as +// headers to the request. +// +// Returns: +// +// A pointer to the successfully parsed ManifestSchema and a nil error, or a +// nil ManifestSchema and a descriptive error if any part of the process fails. +func (t *ToolboxTransport) LoadManifest(ctx context.Context, url string, headers map[string]string) (*transport.ManifestSchema, error) { + // Create a new GET request with a context for cancellation. + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request : %w", err) + } + + // Add all headers to the request + for k, v := range headers { + req.Header.Set(k, v) + } + + // Execute the HTTP request. + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make HTTP request: %w", err) + } + defer resp.Body.Close() + + // Check for non-successful status codes and include the response body + // for better debugging. + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("server returned non-OK status: %d %s, body: %s", resp.StatusCode, resp.Status, string(bodyBytes)) + } + + // Read the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Unmarshal the JSON body into the ManifestSchema struct. + var manifest transport.ManifestSchema + if err = json.Unmarshal(body, &manifest); err != nil { + return nil, fmt.Errorf("unable to parse manifest correctly: %w", err) + } + return &manifest, nil +} + +func (t *ToolboxTransport) InvokeTool(ctx context.Context, toolName string, payload map[string]any, headers map[string]string) (any, error) { + if !strings.HasPrefix(t.baseURL, "https://") { + log.Println("WARNING: Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication.") + } + + if t.httpClient == nil { + return nil, fmt.Errorf("http client is not set for toolbox tool '%s'", toolName) + } + + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool payload for API call: %w", err) + } + invocationURL, err := url.JoinPath(t.baseURL, "api", "tool", toolName, "invoke") + if err != nil { + return nil, fmt.Errorf("failed to construct URL: %w", err) + } + + // Assemble the API request + req, err := http.NewRequestWithContext(ctx, "POST", invocationURL, bytes.NewBuffer(payloadBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create API request for tool '%s': %w", toolName, err) + } + req.Header.Set("Content-Type", "application/json") + + // Add all headers to the request + for k, v := range headers { + req.Header.Set(k, v) + } + + // API call execution + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP call to tool '%s' failed: %w", toolName, err) + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body for tool '%s': %w", toolName, err) + } + + // Handle non-successful status codes + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var errorResponse map[string]any + if jsonErr := json.Unmarshal(responseBody, &errorResponse); jsonErr == nil { + if errMsg, ok := errorResponse["error"].(string); ok { + return nil, fmt.Errorf("tool '%s' API returned error status %d: %s", toolName, resp.StatusCode, errMsg) + } + } + return nil, fmt.Errorf("tool '%s' API returned unexpected status: %d %s, body: %s", toolName, resp.StatusCode, resp.Status, string(responseBody)) + } + + // For successful responses, attempt to extract the 'result' field. + var apiResult map[string]any + if err := json.Unmarshal(responseBody, &apiResult); err == nil { + if result, ok := apiResult["result"]; ok { + return result, nil + } + } + return string(responseBody), nil +} diff --git a/core/transport/toolboxtransport/http_test.go b/core/transport/toolboxtransport/http_test.go new file mode 100644 index 00000000..7e6d78d8 --- /dev/null +++ b/core/transport/toolboxtransport/http_test.go @@ -0,0 +1,493 @@ +//go:build unit + +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package toolboxtransport_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport" + "github.com/googleapis/mcp-toolbox-sdk-go/core/transport/toolboxtransport" +) + +const ( + testBaseURL = "http://fake-toolbox-server.com" + testToolName = "test_tool" +) + +func TestBaseURL(t *testing.T) { + tr := toolboxtransport.New(testBaseURL, http.DefaultClient) + if tr.BaseURL() != testBaseURL { + t.Errorf("expected BaseURL %q, got %q", testBaseURL, tr.BaseURL()) + } +} + +func TestGetTool_Success(t *testing.T) { + // Mock Manifest Response + mockManifest := transport.ManifestSchema{ + ServerVersion: "1.0.0", + Tools: map[string]transport.ToolSchema{ + testToolName: { + Description: "A test tool", + Parameters: []transport.ParameterSchema{ + {Name: "param1", Type: "string", Description: "The first parameter.", Required: true}, + }, + }, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify URL + if r.URL.Path != "/api/tool/"+testToolName { + t.Errorf("unexpected path: %s", r.URL.Path) + } + // Verify Headers + if r.Header.Get("X-Test-Header") != "value" { + t.Errorf("missing or incorrect header X-Test-Header") + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(mockManifest) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + headers := map[string]string{"X-Test-Header": "value"} + + result, err := tr.GetTool(context.Background(), testToolName, headers) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.ServerVersion != "1.0.0" { + t.Errorf("expected version 1.0.0, got %s", result.ServerVersion) + } + if tool, ok := result.Tools[testToolName]; !ok { + t.Errorf("tool %s not found in result", testToolName) + } else if tool.Description != "A test tool" { + t.Errorf("expected description 'A test tool', got %q", tool.Description) + } +} + +func TestGetTool_Failure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal Server Error")) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.GetTool(context.Background(), testToolName, nil) + + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "500") || !strings.Contains(err.Error(), "Internal Server Error") { + t.Errorf("expected error message to contain 500 and Internal Server Error, got: %v", err) + } +} + +func TestListTools_Success(t *testing.T) { + mockManifest := transport.ManifestSchema{ServerVersion: "1.0.0", Tools: map[string]transport.ToolSchema{}} + + testCases := []struct { + name string + toolsetName string + expectedPath string + }{ + {"With Toolset", "my_toolset", "/api/toolset/my_toolset"}, + {"Without Toolset", "", "/api/toolset/"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != tc.expectedPath { + t.Errorf("expected path %q, got %q", tc.expectedPath, r.URL.Path) + } + _ = json.NewEncoder(w).Encode(mockManifest) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.ListTools(context.Background(), tc.toolsetName, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestInvokeTool_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify Path & Method + if r.URL.Path != "/api/tool/"+testToolName+"/invoke" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Method != "POST" { + t.Errorf("unexpected method: %s", r.Method) + } + // Verify Headers + if r.Header.Get("Authorization") != "Bearer token" { + t.Errorf("missing or incorrect Authorization header") + } + // Verify Body + var body map[string]any + _ = json.NewDecoder(r.Body).Decode(&body) + if body["param1"] != "value1" { + t.Errorf("unexpected body param1: %v", body["param1"]) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("success")) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + payload := map[string]any{"param1": "value1"} + headers := map[string]string{"Authorization": "Bearer token"} + + result, err := tr.InvokeTool(context.Background(), testToolName, payload, headers) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := "success" + + if result != expected { + t.Errorf("expected result %s, got %s", expected, result) + } +} + +func TestInvokeTool_Failure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error": "Invalid arguments"}`)) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.InvokeTool(context.Background(), testToolName, map[string]any{}, nil) + + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "Invalid arguments") { + t.Errorf("expected error to contain 'Invalid arguments', got: %v", err) + } +} + +type mockTransport struct { + RoundTripFunc func(req *http.Request) (*http.Response, error) +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return m.RoundTripFunc(req) +} + +func TestInvokeTool_HTTPWarning(t *testing.T) { + // Capture logs to verify the warning + var buf bytes.Buffer + log.SetOutput(&buf) + defer log.SetOutput(os.Stderr) + + // Mock a successful response so InvokeTool completes (or fails gracefully after the log) + dummyResponse := func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBufferString("ok")), + Header: make(http.Header), + }, nil + } + + testCases := []struct { + name string + baseURL string + shouldWarn bool + }{ + {"HTTP", "http://insecure.com", true}, + {"HTTPS", "https://secure.com", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf.Reset() + + client := &http.Client{ + Transport: &mockTransport{RoundTripFunc: dummyResponse}, + } + + tr := toolboxtransport.New(tc.baseURL, client) + + payload := map[string]any{"param": "val"} + + _, _ = tr.InvokeTool(context.Background(), testToolName, payload, nil) + + logOutput := buf.String() + hasWarning := strings.Contains(logOutput, "Sending ID token over HTTP") + + if tc.shouldWarn && !hasWarning { + t.Errorf("expected warning for %s, but got none", tc.name) + } + if !tc.shouldWarn && hasWarning { + t.Errorf("unexpected warning for %s", tc.name) + } + }) + } +} + +func TestLoadManifest(t *testing.T) { + mockJSON := `{"serverVersion":"1.0.0","tools":{"test":{"description":"foo"}}}` + + t.Run("Success", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer token" { + t.Errorf("Missing Authorization header") + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(mockJSON)) + })) + defer server.Close() + + transportConcrete := toolboxtransport.New(server.URL, server.Client()).(*toolboxtransport.ToolboxTransport) + + sources := map[string]string{"Authorization": "Bearer token"} + + manifest, err := transportConcrete.LoadManifest(context.Background(), server.URL+"/some/path", sources) + if err != nil { + t.Fatalf("LoadManifest failed: %v", err) + } + + if manifest.ServerVersion != "1.0.0" { + t.Errorf("unexpected version: %s", manifest.ServerVersion) + } + }) + + t.Run("Failure_Status500", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("oops")) + })) + defer server.Close() + + transportConcrete := toolboxtransport.New(server.URL, server.Client()).(*toolboxtransport.ToolboxTransport) + + _, err := transportConcrete.LoadManifest(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected 500 error, got: %v", err) + } + }) + + t.Run("Failure_BadJSON", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{bad json`)) + })) + defer server.Close() + + transportConcrete := toolboxtransport.New(server.URL, server.Client()).(*toolboxtransport.ToolboxTransport) + + _, err := transportConcrete.LoadManifest(context.Background(), server.URL, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "unable to parse manifest") { + t.Errorf("expected parse error, got: %v", err) + } + }) +} + +func TestLoadManifest_EdgeCases(t *testing.T) { + t.Run("Unreadable JSON Response", func(t *testing.T) { + // Server returns 200 OK but invalid JSON body + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken-manifest`)) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.GetTool(context.Background(), testToolName, nil) + + if err == nil { + t.Fatal("expected error, got nil") + } + // Matches: "unable to parse manifest correctly" + if !strings.Contains(err.Error(), "unable to parse manifest") { + t.Errorf("expected parse error, got: %v", err) + } + }) + + t.Run("Network Error (Server Down)", func(t *testing.T) { + // Start a server to get a valid URL, then immediately close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + url := server.URL + server.Close() // Kill the server + + tr := toolboxtransport.New(url, server.Client()) + _, err := tr.GetTool(context.Background(), testToolName, nil) + + if err == nil { + t.Fatal("expected network error, got nil") + } + // Matches: "failed to make HTTP request" + if !strings.Contains(err.Error(), "failed to make HTTP request") { + t.Errorf("expected request error, got: %v", err) + } + }) + + t.Run("HTTP 500 with Non-JSON Body", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Fatal Server Error")) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.GetTool(context.Background(), testToolName, nil) + + if err == nil { + t.Fatal("expected error, got nil") + } + // Matches: "server returned non-OK status: 500 ... body: Fatal Server Error" + if !strings.Contains(err.Error(), "500") || !strings.Contains(err.Error(), "Fatal Server Error") { + t.Errorf("expected error to contain status and raw body, got: %v", err) + } + }) + + t.Run("NewRequest Error (Bad URL)", func(t *testing.T) { + // Pass a URL with control characters to trigger http.NewRequest failure + tr := toolboxtransport.New("http://bad\nurl.com", http.DefaultClient) + + _, err := tr.GetTool(context.Background(), testToolName, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + // Matches: "failed to create HTTP request" + if !strings.Contains(err.Error(), "invalid control character in URL") { + t.Errorf("unexpected error: %v", err) + } + }) +} + +func TestInvokeTool_EdgeCases(t *testing.T) { + ctx := context.Background() + t.Run("Nil_HTTP_Client", func(t *testing.T) { + // Create transport with nil http client + tr := toolboxtransport.New(testBaseURL, nil) + _, err := tr.InvokeTool(ctx, "tool", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "http client is not set") { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("Unreadable JSON Response", func(t *testing.T) { + // Server returns 200 OK but invalid JSON body + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken-json`)) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + result, err := tr.InvokeTool(context.Background(), testToolName, map[string]any{}, nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // If JSON parsing fails, the transport is designed to return the raw body as string + // This matches the logic: "Fallback for non-enveloped responses" or malformed result envelopes + if resStr, ok := result.(string); !ok || resStr != `{broken-json` { + t.Errorf("expected raw string '{broken-json', got %v", result) + } + }) + + t.Run("Network Error (Server Down)", func(t *testing.T) { + // Start a server to get a valid URL, then immediately close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + url := server.URL + server.Close() // Kill the server + + tr := toolboxtransport.New(url, server.Client()) + _, err := tr.InvokeTool(context.Background(), testToolName, map[string]any{}, nil) + + if err == nil { + t.Fatal("expected network error, got nil") + } + // Error should come from http.Client.Do + if !strings.Contains(err.Error(), "connection refused") && !strings.Contains(err.Error(), "HTTP call to tool") { + t.Errorf("expected connection error, got: %v", err) + } + }) + + t.Run("HTTP 500 with Non-JSON Body", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Fatal Database Error")) + })) + defer server.Close() + + tr := toolboxtransport.New(server.URL, server.Client()) + _, err := tr.InvokeTool(context.Background(), testToolName, map[string]any{}, nil) + + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "500") || !strings.Contains(err.Error(), "Fatal Database Error") { + t.Errorf("expected error to contain status and raw body, got: %v", err) + } + }) + + t.Run("Marshal_Error", func(t *testing.T) { + tr := toolboxtransport.New(testBaseURL, http.DefaultClient) + // Pass a channel which cannot be marshaled to JSON + payload := map[string]any{"bad": make(chan int)} + _, err := tr.InvokeTool(ctx, "tool", payload, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to marshal tool payload") { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("NewRequest_Error", func(t *testing.T) { + tr := toolboxtransport.New("http://bad\nurl.com", http.DefaultClient) + _, err := tr.InvokeTool(ctx, "tool", map[string]any{}, nil) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "invalid control character in URL") { + t.Errorf("unexpected error: %v", err) + } + }) +} diff --git a/core/transport/types.go b/core/transport/types.go new file mode 100644 index 00000000..ad68e7a9 --- /dev/null +++ b/core/transport/types.go @@ -0,0 +1,167 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + "reflect" +) + +// Schema for a tool parameter. +type ParameterSchema struct { + Name string `json:"name"` + Type string `json:"type"` + Required bool `json:"required,omitempty"` + Description string `json:"description"` + AuthSources []string `json:"authSources,omitempty"` + Items *ParameterSchema `json:"items,omitempty"` + AdditionalProperties any `json:"additionalProperties,omitempty"` +} + +// ValidateType is a helper for manual type checking. +func (p *ParameterSchema) ValidateType(value any) error { + if value == nil { + if p.Required { + return fmt.Errorf("parameter '%s' is required but received a nil value", p.Name) + } + return nil + } + + switch p.Type { + case "string": + if _, ok := value.(string); !ok { + return fmt.Errorf("parameter '%s' expects a string, but got %T", p.Name, value) + } + case "integer": + switch value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + default: + return fmt.Errorf("parameter '%s' expects an integer, but got %T", p.Name, value) + } + case "float": + switch value.(type) { + case float32, float64: + default: + return fmt.Errorf("parameter '%s' expects an float, but got %T", p.Name, value) + } + case "boolean": + if _, ok := value.(bool); !ok { + return fmt.Errorf("parameter '%s' expects a boolean, but got %T", p.Name, value) + } + case "array": + v := reflect.ValueOf(value) + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return fmt.Errorf("parameter '%s' expects an array/slice, but got %T", p.Name, value) + } + for i := range v.Len() { + item := v.Index(i).Interface() + + if err := p.Items.ValidateType(item); err != nil { + return fmt.Errorf("error in array '%s' at index %d: %w", p.Name, i, err) + } + } + case "object": + // First, check that the value is a map with string keys. + valMap, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("parameter '%s' expects a map, but got %T", p.Name, value) + } + + switch ap := p.AdditionalProperties.(type) { + // No validation required, allows any type + case bool: + + // Validate type for each value in map + case *ParameterSchema: + // Raise error if the input is a nested map / array + if ap.Type == "object" || ap.Type == "array" { + return fmt.Errorf("invalid schema for object '%s': values cannot be of type '%s'", p.Name, ap.Type) + } + for key, val := range valMap { + if err := ap.ValidateType(val); err != nil { + return fmt.Errorf("error in object '%s' for key '%s': %w", p.Name, key, err) + } + } + + default: + // This is a schema / manifest error. + return fmt.Errorf( + "invalid schema for parameter '%s': AdditionalProperties must be a boolean or a map[string]any, but got %T", + p.Name, + ap, + ) + } + default: + return fmt.Errorf("unknown type '%s' in schema for parameter '%s'", p.Type, p.Name) + } + return nil +} + +// ValidateDefinition checks if the schema itself is well-formed. +func (p *ParameterSchema) ValidateDefinition() error { + if p.Type == "" { + return fmt.Errorf("schema validation failed for '%s': type is missing", p.Name) + } + + switch p.Type { + case "array": + if p.Items == nil { + return fmt.Errorf("parameter '%s' is an array but is missing item type definition", p.Name) + } + // Recursively validate the nested schema's definition. + if err := p.Items.ValidateDefinition(); err != nil { + return err + } + + case "object": + switch ap := p.AdditionalProperties.(type) { + case bool: + // Valid scenario + case *ParameterSchema: + if err := ap.ValidateDefinition(); err != nil { + return err + } + default: + // Any other type is an invalid schema definition. + return fmt.Errorf( + "invalid schema for parameter '%s': AdditionalProperties must be a boolean or a schema, but got %T", + p.Name, + ap, + ) + } + + case "string", "integer", "float", "boolean": + // No type-specific rules for these. + break + + default: + return fmt.Errorf("unknown schema type '%s' for parameter '%s'", p.Type, p.Name) + } + + return nil +} + +// Schema for a tool. +type ToolSchema struct { + Description string `json:"description"` + Parameters []ParameterSchema `json:"parameters"` + AuthRequired []string `json:"authRequired,omitempty"` +} + +// Schema for the Toolbox manifest. +type ManifestSchema struct { + ServerVersion string `json:"serverVersion"` + Tools map[string]ToolSchema `json:"tools"` +} diff --git a/core/transport/types_test.go b/core/transport/types_test.go new file mode 100644 index 00000000..7b50d42e --- /dev/null +++ b/core/transport/types_test.go @@ -0,0 +1,581 @@ +//go:build unit + +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + "strings" + "testing" + "time" +) + +// Tests ParameterSchema with type 'int'. +func TestParameterSchemaInteger(t *testing.T) { + + schema := ParameterSchema{ + Name: "param_name", + Type: "integer", + Description: "integer parameter", + } + + t.Run("Test int param", func(t *testing.T) { + value := 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test int8 param", func(t *testing.T) { + var value int8 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test int16 param", func(t *testing.T) { + var value int16 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test int32 param", func(t *testing.T) { + var value int32 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test int64 param", func(t *testing.T) { + var value int64 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test uint param", func(t *testing.T) { + var value uint = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test uint8 param", func(t *testing.T) { + var value uint8 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test uint16 param", func(t *testing.T) { + var value uint16 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test uint32 param", func(t *testing.T) { + var value uint32 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test uint64 param", func(t *testing.T) { + var value uint64 = 1 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + +} + +// Tests ParameterSchema with type 'string'. +func TestParameterSchemaString(t *testing.T) { + + schema := ParameterSchema{ + Name: "param_name", + Type: "string", + Description: "string parameter", + } + + value := "abc" + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + +} + +// Tests ParameterSchema with type 'boolean'. +func TestParameterSchemaBoolean(t *testing.T) { + + schema := ParameterSchema{ + Name: "param_name", + Type: "boolean", + Description: "boolean parameter", + } + + value := true + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + +} + +// Tests ParameterSchema with type 'float'. +func TestParameterSchemaFloat(t *testing.T) { + + schema := ParameterSchema{ + Name: "param_name", + Type: "float", + Description: "float parameter", + } + + t.Run("Test float32 param", func(t *testing.T) { + var value float32 = 3.14 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + t.Run("Test float64 param", func(t *testing.T) { + value := 3.14 + + err := schema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + }) + +} + +// Tests ParameterSchema with type 'array'. +func TestParameterSchemaStringArray(t *testing.T) { + + itemSchema := ParameterSchema{ + Name: "item", + Type: "string", + Description: "item of the array", + } + + paramSchema := ParameterSchema{ + Name: "param_name", + Type: "array", + Description: "array parameter", + Items: &itemSchema, + } + + value := []string{"abc", "def"} + + err := paramSchema.ValidateType(value) + + if err != nil { + t.Fatal(err.Error()) + } + +} + +// Tests ParameterSchema with an undefined type. +func TestParameterSchemaUndefinedType(t *testing.T) { + + paramSchema := ParameterSchema{ + Name: "param_name", + Type: "time", + Description: "time parameter", + } + + value := time.Now() + + err := paramSchema.ValidateType(value) + + if err == nil { + t.Fatal("Expected an error, but got nil") + } + +} + +func TestOptionalStringParameter(t *testing.T) { + schema := ParameterSchema{ + Name: "nickname", + Type: "string", + Description: "An optional nickname", + Required: false, // Explicitly optional + } + + t.Run("allows nil value for optional parameter", func(t *testing.T) { + err := schema.ValidateType(nil) + if err != nil { + t.Errorf("ValidateType() with nil should not return an error for an optional parameter, but got: %v", err) + } + }) + + t.Run("allows valid string value", func(t *testing.T) { + err := schema.ValidateType("my-name") + if err != nil { + t.Errorf("ValidateType() should not return an error for a valid string, but got: %v", err) + } + }) +} + +func TestRequiredParameter(t *testing.T) { + schema := ParameterSchema{ + Name: "id", + Type: "integer", + Description: "A required ID", + Required: true, // Explicitly required + } + + t.Run("rejects nil value for required parameter", func(t *testing.T) { + err := schema.ValidateType(nil) + if err == nil { + t.Errorf("ValidateType() with nil should return an error for a required parameter, but it didn't") + } + }) + + t.Run("allows valid integer value", func(t *testing.T) { + err := schema.ValidateType(12345) + if err != nil { + t.Errorf("ValidateType() should not return an error for a valid integer, but got: %v", err) + } + }) +} + +func TestOptionalArrayParameter(t *testing.T) { + schema := ParameterSchema{ + Name: "optional_scores", + Type: "array", + Description: "An optional list of scores", + Required: false, + Items: &ParameterSchema{ + Type: "integer", + }, + } + + t.Run("allows nil value for optional array", func(t *testing.T) { + err := schema.ValidateType(nil) + if err != nil { + t.Errorf("ValidateType() with nil should not return an error for an optional array, but got: %v", err) + } + }) + + t.Run("allows valid integer slice", func(t *testing.T) { + err := schema.ValidateType([]int{95, 100}) + if err != nil { + t.Errorf("ValidateType() should not return an error for a valid slice, but got: %v", err) + } + }) + + t.Run("rejects slice with wrong item type", func(t *testing.T) { + err := schema.ValidateType([]string{"not", "an", "int"}) + if err == nil { + t.Errorf("ValidateType() should have returned an error for a slice with incorrect item types, but it didn't") + } + }) +} + +func TestValidateTypeObject(t *testing.T) { + t.Run("generic object allows any value types", func(t *testing.T) { + schema := ParameterSchema{ + Name: "metadata", + Type: "object", + AdditionalProperties: true, // or nil + } + + // A map with mixed value types should be valid. + validInput := map[string]any{ + "key_string": "a string", + "key_int": 123, + "key_bool": true, + "key_map": map[string]any{"id": 1}, + "key_array": []string{"id", "number"}, + } + if err := schema.ValidateType(validInput); err != nil { + t.Errorf("Expected no error for generic object, but got: %v", err) + } + + // A value that is not a map should be invalid. + invalidInput := "I am a string, not an object" + if err := schema.ValidateType(invalidInput); err == nil { + t.Errorf("Expected an error for non-map input, but got nil") + } + }) + + t.Run("typed object validation", func(t *testing.T) { + testCases := []struct { + name string + valueType string + validInput map[string]any + invalidInput map[string]any + }{ + { + name: "string values", + valueType: "string", + validInput: map[string]any{"header": "application/json"}, + invalidInput: map[string]any{"bad_header": 123}, + }, + { + name: "integer values", + valueType: "integer", + validInput: map[string]any{"user_score": 100}, + invalidInput: map[string]any{"bad_score": "100"}, + }, + { + name: "float values", + valueType: "float", + validInput: map[string]any{"item_price": 99.99}, + invalidInput: map[string]any{"bad_price": 99}, + }, + { + name: "boolean values", + valueType: "boolean", + validInput: map[string]any{"feature_flag": true}, + invalidInput: map[string]any{"bad_flag": "true"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + schema := ParameterSchema{ + Name: "test_map", + Type: "object", + AdditionalProperties: &ParameterSchema{Type: tc.valueType}, + } + + // Test that valid input passes + if err := schema.ValidateType(tc.validInput); err != nil { + t.Errorf("Expected no error for valid input, got: %v", err) + } + + // Test that invalid input fails + if err := schema.ValidateType(tc.invalidInput); err == nil { + t.Errorf("Expected an error for invalid input, but got nil") + } + }) + } + }) + + t.Run("Fail for object valueType maps", func(t *testing.T) { + + // This schema itself is invalid so there is no valid test case + schema := ParameterSchema{ + Name: "test_map", + Type: "object", + AdditionalProperties: &ParameterSchema{Type: "object"}, + } + + invalidInput := map[string]any{"feature_flag": map[string]any{"id": "123"}} + // Test that invalid input fails + if err := schema.ValidateType(invalidInput); err == nil { + t.Errorf("Expected an error for invalid input, but got nil") + } + }) + + t.Run("Fail for array valueType maps", func(t *testing.T) { + + // This schema itself is invalid so there is no valid test case + schema := ParameterSchema{ + Name: "test_map", + Type: "object", + AdditionalProperties: &ParameterSchema{Type: "array"}, + } + + invalidInput := map[string]any{"feature_flag": []string{"id", "number"}} + // Test that invalid input fails + if err := schema.ValidateType(invalidInput); err == nil { + t.Errorf("Expected an error for invalid input, but got nil") + } + }) + + t.Run("optional and required objects", func(t *testing.T) { + // An optional object can be nil + optionalSchema := ParameterSchema{Name: "optional_metadata", Type: "object", Required: false} + if err := optionalSchema.ValidateType(nil); err != nil { + t.Errorf("Expected no error for nil on optional object, but got: %v", err) + } + + // A required object cannot be nil + requiredSchema := ParameterSchema{Name: "required_metadata", Type: "object", Required: true} + if err := requiredSchema.ValidateType(nil); err == nil { + t.Error("Expected an error for nil on required object, but got nil") + } + }) + + t.Run("object with unsupported value type in schema", func(t *testing.T) { + unsupportedType := "custom_object" + schema := ParameterSchema{ + Name: "custom_data", + Type: "object", + AdditionalProperties: &ParameterSchema{Type: unsupportedType}, + } + + input := map[string]any{"key": "some value"} + err := schema.ValidateType(input) + + if err == nil { + t.Fatal("Expected an error for unsupported sub-schema type, but got nil") + } + + // Check if the error message contains the expected text. + expectedErrorMsg := fmt.Sprintf("unknown type '%s'", unsupportedType) + if !strings.Contains(err.Error(), expectedErrorMsg) { + t.Errorf("Expected error to contain '%s', but got '%v'", expectedErrorMsg, err) + } + }) +} + +func TestParameterSchema_ValidateDefinition(t *testing.T) { + t.Run("should succeed for simple valid types", func(t *testing.T) { + testCases := []struct { + name string + schema *ParameterSchema + }{ + {"String", &ParameterSchema{Name: "p_string", Type: "string"}}, + {"Integer", &ParameterSchema{Name: "p_int", Type: "integer"}}, + {"Float", &ParameterSchema{Name: "p_float", Type: "float"}}, + {"Boolean", &ParameterSchema{Name: "p_bool", Type: "boolean"}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := tc.schema.ValidateDefinition(); err != nil { + t.Errorf("expected no error, but got: %v", err) + } + }) + } + }) + + t.Run("should succeed for a valid array schema", func(t *testing.T) { + schema := &ParameterSchema{ + Name: "p_array", + Type: "array", + Items: &ParameterSchema{Type: "string"}, + } + if err := schema.ValidateDefinition(); err != nil { + t.Errorf("expected no error for valid array, but got: %v", err) + } + }) + + t.Run("should succeed for valid object schemas", func(t *testing.T) { + testCases := []struct { + name string + schema *ParameterSchema + }{ + { + "Typed Object", + &ParameterSchema{ + Name: "p_obj_typed", + Type: "object", + AdditionalProperties: &ParameterSchema{Type: "integer"}, + }, + }, + { + "Generic Object (bool)", + &ParameterSchema{ + Name: "p_obj_bool", + Type: "object", + AdditionalProperties: true, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := tc.schema.ValidateDefinition(); err != nil { + t.Errorf("expected no error, but got: %v", err) + } + }) + } + }) + + t.Run("should fail when type is missing", func(t *testing.T) { + schema := &ParameterSchema{Name: "p_missing_type", Type: "object", AdditionalProperties: &ParameterSchema{Type: ""}} + err := schema.ValidateDefinition() + if err == nil { + t.Fatal("expected an error for missing type, but got nil") + } + if !strings.Contains(err.Error(), "type is missing") { + t.Errorf("error message should mention 'type is missing', but was: %s", err) + } + }) + + t.Run("should fail when type is unknown", func(t *testing.T) { + schema := &ParameterSchema{Name: "p_unknown", Type: "object", AdditionalProperties: &ParameterSchema{Type: "some-custom-type"}} + err := schema.ValidateDefinition() + if err == nil { + t.Fatal("expected an error for unknown type, but got nil") + } + if !strings.Contains(err.Error(), "unknown schema type") { + t.Errorf("error message should mention 'unknown schema type', but was: %s", err) + } + }) + + t.Run("should fail for array with missing items property", func(t *testing.T) { + schema := &ParameterSchema{Name: "p_bad_array", Type: "array", Items: nil} + err := schema.ValidateDefinition() + if err == nil { + t.Fatal("expected an error for array with nil items, but got nil") + } + if !strings.Contains(err.Error(), "missing item type definition") { + t.Errorf("error message should mention 'missing item type definition', but was: %s", err) + } + }) + + t.Run("should fail for object with invalid AdditionalProperties type", func(t *testing.T) { + schema := &ParameterSchema{ + Name: "p_bad_object", + Type: "object", + AdditionalProperties: "a-string-is-not-valid", + } + err := schema.ValidateDefinition() + if err == nil { + t.Fatal("expected an error for invalid AdditionalProperties, but got nil") + } + if !strings.Contains(err.Error(), "must be a boolean or a schema") { + t.Errorf("error message should mention 'must be a boolean or a schema', but was: %s", err) + } + }) +}