diff --git a/aisdk/ai/README.md b/aisdk/ai/README.md index c5d9a8c2..f038a300 100644 --- a/aisdk/ai/README.md +++ b/aisdk/ai/README.md @@ -91,7 +91,8 @@ import ( func main() { // Set up your model - model := openai.NewLanguageModel("gpt-4o") + provider := openai.NewProvider() + model := provider.NewLanguageModel("gpt-4o-mini") // Generate text response, err := ai.GenerateTextStr( diff --git a/aisdk/ai/ai.go b/aisdk/ai/ai.go index ff57dc2c..5c36bfcc 100644 --- a/aisdk/ai/ai.go +++ b/aisdk/ai/ai.go @@ -6,6 +6,19 @@ import ( "go.jetify.com/ai/api" ) +func EmbedMany[T any]( + ctx context.Context, model api.EmbeddingModel[T], values []T, opts ...EmbeddingOption[T], +) (api.EmbeddingResponse, error) { + config := buildEmbeddingConfig(model, opts) + return embed(ctx, values, config) +} + +func embed[T any]( + ctx context.Context, values []T, opts EmbeddingOptions[T], +) (api.EmbeddingResponse, error) { + return opts.Model.DoEmbed(ctx, values, opts.EmbeddingOptions) +} + // TODO: do we want to rename from GenerateText to Generate and from StreamText to Stream? // GenerateText uses a language model to generate a text response from a given prompt. diff --git a/aisdk/ai/api/embedding_model.go b/aisdk/ai/api/embedding_model.go index 2c88aa19..643fa5b8 100644 --- a/aisdk/ai/api/embedding_model.go +++ b/aisdk/ai/api/embedding_model.go @@ -38,7 +38,7 @@ type EmbeddingModel[T any] interface { // // Naming: "do" prefix to prevent accidental direct usage of the method // by the user. - DoEmbed(ctx context.Context, values []T, opts ...EmbeddingOption) EmbeddingResponse + DoEmbed(ctx context.Context, values []T, opts EmbeddingOptions) (EmbeddingResponse, error) } // EmbeddingResponse represents the response from generating embeddings. @@ -55,7 +55,8 @@ type EmbeddingResponse struct { // EmbeddingUsage represents token usage information. type EmbeddingUsage struct { - Tokens int + PromptTokens int64 + TotalTokens int64 } // EmbeddingRawResponse contains raw response information for debugging. diff --git a/aisdk/ai/api/embedding_options.go b/aisdk/ai/api/embedding_options.go index f21e6346..5dcb4f4e 100644 --- a/aisdk/ai/api/embedding_options.go +++ b/aisdk/ai/api/embedding_options.go @@ -5,17 +5,19 @@ import "net/http" // EmbeddingOption represent the options for generating embeddings. type EmbeddingOption func(*EmbeddingOptions) -// WithEmbeddingHeaders sets HTTP headers to be sent with the request. -// Only applicable for HTTP-based providers. -func WithEmbeddingHeaders(headers http.Header) EmbeddingOption { - return func(o *EmbeddingOptions) { - o.Headers = headers - } -} - // EmbeddingOptions represents the options for generating embeddings. type EmbeddingOptions struct { // Headers are additional HTTP headers to be sent with the request. // Only applicable for HTTP-based providers. Headers http.Header + + // BaseURL is the base URL for the API endpoint. + BaseURL *string + + // ProviderMetadata contains additional provider-specific metadata. + // The metadata is passed through to the provider from the AI SDK and enables + // provider-specific functionality that can be fully encapsulated in the provider. + ProviderMetadata *ProviderMetadata } + +func (o EmbeddingOptions) GetProviderMetadata() *ProviderMetadata { return o.ProviderMetadata } diff --git a/aisdk/ai/default.go b/aisdk/ai/default.go index 7fe60393..1b1195bd 100644 --- a/aisdk/ai/default.go +++ b/aisdk/ai/default.go @@ -15,7 +15,8 @@ type modelWrapper struct { var defaultLanguageModel atomic.Value func init() { - model := openai.NewLanguageModel(openai.ChatModelGPT5) + provider := openai.NewProvider() + model := provider.NewLanguageModel(openai.ChatModelGPT5) defaultLanguageModel.Store(&modelWrapper{model: model}) } diff --git a/aisdk/ai/default_test.go b/aisdk/ai/default_test.go index 4d1c108d..7cbbe594 100644 --- a/aisdk/ai/default_test.go +++ b/aisdk/ai/default_test.go @@ -11,7 +11,7 @@ import ( func TestDefaultLanguageModel(t *testing.T) { // Get current model and verify provider and model ID match expected values originalModel := DefaultLanguageModel() - assert.Equal(t, "openai", originalModel.ProviderName()) + assert.Equal(t, "openai.responses", originalModel.ProviderName()) assert.Equal(t, openai.ChatModelGPT5, originalModel.ModelID()) // Change model to different provider (Anthropic) @@ -26,6 +26,6 @@ func TestDefaultLanguageModel(t *testing.T) { SetDefaultLanguageModel(originalModel) restoredModel := DefaultLanguageModel() - assert.Equal(t, "openai", restoredModel.ProviderName()) + assert.Equal(t, "openai.responses", restoredModel.ProviderName()) assert.Equal(t, openai.ChatModelGPT5, restoredModel.ModelID()) } diff --git a/aisdk/ai/embedding_options.go b/aisdk/ai/embedding_options.go new file mode 100644 index 00000000..d44c75c6 --- /dev/null +++ b/aisdk/ai/embedding_options.go @@ -0,0 +1,63 @@ +package ai + +import ( + "net/http" + + "go.jetify.com/ai/api" +) + +// EmbeddingOptions bundles the model + per-call embedding options. +type EmbeddingOptions[T any] struct { + EmbeddingOptions api.EmbeddingOptions + Model api.EmbeddingModel[T] +} + +// EmbeddingOption mutates EmbeddingOptions. +type EmbeddingOption[T any] func(*EmbeddingOptions[T]) + +// WithEmbeddingHeaders sets extra HTTP headers for this embedding call. +// Only applies to HTTP-backed providers. +func WithEmbeddingHeaders[T any](headers http.Header) EmbeddingOption[T] { + return func(o *EmbeddingOptions[T]) { + o.EmbeddingOptions.Headers = headers + } +} + +// WithEmbeddingProviderMetadata sets provider-specific metadata for the embedding call. +func WithEmbeddingProviderMetadata[T any](provider string, metadata any) EmbeddingOption[T] { + return func(o *EmbeddingOptions[T]) { + if o.EmbeddingOptions.ProviderMetadata == nil { + o.EmbeddingOptions.ProviderMetadata = api.NewProviderMetadata(map[string]any{}) + } + o.EmbeddingOptions.ProviderMetadata.Set(provider, metadata) + } +} + +// WithEmbeddingBaseURL sets the base URL for the embedding API endpoint. +func WithEmbeddingBaseURL[T any](baseURL string) EmbeddingOption[T] { + url := baseURL + return func(o *EmbeddingOptions[T]) { + o.EmbeddingOptions.BaseURL = &url + } +} + +// WithEmbeddingEmbeddingOptions sets the entire api.EmbeddingOptions struct. +func WithEmbeddingEmbeddingOptions[T any](embeddingOptions api.EmbeddingOptions) EmbeddingOption[T] { + return func(o *EmbeddingOptions[T]) { + o.EmbeddingOptions = embeddingOptions + } +} + +// buildEmbeddingConfig combines multiple options into a single EmbeddingOptions. +func buildEmbeddingConfig[T any]( + model api.EmbeddingModel[T], opts []EmbeddingOption[T], +) EmbeddingOptions[T] { + config := EmbeddingOptions[T]{ + EmbeddingOptions: api.EmbeddingOptions{}, + Model: model, + } + for _, opt := range opts { + opt(&config) + } + return config +} diff --git a/aisdk/ai/examples/basic/simple-embedding/main.go b/aisdk/ai/examples/basic/simple-embedding/main.go new file mode 100644 index 00000000..ceed5691 --- /dev/null +++ b/aisdk/ai/examples/basic/simple-embedding/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "log" + + "github.com/k0kubun/pp/v3" + "go.jetify.com/ai" + "go.jetify.com/ai/api" + "go.jetify.com/ai/provider/openai" +) + +func example() error { + // Initialize the OpenAI provider + provider := openai.NewProvider() + + // Create a model + model := provider.NewEmbeddingModel("text-embedding-3-small") + + // Generate text + response, err := ai.EmbedMany( + context.Background(), + model, + []string{ + "Artificial intelligence is the simulation of human intelligence in machines.", + "Machine learning is a subset of AI that enables systems to learn from data.", + }, + ) + if err != nil { + return err + } + + // Print the response: + printResponse(response) + + return nil +} + +func printResponse(response api.EmbeddingResponse) { + printer := pp.New() + printer.SetOmitEmpty(true) + printer.Print(response) +} + +func main() { + if err := example(); err != nil { + log.Fatal(err) + } +} diff --git a/aisdk/ai/examples/basic/simple-text/main.go b/aisdk/ai/examples/basic/simple-text/main.go index 5dd2f341..93920421 100644 --- a/aisdk/ai/examples/basic/simple-text/main.go +++ b/aisdk/ai/examples/basic/simple-text/main.go @@ -11,8 +11,11 @@ import ( ) func example() error { + // Initialize the OpenAI provider + provider := openai.NewProvider() + // Create a model - model := openai.NewLanguageModel("gpt-4o-mini") + model := provider.NewLanguageModel("gpt-4o-mini") // Generate text response, err := ai.GenerateTextStr( diff --git a/aisdk/ai/examples/basic/streaming-text/main.go b/aisdk/ai/examples/basic/streaming-text/main.go index ea55e007..892b44b9 100644 --- a/aisdk/ai/examples/basic/streaming-text/main.go +++ b/aisdk/ai/examples/basic/streaming-text/main.go @@ -12,7 +12,8 @@ import ( func example() error { // Create a model - model := openai.NewLanguageModel("gpt-4o-mini") + provider := openai.NewProvider() + model := provider.NewLanguageModel("gpt-4o-mini") // Stream text response, err := ai.StreamTextStr( diff --git a/aisdk/ai/provider/openai/embedding.go b/aisdk/ai/provider/openai/embedding.go new file mode 100644 index 00000000..fd023229 --- /dev/null +++ b/aisdk/ai/provider/openai/embedding.go @@ -0,0 +1,77 @@ +package openai + +import ( + "context" + "fmt" + + "go.jetify.com/ai/api" + "go.jetify.com/ai/provider/openai/internal/codec" +) + +// EmbeddingModel represents an OpenAI embedding model. +type EmbeddingModel struct { + modelID string + pc ProviderConfig +} + +var _ api.EmbeddingModel[string] = &EmbeddingModel{} + +// NewEmbeddingModel creates a new OpenAI embedding model. +func (p *Provider) NewEmbeddingModel(modelID string) *EmbeddingModel { + // Create model with provider's client + model := &EmbeddingModel{ + modelID: modelID, + pc: ProviderConfig{ + providerName: fmt.Sprintf("%s.embedding", p.name), + client: p.client, + }, + } + + return model +} + +func (m *EmbeddingModel) ProviderName() string { + return m.pc.providerName +} + +func (m *EmbeddingModel) SpecificationVersion() string { + return "v2" +} + +func (m *EmbeddingModel) ModelID() string { + return m.modelID +} + +// SupportsParallelCalls implements api.EmbeddingModel. +func (m *EmbeddingModel) SupportsParallelCalls() bool { + return true +} + +// MaxEmbeddingsPerCall implements api.EmbeddingModel. +func (m *EmbeddingModel) MaxEmbeddingsPerCall() *int { + max := 2048 + return &max +} + +// DoEmbed implements api.EmbeddingModel. +func (m *EmbeddingModel) DoEmbed( + ctx context.Context, + values []string, + opts api.EmbeddingOptions, +) (api.EmbeddingResponse, error) { + embeddingParams, openaiOpts, _, err := codec.EncodeEmbedding( + m.modelID, + values, + opts, + ) + if err != nil { + return api.EmbeddingResponse{}, err + } + + resp, err := m.pc.client.Embeddings.New(ctx, embeddingParams, openaiOpts...) + if err != nil { + return api.EmbeddingResponse{}, err + } + + return codec.DecodeEmbedding(resp) +} diff --git a/aisdk/ai/provider/openai/embedding_test.go b/aisdk/ai/provider/openai/embedding_test.go new file mode 100644 index 00000000..e63f68a5 --- /dev/null +++ b/aisdk/ai/provider/openai/embedding_test.go @@ -0,0 +1,188 @@ +package openai + +import ( + "net/http" + "testing" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" + "github.com/stretchr/testify/require" + "go.jetify.com/ai/api" + "go.jetify.com/pkg/httpmock" +) + +func TestDoEmbed(t *testing.T) { + standardInput := []string{"Hello", "World"} + + // Standard OpenAI response body used across tests + standardResponseBody := `{ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.0023064255, -0.009327292, 0.015797527], + "index": 0 + }, + { + "object": "embedding", + "embedding": [0.0072664247, -0.008545238, 0.017125098], + "index": 1 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 2, + "total_tokens": 2 + } + }` + + standardExchange := []httpmock.Exchange{ + { + Request: httpmock.Request{ + Method: http.MethodPost, + Path: "/embeddings", + Body: `{ + "input": ["Hello", "World"], + "model": "text-embedding-ada-002", + "encoding_format": "float" + }`, + }, + Response: httpmock.Response{ + StatusCode: http.StatusOK, + Body: standardResponseBody, + }, + }, + } + + tests := []struct { + name string + modelID string + input []string + options api.EmbeddingOptions + exchanges []httpmock.Exchange + wantErr bool + expectedResp api.EmbeddingResponse + skip bool + }{ + { + name: "successful embedding", + modelID: "text-embedding-ada-002", + input: standardInput, + exchanges: standardExchange, + expectedResp: api.EmbeddingResponse{ + Embeddings: []api.Embedding{ + {0.0023064255, -0.009327292, 0.015797527}, + {0.0072664247, -0.008545238, 0.017125098}, + }, + Usage: &api.EmbeddingUsage{ + PromptTokens: 2, + TotalTokens: 2, + }, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, + }, + { + name: "with custom headers", + modelID: "text-embedding-ada-002", + input: standardInput, + options: api.EmbeddingOptions{ + Headers: http.Header{ + "Custom-Header": []string{"test-value"}, + }, + }, + exchanges: []httpmock.Exchange{ + { + Request: httpmock.Request{ + Method: http.MethodPost, + Path: "/embeddings", + Headers: map[string]string{ + "Custom-Header": "test-value", + }, + Body: `{ + "input": ["Hello", "World"], + "model": "text-embedding-ada-002", + "encoding_format": "float" + }`, + }, + Response: httpmock.Response{ + StatusCode: http.StatusOK, + Body: standardResponseBody, + }, + }, + }, + expectedResp: api.EmbeddingResponse{ + Embeddings: []api.Embedding{ + {0.0023064255, -0.009327292, 0.015797527}, + {0.0072664247, -0.008545238, 0.017125098}, + }, + Usage: &api.EmbeddingUsage{ + PromptTokens: 2, + TotalTokens: 2, + }, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, + }, + } + + runDoEmbedTests(t, tests) +} + +func runDoEmbedTests(t *testing.T, tests []struct { + name string + modelID string + input []string + options api.EmbeddingOptions + exchanges []httpmock.Exchange + wantErr bool + expectedResp api.EmbeddingResponse + skip bool +}, +) { + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + if testCase.skip { + t.Skipf("Skipping test: %s", testCase.name) + } + + server := httpmock.NewServer(t, testCase.exchanges) + defer server.Close() + + // Set up client options for the OpenAI client + clientOptions := []option.RequestOption{ + option.WithBaseURL(server.BaseURL()), + option.WithAPIKey("test-key"), + option.WithMaxRetries(0), // Disable retries + } + + // Create client with options + client := openai.NewClient(clientOptions...) + + // Use custom model ID + modelID := testCase.modelID + + // instantiate the provider with the mocked client + provider := NewProvider(WithClient(client)) + + // Create model with the provider + model := provider.NewEmbeddingModel(modelID) + + // Build embedding options + resp, err := model.DoEmbed(t.Context(), testCase.input, testCase.options) + + if testCase.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, resp) + + // Compare response fields + require.Equal(t, testCase.expectedResp, resp) + }) + } +} diff --git a/aisdk/ai/provider/openai/internal/codec/decode_embedding.go b/aisdk/ai/provider/openai/internal/codec/decode_embedding.go new file mode 100644 index 00000000..5c66586c --- /dev/null +++ b/aisdk/ai/provider/openai/internal/codec/decode_embedding.go @@ -0,0 +1,35 @@ +package codec + +import ( + "net/http" + + "github.com/openai/openai-go/v2" + "go.jetify.com/ai/api" +) + +// DecodeEmbedding maps the OpenAI embedding API response to the unified api.EmbeddingResponse. +func DecodeEmbedding(resp *openai.CreateEmbeddingResponse) (api.EmbeddingResponse, error) { + if resp == nil { + return api.EmbeddingResponse{}, api.NewEmptyResponseBodyError("response from OpenAI embeddings API is nil") + } + + embs := make([]api.Embedding, len(resp.Data)) + for i, d := range resp.Data { + vec := make([]float64, len(d.Embedding)) + copy(vec, d.Embedding) + embs[i] = vec + } + + usage := &api.EmbeddingUsage{ + PromptTokens: resp.Usage.PromptTokens, + TotalTokens: resp.Usage.TotalTokens, + } + + return api.EmbeddingResponse{ + Embeddings: embs, + Usage: usage, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, nil +} diff --git a/aisdk/ai/provider/openai/internal/codec/decode_embedding_test.go b/aisdk/ai/provider/openai/internal/codec/decode_embedding_test.go new file mode 100644 index 00000000..a661bdd3 --- /dev/null +++ b/aisdk/ai/provider/openai/internal/codec/decode_embedding_test.go @@ -0,0 +1,113 @@ +package codec + +import ( + "net/http" + "testing" + + "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.jetify.com/ai/api" +) + +func TestDecodeEmbedding(t *testing.T) { + type tc struct { + name string + in *openai.CreateEmbeddingResponse + want api.EmbeddingResponse + wantErrSub string + } + + tests := []tc{ + { + name: "nil response -> error", + in: nil, + wantErrSub: "response from OpenAI embeddings API is nil", + }, + { + name: "maps data and usage; copies vectors; empty headers", + in: &openai.CreateEmbeddingResponse{ + Data: []openai.Embedding{ + {Embedding: []float64{0.1, 0.2, 0.3}}, + {Embedding: []float64{0.4, 0.5}}, + }, + Usage: openai.CreateEmbeddingResponseUsage{ + PromptTokens: 27, + TotalTokens: 27, + }, + }, + want: api.EmbeddingResponse{ + Embeddings: []api.Embedding{ + []float64{0.1, 0.2, 0.3}, + []float64{0.4, 0.5}, + }, + Usage: &api.EmbeddingUsage{ + PromptTokens: 27, + TotalTokens: 27, + }, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, + }, + { + name: "empty data yields empty embeddings and zero usage", + in: &openai.CreateEmbeddingResponse{ + Data: []openai.Embedding{}, + Usage: openai.CreateEmbeddingResponseUsage{ + PromptTokens: 0, + TotalTokens: 0, + }, + }, + want: api.EmbeddingResponse{ + Embeddings: []api.Embedding{}, + Usage: &api.EmbeddingUsage{ + PromptTokens: 0, + TotalTokens: 0, + }, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, + }, + { + name: "single long vector", + in: &openai.CreateEmbeddingResponse{ + Data: []openai.Embedding{ + {Embedding: []float64{1, 2, 3, 4, 5, 6}}, + }, + Usage: openai.CreateEmbeddingResponseUsage{ + PromptTokens: 12, + TotalTokens: 12, + }, + }, + want: api.EmbeddingResponse{ + Embeddings: []api.Embedding{ + []float64{1, 2, 3, 4, 5, 6}, + }, + Usage: &api.EmbeddingUsage{ + PromptTokens: 12, + TotalTokens: 12, + }, + RawResponse: &api.EmbeddingRawResponse{ + Headers: http.Header{}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecodeEmbedding(tt.in) + + if tt.wantErrSub != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrSub) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/aisdk/ai/provider/openai/internal/codec/encode_embedding.go b/aisdk/ai/provider/openai/internal/codec/encode_embedding.go new file mode 100644 index 00000000..b002a9da --- /dev/null +++ b/aisdk/ai/provider/openai/internal/codec/encode_embedding.go @@ -0,0 +1,48 @@ +package codec + +import ( + "net/http" + + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" + "go.jetify.com/ai/api" +) + +// EncodeEmbedding builds OpenAI params + request options from the unified API options. +func EncodeEmbedding( + modelID string, + values []string, + opts api.EmbeddingOptions, +) (openai.EmbeddingNewParams, []option.RequestOption, []api.CallWarning, error) { + var reqOpts []option.RequestOption + if opts.Headers != nil { + reqOpts = append(reqOpts, applyHeaders(opts.Headers)...) + } + + if opts.BaseURL != nil { + reqOpts = append(reqOpts, option.WithBaseURL(*opts.BaseURL)) + } + + params := openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel(modelID), + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: values, + }, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + } + + var warnings []api.CallWarning + + return params, reqOpts, warnings, nil +} + +// applyHeaders applies the provided HTTP headers to the request options. +func applyHeaders(headers http.Header) []option.RequestOption { + var reqOpts []option.RequestOption + for k, vs := range headers { + for _, v := range vs { + reqOpts = append(reqOpts, option.WithHeaderAdd(k, v)) + } + } + return reqOpts +} diff --git a/aisdk/ai/provider/openai/internal/codec/encode_embedding_test.go b/aisdk/ai/provider/openai/internal/codec/encode_embedding_test.go new file mode 100644 index 00000000..8d6e8759 --- /dev/null +++ b/aisdk/ai/provider/openai/internal/codec/encode_embedding_test.go @@ -0,0 +1,112 @@ +package codec + +import ( + "net/http" + "testing" + + "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.jetify.com/ai/api" +) + +func TestEncodeEmbedding(t *testing.T) { + tests := []struct { + name string + modelID string + values []string + headers http.Header + wantReqOptsLen int + wantWarningsLen int + expectedParams openai.EmbeddingNewParams + }{ + { + name: "no headers, two values", + modelID: "text-embedding-3-small", + values: []string{"hello", "world"}, + headers: nil, + wantReqOptsLen: 0, + wantWarningsLen: 0, + expectedParams: openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel("text-embedding-3-small"), + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"hello", "world"}, + }, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + }, + }, + { + name: "with single and multi-value headers", + modelID: "text-embedding-3-small", + values: []string{"a", "b", "c"}, + headers: func() http.Header { + h := http.Header{} + h.Add("X-One", "1") + h.Add("X-Multi", "A") + h.Add("X-Multi", "B") + return h + }(), + // 1 option for X-One + 2 options for X-Multi + wantReqOptsLen: 3, + wantWarningsLen: 0, + expectedParams: openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel("text-embedding-3-small"), + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"a", "b", "c"}, + }, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + }, + }, + { + name: "empty input slice", + modelID: "text-embedding-3-large", + values: []string{}, + headers: nil, + wantReqOptsLen: 0, + wantWarningsLen: 0, + expectedParams: openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel("text-embedding-3-large"), + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{}, + }, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + }, + }, + { + name: "different model id", + modelID: "text-embedding-3-small", + values: []string{"only one"}, + headers: http.Header{}, + wantReqOptsLen: 0, + wantWarningsLen: 0, + expectedParams: openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel("text-embedding-3-small"), + Input: openai.EmbeddingNewParamsInputUnion{ + OfArrayOfStrings: []string{"only one"}, + }, + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := api.EmbeddingOptions{Headers: tt.headers} + + params, reqOpts, warnings, err := EncodeEmbedding(tt.modelID, tt.values, opts) + require.NoError(t, err) + + // Request options (opaque type): assert count derived from headers + assert.Len(t, reqOpts, tt.wantReqOptsLen) + + // Warnings (currently none expected) + assert.Len(t, warnings, tt.wantWarningsLen) + + // Params: model id + assert.Equal(t, openai.EmbeddingModel(tt.modelID), params.Model) + + // Params: input union mirrors provided values + assert.Equal(t, tt.values, params.Input.OfArrayOfStrings) + }) + } +} diff --git a/aisdk/ai/provider/openai/llm.go b/aisdk/ai/provider/openai/llm.go index 11c8247c..b1078a0b 100644 --- a/aisdk/ai/provider/openai/llm.go +++ b/aisdk/ai/provider/openai/llm.go @@ -2,50 +2,36 @@ package openai import ( "context" + "fmt" - "github.com/openai/openai-go/v2" "go.jetify.com/ai/api" "go.jetify.com/ai/provider/openai/internal/codec" ) -// ModelOption is a function type that modifies a LanguageModel. -type ModelOption func(*LanguageModel) - -// WithClient returns a ModelOption that sets the client. -func WithClient(client openai.Client) ModelOption { - // TODO: Instead of only supporting a single client, we can "flatten" - // the options supported by the OpenAI SDK. - return func(m *LanguageModel) { - m.client = client - } -} - // LanguageModel represents an OpenAI language model. type LanguageModel struct { modelID string - client openai.Client + pc ProviderConfig } var _ api.LanguageModel = &LanguageModel{} // NewLanguageModel creates a new OpenAI language model. -func NewLanguageModel(modelID string, opts ...ModelOption) *LanguageModel { - // Create model with default settings +func (p *Provider) NewLanguageModel(modelID string) *LanguageModel { + // Create model with provider's client model := &LanguageModel{ modelID: modelID, - client: openai.NewClient(), // Default client - } - - // Apply options - for _, opt := range opts { - opt(model) + pc: ProviderConfig{ + providerName: fmt.Sprintf("%s.responses", p.name), + client: p.client, + }, } return model } func (m *LanguageModel) ProviderName() string { - return "openai" + return m.pc.providerName } func (m *LanguageModel) ModelID() string { @@ -72,7 +58,7 @@ func (m *LanguageModel) Generate( return nil, err } - openaiResponse, err := m.client.Responses.New(ctx, params) + openaiResponse, err := m.pc.client.Responses.New(ctx, params) if err != nil { return nil, err } @@ -96,7 +82,7 @@ func (m *LanguageModel) Stream( return nil, err } - stream := m.client.Responses.NewStreaming(ctx, params) + stream := m.pc.client.Responses.NewStreaming(ctx, params) response, err := codec.DecodeStream(stream) if err != nil { return nil, err diff --git a/aisdk/ai/provider/openai/llm_test.go b/aisdk/ai/provider/openai/llm_test.go index 1e2a1e53..361c39b0 100644 --- a/aisdk/ai/provider/openai/llm_test.go +++ b/aisdk/ai/provider/openai/llm_test.go @@ -2990,8 +2990,11 @@ func runGenerateTests(t *testing.T, tests []struct { // Use custom model ID modelID := testCase.modelID - // Create model with mocked client - model := NewLanguageModel(modelID, WithClient(client)) + // instantiate the provider with the mocked client + provider := NewProvider(WithClient(client)) + + // Create model with the provider + model := provider.NewLanguageModel(modelID) // Call Generate with the test's options (or empty if not specified) resp, err := model.Generate(t.Context(), testCase.prompt, testCase.options) @@ -3043,8 +3046,11 @@ func runStreamTests(t *testing.T, tests []struct { // Use custom model ID modelID := testCase.modelID - // Create model with mocked client - model := NewLanguageModel(modelID, WithClient(client)) + // instantiate the provider with the mocked client + provider := NewProvider(WithClient(client)) + + // Create model with the provider + model := provider.NewLanguageModel(modelID) // Call Stream with the test's options (or empty if not specified) resp, err := model.Stream(t.Context(), testCase.prompt, testCase.options) diff --git a/aisdk/ai/provider/openai/provider.go b/aisdk/ai/provider/openai/provider.go new file mode 100644 index 00000000..2135760e --- /dev/null +++ b/aisdk/ai/provider/openai/provider.go @@ -0,0 +1,34 @@ +package openai + +import ( + "github.com/openai/openai-go/v2" +) + +type Provider struct { + // client is the OpenAI client used to make API calls. + client openai.Client + // name is the name of the provider, overrides the default "openai". + name string +} + +type ProviderOption func(*Provider) + +func WithClient(c openai.Client) ProviderOption { + return func(p *Provider) { p.client = c } +} + +func WithName(name string) ProviderOption { + return func(p *Provider) { p.name = name } +} + +func NewProvider(opts ...ProviderOption) *Provider { + p := &Provider{client: openai.NewClient()} + for _, opt := range opts { + opt(p) + } + if p.name == "" { + p.name = "openai" + } + + return p +} diff --git a/aisdk/ai/provider/openai/provider_config.go b/aisdk/ai/provider/openai/provider_config.go new file mode 100644 index 00000000..d74cadb8 --- /dev/null +++ b/aisdk/ai/provider/openai/provider_config.go @@ -0,0 +1,8 @@ +package openai + +import "github.com/openai/openai-go/v2" + +type ProviderConfig struct { + providerName string + client openai.Client +}