Skip to content

Commit

Permalink
astream_events: Add version parameter while method is in beta (langch…
Browse files Browse the repository at this point in the history
…ain-ai#16290)

Add a version parameter while the method is in beta phase.

The idea is to make it possible to minimize making breaking changes for users while we're iterating on schema.

Once the API is stable we can assign a default version requirement.
  • Loading branch information
eyurtsev authored Jan 19, 2024
1 parent 91230ef commit 4ef0ed4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
13 changes: 12 additions & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ async def astream_events(
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -793,7 +794,9 @@ async def reverse(s: str) -> str:
chain = RunnableLambda(func=reverse)
events = [event async for event in chain.astream_events("hello")]
events = [
event async for event in chain.astream_events("hello", version="v1")
]
# will produce the following events (run_id has been omitted for brevity):
[
Expand Down Expand Up @@ -823,6 +826,9 @@ async def reverse(s: str) -> str:
Args:
input: The input to the runnable.
config: The config to use for the runnable.
version: The version of the schema to use.
Currently only version 1 is available.
No default will be assigned until the API is stabilized.
include_names: Only include events from runnables with matching names.
include_types: Only include events from runnables with matching types.
include_tags: Only include events from runnables with matching tags.
Expand All @@ -836,6 +842,11 @@ async def reverse(s: str) -> str:
Returns:
An async stream of StreamEvents.
""" # noqa: E501
if version != "v1":
raise NotImplementedError(
'Only version "v1" of the schema is currently supported.'
)

from langchain_core.runnables.utils import (
_RootEventFilter,
)
Expand Down
48 changes: 32 additions & 16 deletions libs/core/tests/unit_tests/runnables/test_runnable_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def reverse(s: str) -> str:

chain = RunnableLambda(func=reverse)

events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
Expand Down Expand Up @@ -94,7 +94,7 @@ def reverse(s: str) -> str:
| r.with_config({"run_name": "2"})
| r.with_config({"run_name": "3"})
)
events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
Expand Down Expand Up @@ -209,7 +209,9 @@ def reverse(s: str) -> str:
| r.with_config({"run_name": "2", "tags": ["my_tag"]})
| r.with_config({"run_name": "3", "tags": ["my_tag"]})
)
events = await _collect_events(chain.astream_events("hello", include_names=["1"]))
events = await _collect_events(
chain.astream_events("hello", include_names=["1"], version="v1")
)
assert events == [
{
"data": {},
Expand Down Expand Up @@ -238,7 +240,9 @@ def reverse(s: str) -> str:
]

events = await _collect_events(
chain.astream_events("hello", include_tags=["my_tag"], exclude_names=["2"])
chain.astream_events(
"hello", include_tags=["my_tag"], exclude_names=["2"], version="v1"
)
)
assert events == [
{
Expand Down Expand Up @@ -272,7 +276,9 @@ async def test_event_stream_with_lambdas_from_lambda() -> None:
as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config(
{"run_name": "my_lambda"}
)
events = await _collect_events(as_lambdas.astream_events({"question": "hello"}))
events = await _collect_events(
as_lambdas.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
Expand Down Expand Up @@ -331,7 +337,9 @@ async def test_event_stream_with_simple_chain() -> None:
}
)

events = await _collect_events(chain.astream_events({"question": "hello"}))
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
Expand Down Expand Up @@ -497,7 +505,7 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:

# type ignores below because the tools don't appear to be runnables to type checkers
# we can remove as soon as that's fixed
events = await _collect_events(parameterless.astream_events({})) # type: ignore
events = await _collect_events(parameterless.astream_events({}, version="v1")) # type: ignore
assert events == [
{
"data": {"input": {}},
Expand Down Expand Up @@ -525,7 +533,7 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
},
]

events = await _collect_events(with_callbacks.astream_events({})) # type: ignore
events = await _collect_events(with_callbacks.astream_events({}, version="v1")) # type: ignore
assert events == [
{
"data": {"input": {}},
Expand All @@ -552,7 +560,9 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
"tags": [],
},
]
events = await _collect_events(with_parameters.astream_events({"x": 1, "y": "2"})) # type: ignore
events = await _collect_events(
with_parameters.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
{
"data": {"input": {"x": 1, "y": "2"}},
Expand Down Expand Up @@ -581,7 +591,7 @@ def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict:
]

events = await _collect_events(
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}) # type: ignore
with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v1") # type: ignore
)
assert events == [
{
Expand Down Expand Up @@ -634,7 +644,9 @@ async def test_event_stream_with_retriever() -> None:
),
]
)
events = await _collect_events(retriever.astream_events({"query": "hello"}))
events = await _collect_events(
retriever.astream_events({"query": "hello"}, version="v1")
)
assert events == [
{
"data": {
Expand Down Expand Up @@ -695,7 +707,7 @@ def format_docs(docs: List[Document]) -> str:
return ", ".join([doc.page_content for doc in docs])

chain = retriever | format_docs
events = await _collect_events(chain.astream_events("hello"))
events = await _collect_events(chain.astream_events("hello", version="v1"))
assert events == [
{
"data": {"input": "hello"},
Expand Down Expand Up @@ -796,7 +808,9 @@ def reverse(s: str) -> str:
# does not appear to be a runnable
chain = concat | reverse # type: ignore

events = await _collect_events(chain.astream_events({"a": "hello", "b": "world"}))
events = await _collect_events(
chain.astream_events({"a": "hello", "b": "world"}, version="v1")
)
assert events == [
{
"data": {"input": {"a": "hello", "b": "world"}},
Expand Down Expand Up @@ -878,7 +892,7 @@ def fail(inputs: str) -> None:
chain = RunnableLambda(success) | RunnableLambda(fail).with_retry(
stop_after_attempt=1,
)
iterable = chain.astream_events("q")
iterable = chain.astream_events("q", version="v1")

events = []

Expand Down Expand Up @@ -953,7 +967,9 @@ async def test_with_llm() -> None:
llm = FakeStreamingListLLM(responses=["abc"])

chain = prompt | llm
events = await _collect_events(chain.astream_events({"question": "hello"}))
events = await _collect_events(
chain.astream_events({"question": "hello"}, version="v1")
)
assert events == [
{
"data": {"input": {"question": "hello"}},
Expand Down Expand Up @@ -1061,5 +1077,5 @@ async def add_one(x: int) -> int:
assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4]

with pytest.raises(NotImplementedError):
async for _ in add_one_map.astream_events([1, 2, 3]):
async for _ in add_one_map.astream_events([1, 2, 3], version="v1"):
pass

0 comments on commit 4ef0ed4

Please sign in to comment.