Skip to content

Commit f2d2a5a

Browse files
authored
refactor: Simplify tool call handling in streaming chat completions (#6)
Signed-off-by: Eden Reich <[email protected]>
1 parent d033b2d commit f2d2a5a

File tree

2 files changed

+79
-10
lines changed

2 files changed

+79
-10
lines changed

src/client.ts

+70-9
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ export class InferenceGatewayClient {
213213
const decoder = new TextDecoder();
214214
let buffer = '';
215215

216+
const incompleteToolCalls = new Map<
217+
number,
218+
{
219+
id: string;
220+
type: ChatCompletionToolType;
221+
function: {
222+
name: string;
223+
arguments: string;
224+
};
225+
}
226+
>();
227+
216228
while (true) {
217229
const { done, value } = await reader.read();
218230
if (done) break;
@@ -226,6 +238,16 @@ export class InferenceGatewayClient {
226238
const data = line.slice(5).trim();
227239

228240
if (data === '[DONE]') {
241+
for (const [, toolCall] of incompleteToolCalls.entries()) {
242+
callbacks.onTool?.({
243+
id: toolCall.id,
244+
type: toolCall.type,
245+
function: {
246+
name: toolCall.function.name,
247+
arguments: toolCall.function.arguments,
248+
},
249+
});
250+
}
229251
callbacks.onFinish?.(null);
230252
return;
231253
}
@@ -242,15 +264,54 @@ export class InferenceGatewayClient {
242264

243265
const toolCalls = chunk.choices[0]?.delta?.tool_calls;
244266
if (toolCalls && toolCalls.length > 0) {
245-
const toolCall: SchemaChatCompletionMessageToolCall = {
246-
id: toolCalls[0].id || '',
247-
type: ChatCompletionToolType.function,
248-
function: {
249-
name: toolCalls[0].function?.name || '',
250-
arguments: toolCalls[0].function?.arguments || '',
251-
},
252-
};
253-
callbacks.onTool?.(toolCall);
267+
for (const toolCallChunk of toolCalls) {
268+
const index = toolCallChunk.index;
269+
270+
if (!incompleteToolCalls.has(index)) {
271+
incompleteToolCalls.set(index, {
272+
id: toolCallChunk.id || '',
273+
type: ChatCompletionToolType.function,
274+
function: {
275+
name: toolCallChunk.function?.name || '',
276+
arguments: toolCallChunk.function?.arguments || '',
277+
},
278+
});
279+
} else {
280+
const existingToolCall = incompleteToolCalls.get(index)!;
281+
282+
if (toolCallChunk.id) {
283+
existingToolCall.id = toolCallChunk.id;
284+
}
285+
286+
if (toolCallChunk.function?.name) {
287+
existingToolCall.function.name =
288+
toolCallChunk.function.name;
289+
}
290+
291+
if (toolCallChunk.function?.arguments) {
292+
existingToolCall.function.arguments +=
293+
toolCallChunk.function.arguments;
294+
}
295+
}
296+
}
297+
}
298+
299+
const finishReason = chunk.choices[0]?.finish_reason;
300+
if (
301+
finishReason === 'tool_calls' &&
302+
incompleteToolCalls.size > 0
303+
) {
304+
for (const [, toolCall] of incompleteToolCalls.entries()) {
305+
callbacks.onTool?.({
306+
id: toolCall.id,
307+
type: toolCall.type,
308+
function: {
309+
name: toolCall.function.name,
310+
arguments: toolCall.function.arguments,
311+
},
312+
});
313+
}
314+
incompleteToolCalls.clear();
254315
}
255316
} catch (e) {
256317
globalThis.console.error('Error parsing SSE data:', e);

tests/client.test.ts

+9-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,15 @@ describe('InferenceGatewayClient', () => {
320320

321321
expect(callbacks.onOpen).toHaveBeenCalledTimes(1);
322322
expect(callbacks.onChunk).toHaveBeenCalledTimes(6);
323-
expect(callbacks.onTool).toHaveBeenCalledTimes(4);
323+
expect(callbacks.onTool).toHaveBeenCalledTimes(1);
324+
expect(callbacks.onTool).toHaveBeenCalledWith({
325+
id: 'call_123',
326+
type: 'function',
327+
function: {
328+
name: 'get_weather',
329+
arguments: '{"location":"San Francisco, CA"}'
330+
}
331+
});
324332
expect(callbacks.onFinish).toHaveBeenCalledTimes(1);
325333
});
326334

0 commit comments

Comments
 (0)