Skip to content

Commit b6c4411

Browse files
committed
refactor: extract mapToolChoice to shared ai-sdk utilities
- Move duplicated mapToolChoice function from all providers to ai-sdk.ts - Update fireworks, groq, deepseek, cerebras, openai-compatible providers - Consolidate tests for mapToolChoice in ai-sdk.spec.ts - Remove duplicate tests from individual provider test files
1 parent 8655db7 commit b6c4411

File tree

11 files changed

+122
-368
lines changed

11 files changed

+122
-368
lines changed

src/api/providers/__tests__/cerebras.spec.ts

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -452,51 +452,4 @@ describe("CerebrasHandler", () => {
452452
expect(toolCallChunks.length).toBe(0)
453453
})
454454
})
455-
456-
describe("mapToolChoice", () => {
457-
it("should handle string tool choices", () => {
458-
class TestCerebrasHandler extends CerebrasHandler {
459-
public testMapToolChoice(toolChoice: any) {
460-
return this.mapToolChoice(toolChoice)
461-
}
462-
}
463-
464-
const testHandler = new TestCerebrasHandler(mockOptions)
465-
466-
expect(testHandler.testMapToolChoice("auto")).toBe("auto")
467-
expect(testHandler.testMapToolChoice("none")).toBe("none")
468-
expect(testHandler.testMapToolChoice("required")).toBe("required")
469-
expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
470-
})
471-
472-
it("should handle object tool choice with function name", () => {
473-
class TestCerebrasHandler extends CerebrasHandler {
474-
public testMapToolChoice(toolChoice: any) {
475-
return this.mapToolChoice(toolChoice)
476-
}
477-
}
478-
479-
const testHandler = new TestCerebrasHandler(mockOptions)
480-
481-
const result = testHandler.testMapToolChoice({
482-
type: "function",
483-
function: { name: "my_tool" },
484-
})
485-
486-
expect(result).toEqual({ type: "tool", toolName: "my_tool" })
487-
})
488-
489-
it("should return undefined for null or undefined", () => {
490-
class TestCerebrasHandler extends CerebrasHandler {
491-
public testMapToolChoice(toolChoice: any) {
492-
return this.mapToolChoice(toolChoice)
493-
}
494-
}
495-
496-
const testHandler = new TestCerebrasHandler(mockOptions)
497-
498-
expect(testHandler.testMapToolChoice(null)).toBeUndefined()
499-
expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
500-
})
501-
})
502455
})

src/api/providers/__tests__/deepseek.spec.ts

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -733,51 +733,4 @@ describe("DeepSeekHandler", () => {
733733
expect(result).toBe(8192)
734734
})
735735
})
736-
737-
describe("mapToolChoice", () => {
738-
it("should handle string tool choices", () => {
739-
class TestDeepSeekHandler extends DeepSeekHandler {
740-
public testMapToolChoice(toolChoice: any) {
741-
return this.mapToolChoice(toolChoice)
742-
}
743-
}
744-
745-
const testHandler = new TestDeepSeekHandler(mockOptions)
746-
747-
expect(testHandler.testMapToolChoice("auto")).toBe("auto")
748-
expect(testHandler.testMapToolChoice("none")).toBe("none")
749-
expect(testHandler.testMapToolChoice("required")).toBe("required")
750-
expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
751-
})
752-
753-
it("should handle object tool choice with function name", () => {
754-
class TestDeepSeekHandler extends DeepSeekHandler {
755-
public testMapToolChoice(toolChoice: any) {
756-
return this.mapToolChoice(toolChoice)
757-
}
758-
}
759-
760-
const testHandler = new TestDeepSeekHandler(mockOptions)
761-
762-
const result = testHandler.testMapToolChoice({
763-
type: "function",
764-
function: { name: "my_tool" },
765-
})
766-
767-
expect(result).toEqual({ type: "tool", toolName: "my_tool" })
768-
})
769-
770-
it("should return undefined for null or undefined", () => {
771-
class TestDeepSeekHandler extends DeepSeekHandler {
772-
public testMapToolChoice(toolChoice: any) {
773-
return this.mapToolChoice(toolChoice)
774-
}
775-
}
776-
777-
const testHandler = new TestDeepSeekHandler(mockOptions)
778-
779-
expect(testHandler.testMapToolChoice(null)).toBeUndefined()
780-
expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
781-
})
782-
})
783736
})

src/api/providers/__tests__/fireworks.spec.ts

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -842,50 +842,4 @@ describe("FireworksHandler", () => {
842842
expect(toolCallChunks.length).toBe(0)
843843
})
844844
})
845-
846-
describe("mapToolChoice", () => {
847-
it("should map string tool choices correctly", () => {
848-
class TestFireworksHandler extends FireworksHandler {
849-
public testMapToolChoice(toolChoice: any) {
850-
return this.mapToolChoice(toolChoice)
851-
}
852-
}
853-
854-
const testHandler = new TestFireworksHandler(mockOptions)
855-
856-
expect(testHandler.testMapToolChoice("auto")).toBe("auto")
857-
expect(testHandler.testMapToolChoice("none")).toBe("none")
858-
expect(testHandler.testMapToolChoice("required")).toBe("required")
859-
expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
860-
})
861-
862-
it("should map object tool choices correctly", () => {
863-
class TestFireworksHandler extends FireworksHandler {
864-
public testMapToolChoice(toolChoice: any) {
865-
return this.mapToolChoice(toolChoice)
866-
}
867-
}
868-
869-
const testHandler = new TestFireworksHandler(mockOptions)
870-
871-
const result = testHandler.testMapToolChoice({
872-
type: "function",
873-
function: { name: "read_file" },
874-
})
875-
expect(result).toEqual({ type: "tool", toolName: "read_file" })
876-
})
877-
878-
it("should return undefined for null/undefined", () => {
879-
class TestFireworksHandler extends FireworksHandler {
880-
public testMapToolChoice(toolChoice: any) {
881-
return this.mapToolChoice(toolChoice)
882-
}
883-
}
884-
885-
const testHandler = new TestFireworksHandler(mockOptions)
886-
887-
expect(testHandler.testMapToolChoice(null)).toBeUndefined()
888-
expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
889-
})
890-
})
891845
})

src/api/providers/__tests__/groq.spec.ts

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -575,51 +575,4 @@ describe("GroqHandler", () => {
575575
expect(result).toBe(customMaxTokens)
576576
})
577577
})
578-
579-
describe("mapToolChoice", () => {
580-
it("should handle string tool choices", () => {
581-
class TestGroqHandler extends GroqHandler {
582-
public testMapToolChoice(toolChoice: any) {
583-
return this.mapToolChoice(toolChoice)
584-
}
585-
}
586-
587-
const testHandler = new TestGroqHandler(mockOptions)
588-
589-
expect(testHandler.testMapToolChoice("auto")).toBe("auto")
590-
expect(testHandler.testMapToolChoice("none")).toBe("none")
591-
expect(testHandler.testMapToolChoice("required")).toBe("required")
592-
expect(testHandler.testMapToolChoice("unknown")).toBe("auto")
593-
})
594-
595-
it("should handle object tool choice with function name", () => {
596-
class TestGroqHandler extends GroqHandler {
597-
public testMapToolChoice(toolChoice: any) {
598-
return this.mapToolChoice(toolChoice)
599-
}
600-
}
601-
602-
const testHandler = new TestGroqHandler(mockOptions)
603-
604-
const result = testHandler.testMapToolChoice({
605-
type: "function",
606-
function: { name: "my_tool" },
607-
})
608-
609-
expect(result).toEqual({ type: "tool", toolName: "my_tool" })
610-
})
611-
612-
it("should return undefined for null or undefined", () => {
613-
class TestGroqHandler extends GroqHandler {
614-
public testMapToolChoice(toolChoice: any) {
615-
return this.mapToolChoice(toolChoice)
616-
}
617-
}
618-
619-
const testHandler = new TestGroqHandler(mockOptions)
620-
621-
expect(testHandler.testMapToolChoice(null)).toBeUndefined()
622-
expect(testHandler.testMapToolChoice(undefined)).toBeUndefined()
623-
})
624-
})
625578
})

src/api/providers/cerebras.ts

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import { cerebrasModels, cerebrasDefaultModelId, type CerebrasModelId, type Mode
66

77
import type { ApiHandlerOptions } from "../../shared/api"
88

9-
import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk"
9+
import {
10+
convertToAiSdkMessages,
11+
convertToolsForAiSdk,
12+
processAiSdkStreamPart,
13+
mapToolChoice,
14+
} from "../transform/ai-sdk"
1015
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
1116
import { getModelParams } from "../transform/model-params"
1217

@@ -75,40 +80,6 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
7580
}
7681
}
7782

78-
/**
79-
* Map OpenAI tool_choice to AI SDK toolChoice format.
80-
*/
81-
protected mapToolChoice(
82-
toolChoice: any,
83-
): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
84-
if (!toolChoice) {
85-
return undefined
86-
}
87-
88-
// Handle string values
89-
if (typeof toolChoice === "string") {
90-
switch (toolChoice) {
91-
case "auto":
92-
return "auto"
93-
case "none":
94-
return "none"
95-
case "required":
96-
return "required"
97-
default:
98-
return "auto"
99-
}
100-
}
101-
102-
// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
103-
if (typeof toolChoice === "object" && "type" in toolChoice) {
104-
if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
105-
return { type: "tool", toolName: toolChoice.function.name }
106-
}
107-
}
108-
109-
return undefined
110-
}
111-
11283
/**
11384
* Get the max tokens parameter to include in the request.
11485
*/
@@ -143,7 +114,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
143114
temperature: this.options.modelTemperature ?? temperature ?? CEREBRAS_DEFAULT_TEMPERATURE,
144115
maxOutputTokens: this.getMaxOutputTokens(),
145116
tools: aiSdkTools,
146-
toolChoice: this.mapToolChoice(metadata?.tool_choice),
117+
toolChoice: mapToolChoice(metadata?.tool_choice),
147118
}
148119

149120
// Use streamText for streaming responses

src/api/providers/deepseek.ts

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import { deepSeekModels, deepSeekDefaultModelId, DEEP_SEEK_DEFAULT_TEMPERATURE,
66

77
import type { ApiHandlerOptions } from "../../shared/api"
88

9-
import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk"
9+
import {
10+
convertToAiSdkMessages,
11+
convertToolsForAiSdk,
12+
processAiSdkStreamPart,
13+
mapToolChoice,
14+
} from "../transform/ai-sdk"
1015
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
1116
import { getModelParams } from "../transform/model-params"
1217

@@ -83,40 +88,6 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan
8388
}
8489
}
8590

86-
/**
87-
* Map OpenAI tool_choice to AI SDK toolChoice format.
88-
*/
89-
protected mapToolChoice(
90-
toolChoice: any,
91-
): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
92-
if (!toolChoice) {
93-
return undefined
94-
}
95-
96-
// Handle string values
97-
if (typeof toolChoice === "string") {
98-
switch (toolChoice) {
99-
case "auto":
100-
return "auto"
101-
case "none":
102-
return "none"
103-
case "required":
104-
return "required"
105-
default:
106-
return "auto"
107-
}
108-
}
109-
110-
// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
111-
if (typeof toolChoice === "object" && "type" in toolChoice) {
112-
if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
113-
return { type: "tool", toolName: toolChoice.function.name }
114-
}
115-
}
116-
117-
return undefined
118-
}
119-
12091
/**
12192
* Get the max tokens parameter to include in the request.
12293
*/
@@ -152,7 +123,7 @@ export class DeepSeekHandler extends BaseProvider implements SingleCompletionHan
152123
temperature: this.options.modelTemperature ?? temperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE,
153124
maxOutputTokens: this.getMaxOutputTokens(),
154125
tools: aiSdkTools,
155-
toolChoice: this.mapToolChoice(metadata?.tool_choice),
126+
toolChoice: mapToolChoice(metadata?.tool_choice),
156127
}
157128

158129
// Use streamText for streaming responses

src/api/providers/fireworks.ts

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import { fireworksModels, fireworksDefaultModelId, type ModelInfo } from "@roo-c
66

77
import type { ApiHandlerOptions } from "../../shared/api"
88

9-
import { convertToAiSdkMessages, convertToolsForAiSdk, processAiSdkStreamPart } from "../transform/ai-sdk"
9+
import {
10+
convertToAiSdkMessages,
11+
convertToolsForAiSdk,
12+
processAiSdkStreamPart,
13+
mapToolChoice,
14+
} from "../transform/ai-sdk"
1015
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
1116
import { getModelParams } from "../transform/model-params"
1217

@@ -90,40 +95,6 @@ export class FireworksHandler extends BaseProvider implements SingleCompletionHa
9095
}
9196
}
9297

93-
/**
94-
* Map OpenAI tool_choice to AI SDK toolChoice format.
95-
*/
96-
protected mapToolChoice(
97-
toolChoice: any,
98-
): "auto" | "none" | "required" | { type: "tool"; toolName: string } | undefined {
99-
if (!toolChoice) {
100-
return undefined
101-
}
102-
103-
// Handle string values
104-
if (typeof toolChoice === "string") {
105-
switch (toolChoice) {
106-
case "auto":
107-
return "auto"
108-
case "none":
109-
return "none"
110-
case "required":
111-
return "required"
112-
default:
113-
return "auto"
114-
}
115-
}
116-
117-
// Handle object values (OpenAI ChatCompletionNamedToolChoice format)
118-
if (typeof toolChoice === "object" && "type" in toolChoice) {
119-
if (toolChoice.type === "function" && "function" in toolChoice && toolChoice.function?.name) {
120-
return { type: "tool", toolName: toolChoice.function.name }
121-
}
122-
}
123-
124-
return undefined
125-
}
126-
12798
/**
12899
* Get the max tokens parameter to include in the request.
129100
*/
@@ -158,7 +129,7 @@ export class FireworksHandler extends BaseProvider implements SingleCompletionHa
158129
temperature: this.options.modelTemperature ?? temperature ?? FIREWORKS_DEFAULT_TEMPERATURE,
159130
maxOutputTokens: this.getMaxOutputTokens(),
160131
tools: aiSdkTools,
161-
toolChoice: this.mapToolChoice(metadata?.tool_choice),
132+
toolChoice: mapToolChoice(metadata?.tool_choice),
162133
}
163134

164135
// Use streamText for streaming responses

0 commit comments

Comments
 (0)