Skip to content

Commit 44cb1c3

Browse files
authored
fix: Orient dataframe JSON row-wise & support more data-frame types (#183)
* Close #180: Improved ContentToolResult() handling for data frames * Update test; tweak changelog
1 parent 27c9609 commit 44cb1c3

File tree

5 files changed

+133
-4
lines changed

5 files changed

+133
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
## [UNRELEASED]
1111

12+
### Improvements
13+
14+
* `ContentToolResult`'s `.get_model_value()` method now calls `.to_json(orient="record")` (instead of `.to_json()`) when relevant. As a result, if a tool call returns a Pandas `DataFrame` (or similar), the model now receives a less confusing (and smaller) JSON format. (#183)
15+
1216
### Bug fixes
1317

1418
* `ChatAzureOpenAI()` and `ChatDatabricks()` now work as expected when a `OPENAI_API_KEY` environment variable isn't present. (#185)

chatlas/_content.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import inspect
4+
import warnings
35
from pprint import pformat
4-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
6+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
57

68
import orjson
79
from pydantic import BaseModel, ConfigDict
@@ -465,8 +467,36 @@ def get_model_value(self) -> object:
465467

466468
@staticmethod
467469
def _to_json(value: Any) -> object:
470+
if hasattr(value, "to_pandas") and callable(value.to_pandas):
471+
# Many (most?) df libs (polars, pyarrow, ...) have a .to_pandas()
472+
# method, and pandas has a .to_json() method
473+
value = value.to_pandas()
474+
468475
if hasattr(value, "to_json") and callable(value.to_json):
469-
return value.to_json()
476+
# pandas defaults to "columns", which is not ideal for LLMs
477+
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html
478+
sig = inspect.signature(value.to_json)
479+
if "orient" in list(sig.parameters.keys()):
480+
return value.to_json(orient="records")
481+
else:
482+
return value.to_json()
483+
484+
# Support for df libs (beyond those with a .to_pandas() method)
485+
if hasattr(value, "__narwhals_dataframe__"):
486+
try:
487+
import narwhals
488+
489+
val = cast(narwhals.DataFrame, narwhals.from_native(value))
490+
return val.to_pandas().to_json(orient="records")
491+
except ImportError:
492+
warnings.warn(
493+
f"Tool result object of type {type(value)} appears to be a "
494+
"narwhals-compatible DataFrame. If you run into issues with "
495+
"the LLM not understanding this value, try installing narwhals: "
496+
"`pip install narwhals`.",
497+
ImportWarning,
498+
stacklevel=2,
499+
)
470500

471501
if hasattr(value, "to_dict") and callable(value.to_dict):
472502
value = value.to_dict()

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ dev = [
4848
"matplotlib",
4949
"Pillow",
5050
"shiny",
51+
"htmltools",
5152
"shinychat",
53+
"narwhals",
54+
"pandas",
55+
"polars",
5256
"openai",
5357
"anthropic[bedrock]",
5458
"google-genai>=1.14.0",
@@ -58,7 +62,6 @@ dev = [
5862
"snowflake-ml-python>=1.8.4",
5963
# torch (a dependency of snowflake-ml-python) is not yet compatible with Python >3.11
6064
"torch;python_version<='3.11'",
61-
"htmltools",
6265
"tenacity"
6366
]
6467
docs = [

tests/test_content_tools.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Optional, Union
2+
from unittest.mock import Mock
23

4+
import orjson
35
import pytest
46

57
from chatlas import ChatOpenAI
@@ -403,3 +405,93 @@ def add(x: int, y: int) -> int:
403405
assert parsed.tool is not None
404406
assert parsed.tool.name == "add"
405407
assert parsed.tool.description == "Add two numbers"
408+
409+
410+
def test_content_tool_result_pandas_dataframe():
411+
"""Test ContentToolResult with pandas DataFrame using orient='records'"""
412+
pandas = pytest.importorskip("pandas")
413+
414+
# Create a simple pandas DataFrame
415+
df = pandas.DataFrame(
416+
{"name": ["Alice", "Bob"], "age": [25, 30], "city": ["New York", "London"]}
417+
)
418+
419+
# Create ContentToolResult with DataFrame value
420+
result = ContentToolResult(value=df).get_model_value()
421+
expected = df.to_json(orient="records")
422+
assert result == expected
423+
424+
parsed = orjson.loads(str(result))
425+
assert isinstance(parsed, list)
426+
assert len(parsed) == 2
427+
assert parsed[0] == {"name": "Alice", "age": 25, "city": "New York"}
428+
assert parsed[1] == {"name": "Bob", "age": 30, "city": "London"}
429+
430+
431+
def test_content_tool_result_object_with_to_pandas():
432+
"""Test ContentToolResult with objects that have .to_pandas() method"""
433+
pandas = pytest.importorskip("pandas")
434+
435+
# Create mock object with to_pandas method (like Polars, PyArrow)
436+
mock_df_lib = Mock()
437+
pandas_df = pandas.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]})
438+
mock_df_lib.to_pandas.return_value = pandas_df
439+
440+
result = ContentToolResult(value=mock_df_lib).get_model_value()
441+
mock_df_lib.to_pandas.assert_called_once()
442+
expected = pandas_df.to_json(orient="records")
443+
assert result == expected
444+
445+
446+
def test_content_tool_result_narwhals_dataframe():
447+
"""Test ContentToolResult with narwhals DataFrame"""
448+
narwhals = pytest.importorskip("narwhals")
449+
pandas = pytest.importorskip("pandas")
450+
451+
pandas_df = pandas.DataFrame({"a": [1, 2], "b": ["x", "y"]})
452+
nw_df = narwhals.from_native(pandas_df)
453+
result = ContentToolResult(value=nw_df).get_model_value()
454+
expected = pandas_df.to_json(orient="records")
455+
assert result == expected
456+
457+
458+
def test_content_tool_result_object_with_to_dict():
459+
"""Test ContentToolResult with objects that have to_dict method"""
460+
# Mock object with to_dict method but no to_pandas or to_json
461+
mock_obj = Mock(spec=["to_dict"])
462+
mock_obj.to_dict.return_value = {"key": "value"}
463+
result = ContentToolResult(value=mock_obj).get_model_value()
464+
mock_obj.to_dict.assert_called_once()
465+
# Result should be JSON string representation (orjson format)
466+
assert result == '{"key":"value"}'
467+
468+
469+
def test_content_tool_result_string_passthrough():
470+
"""Test ContentToolResult with string values (special case - passed through as-is)"""
471+
result = ContentToolResult(value="plain string").get_model_value()
472+
assert result == "plain string"
473+
474+
475+
def test_content_tool_result_fallback_serialization():
476+
"""Test ContentToolResult fallback for objects without special methods"""
477+
# Regular object without to_json, to_pandas, or to_dict (non-string to avoid the string special case)
478+
result = ContentToolResult(value={"key": "value"}).get_model_value()
479+
assert result == '{"key":"value"}'
480+
481+
482+
def test_content_tool_result_explicit_json_mode():
483+
"""Test ContentToolResult with explicit JSON mode forces _to_json for non-strings"""
484+
# Test with non-string object and explicit JSON mode
485+
result = ContentToolResult(
486+
value={"key": "value"},
487+
model_format="json",
488+
).get_model_value()
489+
# With explicit JSON mode, objects get JSON-encoded
490+
assert result == '{"key":"value"}'
491+
# Test that strings still get special treatment even in JSON mode
492+
string_result = ContentToolResult(
493+
value="plain string",
494+
model_format="json",
495+
).get_model_value()
496+
# Strings are still returned as-is even in JSON mode (current behavior)
497+
assert string_result == "plain string"

tests/test_tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_token_count_method():
7070
chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
7171
assert chat.token_count("What is 1 + 1?") == 16
7272

73-
chat = ChatGoogle(model="gemini-1.5-flash")
73+
chat = ChatGoogle(model="gemini-2.5-flash")
7474
assert chat.token_count("What is 1 + 1?") == 9
7575

7676

0 commit comments

Comments
 (0)