Skip to content

Commit 1b97b50

Browse files
committed
add orchestrator tests
1 parent 21c3e2e commit 1b97b50

File tree

2 files changed

+293
-0
lines changed

2 files changed

+293
-0
lines changed

src/plugins/codemode/codemode.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ type CodeModeUTCP struct {
4242
model interface {
4343
Generate(ctx context.Context, prompt string) (any, error)
4444
}
45+
// For testing purposes, to mock the Execute method.
46+
executeFunc func(ctx context.Context, args CodeModeArgs) (CodeModeResult, error)
4547
}
4648

4749
func NewCodeModeUTCP(client utcp.UtcpClientInterface, model interface {
@@ -243,6 +245,10 @@ func indent(s, prefix string) string {
243245
}
244246

245247
func (c *CodeModeUTCP) Execute(ctx context.Context, args CodeModeArgs) (CodeModeResult, error) {
248+
// Allow mocking for tests
249+
if c.executeFunc != nil {
250+
return c.executeFunc(ctx, args)
251+
}
246252

247253
i, stdout, stderr := newInterpreter()
248254

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
package codemode
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"strings"
9+
"testing"
10+
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"github.com/universal-tool-calling-protocol/go-utcp/src/tools"
14+
)
15+
16+
// mockModel simulates the behavior of an LLM for testing purposes.
17+
type mockModel struct {
18+
GenerateFunc func(ctx context.Context, prompt string) (any, error)
19+
}
20+
21+
func (m *mockModel) Generate(ctx context.Context, prompt string) (any, error) {
22+
if m.GenerateFunc != nil {
23+
return m.GenerateFunc(ctx, prompt)
24+
}
25+
return nil, errors.New("GenerateFunc not implemented")
26+
}
27+
28+
func TestDecideIfToolsNeeded(t *testing.T) {
29+
ctx := context.Background()
30+
31+
tests := []struct {
32+
name string
33+
mockResponse any
34+
mockError error
35+
expectedNeeds bool
36+
expectedError bool
37+
responseIsJSON bool
38+
}{
39+
{
40+
name: "LLM decides tools are needed",
41+
mockResponse: `{"needs": true}`,
42+
expectedNeeds: true,
43+
expectedError: false,
44+
responseIsJSON: true,
45+
},
46+
{
47+
name: "LLM decides tools are not needed",
48+
mockResponse: `{"needs": false}`,
49+
expectedNeeds: false,
50+
expectedError: false,
51+
responseIsJSON: true,
52+
},
53+
{
54+
name: "LLM returns an error",
55+
mockError: errors.New("LLM error"),
56+
expectedNeeds: false,
57+
expectedError: true,
58+
},
59+
{
60+
name: "LLM returns invalid JSON",
61+
mockResponse: `{"needs": tru}`,
62+
expectedNeeds: false,
63+
expectedError: false,
64+
responseIsJSON: true,
65+
},
66+
{
67+
name: "LLM returns non-JSON string",
68+
mockResponse: "I don't know.",
69+
expectedNeeds: false,
70+
expectedError: false,
71+
responseIsJSON: false,
72+
},
73+
}
74+
75+
for _, tc := range tests {
76+
t.Run(tc.name, func(t *testing.T) {
77+
mock := &mockModel{
78+
GenerateFunc: func(ctx context.Context, prompt string) (any, error) {
79+
if tc.responseIsJSON {
80+
return tc.mockResponse, tc.mockError
81+
}
82+
return fmt.Sprintf("%v", tc.mockResponse), tc.mockError
83+
},
84+
}
85+
cm := CodeModeUTCP{model: mock}
86+
87+
needs, err := cm.decideIfToolsNeeded(ctx, "some query", "some tools")
88+
89+
if tc.expectedError {
90+
require.Error(t, err)
91+
} else {
92+
require.NoError(t, err)
93+
assert.Equal(t, tc.expectedNeeds, needs)
94+
}
95+
})
96+
}
97+
}
98+
99+
func TestSelectTools(t *testing.T) {
100+
ctx := context.Background()
101+
mock := &mockModel{
102+
GenerateFunc: func(ctx context.Context, prompt string) (any, error) {
103+
return `{"tools": ["tool1", "tool2"]}`, nil
104+
},
105+
}
106+
cm := &CodeModeUTCP{model: mock}
107+
108+
selected, err := cm.selectTools(ctx, "some query", "some tools")
109+
110+
require.NoError(t, err)
111+
assert.Equal(t, []string{"tool1", "tool2"}, selected)
112+
}
113+
114+
func TestGenerateSnippet(t *testing.T) {
115+
ctx := context.Background()
116+
mockResp := struct {
117+
Code string `json:"code"`
118+
Stream bool `json:"stream"`
119+
}{
120+
Code: `__out = "result"`,
121+
Stream: false,
122+
}
123+
respBytes, _ := json.Marshal(mockResp)
124+
125+
mock := &mockModel{
126+
GenerateFunc: func(ctx context.Context, prompt string) (any, error) {
127+
return string(respBytes), nil
128+
},
129+
}
130+
cm := &CodeModeUTCP{model: mock}
131+
132+
snippet, stream, err := cm.generateSnippet(ctx, "query", []string{"tool1"}, "specs")
133+
134+
require.NoError(t, err)
135+
assert.Equal(t, mockResp.Code, snippet)
136+
assert.Equal(t, mockResp.Stream, stream)
137+
}
138+
139+
func TestRenderUtcpToolsForPrompt(t *testing.T) {
140+
specs := []tools.Tool{
141+
{
142+
Name: "test.tool",
143+
Description: "A test tool.",
144+
Inputs: tools.ToolInputOutputSchema{
145+
Properties: map[string]any{
146+
"arg1": map[string]any{"type": "string"},
147+
},
148+
Required: []string{"arg1"},
149+
},
150+
Outputs: tools.ToolInputOutputSchema{
151+
Properties: map[string]any{
152+
"result": map[string]any{"type": "string"},
153+
},
154+
},
155+
},
156+
}
157+
158+
output := renderUtcpToolsForPrompt(specs)
159+
160+
assert.Contains(t, output, "TOOL: test.tool")
161+
assert.Contains(t, output, "DESCRIPTION: A test tool.")
162+
assert.Contains(t, output, "INPUT FIELDS (USE EXACTLY THESE KEYS):")
163+
assert.Contains(t, output, "- arg1: string")
164+
assert.Contains(t, output, "REQUIRED FIELDS:")
165+
assert.Contains(t, output, "FULL INPUT SCHEMA (JSON):")
166+
assert.Contains(t, output, "OUTPUT SCHEMA (EXACT SHAPE RETURNED BY TOOL):")
167+
}
168+
169+
func TestExtractJSON(t *testing.T) {
170+
tests := []struct {
171+
name string
172+
input string
173+
expected string
174+
}{
175+
{"pure json", `{"key": "value"}`, `{"key": "value"}`},
176+
{"json with markdown", "```json\n{\"key\": \"value\"}\n```", `{"key": "value"}`},
177+
{"json with markdown no lang", "```\n{\"key\": \"value\"}\n```", `{"key": "value"}`},
178+
{"json with trailing text", `{"key": "value"} | some other text`, `{"key": "value"}`},
179+
{"nested json", `{"key": {"nested_key": "nested_value"}}`, `{"key": {"nested_key": "nested_value"}}`},
180+
{"text before json", `Here is the JSON: {"key": "value"}`, `{"key": "value"}`},
181+
{"empty string", "", ""},
182+
{"not a json", "just a string", ""},
183+
{"incomplete json", `{"key":`, ""},
184+
{"json with escaped quotes", `{"key": "value with \"quotes\""}`, `{"key": "value with \"quotes\""}`},
185+
}
186+
187+
for _, tc := range tests {
188+
t.Run(tc.name, func(t *testing.T) {
189+
assert.Equal(t, tc.expected, extractJSON(tc.input))
190+
})
191+
}
192+
}
193+
194+
func TestIsValidSnippet(t *testing.T) {
195+
tests := []struct {
196+
name string
197+
code string
198+
expected bool
199+
}{
200+
{
201+
name: "valid snippet",
202+
code: `__out, err := codemode.CallTool("test", nil)`,
203+
expected: true,
204+
},
205+
{
206+
name: "valid snippet with assignment",
207+
code: `__out = "hello"`,
208+
expected: true,
209+
},
210+
{
211+
name: "invalid due to map[value:]",
212+
code: `__out = map[value:"hello"]`,
213+
expected: false,
214+
},
215+
{
216+
name: "invalid due to missing __out",
217+
code: `result, err := codemode.CallTool("test", nil)`,
218+
expected: false,
219+
},
220+
{
221+
name: "empty code",
222+
code: "",
223+
expected: false,
224+
},
225+
}
226+
227+
for _, tc := range tests {
228+
t.Run(tc.name, func(t *testing.T) {
229+
assert.Equal(t, tc.expected, isValidSnippet(tc.code))
230+
})
231+
}
232+
}
233+
234+
func TestCallTool_NoToolsNeeded(t *testing.T) {
235+
ctx := context.Background()
236+
mock := &mockModel{
237+
GenerateFunc: func(ctx context.Context, prompt string) (any, error) {
238+
// This is for decideIfToolsNeeded
239+
return `{"needs": false}`, nil
240+
},
241+
}
242+
cm := &CodeModeUTCP{model: mock}
243+
244+
needed, result, err := cm.CallTool(ctx, "a prompt that doesn't need tools")
245+
246+
require.NoError(t, err)
247+
assert.False(t, needed)
248+
assert.Equal(t, "", result)
249+
}
250+
251+
func TestCallTool_ToolsNeededAndExecuted(t *testing.T) {
252+
ctx := context.Background()
253+
254+
// 1. Mock LLM responses for each step of the orchestration
255+
mock := &mockModel{
256+
GenerateFunc: func(ctx context.Context, prompt string) (any, error) {
257+
switch {
258+
case strings.Contains(prompt, "Decide if the following user query requires using ANY UTCP tools"):
259+
return `{"needs": true}`, nil
260+
case strings.Contains(prompt, "Select ALL UTCP tools that match the user's intent"):
261+
return `{"tools": ["test.tool"]}`, nil
262+
case strings.Contains(prompt, "Generate a Go snippet"):
263+
return `{"code": "__out = \"success\""}`, nil
264+
default:
265+
return nil, fmt.Errorf("unexpected prompt: %s", prompt)
266+
}
267+
},
268+
}
269+
270+
// 2. Create a CodeModeUTCP instance with the mock model and a mock Execute function
271+
cm := &CodeModeUTCP{
272+
model: mock,
273+
// We override the Execute method for this test to avoid using the real interpreter.
274+
// This is a common testing pattern, but in a real-world scenario,
275+
// using an interface for the executor would be a cleaner approach.
276+
executeFunc: func(ctx context.Context, args CodeModeArgs) (CodeModeResult, error) {
277+
require.Equal(t, `__out = "success"`, args.Code, "Code passed to Execute should match the generated snippet")
278+
return CodeModeResult{Value: "execution result"}, nil
279+
},
280+
}
281+
282+
// 3. Call the function and assert the results
283+
needed, result, err := cm.CallTool(ctx, "a prompt that needs tools")
284+
require.NoError(t, err)
285+
assert.True(t, needed, "Should indicate that tools were needed")
286+
assert.Equal(t, "execution result", result.(CodeModeResult).Value, "Should return the result from the mocked Execute function")
287+
}

0 commit comments

Comments
 (0)