Skip to content

Commit a5d0f2b

Browse files
authored
Ensure tool call parts are emitted immediately when streaming (#5634)
1 parent b448792 commit a5d0f2b

File tree

3 files changed

+108
-8
lines changed

3 files changed

+108
-8
lines changed

.changeset/social-walls-share.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@effect/ai": patch
3+
---
4+
5+
Ensure that tool calls are emitted as soon as possible when streaming

packages/ai/ai/src/LanguageModel.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@
5151
import * as Chunk from "effect/Chunk"
5252
import * as Context from "effect/Context"
5353
import * as Effect from "effect/Effect"
54+
import * as Mailbox from "effect/Mailbox"
5455
import * as Option from "effect/Option"
5556
import * as ParseResult from "effect/ParseResult"
5657
import * as Predicate from "effect/Predicate"
5758
import * as Schema from "effect/Schema"
59+
import type * as Scope from "effect/Scope"
5860
import * as Stream from "effect/Stream"
5961
import type { Span } from "effect/Tracer"
6062
import type { Concurrency, Mutable, NoExcessProperties } from "effect/Types"
@@ -797,7 +799,9 @@ export const make: (params: ConstructorParams) => Effect.Effect<Service> = Effec
797799
>(options: Options & GenerateTextOptions<Tools>, providerOptions: Mutable<ProviderOptions>) => Effect.Effect<
798800
Stream.Stream<Response.StreamPart<Tools>, AiError.AiError | ParseResult.ParseError, IdGenerator>,
799801
Options extends { readonly toolkit: Effect.Effect<Toolkit.WithHandler<Tools>, infer _E, infer _R> } ? _E : never,
800-
Options extends { readonly toolkit: Effect.Effect<Toolkit.WithHandler<Tools>, infer _E, infer _R> } ? _R : never
802+
| (Options extends { readonly toolkit: Effect.Effect<Toolkit.WithHandler<Tools>, infer _E, infer _R> } ? _R
803+
: never)
804+
| Scope.Scope
801805
> = Effect.fnUntraced(
802806
function*<
803807
Tools extends Record<string, Tool.Any>,
@@ -842,16 +846,21 @@ export const make: (params: ConstructorParams) => Effect.Effect<Service> = Effec
842846
) as Stream.Stream<Response.StreamPart<Tools>, AiError.AiError | ParseResult.ParseError, IdGenerator>
843847
}
844848

845-
const ResponseSchema = Schema.Chunk(Response.StreamPart(toolkit))
849+
const mailbox = yield* Mailbox.make<Response.StreamPart<Tools>, AiError.AiError | ParseResult.ParseError>()
850+
const ResponseSchema = Schema.Array(Response.StreamPart(toolkit))
846851
const decode = Schema.decode(ResponseSchema)
847-
return params.streamText(providerOptions).pipe(
848-
Stream.mapChunksEffect(Effect.fnUntraced(function*(chunk) {
852+
yield* params.streamText(providerOptions).pipe(
853+
Stream.runForEachChunk(Effect.fnUntraced(function*(chunk) {
849854
const rawContent = Chunk.toArray(chunk)
850-
const toolResults = yield* resolveToolCalls(rawContent, toolkit, options.concurrency)
851855
const content = yield* decode(rawContent)
852-
return Chunk.unsafeFromArray([...content, ...toolResults])
853-
}))
854-
) as Stream.Stream<Response.StreamPart<Tools>, AiError.AiError | ParseResult.ParseError, IdGenerator>
856+
yield* mailbox.offerAll(content)
857+
const toolResults = yield* resolveToolCalls(rawContent, toolkit, options.concurrency)
858+
yield* mailbox.offerAll(toolResults as any)
859+
})),
860+
Mailbox.into(mailbox),
861+
Effect.forkScoped
862+
)
863+
return Mailbox.toStream(mailbox)
855864
}
856865
)
857866

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import * as LanguageModel from "@effect/ai/LanguageModel"
2+
import * as Response from "@effect/ai/Response"
3+
import * as Tool from "@effect/ai/Tool"
4+
import * as Toolkit from "@effect/ai/Toolkit"
5+
import { assert, describe, it } from "@effect/vitest"
6+
import * as Effect from "effect/Effect"
7+
import * as Schema from "effect/Schema"
8+
import * as Stream from "effect/Stream"
9+
import * as TestClock from "effect/TestClock"
10+
import * as TestUtils from "./utilities.js"
11+
12+
const MyTool = Tool.make("MyTool", {
13+
parameters: { testParam: Schema.String },
14+
success: Schema.Struct({ testSuccess: Schema.String })
15+
})
16+
17+
const MyToolkit = Toolkit.make(MyTool)
18+
19+
const MyToolkitLayer = MyToolkit.toLayer({
20+
MyTool: () =>
21+
Effect.succeed({ testSuccess: "test-success" }).pipe(
22+
Effect.delay("10 seconds")
23+
)
24+
})
25+
26+
describe("LanguageModel", () => {
27+
describe("streamText", () => {
28+
it.effect("should emit tool calls before executing tool handlers", () =>
29+
Effect.gen(function*() {
30+
const parts: Array<Response.StreamPart<Toolkit.Tools<typeof MyToolkit>>> = []
31+
const latch = yield* Effect.makeLatch()
32+
33+
const toolCallId = "tool-abc123"
34+
const toolName = "MyTool"
35+
const toolParams = { testParam: "test-param" }
36+
const toolResult = { testSuccess: "test-success" }
37+
38+
yield* LanguageModel.streamText({
39+
prompt: [],
40+
toolkit: MyToolkit
41+
}).pipe(
42+
Stream.runForEach((part) =>
43+
Effect.andThen(latch.open, () => {
44+
parts.push(part)
45+
})
46+
),
47+
TestUtils.withLanguageModel({
48+
streamText: [
49+
{
50+
type: "tool-call",
51+
id: toolCallId,
52+
name: toolName,
53+
params: toolParams
54+
}
55+
]
56+
}),
57+
Effect.provide(MyToolkitLayer),
58+
Effect.fork
59+
)
60+
61+
yield* latch.await
62+
63+
const toolCallPart = Response.makePart("tool-call", {
64+
id: toolCallId,
65+
name: toolName,
66+
params: toolParams,
67+
providerExecuted: false
68+
})
69+
70+
const toolResultPart = Response.toolResultPart({
71+
id: toolCallId,
72+
name: toolName,
73+
result: toolResult,
74+
encodedResult: toolResult,
75+
isFailure: false,
76+
providerExecuted: false
77+
})
78+
79+
assert.deepStrictEqual(parts, [toolCallPart])
80+
81+
yield* TestClock.adjust("10 seconds")
82+
83+
assert.deepStrictEqual(parts, [toolCallPart, toolResultPart])
84+
}))
85+
})
86+
})

0 commit comments

Comments
 (0)