Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semantic similarity score to TestEvalLLMs #37

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
run:
timeout: 5m
141 changes: 140 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ Does your company depend on this project? [Contact me at [email protected]](mailt

## Usage

This test will only run with `go test -run TestEval ./...` and otherwise be skipped:
Evals will only run with `go test -run TestEval ./...` and otherwise be skipped.

### Simple example

Eval a mocked LLM, construct a sample, score it with a lexical similarity scorer, and log the result.

```go
package examples_test
Expand Down Expand Up @@ -66,3 +70,138 @@ func (l *powerfulLLM) Prompt(request string) string {
return l.response
}
```

### Advanced example

This eval uses real LLMs (OpenAI GPT4o mini, Google Gemini 1.5 Flash, Anthropic 3.5 Haiku)
and compares the response to an expected response using both lexical similarity (with Levenshtein distance)
and semantic similarity (with an OpenAI embedding model and cosine similarity comparison).

```go
package examples_test

import (
"context"
"errors"
"fmt"
"strings"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/shared"
"maragu.dev/env"

"maragu.dev/llm"
"maragu.dev/llm/eval"
)

// TestEvalLLMs evaluates different LLMs with the same prompts.
func TestEvalLLMs(t *testing.T) {
_ = env.Load("../../.env.test.local")

tests := []struct {
name string
prompt func(prompt string) string
expected string
}{
{
name: "gpt-4o-mini",
prompt: gpt4oMini,
expected: "Hello! How can I assist you today?",
},
{
name: "gemini-1.5-flash",
prompt: gemini15Flash,
expected: "Hi there! How can I help you today?",
},
{
name: "claude-3.5-haiku",
prompt: claude35Haiku,
expected: "Hello! How are you doing today? Is there anything I can help you with?",
},
}

for _, test := range tests {
eval.Run(t, test.name, func(e *eval.E) {
input := "Hi!"
output := test.prompt(input)

sample := eval.Sample{
Input: input,
Output: output,
Expected: test.expected,
}

result := e.Score(sample, eval.LexicalSimilarityScorer(eval.LevenshteinDistance))
e.Log(sample, result)

result = e.Score(sample, eval.SemanticSimilarityScorer(&embeddingGetter{}, eval.CosineSimilarity))
e.Log(sample, result)
})
}
}

func gpt4oMini(prompt string) string {
client := llm.NewOpenAIClient(llm.NewOpenAIClientOptions{Key: env.GetStringOrDefault("OPENAI_KEY", "")})
res, err := client.Client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
}),
Model: openai.F(openai.ChatModelGPT4oMini),
Temperature: openai.F(0.0),
})
if err != nil {
panic(err)
}
return res.Choices[0].Message.Content
}

func gemini15Flash(prompt string) string {
client := llm.NewGoogleClient(llm.NewGoogleClientOptions{Key: env.GetStringOrDefault("GOOGLE_KEY", "")})
model := client.Client.GenerativeModel("models/gemini-1.5-flash-latest")
var temperature float32 = 0
model.Temperature = &temperature
res, err := model.GenerateContent(context.Background(), genai.Text(prompt))
if err != nil {
panic(err)
}
return strings.TrimSpace(fmt.Sprint(res.Candidates[0].Content.Parts[0]))
}

func claude35Haiku(prompt string) string {
client := llm.NewAnthropicClient(llm.NewAnthropicClientOptions{Key: env.GetStringOrDefault("ANTHROPIC_KEY", "")})
res, err := client.Client.Messages.New(context.Background(), anthropic.MessageNewParams{
Messages: anthropic.F([]anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock(prompt)),
}),
Model: anthropic.F(anthropic.ModelClaude3_5HaikuLatest),
MaxTokens: anthropic.F(int64(1024)),
Temperature: anthropic.F(0.0),
})
if err != nil {
panic(err)
}
return fmt.Sprint(res.Content[0].Text)
}

type embeddingGetter struct{}

func (e *embeddingGetter) GetEmbedding(v string) ([]float64, error) {
client := llm.NewOpenAIClient(llm.NewOpenAIClientOptions{Key: env.GetStringOrDefault("OPENAI_KEY", "")})
res, err := client.Client.Embeddings.New(context.Background(), openai.EmbeddingNewParams{
Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(v)),
Model: openai.F(openai.EmbeddingModelTextEmbedding3Small),
EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat),
Dimensions: openai.F(int64(128)),
})
if err != nil {
return nil, err
}
if len(res.Data) == 0 {
return nil, errors.New("no embeddings returned")
}
return res.Data[0].Embedding, nil
}
```
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/openai/openai-go v0.1.0-alpha.46
google.golang.org/api v0.217.0
maragu.dev/env v0.2.0
maragu.dev/evals v0.0.0-20250114114008-6c73fea1551c
maragu.dev/evals v0.0.0-20250121095818-455e49387b21
maragu.dev/is v0.2.0
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ maragu.dev/env v0.2.0 h1:nQKitDEB65ArZsh6E7vxzodOqY9bxEVFdBg+tskS1ys=
maragu.dev/env v0.2.0/go.mod h1:t5CCbaEnjCM5mewiAVVzTS4N+oXTus2+SRnzKQbQVME=
maragu.dev/errors v0.3.0 h1:huI+n+ddMfVgQFD+cEqIPaozUlfz3TkfgpkssNip5G0=
maragu.dev/errors v0.3.0/go.mod h1:cygLiyNnq4ofF3whYscilo2ecUADCaUQXwvwFrMOhmM=
maragu.dev/evals v0.0.0-20250114114008-6c73fea1551c h1:huPj1S5RhqgpbBAd3aCLfdVie3ZsU8Du7kepL2ZtDUQ=
maragu.dev/evals v0.0.0-20250114114008-6c73fea1551c/go.mod h1:+2Y3dYZ6oANM+cL88kFxaPD1H7rq3FXOrI3NOeNKaZ8=
maragu.dev/evals v0.0.0-20250121095818-455e49387b21 h1:Eg2DvonBz4eOPIhN+/aL1BXQAlNla4o+aBY+03e/6mA=
maragu.dev/evals v0.0.0-20250121095818-455e49387b21/go.mod h1:uLfBl7/FhUJULS4PjmpMdNG+joMRYAxgMJbzGWhQhWE=
maragu.dev/is v0.2.0 h1:poeuVEA5GG3vrDpGmzo2KjWtIMZmqUyvGnOB0/pemig=
maragu.dev/is v0.2.0/go.mod h1:bviaM5S0fBshCw7wuumFGTju/izopZ/Yvq4g7Klc7y8=
maragu.dev/migrate v0.6.0 h1:gJLAIVaRh9z9sN55Q2sWwScpEH+JsT6N0L1DnzedXFE=
Expand Down
23 changes: 23 additions & 0 deletions internal/examples/hi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package examples_test

import (
"context"
"errors"
"fmt"
"strings"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/google/generative-ai-go/genai"
"github.com/openai/openai-go"
"github.com/openai/openai-go/shared"
"maragu.dev/env"

"maragu.dev/llm"
Expand Down Expand Up @@ -53,7 +55,9 @@ func TestEvalLLMs(t *testing.T) {
}

result := e.Score(sample, eval.LexicalSimilarityScorer(eval.LevenshteinDistance))
e.Log(sample, result)

result = e.Score(sample, eval.SemanticSimilarityScorer(&embeddingGetter{}, eval.CosineSimilarity))
e.Log(sample, result)
})
}
Expand Down Expand Up @@ -101,3 +105,22 @@ func claude35Haiku(prompt string) string {
}
return fmt.Sprint(res.Content[0].Text)
}

type embeddingGetter struct{}

func (e *embeddingGetter) GetEmbedding(v string) ([]float64, error) {
client := llm.NewOpenAIClient(llm.NewOpenAIClientOptions{Key: env.GetStringOrDefault("OPENAI_KEY", "")})
res, err := client.Client.Embeddings.New(context.Background(), openai.EmbeddingNewParams{
Input: openai.F[openai.EmbeddingNewParamsInputUnion](shared.UnionString(v)),
Model: openai.F(openai.EmbeddingModelTextEmbedding3Small),
EncodingFormat: openai.F(openai.EmbeddingNewParamsEncodingFormatFloat),
Dimensions: openai.F(int64(128)),
})
if err != nil {
return nil, err
}
if len(res.Data) == 0 {
return nil, errors.New("no embeddings returned")
}
return res.Data[0].Embedding, nil
}
Loading