Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import type { ModelInfo } from "@roo-code/types"

import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api"
import { XmlMatcher } from "../../utils/xml-matcher"
import { ApiStream } from "../transform/stream"
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { handleOpenAIError } from "./utils/openai-error-handler"
import { calculateApiCostOpenAI } from "../../shared/cost"

type BaseOpenAiCompatibleProviderOptions<ModelName extends string> = ApiHandlerOptions & {
providerName: string
Expand Down Expand Up @@ -119,6 +120,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()

let lastUsage: OpenAI.CompletionUsage | undefined

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
Expand All @@ -137,10 +140,15 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

if (delta && "reasoning_content" in delta) {
const reasoning_content = (delta.reasoning_content as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
}
break
}
}
}

Expand Down Expand Up @@ -176,20 +184,40 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
lastUsage = chunk.usage
}
}

if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}

// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
}

protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk {
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

const { totalCost } = modelInfo
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
: { totalCost: 0 }

return {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens: cacheWriteTokens || undefined,
cacheReadTokens: cacheReadTokens || undefined,
totalCost,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new processUsageMetrics method in the base class now includes cost calculation via calculateApiCostOpenAI. However, FeatherlessHandler has a custom createMessage implementation for DeepSeek-R1 models that yields usage without cost calculation (lines 79-85). This creates an inconsistency: non-R1 models (which use super.createMessage()) will get cost calculation, but R1 models won't. Consider updating the R1 path to also use processUsageMetrics or call calculateApiCostOpenAI directly to maintain consistent cost tracking across all Featherless models.

Fix it with Roo Code or mention @roomote and request a fix.

}

async completePrompt(prompt: string): Promise<string> {
const { id: modelId } = this.getModel()

Expand Down
59 changes: 0 additions & 59 deletions src/api/providers/groq.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import type { ApiHandlerOptions } from "../../shared/api"
import type { ApiHandlerCreateMessageMetadata } from "../index"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { calculateApiCostOpenAI } from "../../shared/cost"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

// Enhanced usage interface to support Groq's cached token fields
interface GroqUsage extends OpenAI.CompletionUsage {
prompt_tokens_details?: {
cached_tokens?: number
}
}

export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
constructor(options: ApiHandlerOptions) {
super({
Expand All @@ -29,50 +16,4 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
defaultTemperature: 0.5,
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield* this.yieldUsage(chunk.usage as GroqUsage)
}
}
}

private async *yieldUsage(usage: GroqUsage | undefined): ApiStream {
const { info } = this.getModel()
const inputTokens = usage?.prompt_tokens || 0
const outputTokens = usage?.completion_tokens || 0

const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0

// Groq does not track cache writes
const cacheWriteTokens = 0

// Calculate cost using OpenAI-compatible cost calculation
const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)

yield {
type: "usage",
inputTokens,
outputTokens,
cacheWriteTokens,
cacheReadTokens,
totalCost,
}
}
}
Loading