Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
360ef92
first pass at langgraph streaming
brianstrauch May 1, 2026
0b3f2bd
Trim obvious comments in langgraph activity/config
brianstrauch May 1, 2026
0d1a536
Remove unrelated runtime tests
brianstrauch May 1, 2026
59be7fa
Tidy langgraph streaming tests
brianstrauch May 1, 2026
a4c5b9b
don't store workflowstream
brianstrauch May 1, 2026
2eb07f3
remove timeout
brianstrauch May 1, 2026
9ce7978
fix lint
brianstrauch May 1, 2026
642b4b9
Make langgraph streaming opt-in via streaming_topic
brianstrauch May 1, 2026
04ea665
add streaming support disclaimer
brianstrauch May 1, 2026
5e7881a
mention streaming in readme
brianstrauch May 4, 2026
32c2c2e
Validate WorkflowStream registration when streaming_topic is set
brianstrauch May 5, 2026
a62677a
Stream from workflow-side LangGraph nodes via in-workflow WorkflowStream
brianstrauch May 5, 2026
5802d88
Document streaming feature in README and plugin docstring
brianstrauch May 5, 2026
627ee3f
Drop compose-mechanisms paragraph from streaming README
brianstrauch May 5, 2026
bda1357
Support sync nodes for streaming and execute_in='workflow'
brianstrauch May 5, 2026
32818b1
Fix astream-publish test race with subscriber ack
brianstrauch May 6, 2026
42a226d
Add CODEOWNERS entries for langgraph contrib
brianstrauch May 6, 2026
86bf1bb
Drop blank line after wrap_activity docstring (D202)
brianstrauch May 6, 2026
f4f1d37
Skip workflow-side streaming tests on Python 3.10
brianstrauch May 6, 2026
7155776
Move 3.10 skip onto the parametrize value
brianstrauch May 6, 2026
88b873e
Fix streaming-ws test race with subscriber ack
brianstrauch May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion temporalio/contrib/langgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
"LangGraphPlugin",
"entrypoint",
"cache",
"entrypoint",
"graph",
]
Comment thread
brianstrauch marked this conversation as resolved.
44 changes: 30 additions & 14 deletions temporalio/contrib/langgraph/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Awaitable
from dataclasses import dataclass
from datetime import timedelta
from inspect import iscoroutinefunction, signature
from typing import Any, Callable

Expand All @@ -19,6 +20,7 @@
cache_lookup,
cache_put,
)
from temporalio.contrib.workflow_streams import WorkflowStreamClient

# Per-run dedupe so we only warn once when a user passes a Store via
# graph.compile(store=...) / @entrypoint(store=...). Cleared by
Expand Down Expand Up @@ -51,6 +53,9 @@ class ActivityOutput:

def wrap_activity(
func: Callable,
*,
streaming_topic: str | None = None,
Comment thread
brianstrauch marked this conversation as resolved.
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]:
"""Wrap a function as a Temporal activity that handles LangGraph config and interrupts."""
# Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks
Expand All @@ -59,20 +64,31 @@ def wrap_activity(
accepts_runtime = "runtime" in signature(func).parameters

async def wrapper(input: ActivityInput) -> ActivityOutput:
runtime = set_langgraph_config(input.langgraph_config)
kwargs = dict(input.kwargs)
if accepts_runtime:
kwargs["runtime"] = runtime
try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = func(*input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])
async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput:
runtime = set_langgraph_config(
input.langgraph_config, stream_writer=stream_writer
)
kwargs = dict(input.kwargs)
if accepts_runtime:
kwargs["runtime"] = runtime
Comment thread
brianstrauch marked this conversation as resolved.
try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = func(*input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])

if streaming_topic is None:
return await run(stream_writer=None)
async with WorkflowStreamClient.from_within_activity(
Comment thread
brianstrauch marked this conversation as resolved.
batch_interval=streaming_batch_interval,
) as client:
topic = client.topic(streaming_topic)
return await run(stream_writer=topic.publish)
Comment thread
brianstrauch marked this conversation as resolved.

return wrapper

Expand Down
10 changes: 7 additions & 3 deletions temporalio/contrib/langgraph/_langgraph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportMissingTypeStubs=false

import dataclasses
from typing import Any
from typing import Any, Callable

from langchain_core.runnables.config import var_child_runnable_config
from langgraph._internal._constants import (
Expand Down Expand Up @@ -93,7 +93,11 @@ def get_langgraph_config() -> dict[str, Any]:
}


def set_langgraph_config(config: dict[str, Any]) -> Runtime:
def set_langgraph_config(
config: dict[str, Any],
*,
stream_writer: Callable[[Any], None] | None = None,
) -> Runtime:
"""Restore a LangGraph runnable config from a serialized dict.

Returns the reconstructed Runtime so callers can re-inject it into the
Expand All @@ -112,7 +116,7 @@ def get_null_resume(consume: bool = False) -> Any:
execution_info_dict = config.get("execution_info")
runtime = Runtime(
context=config.get("context"),
stream_writer=lambda _: None,
stream_writer=stream_writer or (lambda _: None),
previous=config.get("previous"),
execution_info=(
ExecutionInfo(**execution_info_dict) if execution_info_dict else None
Expand Down
19 changes: 17 additions & 2 deletions temporalio/contrib/langgraph/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import warnings
from dataclasses import replace
from datetime import timedelta
from typing import Any, Callable

from langgraph._internal._runnable import RunnableCallable
Expand Down Expand Up @@ -58,8 +59,15 @@ def __init__(
# TODO: Remove activity_options when we have support for @task(metadata=...)
activity_options: dict[str, dict[str, Any]] | None = None,
default_activity_options: dict[str, Any] | None = None,
streaming_topic: str | None = None,
Comment thread
brianstrauch marked this conversation as resolved.
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
):
"""Initialize the LangGraph plugin with graphs, entrypoints, and tasks."""
"""Initialize the LangGraph plugin with graphs, entrypoints, and tasks.

.. warning::
Streaming support is experimental and may change in
future versions.
"""
Comment thread
brianstrauch marked this conversation as resolved.
if sys.version_info < (3, 11):
warnings.warn( # type: ignore[reportUnreachable]
"LangGraphPlugin requires Python >= 3.11 for full async support. "
Expand All @@ -79,6 +87,8 @@ def __init__(
)

self.activities: list = []
self._streaming_topic = streaming_topic
self._streaming_batch_interval = streaming_batch_interval

# Graph API: Wrap graph nodes as Temporal Activities.
if graphs:
Expand Down Expand Up @@ -197,7 +207,12 @@ def execute(
execute_in = opts.pop("execute_in")

if execute_in == "activity":
a = activity.defn(name=activity_name)(wrap_activity(func))
wrapped = wrap_activity(
func,
streaming_topic=self._streaming_topic,
streaming_batch_interval=self._streaming_batch_interval,
)
a = activity.defn(name=activity_name)(wrapped)
self.activities.append(a)
return wrap_execute_activity(a, task_id=task_id(func), **opts)
elif execute_in == "workflow":
Expand Down
135 changes: 117 additions & 18 deletions tests/contrib/langgraph/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,166 @@
from typing import Any
from uuid import uuid4

from langgraph.config import (
get_stream_writer, # pyright: ignore[reportMissingTypeStubs]
)
from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs]
from typing_extensions import TypedDict

from temporalio import workflow
from temporalio.client import Client
from temporalio.contrib.langgraph import LangGraphPlugin, graph
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
from temporalio.worker import Worker


class State(TypedDict):
value: str


async def node_a(state: State) -> dict[str, str]:
return {"value": state["value"] + "a"}
async def token_node(state: State) -> dict[str, str]:
tokens = ["a", "b", "c"]
writer = get_stream_writer()
Comment thread
brianstrauch marked this conversation as resolved.
for token in tokens:
writer({"token": token})
writer({"done": True})
return {"value": state["value"] + "".join(tokens)}


async def node_b(state: State) -> dict[str, str]:
return {"value": state["value"] + "b"}
@workflow.defn
class StreamingWorkflowStreamsWorkflow:
def __init__(self) -> None:
_ = WorkflowStream()
self.app = graph("streaming-ws").compile()

@workflow.run
async def run(self, input: str) -> str:
result = await self.app.ainvoke({"value": input})
return result["value"]


async def test_streaming_via_workflow_streams(client: Client):
g = StateGraph(State)
g.add_node("token_node", token_node, metadata={"execute_in": "activity"})
g.add_edge(START, "token_node")

task_queue = f"streaming-ws-{uuid4()}"

async with Worker(
client,
task_queue=task_queue,
workflows=[StreamingWorkflowStreamsWorkflow],
plugins=[
LangGraphPlugin(
graphs={"streaming-ws": g},
default_activity_options={
"start_to_close_timeout": timedelta(seconds=10)
},
streaming_topic="tokens",
)
],
):
handle = await client.start_workflow(
StreamingWorkflowStreamsWorkflow.run,
"",
id=f"test-streaming-ws-{uuid4()}",
task_queue=task_queue,
)

ws_client = WorkflowStreamClient.create(client, handle.id)
chunks: list[dict[str, Any]] = []
async for item in ws_client.topic("tokens", type=dict).subscribe(
from_offset=0,
poll_cooldown=timedelta(0),
Comment thread
brianstrauch marked this conversation as resolved.
Outdated
):
chunks.append(item.data)
if chunks[-1].get("done"):
break

result = await handle.result()

assert result == "abc"
assert chunks == [
{"token": "a"},
{"token": "b"},
{"token": "c"},
{"done": True},
]


# ---------------------------------------------------------------------------
# Workflow-side publish: iterate astream() in the workflow and forward each
# chunk via self.stream.topic("astream").publish(...) so external subscribers
# see node-level progress alongside any activity-emitted tokens.
# ---------------------------------------------------------------------------


@workflow.defn
class StreamingWorkflow:
class AstreamPublishWorkflow:
def __init__(self) -> None:
self.app = graph("streaming").compile()
self.stream = WorkflowStream()
self.app = graph("astream-publish").compile()

@workflow.run
async def run(self, input: str) -> Any:
chunks = []
async def run(self, input: str) -> str:
topic = self.stream.topic("astream")
async for chunk in self.app.astream({"value": input}):
chunks.append(chunk)
return chunks
topic.publish(chunk)
topic.publish({"done": True})
return "done"


async def node_a(state: State) -> dict[str, str]:
return {"value": state["value"] + "a"}


async def node_b(state: State) -> dict[str, str]:
return {"value": state["value"] + "b"}


async def test_streaming(client: Client):
async def test_workflow_publishes_astream_chunks(client: Client):
g = StateGraph(State)
g.add_node("node_a", node_a, metadata={"execute_in": "activity"})
g.add_node("node_b", node_b, metadata={"execute_in": "activity"})
g.add_edge(START, "node_a")
g.add_edge("node_a", "node_b")

task_queue = f"streaming-{uuid4()}"
task_queue = f"astream-publish-{uuid4()}"

async with Worker(
client,
task_queue=task_queue,
workflows=[StreamingWorkflow],
workflows=[AstreamPublishWorkflow],
plugins=[
LangGraphPlugin(
graphs={"streaming": g},
graphs={"astream-publish": g},
default_activity_options={
"start_to_close_timeout": timedelta(seconds=10)
},
)
],
):
chunks = await client.execute_workflow(
StreamingWorkflow.run,
handle = await client.start_workflow(
AstreamPublishWorkflow.run,
"",
id=f"test-streaming-{uuid4()}",
id=f"test-astream-publish-{uuid4()}",
task_queue=task_queue,
)

assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}]
ws_client = WorkflowStreamClient.create(client, handle.id)
chunks: list[dict[str, Any]] = []
async for item in ws_client.topic("astream", type=dict).subscribe(
from_offset=0,
poll_cooldown=timedelta(0),
Comment thread
brianstrauch marked this conversation as resolved.
Outdated
):
chunks.append(item.data)
if chunks[-1].get("done"):
break

await handle.result()

assert chunks == [
{"node_a": {"value": "a"}},
{"node_b": {"value": "ab"}},
{"done": True},
]
Loading