-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmock.go
More file actions
142 lines (122 loc) · 2.96 KB
/
mock.go
File metadata and controls
142 lines (122 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package iteragent
import (
"context"
"encoding/json"
"fmt"
"strings"
)
type MockProvider struct {
model string
response string
toolCalls []ToolCall
toolCallIndex int
error error
}
func NewMock(response string) *MockProvider {
return &MockProvider{
model: "mock",
response: response,
}
}
func NewMockWithTools(response string, toolCalls []ToolCall) *MockProvider {
return &MockProvider{
model: "mock",
response: response,
toolCalls: toolCalls,
}
}
func NewMockWithError(err error) *MockProvider {
return &MockProvider{
model: "mock",
error: err,
}
}
func (p *MockProvider) Name() string {
return fmt.Sprintf("mock(%s)", p.model)
}
func (p *MockProvider) Complete(ctx context.Context, messages []Message, opts ...CompletionOptions) (string, error) {
if p.error != nil {
return "", p.error
}
if p.toolCallIndex < len(p.toolCalls) {
call := p.toolCalls[p.toolCallIndex]
p.toolCallIndex++
return fmt.Sprintf("```tool\n%s\n```", mustJson(call)), nil
}
return p.response, nil
}
// CompleteStream implements Provider. It calls onToken once per word of the
// response so tests can observe incremental delivery.
func (p *MockProvider) CompleteStream(ctx context.Context, messages []Message, opts CompletionOptions, onToken func(token string)) (string, error) {
full, err := p.Complete(ctx, messages, opts)
if err != nil || onToken == nil {
return full, err
}
words := strings.Split(full, " ")
for i, w := range words {
select {
case <-ctx.Done():
return full, ctx.Err()
default:
}
if i < len(words)-1 {
onToken(w + " ")
} else {
onToken(w)
}
}
return full, nil
}
func mustJson(v interface{}) string {
switch v := v.(type) {
case string:
return v
default:
data, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%+v", v)
}
return string(data)
}
}
type MockProviderBuilder struct {
mock *MockProvider
}
func Mock() *MockProviderBuilder {
return &MockProviderBuilder{
mock: &MockProvider{
model: "mock",
},
}
}
func (b *MockProviderBuilder) Text(text string) *MockProviderBuilder {
b.mock.response = text
return b
}
func (b *MockProviderBuilder) Model(model string) *MockProviderBuilder {
b.mock.model = model
return b
}
func (b *MockProviderBuilder) WithTools(toolCalls ...ToolCall) *MockProviderBuilder {
b.mock.toolCalls = toolCalls
return b
}
func (b *MockProviderBuilder) WithError(err error) *MockProviderBuilder {
b.mock.error = err
return b
}
func (b *MockProviderBuilder) Build() Provider {
return b.mock
}
// NewMockStream returns a mock provider with streaming support.
func NewMockStream(response string) *MockProvider {
return &MockProvider{model: "mock-stream", response: response}
}
// NewMockStreamWithTools returns a mock provider with streaming and tool calls.
func NewMockStreamWithTools(response string, toolCalls []ToolCall) *MockProvider {
return &MockProvider{
model: "mock-stream",
response: response,
toolCalls: toolCalls,
}
}