diff --git a/core/__mocks__/@continuedev/fetch/index.ts b/core/__mocks__/@continuedev/fetch/index.ts new file mode 100644 index 0000000000..526cbe2eaa --- /dev/null +++ b/core/__mocks__/@continuedev/fetch/index.ts @@ -0,0 +1,16 @@ +import { vi } from "vitest"; + +export const fetchwithRequestOptions = vi.fn( + async (url, options, requestOptions) => { + console.log("Mocked fetch called with:", url, options, requestOptions); + return { + ok: true, + status: 200, + statusText: "OK", + }; + }, +); + +export const streamSse = vi.fn(function* () { + yield ""; +}); diff --git a/core/context/providers/GoogleContextProvider.ts b/core/context/providers/GoogleContextProvider.ts index fb8a04d40f..b27b9f5656 100644 --- a/core/context/providers/GoogleContextProvider.ts +++ b/core/context/providers/GoogleContextProvider.ts @@ -38,27 +38,35 @@ class GoogleContextProvider extends BaseContextProvider { body: payload, }); + if (!response.ok) { + throw new Error( + `Failed to fetch Google search results: ${response.statusText}`, + ); + } const results = await response.text(); + try { + const parsed = JSON.parse(results); + let content = `Google Search: ${query}\n\n`; + const answerBox = parsed.answerBox; - const jsonResults = JSON.parse(results); - let content = `Google Search: ${query}\n\n`; - const answerBox = jsonResults.answerBox; + if (answerBox) { + content += `Answer Box (${answerBox.title}): ${answerBox.answer}\n\n`; + } - if (answerBox) { - content += `Answer Box (${answerBox.title}): ${answerBox.answer}\n\n`; - } + for (const result of parsed.organic) { + content += `${result.title}\n${result.link}\n${result.snippet}\n\n`; + } - for (const result of jsonResults.organic) { - content += `${result.title}\n${result.link}\n${result.snippet}\n\n`; + return [ + { + content, + name: "Google Search", + description: "Google Search", + }, + ]; + } catch (e) { + throw new Error(`Failed to parse Google search results: ${results}`); } - - return [ - { - content, - name: "Google Search", - description: "Google Search", - }, - ]; } } diff --git a/core/context/providers/GreptileContextProvider.ts b/core/context/providers/GreptileContextProvider.ts index a13e4376f6..23a730bc9e 100644 --- a/core/context/providers/GreptileContextProvider.ts +++ b/core/context/providers/GreptileContextProvider.ts @@ -81,13 +81,17 @@ class GreptileContextProvider extends BaseContextProvider { } // Parse the response as JSON - const json = JSON.parse(rawText); - - return json.sources.map((source: any) => ({ - description: source.filepath, - content: `File: ${source.filepath}\nLines: ${source.linestart}-${source.lineend}\n\n${source.summary}`, - name: (source.filepath.split("/").pop() ?? "").split("\\").pop() ?? "", - })); + try { + const json = JSON.parse(rawText); + return json.sources.map((source: any) => ({ + description: source.filepath, + content: `File: ${source.filepath}\nLines: ${source.linestart}-${source.lineend}\n\n${source.summary}`, + name: + (source.filepath.split("/").pop() ?? "").split("\\").pop() ?? "", + })); + } catch (jsonError) { + throw new Error(`Failed to parse Greptile response:\n${rawText}`); + } } catch (error) { console.error("Error getting context items from Greptile:", error); throw new Error("Error getting context items from Greptile"); diff --git a/core/core.ts b/core/core.ts index 822d430eed..c90ca9ceb9 100644 --- a/core/core.ts +++ b/core/core.ts @@ -701,7 +701,7 @@ export class Core { this.messenger.send("toolCallPartialOutput", params); }; - return await callTool(tool, toolCall.function.arguments, { + return await callTool(tool, toolCall, { config, ide: this.ide, llm: config.selectedModelByRole.chat, diff --git a/core/indexing/LanceDbIndex.ts b/core/indexing/LanceDbIndex.ts index cf41beaca4..0bf6c94ca9 100644 --- a/core/indexing/LanceDbIndex.ts +++ b/core/indexing/LanceDbIndex.ts @@ -294,17 +294,27 @@ export class LanceDbIndex implements CodebaseIndex { ); const cachedItems = await stmt.all(); - const lanceRows: LanceDbRow[] = cachedItems.map( - ({ uuid, vector, startLine, endLine, contents }) => ({ - path, - uuid, - startLine, - endLine, - contents, - cachekey: cacheKey, - vector: JSON.parse(vector), - }), - ); + const lanceRows: LanceDbRow[] = []; + for (const item of cachedItems) { + try { + const vector = JSON.parse(item.vector); + const { uuid, startLine, endLine, contents } = item; + + cachedItems.push({ + path, + uuid, + startLine, + endLine, + contents, + cachekey: cacheKey, + vector, + }); + } catch (err) { + console.warn( + `LanceDBIndex, skipping ${item.path} due to invalid vector JSON:\n${item.vector}\n\nError: ${err}`, + ); + } + } if (lanceRows.length > 0) { if (needToCreateLanceTable) { diff --git a/core/llm/llm-pre-fetch.vitest.ts b/core/llm/llm-pre-fetch.vitest.ts new file mode 100644 index 0000000000..3db39fc3ef --- /dev/null +++ b/core/llm/llm-pre-fetch.vitest.ts @@ -0,0 +1,113 @@ +import { fetchwithRequestOptions } from "@continuedev/fetch"; +import * as openAiAdapters from "@continuedev/openai-adapters"; +import * as dotenv from "dotenv"; +import { beforeEach, describe, expect, test, vi } from "vitest"; +import { ChatMessage, ILLM } from ".."; +import Anthropic from "./llms/Anthropic"; +import Gemini from "./llms/Gemini"; +import OpenAI from "./llms/OpenAI"; + +dotenv.config(); + +vi.mock("@continuedev/fetch"); +vi.mock("@continuedev/openai-adapters"); + +async function dudLLMCall(llm: ILLM, messages: ChatMessage[]) { + try { + const abortController = new AbortController(); + const gen = llm.streamChat(messages, abortController.signal, {}); + await gen.next(); + await gen.return({ + completion: "", + completionOptions: { + model: "", + }, + modelTitle: "", + prompt: "", + }); + abortController.abort(); + } catch (e) { + console.error("Expected error", e); + } +} + +const invalidToolCallArg = '{"name": "Ali'; +const messagesWithInvalidToolCallArgs: ChatMessage[] = [ + { + role: "user", + content: "Call the say_hello tool", + }, + { + role: "assistant", + content: "", + toolCalls: [ + { + id: "tool_call_1", + type: "function", + function: { + name: "say_name", + arguments: invalidToolCallArg, + }, + }, + ], + }, + { + role: "user", + content: "This is my response", + }, +]; + +describe("LLM Pre-fetch", () => { + beforeEach(() => { + vi.resetAllMocks(); + // Log to verify the mock is properly set up + console.log("Mock setup:", openAiAdapters); + }); + + test("Invalid tool call args are ignored", async () => { + const anthropic = new Anthropic({ + model: "not-important", + apiKey: "invalid", + }); + await dudLLMCall(anthropic, messagesWithInvalidToolCallArgs); + expect(fetchwithRequestOptions).toHaveBeenCalledWith( + expect.any(URL), + { + method: "POST", + headers: expect.any(Object), + signal: expect.any(AbortSignal), + body: expect.stringContaining('"name":"say_name","input":{}'), + }, + expect.any(Object), + ); + + vi.clearAllMocks(); + const gemini = new Gemini({ model: "gemini-something", apiKey: "invalid" }); + await dudLLMCall(gemini, messagesWithInvalidToolCallArgs); + expect(fetchwithRequestOptions).toHaveBeenCalledWith( + expect.any(URL), + { + method: "POST", + // headers: expect.any(Object), + signal: expect.any(AbortSignal), + body: expect.stringContaining('"name":"say_name","args":{}'), + }, + expect.any(Object), + ); + + // OPENAI DOES NOT NEED TO CLEAR INVALID TOOL CALL ARGS BECAUSE IT STORES THEM IN STRINGS + vi.clearAllMocks(); + const openai = new OpenAI({ model: "gpt-something", apiKey: "invalid" }); + await dudLLMCall(openai, messagesWithInvalidToolCallArgs); + expect(fetchwithRequestOptions).toHaveBeenCalledWith( + expect.any(URL), + { + method: "POST", + headers: expect.any(Object), + signal: expect.any(AbortSignal), + body: expect.stringContaining(JSON.stringify(invalidToolCallArg)), + }, + expect.any(Object), + ); + }); +}); diff --git a/core/llm/llms/Anthropic.ts b/core/llm/llms/Anthropic.ts index b5906590ab..8d1d3a67f3 100644 --- a/core/llm/llms/Anthropic.ts +++ b/core/llm/llms/Anthropic.ts @@ -1,5 +1,6 @@ import { streamSse } from "@continuedev/fetch"; import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; +import { safeParseToolCallArgs } from "../../tools/parseArgs.js"; import { renderChatMessage, stripImages } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; @@ -66,7 +67,7 @@ class Anthropic extends BaseLLM { type: "tool_use", id: toolCall.id, name: toolCall.function?.name, - input: JSON.parse(toolCall.function?.arguments || "{}"), + input: safeParseToolCallArgs(toolCall), })), }; } else if (message.role === "thinking" && !message.redactedThinking) { diff --git a/core/llm/llms/Bedrock.ts b/core/llm/llms/Bedrock.ts index 83b9a98c13..ae595f3f33 100644 --- a/core/llm/llms/Bedrock.ts +++ b/core/llm/llms/Bedrock.ts @@ -15,6 +15,7 @@ import { CompletionOptions, LLMOptions, } from "../../index.js"; +import { safeParseToolCallArgs } from "../../tools/parseArgs.js"; import { renderChatMessage, stripImages } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { PROVIDER_TOOL_SUPPORT } from "../toolSupport.js"; @@ -408,7 +409,7 @@ class Bedrock extends BaseLLM { toolUse: { toolUseId: toolCall.id, name: toolCall.function?.name, - input: JSON.parse(toolCall.function?.arguments || "{}"), + input: safeParseToolCallArgs(toolCall), }, })), }; @@ -564,10 +565,14 @@ class Bedrock extends BaseLLM { const command = new InvokeModelCommand(input); const response = await client.send(command); if (response.body) { - const responseBody = JSON.parse( - new TextDecoder().decode(response.body), - ); - return this._extractEmbeddings(responseBody); + const decoder = new TextDecoder(); + const decoded = decoder.decode(response.body); + try { + const responseBody = JSON.parse(decoded); + return this._extractEmbeddings(responseBody); + } catch (e) { + console.error(`Error parsing response body from:\n${decoded}`, e); + } } return []; }), @@ -662,12 +667,19 @@ class Bedrock extends BaseLLM { throw new Error("Empty response received from Bedrock"); } - const responseBody = JSON.parse(new TextDecoder().decode(response.body)); - - // Sort results by index to maintain original order - return responseBody.results - .sort((a: any, b: any) => a.index - b.index) - .map((result: any) => result.relevance_score); + const decoder = new TextDecoder(); + const decoded = decoder.decode(response.body); + try { + const responseBody = JSON.parse(decoded); + // Sort results by index to maintain original order + return responseBody.results + .sort((a: any, b: any) => a.index - b.index) + .map((result: any) => result.relevance_score); + } catch (e) { + throw new Error( + `Error parsing JSON from Bedrock response body:\n${decoded}, ${JSON.stringify(e)}`, + ); + } } catch (error: unknown) { if (error instanceof Error) { if ("code" in error) { diff --git a/core/llm/llms/BedrockImport.ts b/core/llm/llms/BedrockImport.ts index 78701ca704..3ab45f373f 100644 --- a/core/llm/llms/BedrockImport.ts +++ b/core/llm/llms/BedrockImport.ts @@ -51,9 +51,15 @@ class BedrockImport extends BaseLLM { if (response.body) { for await (const item of response.body) { - const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes)); - if (chunk.outputs[0].text) { - yield chunk.outputs[0].text; + const decoder = new TextDecoder(); + const decoded = decoder.decode(item.chunk?.bytes); + try { + const chunk = JSON.parse(decoded); + if (chunk.outputs[0].text) { + yield chunk.outputs[0].text; + } + } catch (e) { + throw new Error(`Malformed JSON received from Bedrock: ${decoded}`); } } } diff --git a/core/llm/llms/Gemini.ts b/core/llm/llms/Gemini.ts index 63dd9cc3ac..c6d32b501f 100644 --- a/core/llm/llms/Gemini.ts +++ b/core/llm/llms/Gemini.ts @@ -9,6 +9,7 @@ import { TextMessagePart, ToolCallDelta, } from "../../index.js"; +import { safeParseToolCallArgs } from "../../tools/parseArgs.js"; import { renderChatMessage, stripImages } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { @@ -250,11 +251,11 @@ class Gemini extends BaseLLM { }; if (msg.toolCalls) { msg.toolCalls.forEach((toolCall) => { - if (toolCall.function?.name && toolCall.function?.arguments) { + if (toolCall.function?.name) { assistantMsg.parts.push({ functionCall: { name: toolCall.function.name, - args: JSON.parse(toolCall.function.arguments), + args: safeParseToolCallArgs(toolCall), }, }); } diff --git a/core/llm/llms/HuggingFaceTEI.ts b/core/llm/llms/HuggingFaceTEI.ts index a96840a422..175aedde02 100644 --- a/core/llm/llms/HuggingFaceTEI.ts +++ b/core/llm/llms/HuggingFaceTEI.ts @@ -52,11 +52,16 @@ class HuggingFaceTEIEmbeddingsProvider extends BaseLLM { }); if (!resp.ok) { const text = await resp.text(); - const embedError = JSON.parse(text) as TEIEmbedErrorResponse; - if (!embedError.error_type || !embedError.error) { - throw new Error(text); + let teiError: TEIEmbedErrorResponse | null = null; + try { + teiError = JSON.parse(text); + } catch (e) { + console.log(`Failed to parse TEI embed error response:\n${text}`, e); } - throw new TEIEmbedError(embedError); + if (teiError && (teiError.error_type || teiError.error)) { + throw new TEIEmbedError(teiError); + } + throw new Error(text); } return (await resp.json()) as number[][]; } diff --git a/core/llm/llms/SageMaker.ts b/core/llm/llms/SageMaker.ts index 056fed2ccf..bbb87c1310 100644 --- a/core/llm/llms/SageMaker.ts +++ b/core/llm/llms/SageMaker.ts @@ -157,14 +157,23 @@ class SageMaker extends BaseLLM { const response = await client.send(command); if (response.Body) { - const responseBody = JSON.parse(new TextDecoder().decode(response.Body)); - // If the body contains a key called "embedding" or "embeddings", return the value, otherwise return the whole body - if (responseBody.embedding) { - return responseBody.embedding; - } else if (responseBody.embeddings) { - return responseBody.embeddings; - } else { - return responseBody; + const decoder = new TextDecoder(); + const decoded = decoder.decode(response.Body); + try { + const responseBody = JSON.parse(decoded); + // If the body contains a key called "embedding" or "embeddings", return the value, otherwise return the whole body + if (responseBody.embedding) { + return responseBody.embedding; + } else if (responseBody.embeddings) { + return responseBody.embeddings; + } else { + return responseBody; + } + } catch (e) { + let message = e instanceof Error ? e.message : String(e); + throw new Error( + `Failed to parse response from SageMaker:\n${decoded}\nError: ${message}`, + ); } } } diff --git a/core/llm/openaiTypeConverters.ts b/core/llm/openaiTypeConverters.ts index a3a6ad8951..a24af02f7f 100644 --- a/core/llm/openaiTypeConverters.ts +++ b/core/llm/openaiTypeConverters.ts @@ -1,21 +1,14 @@ import { FimCreateParamsStreaming } from "@continuedev/openai-adapters/dist/apis/base"; import { - Chat, ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, ChatCompletionCreateParams, ChatCompletionMessageParam, - ChatCompletionUserMessageParam, CompletionCreateParams, } from "openai/resources/index"; -import { - ChatMessage, - CompletionOptions, - MessageContent, - TextMessagePart, -} from ".."; +import { ChatMessage, CompletionOptions, TextMessagePart } from ".."; export function toChatMessage( message: ChatMessage, @@ -51,7 +44,7 @@ export function toChatMessage( type: toolCall.type!, function: { name: toolCall.function?.name!, - arguments: toolCall.function?.arguments! || "{}", + arguments: toolCall.function?.arguments || "{}", }, })); } @@ -189,12 +182,12 @@ export function fromChatCompletionChunk( return { role: "assistant", content: "", - toolCalls: delta?.tool_calls.map((tool_call: any) => ({ + toolCalls: delta?.tool_calls.map((tool_call) => ({ id: tool_call.id, type: tool_call.type, function: { - name: tool_call.function.name, - arguments: tool_call.function.arguments, + name: tool_call.function?.name, + arguments: tool_call.function?.arguments, }, })), }; diff --git a/core/tools/callTool.ts b/core/tools/callTool.ts index 243c162bb1..6ea95610b3 100644 --- a/core/tools/callTool.ts +++ b/core/tools/callTool.ts @@ -1,4 +1,4 @@ -import { ContextItem, Tool, ToolExtras } from ".."; +import { ContextItem, Tool, ToolCall, ToolExtras } from ".."; import { MCPManagerSingleton } from "../context/mcp/MCPManagerSingleton"; import { canParseUrl } from "../util/url"; import { BuiltInToolNames } from "./builtIn"; @@ -14,6 +14,7 @@ import { requestRuleImpl } from "./implementations/requestRule"; import { runTerminalCommandImpl } from "./implementations/runTerminalCommand"; import { searchWebImpl } from "./implementations/searchWeb"; import { viewDiffImpl } from "./implementations/viewDiff"; +import { safeParseToolCallArgs } from "./parseArgs"; async function callHttpTool( url: string, @@ -170,14 +171,14 @@ async function callBuiltInTool( // Note: Edit tool is handled on client export async function callTool( tool: Tool, - callArgs: string, + toolCall: ToolCall, extras: ToolExtras, ): Promise<{ contextItems: ContextItem[]; errorMessage: string | undefined; }> { try { - const args = JSON.parse(callArgs || "{}"); + const args = safeParseToolCallArgs(toolCall); const contextItems = tool.uri ? await callToolFromUri(tool.uri, args, extras) : await callBuiltInTool(tool.function.name, args, extras); diff --git a/core/tools/parseArgs.ts b/core/tools/parseArgs.ts new file mode 100644 index 0000000000..36d75e0207 --- /dev/null +++ b/core/tools/parseArgs.ts @@ -0,0 +1,14 @@ +import { ToolCallDelta } from ".."; + +export function safeParseToolCallArgs( + toolCall: ToolCallDelta, +): Record { + try { + return JSON.parse(toolCall.function?.arguments ?? "{}"); + } catch (e) { + console.error( + `Failed to parse tool call arguments:\nTool call: ${toolCall.function?.name + " " + toolCall.id}\nArgs:${toolCall.function?.arguments}\n`, + ); + return {}; + } +} diff --git a/gui/src/redux/thunks/callCurrentTool.ts b/gui/src/redux/thunks/callCurrentTool.ts index 84270fdd8a..af8b4d4d9d 100644 --- a/gui/src/redux/thunks/callCurrentTool.ts +++ b/gui/src/redux/thunks/callCurrentTool.ts @@ -60,7 +60,7 @@ export const callCurrentTool = createAsyncThunk( output: clientToolOuput, respondImmediately, errorMessage: clientToolError, - } = await callClientTool(toolCallState.toolCall, { + } = await callClientTool(toolCallState, { dispatch, ideMessenger: extra.ideMessenger, streamId: state.session.codeBlockApplyStates.states.find( diff --git a/gui/src/util/clientTools/callClientTool.ts b/gui/src/util/clientTools/callClientTool.ts index 65f11fbf6e..cdf6a960f5 100644 --- a/gui/src/util/clientTools/callClientTool.ts +++ b/gui/src/util/clientTools/callClientTool.ts @@ -1,4 +1,4 @@ -import { ContextItem, ToolCall } from "core"; +import { ContextItem, ToolCallState } from "core"; import { BuiltInToolNames } from "core/tools/builtIn"; import { IIdeMessenger } from "../../context/IdeMessenger"; import { AppThunkDispatch, RootState } from "../../redux/store"; @@ -27,15 +27,15 @@ export type ClientToolImpl = ( ) => Promise; export async function callClientTool( - toolCall: ToolCall, + toolCallState: ToolCallState, extras: ClientToolExtras, ): Promise { - const args = JSON.parse(toolCall.function.arguments || "{}"); + const { toolCall, parsedArgs } = toolCallState; try { let output: ClientToolOutput; switch (toolCall.function.name) { case BuiltInToolNames.EditExistingFile: - output = await editToolImpl(args, toolCall.id, extras); + output = await editToolImpl(parsedArgs, toolCall.id, extras); break; default: throw new Error(`Invalid client tool name ${toolCall.function.name}`); diff --git a/packages/fetch/src/stream.ts b/packages/fetch/src/stream.ts index b1f5ab9061..fd84b76604 100644 --- a/packages/fetch/src/stream.ts +++ b/packages/fetch/src/stream.ts @@ -126,8 +126,12 @@ export async function* streamJSON(response: Response): AsyncGenerator { let position; while ((position = buffer.indexOf("\n")) >= 0) { const line = buffer.slice(0, position); - const data = JSON.parse(line); - yield data; + try { + const data = JSON.parse(line); + yield data; + } catch (e) { + throw new Error(`Malformed JSON sent from server: ${line}`); + } buffer = buffer.slice(position + 1); } } diff --git a/packages/openai-adapters/src/apis/Anthropic.ts b/packages/openai-adapters/src/apis/Anthropic.ts index 1bab370b0b..299f895af9 100644 --- a/packages/openai-adapters/src/apis/Anthropic.ts +++ b/packages/openai-adapters/src/apis/Anthropic.ts @@ -13,6 +13,7 @@ import { ChatCompletionCreateParams } from "openai/src/resources/index.js"; import { AnthropicConfig } from "../types.js"; import { chatChunk, chatChunkFromDelta, customFetch } from "../util.js"; import { EMPTY_CHAT_COMPLETION } from "../util/emptyChatCompletion.js"; +import { safeParseArgs } from "../util/parseArgs.js"; import { BaseLlmApi, CreateRerankResponse, @@ -105,7 +106,10 @@ export class AnthropicApi implements BaseLlmApi { type: "tool_use", id: toolCall.id, name: toolCall.function?.name, - input: JSON.parse(toolCall.function?.arguments || "{}"), + input: safeParseArgs( + toolCall.function?.arguments, + `${toolCall.function?.name} ${toolCall.id}`, + ), })), }; } diff --git a/packages/openai-adapters/src/apis/Gemini.ts b/packages/openai-adapters/src/apis/Gemini.ts index 6b6f58870a..9df0c2e858 100644 --- a/packages/openai-adapters/src/apis/Gemini.ts +++ b/packages/openai-adapters/src/apis/Gemini.ts @@ -28,6 +28,7 @@ import { GeminiChatContentPart, GeminiToolFunctionDeclaration, } from "../util/gemini-types.js"; +import { safeParseArgs } from "../util/parseArgs.js"; import { BaseLlmApi, CreateRerankResponse, @@ -125,7 +126,10 @@ export class GeminiApi implements BaseLlmApi { functionCall: { id: toolCall.id, name: toolCall.function.name, - args: JSON.parse(toolCall.function.arguments || "{}"), + args: safeParseArgs( + toolCall.function.arguments, + `Call: ${toolCall.function.name} ${toolCall.id}`, + ), }, })), }; diff --git a/packages/openai-adapters/src/util/parseArgs.ts b/packages/openai-adapters/src/util/parseArgs.ts new file mode 100644 index 0000000000..8e7bb7380e --- /dev/null +++ b/packages/openai-adapters/src/util/parseArgs.ts @@ -0,0 +1,15 @@ +export function safeParseArgs( + args: string | undefined, + errorId?: string, +): Record { + try { + return JSON.parse(args ?? "{}"); + } catch (e) { + const identifier = errorId ? `Call: ${errorId}\nArgs:${args}\n` : ""; + console.error( + `Failed to parse tool call arguments\n${identifier}Error:`, + e, + ); + return {}; + } +}