Skip to content

Commit e203ecd

Browse files
committed
#153: Fixed ollama broken tests
1 parent 1d82943 commit e203ecd

File tree

4 files changed

+38
-28
lines changed

4 files changed

+38
-28
lines changed

pkg/providers/ollama/client_test.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,19 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) {
7777

7878
defer mockServer.Close()
7979

80-
// Create a new client with the mock server URL
81-
client := &Client{
82-
httpClient: http.DefaultClient,
83-
chatURL: mockServer.URL,
84-
config: DefaultConfig(),
85-
telemetry: telemetry.NewTelemetryMock(),
86-
}
80+
providerCfg := DefaultConfig()
81+
clientCfg := clients.DefaultClientConfig()
82+
providerCfg.BaseURL = mockServer.URL
83+
84+
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
85+
require.NoError(t, err)
8786

8887
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
8988
Role: "user",
9089
Content: "What's the capital of the United Kingdom?",
9190
}}}
9291

93-
_, err := client.Chat(context.Background(), &chatParams)
92+
_, err = client.Chat(context.Background(), &chatParams)
9493

9594
require.Error(t, err)
9695
require.Contains(t, err.Error(), "provider is not available")
@@ -99,19 +98,28 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) {
9998
func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) {
10099
// Create a mock HTTP server that returns an OK status code and a sample response
101100
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
101+
chatResponse, err := os.ReadFile(filepath.Clean("./testdata/chat.success.json"))
102+
if err != nil {
103+
t.Errorf("error reading cohere chat mock response: %v", err)
104+
}
105+
102106
w.WriteHeader(http.StatusOK)
103-
_, _ = w.Write([]byte(`{"response": "OK"}`))
107+
w.Header().Set("Content-Type", "application/json")
108+
109+
_, err = w.Write(chatResponse)
110+
if err != nil {
111+
t.Errorf("error on sending chat response: %v", err)
112+
}
104113
}))
105114

106115
defer mockServer.Close()
107116

108-
// Create a new client with the mock server URL
109-
client := &Client{
110-
httpClient: http.DefaultClient,
111-
chatURL: mockServer.URL,
112-
config: DefaultConfig(),
113-
telemetry: telemetry.NewTelemetryMock(),
114-
}
117+
providerCfg := DefaultConfig()
118+
clientCfg := clients.DefaultClientConfig()
119+
providerCfg.BaseURL = mockServer.URL
120+
121+
client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock())
122+
require.NoError(t, err)
115123

116124
chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{
117125
Role: "user",
@@ -122,5 +130,6 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) {
122130

123131
require.NoError(t, err)
124132
require.NotNil(t, response)
125-
require.Equal(t, "", response.ModelResponse.Message.Role)
133+
require.Equal(t, "assistant", response.ModelResponse.Message.Role)
134+
require.Equal(t, "London", response.ModelResponse.Message.Content)
126135
}

pkg/providers/ollama/testdata/chat.success.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"created_at": "2023-12-12T14:13:43.416799Z",
44
"message": {
55
"role": "assistant",
6-
"content": "Hello! How are you today?"
6+
"content": "London"
77
},
88
"done": true,
99
"total_duration": 5191566416,
@@ -12,4 +12,4 @@
1212
"prompt_eval_duration": 383809000,
1313
"eval_count": 298,
1414
"eval_duration": 4799921000
15-
}
15+
}

pkg/providers/testing/lang.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,16 @@ type ProviderMock struct {
103103
func NewProviderMock(modelName *string, responses []RespMock) *ProviderMock {
104104
return &ProviderMock{
105105
idx: 0,
106-
modelName: modelName,
107106
chatResps: &responses,
108107
supportStreaming: false,
108+
modelName: modelName,
109109
}
110110
}
111111

112-
func NewStreamProviderMock(chatStreams []RespStreamMock) *ProviderMock {
112+
func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *ProviderMock {
113113
return &ProviderMock{
114114
idx: 0,
115+
modelName: modelName,
115116
chatStreams: &chatStreams,
116117
supportStreaming: true,
117118
}
@@ -156,7 +157,7 @@ func (c *ProviderMock) Provider() string {
156157
}
157158

158159
func (c *ProviderMock) ModelName() string {
159-
if c.modelName != nil {
160+
if c.modelName == nil {
160161
return "model_mock"
161162
}
162163

pkg/routers/router_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ func TestLangRouter_ChatStream(t *testing.T) {
262262
langModels := []*providers.LanguageModel{
263263
providers.NewLangModel(
264264
"first",
265-
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
265+
ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{
266266
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
267267
{Msg: "Bill"},
268268
{Msg: "Gates"},
@@ -277,7 +277,7 @@ func TestLangRouter_ChatStream(t *testing.T) {
277277
),
278278
providers.NewLangModel(
279279
"second",
280-
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
280+
ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{
281281
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
282282
{Msg: "Knock"},
283283
{Msg: "Knock"},
@@ -338,14 +338,14 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) {
338338
langModels := []*providers.LanguageModel{
339339
providers.NewLangModel(
340340
"first",
341-
ptesting.NewStreamProviderMock(nil),
341+
ptesting.NewStreamProviderMock(nil, nil),
342342
budget,
343343
*latConfig,
344344
1,
345345
),
346346
providers.NewLangModel(
347347
"second",
348-
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
348+
ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{
349349
ptesting.NewRespStreamMock(
350350
&[]ptesting.RespMock{
351351
{Msg: "Knock"},
@@ -408,7 +408,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
408408
langModels := []*providers.LanguageModel{
409409
providers.NewLangModel(
410410
"first",
411-
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
411+
ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{
412412
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
413413
{Err: clients.ErrProviderUnavailable},
414414
}),
@@ -419,7 +419,7 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) {
419419
),
420420
providers.NewLangModel(
421421
"second",
422-
ptesting.NewStreamProviderMock([]ptesting.RespStreamMock{
422+
ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{
423423
ptesting.NewRespStreamMock(&[]ptesting.RespMock{
424424
{Err: clients.ErrProviderUnavailable},
425425
}),

0 commit comments

Comments
 (0)