Skip to content

Commit b4bdc54

Browse files
committed
feat(pydantic): Wrap Agent.to_cli_sync
1 parent 9548ae3 commit b4bdc54

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

py/src/braintrust/wrappers/pydantic_ai.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,26 @@ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
127127

128128
wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper)
129129

130+
def agent_to_cli_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
131+
_ensure_model_wrapped(instance)
132+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
133+
134+
with start_span(
135+
name=f"agent_to_cli_sync [{instance.name}]"
136+
if hasattr(instance, "name") and instance.name
137+
else "agent_to_cli_sync",
138+
type=SpanTypeAttribute.LLM,
139+
input=input_data if input_data else None,
140+
metadata=metadata,
141+
) as agent_span:
142+
start_time = time.time()
143+
result = wrapped(*args, **kwargs)
144+
end_time = time.time()
145+
agent_span.log(metrics={"start": start_time, "end": end_time, "duration": end_time - start_time})
146+
return result
147+
148+
wrap_function_wrapper(Agent, "to_cli_sync", agent_to_cli_sync_wrapper)
149+
130150
def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
131151
_ensure_model_wrapped(instance)
132152
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)

py/src/braintrust/wrappers/test_pydantic_ai_integration.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pydantic import BaseModel
1414
from pydantic_ai import Agent, ModelSettings
1515
from pydantic_ai.messages import ModelRequest, UserPromptPart
16+
from pydantic_ai.usage import UsageLimits
1617

1718
PROJECT_NAME = "test-pydantic-ai-integration"
1819
MODEL = "openai:gpt-4o-mini" # Use cheaper model for tests
@@ -168,6 +169,60 @@ def is_descendant(child_span, ancestor_id):
168169
assert "completion_tokens" in agent_sync_span["metrics"]
169170

170171

172+
def test_agent_to_cli_sync(memory_logger, monkeypatch):
173+
"""Test Agent.to_cli_sync() records a CLI session span."""
174+
assert not memory_logger.pop()
175+
176+
message_history = [ModelRequest(parts=[UserPromptPart(content="Previous question")])]
177+
usage_limits = UsageLimits(request_limit=3)
178+
agent = Agent(MODEL, name="cli-agent", model_settings=ModelSettings(max_tokens=50))
179+
180+
async def fake_run_chat(
181+
*,
182+
stream,
183+
agent,
184+
deps,
185+
console,
186+
code_theme,
187+
prog_name,
188+
message_history,
189+
model_settings,
190+
usage_limits,
191+
):
192+
assert stream is True
193+
assert prog_name == "braintrust-cli"
194+
assert message_history is not None
195+
assert model_settings is not None
196+
assert usage_limits is not None
197+
return 0
198+
199+
monkeypatch.setattr("pydantic_ai._cli.run_chat", fake_run_chat)
200+
201+
start = time.time()
202+
agent.to_cli_sync(
203+
prog_name="braintrust-cli",
204+
message_history=message_history,
205+
model_settings=ModelSettings(max_tokens=20, temperature=0.2),
206+
usage_limits=usage_limits,
207+
)
208+
end = time.time()
209+
210+
spans = memory_logger.pop()
211+
assert len(spans) == 1, f"Expected 1 CLI span, got {len(spans)}"
212+
213+
cli_span = spans[0]
214+
assert cli_span["span_attributes"]["type"] == SpanTypeAttribute.LLM
215+
assert cli_span["span_attributes"]["name"] == "agent_to_cli_sync [cli-agent]"
216+
assert cli_span["metadata"]["model"] == "gpt-4o-mini"
217+
assert cli_span["metadata"]["provider"] == "openai"
218+
assert cli_span["input"]["prog_name"] == "braintrust-cli"
219+
assert "message_history" in cli_span["input"]
220+
assert cli_span["input"]["model_settings"]["max_tokens"] == 20
221+
assert cli_span["input"]["model_settings"]["temperature"] == 0.2
222+
assert cli_span["input"]["usage_limits"]["request_limit"] == 3
223+
_assert_metrics_are_valid(cli_span["metrics"], start, end)
224+
225+
171226
@pytest.mark.vcr
172227
@pytest.mark.asyncio
173228
async def test_multiple_identical_sequential_streams(memory_logger):

0 commit comments

Comments
 (0)