Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions py/src/braintrust/wrappers/agno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
response = agent.run(...)
"""

__all__ = ["setup_agno", "wrap_agent", "wrap_function_call", "wrap_model", "wrap_team"]
__all__ = ["setup_agno", "wrap_agent", "wrap_function_call", "wrap_model", "wrap_team", "wrap_workflow"]

import logging

Expand All @@ -28,6 +28,7 @@
from .function_call import wrap_function_call
from .model import wrap_model
from .team import wrap_team
from .workflow import wrap_workflow

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,7 +62,16 @@ def setup_agno(
team.Team = wrap_team(team.Team) # pyright: ignore[reportUnknownMemberType]
models.base.Model = wrap_model(models.base.Model) # pyright: ignore[reportUnknownMemberType]
tools.function.FunctionCall = wrap_function_call(tools.function.FunctionCall) # pyright: ignore[reportUnknownMemberType]
return True
except ImportError:
# Not installed - this is expected when using auto_instrument()
return False

try:
from agno import workflow # pyright: ignore

workflow.Workflow = wrap_workflow(workflow.Workflow) # pyright: ignore[reportUnknownMemberType]
except ImportError:
# agno.workflow requires fastapi which may not be installed
pass

return True
5 changes: 5 additions & 0 deletions py/src/braintrust/wrappers/agno/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def extract_metadata(instance: Any, component: str) -> dict[str, Any]:
model = getattr(instance, "model", None)
if model:
metadata["model"] = getattr(model, "id", None) or model.__class__.__name__
elif component == "workflow":
metadata["workflow_name"] = getattr(instance, "name", None)
steps = getattr(instance, "steps", None)
if steps:
metadata["steps_count"] = len(steps)

return metadata

Expand Down
204 changes: 204 additions & 0 deletions py/src/braintrust/wrappers/agno/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import time
from typing import Any

from braintrust.logger import start_span
from braintrust.span_types import SpanTypeAttribute
from wrapt import wrap_function_wrapper

from .utils import (
_aggregate_agent_chunks,
_try_to_dict,
extract_metadata,
extract_metrics,
extract_streaming_metrics,
is_patched,
mark_patched,
)


def _extract_workflow_input(args: Any, kwargs: Any) -> dict[str, Any]:
"""Extract the input from _execute parameters.

_execute signature: (self, session, execution_input, workflow_run_response, run_context, ...)
- args[0]: session (WorkflowSession)
- args[1]: execution_input (WorkflowExecutionInput) - contains .input
- args[2]: workflow_run_response (WorkflowRunOutput) - contains .input, accumulates results
"""
execution_input = args[1] if len(args) > 1 else kwargs.get("execution_input")
workflow_run_response = args[2] if len(args) > 2 else kwargs.get("workflow_run_response")

result: dict[str, Any] = {}

if execution_input:
if hasattr(execution_input, "input"):
result["input"] = execution_input.input
result["execution_input"] = _try_to_dict(execution_input)

if workflow_run_response:
result["run_response"] = _try_to_dict(workflow_run_response)

return result


def wrap_workflow(Workflow: Any) -> Any:
if is_patched(Workflow):
return Workflow

def execute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
workflow_name = getattr(instance, "name", None) or "Workflow"
span_name = f"{workflow_name}.run"

input_data = _extract_workflow_input(args, kwargs)

with start_span(
name=span_name,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=extract_metadata(instance, "workflow"),
) as span:
result = wrapped(*args, **kwargs)
span.log(
output=result,
metrics=extract_metrics(result),
)
return result

if hasattr(Workflow, "_execute"):
wrap_function_wrapper(Workflow, "_execute", execute_wrapper)

def execute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
workflow_name = getattr(instance, "name", None) or "Workflow"
span_name = f"{workflow_name}.run_stream"

input_data = _extract_workflow_input(args, kwargs)

def _trace_stream():
start = time.time()
span = start_span(
name=span_name,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=extract_metadata(instance, "workflow"),
)
span.set_current()

should_unset = True
try:
first = True
all_chunks = []

for chunk in wrapped(*args, **kwargs):
if first:
span.log(
metrics={
"time_to_first_token": time.time() - start,
}
)
first = False
all_chunks.append(chunk)
yield chunk

aggregated = _aggregate_agent_chunks(all_chunks)

span.log(
output=aggregated,
metrics=extract_streaming_metrics(aggregated, start),
)
except GeneratorExit:
should_unset = False
raise
except Exception as e:
span.log(
error=str(e),
)
raise
finally:
if should_unset:
span.unset_current()
span.end()

return _trace_stream()

if hasattr(Workflow, "_execute_stream"):
wrap_function_wrapper(Workflow, "_execute_stream", execute_stream_wrapper)

async def aexecute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
workflow_name = getattr(instance, "name", None) or "Workflow"
span_name = f"{workflow_name}.arun"

input_data = _extract_workflow_input(args, kwargs)

with start_span(
name=span_name,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=extract_metadata(instance, "workflow"),
) as span:
result = await wrapped(*args, **kwargs)
span.log(
output=result,
metrics=extract_metrics(result),
)
return result

if hasattr(Workflow, "_aexecute"):
wrap_function_wrapper(Workflow, "_aexecute", aexecute_wrapper)

def aexecute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
workflow_name = getattr(instance, "name", None) or "Workflow"
span_name = f"{workflow_name}.arun_stream"

input_data = _extract_workflow_input(args, kwargs)

async def _trace_stream():
start = time.time()
span = start_span(
name=span_name,
type=SpanTypeAttribute.TASK,
input=input_data,
metadata=extract_metadata(instance, "workflow"),
)
span.set_current()

should_unset = True
try:
first = True
all_chunks = []

async for chunk in wrapped(*args, **kwargs):
if first:
span.log(
metrics={
"time_to_first_token": time.time() - start,
}
)
first = False
all_chunks.append(chunk)
yield chunk

aggregated = _aggregate_agent_chunks(all_chunks)

span.log(
output=aggregated,
metrics=extract_streaming_metrics(aggregated, start),
)
except GeneratorExit:
should_unset = False
raise
except Exception as e:
span.log(
error=str(e),
)
raise
finally:
if should_unset:
span.unset_current()
span.end()

return _trace_stream()

if hasattr(Workflow, "_aexecute_stream"):
wrap_function_wrapper(Workflow, "_aexecute_stream", aexecute_stream_wrapper)

mark_patched(Workflow)
return Workflow
57 changes: 57 additions & 0 deletions py/src/braintrust/wrappers/test_agno.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,63 @@ def test_agno_simple_agent_execution(memory_logger):
assert llm_span["metrics"]["tokens"] == 42


@pytest.mark.vcr
def test_agno_workflow_with_agent(memory_logger):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this should generate a new vcr cassette. I seem to have messed something up with my CI changes. I will push something up to this PR to fix that!

"""Test that workflows create a parent span and agents nest under it."""
Agent = pytest.importorskip("agno.agent.Agent")
Workflow = pytest.importorskip("agno.workflow.Workflow")
OpenAIChat = pytest.importorskip("agno.models.openai.OpenAIChat")

setup_agno(project_name=PROJECT_NAME)

assert not memory_logger.pop()

author_agent = Agent(
name="Author Agent",
model=OpenAIChat(id="gpt-4o-mini"),
instructions="You are librarian. Answer the questions by only replying with the author that wrote the book.",
)

workflow = Workflow(
name="Book Lookup Workflow",
steps=[author_agent],
)

response = workflow.run("Charlotte's Web")

assert response
assert response.content
assert len(response.content) > 0

spans = memory_logger.pop()
assert len(spans) >= 3, f"Expected at least 3 spans (workflow + agent + llm), got {len(spans)}"

workflow_span = spans[0]
assert workflow_span["span_attributes"]["name"] == "Book Lookup Workflow.run"
assert workflow_span["span_attributes"]["type"].value == "task"
assert workflow_span["metadata"]["component"] == "workflow"
assert workflow_span["metadata"]["workflow_name"] == "Book Lookup Workflow"
assert workflow_span["metadata"]["steps_count"] == 1

agent_span = None
for span in spans[1:]:
if "Agent" in span["span_attributes"]["name"] and ".run" in span["span_attributes"]["name"]:
agent_span = span
break

assert agent_span is not None, "Could not find agent span"
assert agent_span["span_parents"] == [workflow_span["span_id"]], "Agent span should be child of workflow span"

llm_span = None
for span in spans:
if span["span_attributes"]["type"].value == "llm":
llm_span = span
break

assert llm_span is not None, "Could not find LLM span"
assert llm_span["span_parents"] == [agent_span["span_id"]], "LLM span should be child of agent span"


class TestAutoInstrumentAgno:
"""Tests for auto_instrument() with Agno."""

Expand Down
Loading