Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,13 @@ def stream(
index = index + 1
if "index" not in block:
block["index"] = index
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)

if isinstance(chunk.message.content, list):
run_manager.on_llm_new_token(
json.dumps(chunk.message.content), chunk=chunk
)
Comment on lines +536 to +538
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to json.dumps the content or just call chunk.message.text as the original issue proposed?

my bias is to do chunk.message.text, which will represent tokens of text content. it will miss tokens from reasoning, tool calls, and other content, but I don't think on_llm_new_token was designed to support these well. so my thought is to go with well-defined behavior for now.

else:
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
chunks.append(chunk)
yield cast("AIMessageChunk", chunk.message)
yielded = True
Expand Down Expand Up @@ -663,9 +667,14 @@ async def astream(
index = index + 1
if "index" not in block:
block["index"] = index
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
if isinstance(chunk.message.content, list):
await run_manager.on_llm_new_token(
json.dumps(chunk.message.content), chunk=chunk
)
else:
await run_manager.on_llm_new_token(
chunk.message.content, chunk=chunk
)
chunks.append(chunk)
yield cast("AIMessageChunk", chunk.message)
yielded = True
Expand Down Expand Up @@ -1164,9 +1173,12 @@ def _generate_with_cache(
if run_manager:
if chunk.message.id is None:
chunk.message.id = run_id
run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
if isinstance(chunk.message.content, list):
run_manager.on_llm_new_token(
json.dumps(chunk.message.content), chunk=chunk
)
else:
run_manager.on_llm_new_token(chunk.message.content, chunk=chunk)
chunks.append(chunk)
yielded = True

Expand Down Expand Up @@ -1282,9 +1294,14 @@ async def _agenerate_with_cache(
if run_manager:
if chunk.message.id is None:
chunk.message.id = run_id
await run_manager.on_llm_new_token(
cast("str", chunk.message.content), chunk=chunk
)
if isinstance(chunk.message.content, list):
await run_manager.on_llm_new_token(
json.dumps(chunk.message.content), chunk=chunk
)
else:
await run_manager.on_llm_new_token(
chunk.message.content, chunk=chunk
)
chunks.append(chunk)
yielded = True

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Test base chat model."""

import json
import uuid
import warnings
from collections.abc import AsyncIterator, Iterator
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast

import pytest
from typing_extensions import override

from langchain_core.callbacks import (
AsyncCallbackHandler,
AsyncCallbackManagerForLLMRun,
BaseCallbackHandler,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import (
Expand Down Expand Up @@ -178,6 +181,216 @@ def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
eval_response(cb_sync, i)


def test_on_llm_new_token_with_string_content() -> None:
"""Test that on_llm_new_token receives string when content is string."""

class TokenCollector(BaseCallbackHandler):
"""Callback handler that collects tokens."""

def __init__(self) -> None:
self.tokens: list[str] = []
self.token_types: list[type] = []

def on_llm_new_token(
self,
token: str,
*,
_chunk: ChatGenerationChunk | None = None,
**_kwargs: Any,
) -> None:
"""Store token and its type."""
self.tokens.append(token)
self.token_types.append(type(token))

# Test with string content
llm = FakeListChatModel(responses=["hello world"])
callback = TokenCollector()

list(llm.stream("test", config={"callbacks": [callback]}))

# Verify all tokens are strings
assert len(callback.tokens) > 0
assert all(isinstance(t, str) for t in callback.tokens)
assert all(t is str for t in callback.token_types)


def test_on_llm_new_token_with_list_content() -> None:
"""Test that on_llm_new_token receives JSON string when content is list."""

class TokenCollector(BaseCallbackHandler):
"""Callback handler that collects tokens."""

def __init__(self) -> None:
self.tokens: list[str] = []
self.token_types: list[type] = []

def on_llm_new_token(
self,
token: str,
*,
_chunk: ChatGenerationChunk | None = None,
**_kwargs: Any,
) -> None:
"""Store token and its type."""
self.tokens.append(token)
self.token_types.append(type(token))

class FakeChatModelWithListContent(BaseChatModel):
"""Fake chat model that returns structured content as list."""

@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate with list content."""
content: list[dict[str, Any]] = [
{"type": "text", "text": "Hello"},
{"type": "tool_use", "id": "1", "name": "test_tool"},
]
message = AIMessage(content=cast("Any", content))
return ChatResult(generations=[ChatGeneration(message=message)])

@override
def _stream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream chunks with list content."""
# First chunk with partial content
yield ChatGenerationChunk(
message=AIMessageChunk(content=[{"type": "text", "text": "Hello"}])
)
# Second chunk with more content
yield ChatGenerationChunk(
message=AIMessageChunk(
content=[
{"type": "text", "text": "Hello"},
{"type": "tool_use", "id": "1", "name": "test_tool"},
]
)
)

@property
def _llm_type(self) -> str:
return "fake-chat-model-with-list-content"

llm = FakeChatModelWithListContent()
callback = TokenCollector()

list(llm.stream("test", config={"callbacks": [callback]}))

# Verify all tokens are strings (JSON-dumped)
assert len(callback.tokens) > 0
assert all(isinstance(t, str) for t in callback.tokens)
assert all(t is str for t in callback.token_types)

# Verify tokens are valid JSON strings
for token in callback.tokens:
if token: # Skip empty strings (like final chunk markers)
try:
parsed = json.loads(token)
assert isinstance(parsed, list)
except json.JSONDecodeError:
# This is ok if it's an empty string for the final chunk
assert token == ""


async def test_on_llm_new_token_with_list_content_async() -> None:
"""Test that on_llm_new_token receives JSON string when content is list (async)."""

class AsyncTokenCollector(AsyncCallbackHandler):
"""Async callback handler that collects tokens."""

def __init__(self) -> None:
self.tokens: list[str] = []
self.token_types: list[type] = []

async def on_llm_new_token(
self,
token: str,
*,
_chunk: ChatGenerationChunk | None = None,
**_kwargs: Any,
) -> None:
"""Store token and its type."""
self.tokens.append(token)
self.token_types.append(type(token))

class FakeChatModelWithListContent(BaseChatModel):
"""Fake chat model that returns structured content as list."""

@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate with list content."""
content: list[dict[str, Any]] = [
{"type": "text", "text": "Hello"},
{"type": "tool_use", "id": "1", "name": "test_tool"},
]
message = AIMessage(content=cast("Any", content))
return ChatResult(generations=[ChatGeneration(message=message)])

@override
async def _astream(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream chunks with list content."""
# First chunk with partial content
yield ChatGenerationChunk(
message=AIMessageChunk(content=[{"type": "text", "text": "Hello"}])
)
# Second chunk with more content
yield ChatGenerationChunk(
message=AIMessageChunk(
content=[
{"type": "text", "text": "Hello"},
{"type": "tool_use", "id": "1", "name": "test_tool"},
]
)
)

@property
def _llm_type(self) -> str:
return "fake-chat-model-with-list-content"

llm = FakeChatModelWithListContent()
callback = AsyncTokenCollector()

async for _ in llm.astream("test", config={"callbacks": [callback]}):
pass

# Verify all tokens are strings (JSON-dumped)
assert len(callback.tokens) > 0
assert all(isinstance(t, str) for t in callback.tokens)
assert all(t is str for t in callback.token_types)

# Verify tokens are valid JSON strings
for token in callback.tokens:
if token: # Skip empty strings (like final chunk markers)
try:
parsed = json.loads(token)
assert isinstance(parsed, list)
except json.JSONDecodeError:
# This is ok if it's an empty string for the final chunk
assert token == ""


async def test_astream_fallback_to_ainvoke() -> None:
"""Test `astream()` uses appropriate implementation."""

Expand Down
Loading