Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aisdk/ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ import (

func main() {
// Set up your model
model := openai.NewLanguageModel("gpt-4o")
provider := openai.NewProvider()
model := provider.NewLanguageModel("gpt-4o-mini")

// Generate text
response, err := ai.GenerateTextStr(
Expand Down
13 changes: 13 additions & 0 deletions aisdk/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@ import (
"go.jetify.com/ai/api"
)

func EmbedMany[T any](
ctx context.Context, model api.EmbeddingModel[T], values []T, opts ...EmbeddingOption[T],
) (api.EmbeddingResponse, error) {
config := buildEmbeddingConfig(model, opts)
return embed(ctx, values, config)
}

func embed[T any](
ctx context.Context, values []T, opts EmbeddingOptions[T],
) (api.EmbeddingResponse, error) {
return opts.Model.DoEmbed(ctx, values, opts.EmbeddingOptions)
}

// TODO: do we want to rename from GenerateText to Generate and from StreamText to Stream?

// GenerateText uses a language model to generate a text response from a given prompt.
Expand Down
5 changes: 3 additions & 2 deletions aisdk/ai/api/embedding_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type EmbeddingModel[T any] interface {
//
// Naming: "do" prefix to prevent accidental direct usage of the method
// by the user.
DoEmbed(ctx context.Context, values []T, opts ...EmbeddingOption) EmbeddingResponse
DoEmbed(ctx context.Context, values []T, opts EmbeddingOptions) (EmbeddingResponse, error)
}

// EmbeddingResponse represents the response from generating embeddings.
Expand All @@ -55,7 +55,8 @@ type EmbeddingResponse struct {

// EmbeddingUsage represents token usage information.
type EmbeddingUsage struct {
Tokens int
PromptTokens int64
TotalTokens int64
}

// EmbeddingRawResponse contains raw response information for debugging.
Expand Down
18 changes: 10 additions & 8 deletions aisdk/ai/api/embedding_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import "net/http"
// EmbeddingOption represent the options for generating embeddings.
type EmbeddingOption func(*EmbeddingOptions)

// WithEmbeddingHeaders sets HTTP headers to be sent with the request.
// Only applicable for HTTP-based providers.
func WithEmbeddingHeaders(headers http.Header) EmbeddingOption {
return func(o *EmbeddingOptions) {
o.Headers = headers
}
}

// EmbeddingOptions represents the options for generating embeddings.
type EmbeddingOptions struct {
// Headers are additional HTTP headers to be sent with the request.
// Only applicable for HTTP-based providers.
Headers http.Header

// BaseURL is the base URL for the API endpoint.
BaseURL *string

// ProviderMetadata contains additional provider-specific metadata.
// The metadata is passed through to the provider from the AI SDK and enables
// provider-specific functionality that can be fully encapsulated in the provider.
ProviderMetadata *ProviderMetadata
}

func (o EmbeddingOptions) GetProviderMetadata() *ProviderMetadata { return o.ProviderMetadata }
3 changes: 2 additions & 1 deletion aisdk/ai/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ type modelWrapper struct {
var defaultLanguageModel atomic.Value

func init() {
model := openai.NewLanguageModel(openai.ChatModelGPT5)
provider := openai.NewProvider()
model := provider.NewLanguageModel(openai.ChatModelGPT5)
defaultLanguageModel.Store(&modelWrapper{model: model})
}

Expand Down
4 changes: 2 additions & 2 deletions aisdk/ai/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
func TestDefaultLanguageModel(t *testing.T) {
// Get current model and verify provider and model ID match expected values
originalModel := DefaultLanguageModel()
assert.Equal(t, "openai", originalModel.ProviderName())
assert.Equal(t, "openai.responses", originalModel.ProviderName())
assert.Equal(t, openai.ChatModelGPT5, originalModel.ModelID())

// Change model to different provider (Anthropic)
Expand All @@ -26,6 +26,6 @@ func TestDefaultLanguageModel(t *testing.T) {
SetDefaultLanguageModel(originalModel)

restoredModel := DefaultLanguageModel()
assert.Equal(t, "openai", restoredModel.ProviderName())
assert.Equal(t, "openai.responses", restoredModel.ProviderName())
assert.Equal(t, openai.ChatModelGPT5, restoredModel.ModelID())
}
63 changes: 63 additions & 0 deletions aisdk/ai/embedding_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ai

import (
"net/http"

"go.jetify.com/ai/api"
)

// EmbeddingOptions bundles the model + per-call embedding options.
type EmbeddingOptions[T any] struct {
EmbeddingOptions api.EmbeddingOptions
Model api.EmbeddingModel[T]
}

// EmbeddingOption mutates EmbeddingOptions.
type EmbeddingOption[T any] func(*EmbeddingOptions[T])

// WithEmbeddingHeaders sets extra HTTP headers for this embedding call.
// Only applies to HTTP-backed providers.
func WithEmbeddingHeaders[T any](headers http.Header) EmbeddingOption[T] {
return func(o *EmbeddingOptions[T]) {
o.EmbeddingOptions.Headers = headers
}
}

// WithEmbeddingProviderMetadata sets provider-specific metadata for the embedding call.
func WithEmbeddingProviderMetadata[T any](provider string, metadata any) EmbeddingOption[T] {
return func(o *EmbeddingOptions[T]) {
if o.EmbeddingOptions.ProviderMetadata == nil {
o.EmbeddingOptions.ProviderMetadata = api.NewProviderMetadata(map[string]any{})
}
o.EmbeddingOptions.ProviderMetadata.Set(provider, metadata)
}
}

// WithEmbeddingBaseURL sets the base URL for the embedding API endpoint.
func WithEmbeddingBaseURL[T any](baseURL string) EmbeddingOption[T] {
url := baseURL
return func(o *EmbeddingOptions[T]) {
o.EmbeddingOptions.BaseURL = &url
}
}

// WithEmbeddingEmbeddingOptions sets the entire api.EmbeddingOptions struct.
func WithEmbeddingEmbeddingOptions[T any](embeddingOptions api.EmbeddingOptions) EmbeddingOption[T] {
return func(o *EmbeddingOptions[T]) {
o.EmbeddingOptions = embeddingOptions
}
}

// buildEmbeddingConfig combines multiple options into a single EmbeddingOptions.
func buildEmbeddingConfig[T any](
model api.EmbeddingModel[T], opts []EmbeddingOption[T],
) EmbeddingOptions[T] {
config := EmbeddingOptions[T]{
EmbeddingOptions: api.EmbeddingOptions{},
Model: model,
}
for _, opt := range opts {
opt(&config)
}
return config
}
49 changes: 49 additions & 0 deletions aisdk/ai/examples/basic/simple-embedding/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"context"
"log"

"github.com/k0kubun/pp/v3"
"go.jetify.com/ai"
"go.jetify.com/ai/api"
"go.jetify.com/ai/provider/openai"
)

func example() error {
// Initialize the OpenAI provider
provider := openai.NewProvider()

// Create a model
model := provider.NewEmbeddingModel("text-embedding-3-small")

// Generate text
response, err := ai.EmbedMany(
context.Background(),
model,
[]string{
"Artificial intelligence is the simulation of human intelligence in machines.",
"Machine learning is a subset of AI that enables systems to learn from data.",
},
)
if err != nil {
return err
}

// Print the response:
printResponse(response)

return nil
}

func printResponse(response api.EmbeddingResponse) {
printer := pp.New()
printer.SetOmitEmpty(true)
printer.Print(response)
}

func main() {
if err := example(); err != nil {
log.Fatal(err)
}
}
5 changes: 4 additions & 1 deletion aisdk/ai/examples/basic/simple-text/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
)

func example() error {
// Initialize the OpenAI provider
provider := openai.NewProvider()

// Create a model
model := openai.NewLanguageModel("gpt-4o-mini")
model := provider.NewLanguageModel("gpt-4o-mini")

// Generate text
response, err := ai.GenerateTextStr(
Expand Down
3 changes: 2 additions & 1 deletion aisdk/ai/examples/basic/streaming-text/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (

func example() error {
// Create a model
model := openai.NewLanguageModel("gpt-4o-mini")
provider := openai.NewProvider()
model := provider.NewLanguageModel("gpt-4o-mini")

// Stream text
response, err := ai.StreamTextStr(
Expand Down
77 changes: 77 additions & 0 deletions aisdk/ai/provider/openai/embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package openai

import (
"context"
"fmt"

"go.jetify.com/ai/api"
"go.jetify.com/ai/provider/openai/internal/codec"
)

// EmbeddingModel represents an OpenAI embedding model.
type EmbeddingModel struct {
modelID string
pc ProviderConfig
}

var _ api.EmbeddingModel[string] = &EmbeddingModel{}

// NewEmbeddingModel creates a new OpenAI embedding model.
func (p *Provider) NewEmbeddingModel(modelID string) *EmbeddingModel {
// Create model with provider's client
model := &EmbeddingModel{
modelID: modelID,
pc: ProviderConfig{
providerName: fmt.Sprintf("%s.embedding", p.name),
client: p.client,
},
}

return model
}

func (m *EmbeddingModel) ProviderName() string {
return m.pc.providerName
}

func (m *EmbeddingModel) SpecificationVersion() string {
return "v2"
}

func (m *EmbeddingModel) ModelID() string {
return m.modelID
}

// SupportsParallelCalls implements api.EmbeddingModel.
func (m *EmbeddingModel) SupportsParallelCalls() bool {
return true
}

// MaxEmbeddingsPerCall implements api.EmbeddingModel.
func (m *EmbeddingModel) MaxEmbeddingsPerCall() *int {
max := 2048
return &max
}

// DoEmbed implements api.EmbeddingModel.
func (m *EmbeddingModel) DoEmbed(
ctx context.Context,
values []string,
opts api.EmbeddingOptions,
) (api.EmbeddingResponse, error) {
embeddingParams, openaiOpts, _, err := codec.EncodeEmbedding(
m.modelID,
values,
opts,
)
if err != nil {
return api.EmbeddingResponse{}, err
}

resp, err := m.pc.client.Embeddings.New(ctx, embeddingParams, openaiOpts...)
if err != nil {
return api.EmbeddingResponse{}, err
}

return codec.DecodeEmbedding(resp)
}
Loading