diff --git a/src/strands_evals/extractors/trace_extractor.py b/src/strands_evals/extractors/trace_extractor.py index f86780e..e55ef85 100644 --- a/src/strands_evals/extractors/trace_extractor.py +++ b/src/strands_evals/extractors/trace_extractor.py @@ -45,9 +45,11 @@ def extract(self, session: Session) -> Union[list[TraceLevelInput], list[ToolLev def _extract_trace_level(self, session: Session) -> list[TraceLevelInput]: """Extract trace-level inputs with session history up to each turn.""" evaluation_inputs: list[TraceLevelInput] = [] - previous_turns: list[Union[UserMessage, AssistantMessage]] = [] + previous_turns: list[Union[UserMessage, list[ToolExecution], AssistantMessage]] = [] for trace in session.traces: + tool_spans = self._find_tool_execution_spans(trace) + for span in trace.spans: if not isinstance(span, AgentInvocationSpan): continue @@ -59,6 +61,17 @@ def _extract_trace_level(self, session: Session) -> list[TraceLevelInput]: logger.warning(f"Failed to create user message: {e}") continue + # Include tool executions in session history + if tool_spans: + try: + tool_executions = [ + ToolExecution(tool_call=ts.tool_call, tool_result=ts.tool_result) + for ts in tool_spans + ] + previous_turns.append(tool_executions) + except (AttributeError, TypeError, ValueError) as e: + logger.warning(f"Failed to create tool executions: {e}") + trace_input = TraceLevelInput( span_info=span.span_info, agent_response=TextContent(text=span.agent_response),