Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 110 additions & 118 deletions go/plugins/compat_oai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,145 +251,110 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
stream := g.client.Chat.Completions.NewStreaming(ctx, *g.request)
defer stream.Close()

var fullResponse ai.ModelResponse
fullResponse.Message = &ai.Message{
Role: ai.RoleModel,
Content: make([]*ai.Part, 0),
}

// Initialize request and usage
fullResponse.Request = &ai.ModelRequest{}
fullResponse.Usage = &ai.GenerationUsage{
InputTokens: 0,
OutputTokens: 0,
TotalTokens: 0,
}

var currentToolCall *ai.ToolRequest
var currentArguments string
var toolCallCollects []struct {
toolCall *ai.ToolRequest
args string
}
// Use openai-go's accumulator to collect the complete response
acc := &openai.ChatCompletionAccumulator{}

for stream.Next() {
chunk := stream.Current()
if len(chunk.Choices) > 0 {
choice := chunk.Choices[0]
modelChunk := &ai.ModelResponseChunk{}

switch choice.FinishReason {
case "tool_calls", "stop":
fullResponse.FinishReason = ai.FinishReasonStop
case "length":
fullResponse.FinishReason = ai.FinishReasonLength
case "content_filter":
fullResponse.FinishReason = ai.FinishReasonBlocked
case "function_call":
fullResponse.FinishReason = ai.FinishReasonOther
default:
fullResponse.FinishReason = ai.FinishReasonUnknown
}
acc.AddChunk(chunk)

// handle tool calls
for _, toolCall := range choice.Delta.ToolCalls {
// first tool call (= current tool call is nil) contains the tool call name
if currentToolCall != nil && toolCall.ID != "" && currentToolCall.Ref != toolCall.ID {
toolCallCollects = append(toolCallCollects, struct {
toolCall *ai.ToolRequest
args string
}{
toolCall: currentToolCall,
args: currentArguments,
})
currentToolCall = nil
currentArguments = ""
}
if len(chunk.Choices) == 0 {
continue
}

if currentToolCall == nil {
currentToolCall = &ai.ToolRequest{
Name: toolCall.Function.Name,
Ref: toolCall.ID,
}
}
// Create chunk for callback
modelChunk := &ai.ModelResponseChunk{}

if toolCall.Function.Arguments != "" {
currentArguments += toolCall.Function.Arguments
}
// Handle content delta
if chunk.Choices[0].Delta.Content != "" {
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(chunk.Choices[0].Delta.Content))
}

// Handle tool call deltas
for _, toolCall := range chunk.Choices[0].Delta.ToolCalls {
// Send the incremental tool call part in the chunk
if toolCall.Function.Name != "" || toolCall.Function.Arguments != "" {
modelChunk.Content = append(modelChunk.Content, ai.NewToolRequestPart(&ai.ToolRequest{
Name: currentToolCall.Name,
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
Ref: currentToolCall.Ref,
Ref: toolCall.ID,
}))
}
}

// when tool call is complete
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
// parse accumulated arguments string
for _, toolcall := range toolCallCollects {
args, err := jsonStringToMap(toolcall.args)
if err != nil {
return nil, fmt.Errorf("could not parse tool args: %w", err)
}
toolcall.toolCall.Input = args
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
}
if currentArguments != "" {
args, err := jsonStringToMap(currentArguments)
if err != nil {
return nil, fmt.Errorf("could not parse tool args: %w", err)
}
currentToolCall.Input = args
}
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
}

content := chunk.Choices[0].Delta.Content
// when starting a tool call, the content is empty
if content != "" {
modelChunk.Content = append(modelChunk.Content, ai.NewTextPart(content))
fullResponse.Message.Content = append(fullResponse.Message.Content, modelChunk.Content...)
}

// Call the chunk handler with incremental data
if len(modelChunk.Content) > 0 {
if err := handleChunk(ctx, modelChunk); err != nil {
return nil, fmt.Errorf("callback error: %w", err)
}

fullResponse.Usage.InputTokens += int(chunk.Usage.PromptTokens)
fullResponse.Usage.OutputTokens += int(chunk.Usage.CompletionTokens)
fullResponse.Usage.TotalTokens += int(chunk.Usage.TotalTokens)
}
}

if err := stream.Err(); err != nil {
return nil, fmt.Errorf("stream error: %w", err)
}

return &fullResponse, nil
// Convert accumulated ChatCompletion to ai.ModelResponse
return convertChatCompletionToModelResponse(&acc.ChatCompletion)
}

// generateComplete generates a complete model response
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
if err != nil {
return nil, fmt.Errorf("failed to create completion: %w", err)
// convertChatCompletionToModelResponse converts openai.ChatCompletion to ai.ModelResponse
func convertChatCompletionToModelResponse(completion *openai.ChatCompletion) (*ai.ModelResponse, error) {
if len(completion.Choices) == 0 {
return nil, fmt.Errorf("no choices in completion")
}

choice := completion.Choices[0]

// Build usage information with detailed token breakdown
usage := &ai.GenerationUsage{
InputTokens: int(completion.Usage.PromptTokens),
OutputTokens: int(completion.Usage.CompletionTokens),
TotalTokens: int(completion.Usage.TotalTokens),
}

// Add reasoning tokens (thoughts tokens) if available
if completion.Usage.CompletionTokensDetails.ReasoningTokens > 0 {
usage.ThoughtsTokens = int(completion.Usage.CompletionTokensDetails.ReasoningTokens)
}

// Add cached tokens if available
if completion.Usage.PromptTokensDetails.CachedTokens > 0 {
usage.CachedContentTokens = int(completion.Usage.PromptTokensDetails.CachedTokens)
}

// Add audio tokens to custom field if available
if completion.Usage.CompletionTokensDetails.AudioTokens > 0 {
if usage.Custom == nil {
usage.Custom = make(map[string]float64)
}
usage.Custom["audioTokens"] = float64(completion.Usage.CompletionTokensDetails.AudioTokens)
}

// Add prediction tokens to custom field if available
if completion.Usage.CompletionTokensDetails.AcceptedPredictionTokens > 0 {
if usage.Custom == nil {
usage.Custom = make(map[string]float64)
}
usage.Custom["acceptedPredictionTokens"] = float64(completion.Usage.CompletionTokensDetails.AcceptedPredictionTokens)
}
if completion.Usage.CompletionTokensDetails.RejectedPredictionTokens > 0 {
if usage.Custom == nil {
usage.Custom = make(map[string]float64)
}
usage.Custom["rejectedPredictionTokens"] = float64(completion.Usage.CompletionTokensDetails.RejectedPredictionTokens)
}

resp := &ai.ModelResponse{
Request: req,
Usage: &ai.GenerationUsage{
InputTokens: int(completion.Usage.PromptTokens),
OutputTokens: int(completion.Usage.CompletionTokens),
TotalTokens: int(completion.Usage.TotalTokens),
},
Request: &ai.ModelRequest{},
Usage: usage,
Message: &ai.Message{
Role: ai.RoleModel,
Role: ai.RoleModel,
Content: make([]*ai.Part, 0),
},
}

choice := completion.Choices[0]

// Map finish reason
switch choice.FinishReason {
case "stop", "tool_calls":
resp.FinishReason = ai.FinishReasonStop
Expand All @@ -403,30 +368,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequ
resp.FinishReason = ai.FinishReasonUnknown
}

// handle tool calls
var toolRequestParts []*ai.Part
// Set finish message if there's a refusal
if choice.Message.Refusal != "" {
resp.FinishMessage = choice.Message.Refusal
resp.FinishReason = ai.FinishReasonBlocked
}

// Add text content
if choice.Message.Content != "" {
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(choice.Message.Content))
}

// Add tool calls
for _, toolCall := range choice.Message.ToolCalls {
args, err := jsonStringToMap(toolCall.Function.Arguments)
if err != nil {
return nil, err
return nil, fmt.Errorf("could not parse tool args: %w", err)
}
toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{
resp.Message.Content = append(resp.Message.Content, ai.NewToolRequestPart(&ai.ToolRequest{
Ref: toolCall.ID,
Name: toolCall.Function.Name,
Input: args,
}))
}

// content and tool call may exist simultaneously
if completion.Choices[0].Message.Content != "" {
resp.Message.Content = append(resp.Message.Content, ai.NewTextPart(completion.Choices[0].Message.Content))
// Store additional metadata in custom field if needed
if completion.SystemFingerprint != "" {
resp.Custom = map[string]any{
"systemFingerprint": completion.SystemFingerprint,
"model": completion.Model,
"id": completion.ID,
}
}

return resp, nil
}

// generateComplete generates a complete model response
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
if err != nil {
return nil, fmt.Errorf("failed to create completion: %w", err)
}

if len(toolRequestParts) > 0 {
resp.Message.Content = append(resp.Message.Content, toolRequestParts...)
return resp, nil
resp, err := convertChatCompletionToModelResponse(completion)
if err != nil {
return nil, err
}

// Set the original request
resp.Request = req

return resp, nil
}

Expand Down
Loading