diff --git a/.changeset/six-roses-peel.md b/.changeset/six-roses-peel.md new file mode 100644 index 000000000000..54fe7786e85c --- /dev/null +++ b/.changeset/six-roses-peel.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/react': patch +'ai': patch +--- + +Added finishReason on useChat onFinish callbck diff --git a/examples/next-openai/app/use-chat-data-ui-parts/page.tsx b/examples/next-openai/app/use-chat-data-ui-parts/page.tsx index d7e1d6f96a1d..73fbba5a7537 100644 --- a/examples/next-openai/app/use-chat-data-ui-parts/page.tsx +++ b/examples/next-openai/app/use-chat-data-ui-parts/page.tsx @@ -2,7 +2,8 @@ import ChatInput from '@/components/chat-input'; import { useChat } from '@ai-sdk/react'; -import { DefaultChatTransport, UIMessage } from 'ai'; +import { DefaultChatTransport, UIMessage, type FinishReason } from 'ai'; +import { useState } from 'react'; type MyMessage = UIMessage< never, @@ -16,6 +17,7 @@ type MyMessage = UIMessage< >; export default function Chat() { + const [lastFinishReason, setLastFinishReason] = useState(undefined); const { error, status, sendMessage, messages, regenerate, stop } = useChat({ transport: new DefaultChatTransport({ @@ -24,6 +26,9 @@ export default function Chat() { onData: dataPart => { console.log('dataPart', JSON.stringify(dataPart, null, 2)); }, + onFinish: ({ finishReason }) => { + setLastFinishReason(finishReason); + }, }); return ( @@ -94,6 +99,8 @@ export default function Chat() { )} + {messages.length > 0 &&
Finish reason: {String(lastFinishReason)}
} + sendMessage({ text })} /> ); diff --git a/packages/ai/src/generate-text/stream-text.ts b/packages/ai/src/generate-text/stream-text.ts index b0790911d3df..1b2a2c86b362 100644 --- a/packages/ai/src/generate-text/stream-text.ts +++ b/packages/ai/src/generate-text/stream-text.ts @@ -2153,6 +2153,7 @@ However, the LLM results are expected to be small enough to not cause issues. if (sendFinish) { controller.enqueue({ type: 'finish', + finishReason: part.finishReason, ...(messageMetadataValue != null ? { messageMetadata: messageMetadataValue } : {}), diff --git a/packages/ai/src/ui-message-stream/handle-ui-message-stream-finish.ts b/packages/ai/src/ui-message-stream/handle-ui-message-stream-finish.ts index 3ed973ab0f4c..e05ba84c2406 100644 --- a/packages/ai/src/ui-message-stream/handle-ui-message-stream-finish.ts +++ b/packages/ai/src/ui-message-stream/handle-ui-message-stream-finish.ts @@ -106,6 +106,7 @@ export function handleUIMessageStreamFinish({ ...(isContinuation ? originalMessages.slice(0, -1) : originalMessages), state.message, ] as UI_MESSAGE[], + finishReason: state.finishReason, }); }; diff --git a/packages/ai/src/ui-message-stream/ui-message-chunks.ts b/packages/ai/src/ui-message-stream/ui-message-chunks.ts index 0cda93fca633..d5bb34d3f221 100644 --- a/packages/ai/src/ui-message-stream/ui-message-chunks.ts +++ b/packages/ai/src/ui-message-stream/ui-message-chunks.ts @@ -3,6 +3,7 @@ import { ProviderMetadata, providerMetadataSchema, } from '../types/provider-metadata'; +import { FinishReason } from '../types/language-model'; import { InferUIMessageData, InferUIMessageMetadata, @@ -153,6 +154,17 @@ export const uiMessageChunkSchema = lazySchema(() => }), z.strictObject({ type: z.literal('finish'), + finishReason: z + .enum([ + 'stop', + 'length', + 'content-filter', + 'tool-calls', + 'error', + 'other', + 'unknown', + ] as const satisfies readonly FinishReason[]) + .optional(), messageMetadata: z.unknown().optional(), }), z.strictObject({ @@ -308,6 +320,7 @@ export type UIMessageChunk< } | { type: 'finish'; + finishReason?: FinishReason; messageMetadata?: METADATA; } | { diff --git a/packages/ai/src/ui-message-stream/ui-message-stream-on-finish-callback.ts b/packages/ai/src/ui-message-stream/ui-message-stream-on-finish-callback.ts index b301d5bb3cdc..c65f95b34be6 100644 --- a/packages/ai/src/ui-message-stream/ui-message-stream-on-finish-callback.ts +++ b/packages/ai/src/ui-message-stream/ui-message-stream-on-finish-callback.ts @@ -1,3 +1,4 @@ +import { FinishReason } from '../types/language-model'; import { UIMessage } from '../ui/ui-messages'; export type UIMessageStreamOnFinishCallback = @@ -23,4 +24,9 @@ export type UIMessageStreamOnFinishCallback = * (including the original message if it was extended). */ responseMessage: UI_MESSAGE; + + /** + * The reason why the generation finished. + */ + finishReason?: FinishReason; }) => PromiseLike | void; diff --git a/packages/ai/src/ui/chat.test.ts b/packages/ai/src/ui/chat.test.ts index 7b2fd90d1ed0..e0c1a58faf7c 100644 --- a/packages/ai/src/ui/chat.test.ts +++ b/packages/ai/src/ui/chat.test.ts @@ -96,7 +96,7 @@ describe('Chat', () => { formatChunk({ type: 'text-delta', id: 'text-1', delta: '.' }), formatChunk({ type: 'text-end', id: 'text-1' }), formatChunk({ type: 'finish-step' }), - formatChunk({ type: 'finish' }), + formatChunk({ type: 'finish', finishReason: 'stop' }), ], }; @@ -126,6 +126,7 @@ describe('Chat', () => { expect(letOnFinishArgs).toMatchInlineSnapshot(` [ { + "finishReason": "stop", "isAbort": false, "isDisconnect": false, "isError": false, @@ -484,6 +485,7 @@ describe('Chat', () => { expect(letOnFinishArgs).toMatchInlineSnapshot(` [ { + "finishReason": undefined, "isAbort": false, "isDisconnect": true, "isError": true, @@ -719,6 +721,7 @@ describe('Chat', () => { expect(letOnFinishArgs).toMatchInlineSnapshot(` [ { + "finishReason": undefined, "isAbort": true, "isDisconnect": false, "isError": false, @@ -900,7 +903,7 @@ describe('Chat', () => { }), formatChunk({ type: 'text-end', id: 'text-1' }), formatChunk({ type: 'finish-step' }), - formatChunk({ type: 'finish' }), + formatChunk({ type: 'finish', finishReason: 'stop' }), ], }; @@ -1553,7 +1556,7 @@ describe('Chat', () => { // finish stream controller1.write(formatChunk({ type: 'finish-step' })); - controller1.write(formatChunk({ type: 'finish' })); + controller1.write(formatChunk({ type: 'finish', finishReason: 'stop' })); await controller1.close(); @@ -1920,7 +1923,7 @@ describe('Chat', () => { }), formatChunk({ type: 'text-end', id: 'id-1' }), formatChunk({ type: 'finish-step' }), - formatChunk({ type: 'finish' }), + formatChunk({ type: 'finish', finishReason: 'stop' }), ], }, ]; @@ -2423,7 +2426,7 @@ describe('Chat', () => { }), formatChunk({ type: 'text-end', id: 'txt-1' }), formatChunk({ type: 'finish-step' }), - formatChunk({ type: 'finish' }), + formatChunk({ type: 'finish', finishReason: 'stop' }), ], }, ]; diff --git a/packages/ai/src/ui/chat.ts b/packages/ai/src/ui/chat.ts index c0f5e537eb41..1b505741289d 100644 --- a/packages/ai/src/ui/chat.ts +++ b/packages/ai/src/ui/chat.ts @@ -4,6 +4,7 @@ import { IdGenerator, InferSchema, } from '@ai-sdk/provider-utils'; +import { FinishReason } from '../types/language-model'; import { UIMessageChunk } from '../ui-message-stream/ui-message-chunks'; import { consumeStream } from '../util/consume-stream'; import { SerialJobExecutor } from '../util/serial-job-executor'; @@ -122,6 +123,7 @@ export type ChatOnDataCallback = ( * @param isAbort Indicates whether the request has been aborted. * @param isDisconnect Indicates whether the request has been ended by a network error. * @param isError Indicates whether the request has been ended by an error. + * @param finishReason The reason why the generation finished. */ export type ChatOnFinishCallback = (options: { message: UI_MESSAGE; @@ -129,6 +131,7 @@ export type ChatOnFinishCallback = (options: { isAbort: boolean; isDisconnect: boolean; isError: boolean; + finishReason?: FinishReason; }) => void; export interface ChatInit { @@ -687,6 +690,7 @@ export abstract class AbstractChat { isAbort, isDisconnect, isError, + finishReason: this.activeResponse?.state.finishReason, }); } catch (err) { console.error(err); diff --git a/packages/ai/src/ui/process-ui-message-stream.ts b/packages/ai/src/ui/process-ui-message-stream.ts index d6aa3d29ada8..bde25d2eacd1 100644 --- a/packages/ai/src/ui/process-ui-message-stream.ts +++ b/packages/ai/src/ui/process-ui-message-stream.ts @@ -1,5 +1,6 @@ import { FlexibleSchema, validateTypes } from '@ai-sdk/provider-utils'; import { ProviderMetadata } from '../types'; +import { FinishReason } from '../types/language-model'; import { DataUIMessageChunk, InferUIMessageChunk, @@ -41,6 +42,7 @@ export type StreamingUIMessageState = { title?: string; } >; + finishReason?: FinishReason; }; export function createStreamingUIMessageState({ @@ -635,6 +637,9 @@ export function processUIMessageStream({ } case 'finish': { + if (chunk.finishReason != null) { + state.finishReason = chunk.finishReason; + } await updateMessageMetadata(chunk.messageMetadata); if (chunk.messageMetadata != null) { write(); diff --git a/packages/react/src/use-chat.ui.test.tsx b/packages/react/src/use-chat.ui.test.tsx index 310e0f9169d3..9a9214b216d9 100644 --- a/packages/react/src/use-chat.ui.test.tsx +++ b/packages/react/src/use-chat.ui.test.tsx @@ -10,6 +10,7 @@ import { screen, waitFor } from '@testing-library/react'; import userEvent from '@testing-library/user-event'; import { DefaultChatTransport, + FinishReason, isToolUIPart, TextStreamChatTransport, UIMessage, @@ -84,7 +85,14 @@ describe('initial messages', () => { }); describe('data protocol stream', () => { - let onFinishCalls: Array<{ message: UIMessage }> = []; + let onFinishCalls: Array<{ + message: UIMessage; + messages: UIMessage[]; + isAbort: boolean; + isDisconnect: boolean; + isError: boolean; + finishReason?: FinishReason; + }> = []; setupTestComponent( ({ id: idParam }: { id: string }) => { @@ -304,6 +312,7 @@ describe('data protocol stream', () => { controller.write( formatChunk({ type: 'finish', + finishReason: 'stop', messageMetadata: { example: 'metadata', }, @@ -346,6 +355,7 @@ describe('data protocol stream', () => { expect(onFinishCalls).toMatchInlineSnapshot(` [ { + "finishReason": "stop", "isAbort": false, "isDisconnect": false, "isError": false, @@ -436,7 +446,14 @@ describe('data protocol stream', () => { }); describe('text stream', () => { - let onFinishCalls: Array<{ message: UIMessage }> = []; + let onFinishCalls: Array<{ + message: UIMessage; + messages: UIMessage[]; + isAbort: boolean; + isDisconnect: boolean; + isError: boolean; + finishReason?: FinishReason; + }> = []; setupTestComponent(() => { const { messages, sendMessage } = useChat({ @@ -537,6 +554,7 @@ describe('text stream', () => { expect(onFinishCalls).toMatchInlineSnapshot(` [ { + "finishReason": undefined, "isAbort": false, "isDisconnect": false, "isError": false,