diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts index 667083ee5f..57ee649d6e 100644 --- a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts +++ b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts @@ -380,7 +380,7 @@ describe("BaseOpenAiCompatibleProvider", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 100, outputTokens: 50 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 100, outputTokens: 50 }) }) }) }) diff --git a/src/api/providers/__tests__/featherless.spec.ts b/src/api/providers/__tests__/featherless.spec.ts index b0b4c01b86..0051882885 100644 --- a/src/api/providers/__tests__/featherless.spec.ts +++ b/src/api/providers/__tests__/featherless.spec.ts @@ -123,11 +123,9 @@ describe("FeatherlessHandler", () => { chunks.push(chunk) } - expect(chunks).toEqual([ - { type: "reasoning", text: "Thinking..." }, - { type: "text", text: "Hello" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) + expect(chunks[0]).toEqual({ type: "reasoning", text: "Thinking..." }) + expect(chunks[1]).toEqual({ type: "text", text: "Hello" }) + expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 }) }) it("should fall back to base provider for non-DeepSeek models", async () => { @@ -145,10 +143,8 @@ describe("FeatherlessHandler", () => { chunks.push(chunk) } - expect(chunks).toEqual([ - { type: "text", text: "Test response" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) + expect(chunks[0]).toEqual({ type: "text", text: "Test response" }) + expect(chunks[1]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 }) }) it("should return default model when no model is specified", () => { @@ -226,7 +222,7 @@ describe("FeatherlessHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) it("createMessage should pass correct parameters to Featherless client for DeepSeek R1", async () => { diff --git a/src/api/providers/__tests__/fireworks.spec.ts b/src/api/providers/__tests__/fireworks.spec.ts index da0a8cf9a4..9b837fef60 100644 --- a/src/api/providers/__tests__/fireworks.spec.ts +++ b/src/api/providers/__tests__/fireworks.spec.ts @@ -384,7 +384,7 @@ describe("FireworksHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) it("createMessage should pass correct parameters to Fireworks client", async () => { @@ -494,10 +494,8 @@ describe("FireworksHandler", () => { chunks.push(chunk) } - expect(chunks).toEqual([ - { type: "text", text: "Hello" }, - { type: "text", text: " world" }, - { type: "usage", inputTokens: 5, outputTokens: 10 }, - ]) + expect(chunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(chunks[1]).toEqual({ type: "text", text: " world" }) + expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) }) }) diff --git a/src/api/providers/__tests__/groq.spec.ts b/src/api/providers/__tests__/groq.spec.ts index 2aee4ea052..3b6239fd76 100644 --- a/src/api/providers/__tests__/groq.spec.ts +++ b/src/api/providers/__tests__/groq.spec.ts @@ -112,9 +112,10 @@ describe("GroqHandler", () => { type: "usage", inputTokens: 10, outputTokens: 20, - cacheWriteTokens: 0, - cacheReadTokens: 0, }) + // cacheWriteTokens and cacheReadTokens will be undefined when 0 + expect(firstChunk.value.cacheWriteTokens).toBeUndefined() + expect(firstChunk.value.cacheReadTokens).toBeUndefined() // Check that totalCost is a number (we don't need to test the exact value as that's tested in cost.spec.ts) expect(typeof firstChunk.value.totalCost).toBe("number") }) @@ -151,9 +152,10 @@ describe("GroqHandler", () => { type: "usage", inputTokens: 100, outputTokens: 50, - cacheWriteTokens: 0, cacheReadTokens: 30, }) + // cacheWriteTokens will be undefined when 0 + expect(firstChunk.value.cacheWriteTokens).toBeUndefined() expect(typeof firstChunk.value.totalCost).toBe("number") }) diff --git a/src/api/providers/__tests__/io-intelligence.spec.ts b/src/api/providers/__tests__/io-intelligence.spec.ts index 3b46b79ee2..99dfcefea4 100644 --- a/src/api/providers/__tests__/io-intelligence.spec.ts +++ b/src/api/providers/__tests__/io-intelligence.spec.ts @@ -178,7 +178,7 @@ describe("IOIntelligenceHandler", () => { expect(results).toHaveLength(3) expect(results[0]).toEqual({ type: "text", text: "Hello" }) expect(results[1]).toEqual({ type: "text", text: " world" }) - expect(results[2]).toEqual({ + expect(results[2]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5, @@ -243,7 +243,7 @@ describe("IOIntelligenceHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) it("should return model info from cache when available", () => { diff --git a/src/api/providers/__tests__/sambanova.spec.ts b/src/api/providers/__tests__/sambanova.spec.ts index d8cae8bf80..f03f6720bb 100644 --- a/src/api/providers/__tests__/sambanova.spec.ts +++ b/src/api/providers/__tests__/sambanova.spec.ts @@ -113,7 +113,7 @@ describe("SambaNovaHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) it("createMessage should pass correct parameters to SambaNova client", async () => { diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 9db5350080..95b263ecac 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -252,7 +252,7 @@ describe("ZAiHandler", () => { const firstChunk = await stream.next() expect(firstChunk.done).toBe(false) - expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 }) + expect(firstChunk.value).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 20 }) }) it("createMessage should pass correct parameters to Z AI client", async () => { diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index ea0dc1b2e8..3d78ef75d1 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -5,13 +5,14 @@ import type { ModelInfo } from "@roo-code/types" import { type ApiHandlerOptions, getModelMaxOutputTokens } from "../../shared/api" import { XmlMatcher } from "../../utils/xml-matcher" -import { ApiStream } from "../transform/stream" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import { handleOpenAIError } from "./utils/openai-error-handler" +import { calculateApiCostOpenAI } from "../../shared/cost" type BaseOpenAiCompatibleProviderOptions = ApiHandlerOptions & { providerName: string @@ -94,6 +95,11 @@ export abstract class BaseOpenAiCompatibleProvider ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } + // Add thinking parameter if reasoning is enabled and model supports it + if (this.options.enableReasoningEffort && info.supportsReasoningBinary) { + ;(params as any).thinking = { type: "enabled" } + } + try { return this.client.chat.completions.create(params, requestOptions) } catch (error) { @@ -119,6 +125,8 @@ export abstract class BaseOpenAiCompatibleProvider const toolCallAccumulator = new Map() + let lastUsage: OpenAI.CompletionUsage | undefined + for await (const chunk of stream) { // Check for provider-specific error responses (e.g., MiniMax base_resp) const chunkAny = chunk as any @@ -137,10 +145,15 @@ export abstract class BaseOpenAiCompatibleProvider } } - if (delta && "reasoning_content" in delta) { - const reasoning_content = (delta.reasoning_content as string | undefined) || "" - if (reasoning_content?.trim()) { - yield { type: "reasoning", text: reasoning_content } + if (delta) { + for (const key of ["reasoning_content", "reasoning"] as const) { + if (key in delta) { + const reasoning_content = ((delta as any)[key] as string | undefined) || "" + if (reasoning_content?.trim()) { + yield { type: "reasoning", text: reasoning_content } + } + break + } } } @@ -176,11 +189,7 @@ export abstract class BaseOpenAiCompatibleProvider } if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } + lastUsage = chunk.usage } } @@ -198,20 +207,51 @@ export abstract class BaseOpenAiCompatibleProvider toolCallAccumulator.clear() } + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, this.getModel().info) + } + // Process any remaining content for (const processedChunk of matcher.final()) { yield processedChunk } } + protected processUsageMetrics(usage: any, modelInfo?: any): ApiStreamUsageChunk { + const inputTokens = usage?.prompt_tokens || 0 + const outputTokens = usage?.completion_tokens || 0 + const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0 + const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 + + const { totalCost } = modelInfo + ? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + : { totalCost: 0 } + + return { + type: "usage", + inputTokens, + outputTokens, + cacheWriteTokens: cacheWriteTokens || undefined, + cacheReadTokens: cacheReadTokens || undefined, + totalCost, + } + } + async completePrompt(prompt: string): Promise { - const { id: modelId } = this.getModel() + const { id: modelId, info: modelInfo } = this.getModel() + + const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { + model: modelId, + messages: [{ role: "user", content: prompt }], + } + + // Add thinking parameter if reasoning is enabled and model supports it + if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { + ;(params as any).thinking = { type: "enabled" } + } try { - const response = await this.client.chat.completions.create({ - model: modelId, - messages: [{ role: "user", content: prompt }], - }) + const response = await this.client.chat.completions.create(params) // Check for provider-specific error responses (e.g., MiniMax base_resp) const responseAny = response as any diff --git a/src/api/providers/groq.ts b/src/api/providers/groq.ts index c2f2dd19db..7583edc51c 100644 --- a/src/api/providers/groq.ts +++ b/src/api/providers/groq.ts @@ -1,22 +1,9 @@ import { type GroqModelId, groqDefaultModelId, groqModels } from "@roo-code/types" -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" -import type { ApiHandlerCreateMessageMetadata } from "../index" -import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" -import { calculateApiCostOpenAI } from "../../shared/cost" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" -// Enhanced usage interface to support Groq's cached token fields -interface GroqUsage extends OpenAI.CompletionUsage { - prompt_tokens_details?: { - cached_tokens?: number - } -} - export class GroqHandler extends BaseOpenAiCompatibleProvider { constructor(options: ApiHandlerOptions) { super({ @@ -29,50 +16,4 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: 0.5, }) } - - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const stream = await this.createStream(systemPrompt, messages, metadata) - - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta - - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } - - if (chunk.usage) { - yield* this.yieldUsage(chunk.usage as GroqUsage) - } - } - } - - private async *yieldUsage(usage: GroqUsage | undefined): ApiStream { - const { info } = this.getModel() - const inputTokens = usage?.prompt_tokens || 0 - const outputTokens = usage?.completion_tokens || 0 - - const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 - - // Groq does not track cache writes - const cacheWriteTokens = 0 - - // Calculate cost using OpenAI-compatible cost calculation - const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) - - yield { - type: "usage", - inputTokens, - outputTokens, - cacheWriteTokens, - cacheReadTokens, - totalCost, - } - } } diff --git a/src/api/providers/zai.ts b/src/api/providers/zai.ts index cc83945e48..76b6b87d5a 100644 --- a/src/api/providers/zai.ts +++ b/src/api/providers/zai.ts @@ -3,21 +3,12 @@ import { mainlandZAiModels, internationalZAiDefaultModelId, mainlandZAiDefaultModelId, - type InternationalZAiModelId, - type MainlandZAiModelId, type ModelInfo, ZAI_DEFAULT_TEMPERATURE, zaiApiLineConfigs, } from "@roo-code/types" -import { Anthropic } from "@anthropic-ai/sdk" -import OpenAI from "openai" - import type { ApiHandlerOptions } from "../../shared/api" -import { getModelMaxOutputTokens } from "../../shared/api" -import { convertToOpenAiMessages } from "../transform/openai-format" -import type { ApiHandlerCreateMessageMetadata } from "../index" -import { handleOpenAIError } from "./utils/openai-error-handler" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" @@ -37,67 +28,4 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider { defaultTemperature: ZAI_DEFAULT_TEMPERATURE, }) } - - protected override createStream( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - requestOptions?: OpenAI.RequestOptions, - ) { - const { id: model, info } = this.getModel() - - // Centralized cap: clamp to 20% of the context window (unless provider-specific exceptions apply) - const max_tokens = - getModelMaxOutputTokens({ - modelId: model, - model: info, - settings: this.options, - format: "openai", - }) ?? undefined - - const temperature = this.options.modelTemperature ?? this.defaultTemperature - - const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model, - max_tokens, - temperature, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - } - - // Add thinking parameter if reasoning is enabled and model supports it - const { id: modelId, info: modelInfo } = this.getModel() - if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { - ;(params as any).thinking = { type: "enabled" } - } - - try { - return this.client.chat.completions.create(params, requestOptions) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - } - - override async completePrompt(prompt: string): Promise { - const { id: modelId } = this.getModel() - - const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { - model: modelId, - messages: [{ role: "user", content: prompt }], - } - - // Add thinking parameter if reasoning is enabled and model supports it - const { info: modelInfo } = this.getModel() - if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { - ;(params as any).thinking = { type: "enabled" } - } - - try { - const response = await this.client.chat.completions.create(params) - return response.choices[0]?.message.content || "" - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - } }