Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/six-roses-peel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@ai-sdk/react': patch
'ai': patch
---

Added finishReason on useChat onFinish callbck
9 changes: 8 additions & 1 deletion examples/next-openai/app/use-chat-data-ui-parts/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import ChatInput from '@/components/chat-input';
import { useChat } from '@ai-sdk/react';
import { DefaultChatTransport, UIMessage } from 'ai';
import { DefaultChatTransport, UIMessage, type FinishReason } from 'ai';
import { useState } from 'react';

type MyMessage = UIMessage<
never,
Expand All @@ -16,6 +17,7 @@ type MyMessage = UIMessage<
>;

export default function Chat() {
const [lastFinishReason, setLastFinishReason] = useState<FinishReason | undefined>(undefined);
const { error, status, sendMessage, messages, regenerate, stop } =
useChat<MyMessage>({
transport: new DefaultChatTransport({
Expand All @@ -24,6 +26,9 @@ export default function Chat() {
onData: dataPart => {
console.log('dataPart', JSON.stringify(dataPart, null, 2));
},
onFinish: ({ finishReason }) => {
setLastFinishReason(finishReason);
},
});

return (
Expand Down Expand Up @@ -94,6 +99,8 @@ export default function Chat() {
</div>
)}

{messages.length > 0 && <div className="mt-4 text-gray-500">Finish reason: {String(lastFinishReason)}</div>}

<ChatInput status={status} onSubmit={text => sendMessage({ text })} />
</div>
);
Expand Down
1 change: 1 addition & 0 deletions packages/ai/src/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2153,6 +2153,7 @@ However, the LLM results are expected to be small enough to not cause issues.
if (sendFinish) {
controller.enqueue({
type: 'finish',
finishReason: part.finishReason,
...(messageMetadataValue != null
? { messageMetadata: messageMetadataValue }
: {}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export function handleUIMessageStreamFinish<UI_MESSAGE extends UIMessage>({
...(isContinuation ? originalMessages.slice(0, -1) : originalMessages),
state.message,
] as UI_MESSAGE[],
finishReason: state.finishReason,
});
};

Expand Down
13 changes: 13 additions & 0 deletions packages/ai/src/ui-message-stream/ui-message-chunks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
ProviderMetadata,
providerMetadataSchema,
} from '../types/provider-metadata';
import { FinishReason } from '../types/language-model';
import {
InferUIMessageData,
InferUIMessageMetadata,
Expand Down Expand Up @@ -153,6 +154,17 @@ export const uiMessageChunkSchema = lazySchema(() =>
}),
z.strictObject({
type: z.literal('finish'),
finishReason: z
.enum([
'stop',
'length',
'content-filter',
'tool-calls',
'error',
'other',
'unknown',
] as const satisfies readonly FinishReason[])
.optional(),
messageMetadata: z.unknown().optional(),
}),
z.strictObject({
Expand Down Expand Up @@ -308,6 +320,7 @@ export type UIMessageChunk<
}
| {
type: 'finish';
finishReason?: FinishReason;
messageMetadata?: METADATA;
}
| {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { FinishReason } from '../types/language-model';
import { UIMessage } from '../ui/ui-messages';

export type UIMessageStreamOnFinishCallback<UI_MESSAGE extends UIMessage> =
Expand All @@ -23,4 +24,9 @@ export type UIMessageStreamOnFinishCallback<UI_MESSAGE extends UIMessage> =
* (including the original message if it was extended).
*/
responseMessage: UI_MESSAGE;

/**
* The reason why the generation finished.
*/
finishReason?: FinishReason;
}) => PromiseLike<void> | void;
13 changes: 8 additions & 5 deletions packages/ai/src/ui/chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ describe('Chat', () => {
formatChunk({ type: 'text-delta', id: 'text-1', delta: '.' }),
formatChunk({ type: 'text-end', id: 'text-1' }),
formatChunk({ type: 'finish-step' }),
formatChunk({ type: 'finish' }),
formatChunk({ type: 'finish', finishReason: 'stop' }),
],
};

Expand Down Expand Up @@ -126,6 +126,7 @@ describe('Chat', () => {
expect(letOnFinishArgs).toMatchInlineSnapshot(`
[
{
"finishReason": "stop",
"isAbort": false,
"isDisconnect": false,
"isError": false,
Expand Down Expand Up @@ -484,6 +485,7 @@ describe('Chat', () => {
expect(letOnFinishArgs).toMatchInlineSnapshot(`
[
{
"finishReason": undefined,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked in detail yet, but is there a case where finishReason would be undefined? We might need to update our fixtures in the tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, finishReason can be undefined when there’s no finish event (abort, disconnect, server error). For normal completions it should be set (usually 'stop').

I’ve updated fixtures: normal paths include finishReason: 'stop'; abnormal paths keep it undefined.

"isAbort": false,
"isDisconnect": true,
"isError": true,
Expand Down Expand Up @@ -719,6 +721,7 @@ describe('Chat', () => {
expect(letOnFinishArgs).toMatchInlineSnapshot(`
[
{
"finishReason": undefined,
"isAbort": true,
"isDisconnect": false,
"isError": false,
Expand Down Expand Up @@ -900,7 +903,7 @@ describe('Chat', () => {
}),
formatChunk({ type: 'text-end', id: 'text-1' }),
formatChunk({ type: 'finish-step' }),
formatChunk({ type: 'finish' }),
formatChunk({ type: 'finish', finishReason: 'stop' }),
],
};

Expand Down Expand Up @@ -1553,7 +1556,7 @@ describe('Chat', () => {

// finish stream
controller1.write(formatChunk({ type: 'finish-step' }));
controller1.write(formatChunk({ type: 'finish' }));
controller1.write(formatChunk({ type: 'finish', finishReason: 'stop' }));

await controller1.close();

Expand Down Expand Up @@ -1920,7 +1923,7 @@ describe('Chat', () => {
}),
formatChunk({ type: 'text-end', id: 'id-1' }),
formatChunk({ type: 'finish-step' }),
formatChunk({ type: 'finish' }),
formatChunk({ type: 'finish', finishReason: 'stop' }),
],
},
];
Expand Down Expand Up @@ -2423,7 +2426,7 @@ describe('Chat', () => {
}),
formatChunk({ type: 'text-end', id: 'txt-1' }),
formatChunk({ type: 'finish-step' }),
formatChunk({ type: 'finish' }),
formatChunk({ type: 'finish', finishReason: 'stop' }),
],
},
];
Expand Down
4 changes: 4 additions & 0 deletions packages/ai/src/ui/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
IdGenerator,
InferSchema,
} from '@ai-sdk/provider-utils';
import { FinishReason } from '../types/language-model';
import { UIMessageChunk } from '../ui-message-stream/ui-message-chunks';
import { consumeStream } from '../util/consume-stream';
import { SerialJobExecutor } from '../util/serial-job-executor';
Expand Down Expand Up @@ -123,13 +124,15 @@ export type ChatOnDataCallback<UI_MESSAGE extends UIMessage> = (
* @param isAbort Indicates whether the request has been aborted.
* @param isDisconnect Indicates whether the request has been ended by a network error.
* @param isError Indicates whether the request has been ended by an error.
* @param finishReason The reason why the generation finished.
*/
export type ChatOnFinishCallback<UI_MESSAGE extends UIMessage> = (options: {
message: UI_MESSAGE;
messages: UI_MESSAGE[];
isAbort: boolean;
isDisconnect: boolean;
isError: boolean;
finishReason?: FinishReason;
}) => void;

export interface ChatInit<UI_MESSAGE extends UIMessage> {
Expand Down Expand Up @@ -685,6 +688,7 @@ export abstract class AbstractChat<UI_MESSAGE extends UIMessage> {
isAbort,
isDisconnect,
isError,
finishReason: this.activeResponse?.state.finishReason,
});
} catch (err) {
console.error(err);
Expand Down
5 changes: 5 additions & 0 deletions packages/ai/src/ui/process-ui-message-stream.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { FlexibleSchema, validateTypes } from '@ai-sdk/provider-utils';
import { ProviderMetadata } from '../types';
import { FinishReason } from '../types/language-model';
import {
DataUIMessageChunk,
InferUIMessageChunk,
Expand Down Expand Up @@ -41,6 +42,7 @@ export type StreamingUIMessageState<UI_MESSAGE extends UIMessage> = {
title?: string;
}
>;
finishReason?: FinishReason;
};

export function createStreamingUIMessageState<UI_MESSAGE extends UIMessage>({
Expand Down Expand Up @@ -635,6 +637,9 @@ export function processUIMessageStream<UI_MESSAGE extends UIMessage>({
}

case 'finish': {
if (chunk.finishReason != null) {
state.finishReason = chunk.finishReason;
}
await updateMessageMetadata(chunk.messageMetadata);
if (chunk.messageMetadata != null) {
write();
Expand Down
22 changes: 20 additions & 2 deletions packages/react/src/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { screen, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import {
DefaultChatTransport,
FinishReason,
isToolUIPart,
TextStreamChatTransport,
UIMessage,
Expand Down Expand Up @@ -84,7 +85,14 @@ describe('initial messages', () => {
});

describe('data protocol stream', () => {
let onFinishCalls: Array<{ message: UIMessage }> = [];
let onFinishCalls: Array<{
message: UIMessage;
messages: UIMessage[];
isAbort: boolean;
isDisconnect: boolean;
isError: boolean;
finishReason?: FinishReason;
}> = [];

setupTestComponent(
({ id: idParam }: { id: string }) => {
Expand Down Expand Up @@ -304,6 +312,7 @@ describe('data protocol stream', () => {
controller.write(
formatChunk({
type: 'finish',
finishReason: 'stop',
messageMetadata: {
example: 'metadata',
},
Expand Down Expand Up @@ -346,6 +355,7 @@ describe('data protocol stream', () => {
expect(onFinishCalls).toMatchInlineSnapshot(`
[
{
"finishReason": "stop",
"isAbort": false,
"isDisconnect": false,
"isError": false,
Expand Down Expand Up @@ -436,7 +446,14 @@ describe('data protocol stream', () => {
});

describe('text stream', () => {
let onFinishCalls: Array<{ message: UIMessage }> = [];
let onFinishCalls: Array<{
message: UIMessage;
messages: UIMessage[];
isAbort: boolean;
isDisconnect: boolean;
isError: boolean;
finishReason?: FinishReason;
}> = [];

setupTestComponent(() => {
const { messages, sendMessage } = useChat({
Expand Down Expand Up @@ -537,6 +554,7 @@ describe('text stream', () => {
expect(onFinishCalls).toMatchInlineSnapshot(`
[
{
"finishReason": undefined,
"isAbort": false,
"isDisconnect": false,
"isError": false,
Expand Down