|
| 1 | +package codemode |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "errors" |
| 7 | + "fmt" |
| 8 | + "strings" |
| 9 | + "testing" |
| 10 | + |
| 11 | + "github.com/stretchr/testify/assert" |
| 12 | + "github.com/stretchr/testify/require" |
| 13 | + "github.com/universal-tool-calling-protocol/go-utcp/src/tools" |
| 14 | +) |
| 15 | + |
| 16 | +// mockModel simulates the behavior of an LLM for testing purposes. |
| 17 | +type mockModel struct { |
| 18 | + GenerateFunc func(ctx context.Context, prompt string) (any, error) |
| 19 | +} |
| 20 | + |
| 21 | +func (m *mockModel) Generate(ctx context.Context, prompt string) (any, error) { |
| 22 | + if m.GenerateFunc != nil { |
| 23 | + return m.GenerateFunc(ctx, prompt) |
| 24 | + } |
| 25 | + return nil, errors.New("GenerateFunc not implemented") |
| 26 | +} |
| 27 | + |
| 28 | +func TestDecideIfToolsNeeded(t *testing.T) { |
| 29 | + ctx := context.Background() |
| 30 | + |
| 31 | + tests := []struct { |
| 32 | + name string |
| 33 | + mockResponse any |
| 34 | + mockError error |
| 35 | + expectedNeeds bool |
| 36 | + expectedError bool |
| 37 | + responseIsJSON bool |
| 38 | + }{ |
| 39 | + { |
| 40 | + name: "LLM decides tools are needed", |
| 41 | + mockResponse: `{"needs": true}`, |
| 42 | + expectedNeeds: true, |
| 43 | + expectedError: false, |
| 44 | + responseIsJSON: true, |
| 45 | + }, |
| 46 | + { |
| 47 | + name: "LLM decides tools are not needed", |
| 48 | + mockResponse: `{"needs": false}`, |
| 49 | + expectedNeeds: false, |
| 50 | + expectedError: false, |
| 51 | + responseIsJSON: true, |
| 52 | + }, |
| 53 | + { |
| 54 | + name: "LLM returns an error", |
| 55 | + mockError: errors.New("LLM error"), |
| 56 | + expectedNeeds: false, |
| 57 | + expectedError: true, |
| 58 | + }, |
| 59 | + { |
| 60 | + name: "LLM returns invalid JSON", |
| 61 | + mockResponse: `{"needs": tru}`, |
| 62 | + expectedNeeds: false, |
| 63 | + expectedError: false, |
| 64 | + responseIsJSON: true, |
| 65 | + }, |
| 66 | + { |
| 67 | + name: "LLM returns non-JSON string", |
| 68 | + mockResponse: "I don't know.", |
| 69 | + expectedNeeds: false, |
| 70 | + expectedError: false, |
| 71 | + responseIsJSON: false, |
| 72 | + }, |
| 73 | + } |
| 74 | + |
| 75 | + for _, tc := range tests { |
| 76 | + t.Run(tc.name, func(t *testing.T) { |
| 77 | + mock := &mockModel{ |
| 78 | + GenerateFunc: func(ctx context.Context, prompt string) (any, error) { |
| 79 | + if tc.responseIsJSON { |
| 80 | + return tc.mockResponse, tc.mockError |
| 81 | + } |
| 82 | + return fmt.Sprintf("%v", tc.mockResponse), tc.mockError |
| 83 | + }, |
| 84 | + } |
| 85 | + cm := CodeModeUTCP{model: mock} |
| 86 | + |
| 87 | + needs, err := cm.decideIfToolsNeeded(ctx, "some query", "some tools") |
| 88 | + |
| 89 | + if tc.expectedError { |
| 90 | + require.Error(t, err) |
| 91 | + } else { |
| 92 | + require.NoError(t, err) |
| 93 | + assert.Equal(t, tc.expectedNeeds, needs) |
| 94 | + } |
| 95 | + }) |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +func TestSelectTools(t *testing.T) { |
| 100 | + ctx := context.Background() |
| 101 | + mock := &mockModel{ |
| 102 | + GenerateFunc: func(ctx context.Context, prompt string) (any, error) { |
| 103 | + return `{"tools": ["tool1", "tool2"]}`, nil |
| 104 | + }, |
| 105 | + } |
| 106 | + cm := &CodeModeUTCP{model: mock} |
| 107 | + |
| 108 | + selected, err := cm.selectTools(ctx, "some query", "some tools") |
| 109 | + |
| 110 | + require.NoError(t, err) |
| 111 | + assert.Equal(t, []string{"tool1", "tool2"}, selected) |
| 112 | +} |
| 113 | + |
| 114 | +func TestGenerateSnippet(t *testing.T) { |
| 115 | + ctx := context.Background() |
| 116 | + mockResp := struct { |
| 117 | + Code string `json:"code"` |
| 118 | + Stream bool `json:"stream"` |
| 119 | + }{ |
| 120 | + Code: `__out = "result"`, |
| 121 | + Stream: false, |
| 122 | + } |
| 123 | + respBytes, _ := json.Marshal(mockResp) |
| 124 | + |
| 125 | + mock := &mockModel{ |
| 126 | + GenerateFunc: func(ctx context.Context, prompt string) (any, error) { |
| 127 | + return string(respBytes), nil |
| 128 | + }, |
| 129 | + } |
| 130 | + cm := &CodeModeUTCP{model: mock} |
| 131 | + |
| 132 | + snippet, stream, err := cm.generateSnippet(ctx, "query", []string{"tool1"}, "specs") |
| 133 | + |
| 134 | + require.NoError(t, err) |
| 135 | + assert.Equal(t, mockResp.Code, snippet) |
| 136 | + assert.Equal(t, mockResp.Stream, stream) |
| 137 | +} |
| 138 | + |
| 139 | +func TestRenderUtcpToolsForPrompt(t *testing.T) { |
| 140 | + specs := []tools.Tool{ |
| 141 | + { |
| 142 | + Name: "test.tool", |
| 143 | + Description: "A test tool.", |
| 144 | + Inputs: tools.ToolInputOutputSchema{ |
| 145 | + Properties: map[string]any{ |
| 146 | + "arg1": map[string]any{"type": "string"}, |
| 147 | + }, |
| 148 | + Required: []string{"arg1"}, |
| 149 | + }, |
| 150 | + Outputs: tools.ToolInputOutputSchema{ |
| 151 | + Properties: map[string]any{ |
| 152 | + "result": map[string]any{"type": "string"}, |
| 153 | + }, |
| 154 | + }, |
| 155 | + }, |
| 156 | + } |
| 157 | + |
| 158 | + output := renderUtcpToolsForPrompt(specs) |
| 159 | + |
| 160 | + assert.Contains(t, output, "TOOL: test.tool") |
| 161 | + assert.Contains(t, output, "DESCRIPTION: A test tool.") |
| 162 | + assert.Contains(t, output, "INPUT FIELDS (USE EXACTLY THESE KEYS):") |
| 163 | + assert.Contains(t, output, "- arg1: string") |
| 164 | + assert.Contains(t, output, "REQUIRED FIELDS:") |
| 165 | + assert.Contains(t, output, "FULL INPUT SCHEMA (JSON):") |
| 166 | + assert.Contains(t, output, "OUTPUT SCHEMA (EXACT SHAPE RETURNED BY TOOL):") |
| 167 | +} |
| 168 | + |
| 169 | +func TestExtractJSON(t *testing.T) { |
| 170 | + tests := []struct { |
| 171 | + name string |
| 172 | + input string |
| 173 | + expected string |
| 174 | + }{ |
| 175 | + {"pure json", `{"key": "value"}`, `{"key": "value"}`}, |
| 176 | + {"json with markdown", "```json\n{\"key\": \"value\"}\n```", `{"key": "value"}`}, |
| 177 | + {"json with markdown no lang", "```\n{\"key\": \"value\"}\n```", `{"key": "value"}`}, |
| 178 | + {"json with trailing text", `{"key": "value"} | some other text`, `{"key": "value"}`}, |
| 179 | + {"nested json", `{"key": {"nested_key": "nested_value"}}`, `{"key": {"nested_key": "nested_value"}}`}, |
| 180 | + {"text before json", `Here is the JSON: {"key": "value"}`, `{"key": "value"}`}, |
| 181 | + {"empty string", "", ""}, |
| 182 | + {"not a json", "just a string", ""}, |
| 183 | + {"incomplete json", `{"key":`, ""}, |
| 184 | + {"json with escaped quotes", `{"key": "value with \"quotes\""}`, `{"key": "value with \"quotes\""}`}, |
| 185 | + } |
| 186 | + |
| 187 | + for _, tc := range tests { |
| 188 | + t.Run(tc.name, func(t *testing.T) { |
| 189 | + assert.Equal(t, tc.expected, extractJSON(tc.input)) |
| 190 | + }) |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +func TestIsValidSnippet(t *testing.T) { |
| 195 | + tests := []struct { |
| 196 | + name string |
| 197 | + code string |
| 198 | + expected bool |
| 199 | + }{ |
| 200 | + { |
| 201 | + name: "valid snippet", |
| 202 | + code: `__out, err := codemode.CallTool("test", nil)`, |
| 203 | + expected: true, |
| 204 | + }, |
| 205 | + { |
| 206 | + name: "valid snippet with assignment", |
| 207 | + code: `__out = "hello"`, |
| 208 | + expected: true, |
| 209 | + }, |
| 210 | + { |
| 211 | + name: "invalid due to map[value:]", |
| 212 | + code: `__out = map[value:"hello"]`, |
| 213 | + expected: false, |
| 214 | + }, |
| 215 | + { |
| 216 | + name: "invalid due to missing __out", |
| 217 | + code: `result, err := codemode.CallTool("test", nil)`, |
| 218 | + expected: false, |
| 219 | + }, |
| 220 | + { |
| 221 | + name: "empty code", |
| 222 | + code: "", |
| 223 | + expected: false, |
| 224 | + }, |
| 225 | + } |
| 226 | + |
| 227 | + for _, tc := range tests { |
| 228 | + t.Run(tc.name, func(t *testing.T) { |
| 229 | + assert.Equal(t, tc.expected, isValidSnippet(tc.code)) |
| 230 | + }) |
| 231 | + } |
| 232 | +} |
| 233 | + |
| 234 | +func TestCallTool_NoToolsNeeded(t *testing.T) { |
| 235 | + ctx := context.Background() |
| 236 | + mock := &mockModel{ |
| 237 | + GenerateFunc: func(ctx context.Context, prompt string) (any, error) { |
| 238 | + // This is for decideIfToolsNeeded |
| 239 | + return `{"needs": false}`, nil |
| 240 | + }, |
| 241 | + } |
| 242 | + cm := &CodeModeUTCP{model: mock} |
| 243 | + |
| 244 | + needed, result, err := cm.CallTool(ctx, "a prompt that doesn't need tools") |
| 245 | + |
| 246 | + require.NoError(t, err) |
| 247 | + assert.False(t, needed) |
| 248 | + assert.Equal(t, "", result) |
| 249 | +} |
| 250 | + |
| 251 | +func TestCallTool_ToolsNeededAndExecuted(t *testing.T) { |
| 252 | + ctx := context.Background() |
| 253 | + |
| 254 | + // 1. Mock LLM responses for each step of the orchestration |
| 255 | + mock := &mockModel{ |
| 256 | + GenerateFunc: func(ctx context.Context, prompt string) (any, error) { |
| 257 | + switch { |
| 258 | + case strings.Contains(prompt, "Decide if the following user query requires using ANY UTCP tools"): |
| 259 | + return `{"needs": true}`, nil |
| 260 | + case strings.Contains(prompt, "Select ALL UTCP tools that match the user's intent"): |
| 261 | + return `{"tools": ["test.tool"]}`, nil |
| 262 | + case strings.Contains(prompt, "Generate a Go snippet"): |
| 263 | + return `{"code": "__out = \"success\""}`, nil |
| 264 | + default: |
| 265 | + return nil, fmt.Errorf("unexpected prompt: %s", prompt) |
| 266 | + } |
| 267 | + }, |
| 268 | + } |
| 269 | + |
| 270 | + // 2. Create a CodeModeUTCP instance with the mock model and a mock Execute function |
| 271 | + cm := &CodeModeUTCP{ |
| 272 | + model: mock, |
| 273 | + // We override the Execute method for this test to avoid using the real interpreter. |
| 274 | + // This is a common testing pattern, but in a real-world scenario, |
| 275 | + // using an interface for the executor would be a cleaner approach. |
| 276 | + executeFunc: func(ctx context.Context, args CodeModeArgs) (CodeModeResult, error) { |
| 277 | + require.Equal(t, `__out = "success"`, args.Code, "Code passed to Execute should match the generated snippet") |
| 278 | + return CodeModeResult{Value: "execution result"}, nil |
| 279 | + }, |
| 280 | + } |
| 281 | + |
| 282 | + // 3. Call the function and assert the results |
| 283 | + needed, result, err := cm.CallTool(ctx, "a prompt that needs tools") |
| 284 | + require.NoError(t, err) |
| 285 | + assert.True(t, needed, "Should indicate that tools were needed") |
| 286 | + assert.Equal(t, "execution result", result.(CodeModeResult).Value, "Should return the result from the mocked Execute function") |
| 287 | +} |
0 commit comments