diff --git a/src/cellsem_llm_client/agents/agent_connection.py b/src/cellsem_llm_client/agents/agent_connection.py index 3c33a9f..ae4bbbf 100644 --- a/src/cellsem_llm_client/agents/agent_connection.py +++ b/src/cellsem_llm_client/agents/agent_connection.py @@ -176,9 +176,10 @@ def query_unified( raw_response: Any | None = None response_content: str | None = None + all_tool_responses: list[Any] | None = None if tools: - response_content, raw_response = self._run_tool_loop( + response_content, raw_response, all_tool_responses = self._run_tool_loop( messages=messages, tools=tools, tool_handlers=tool_handlers or {}, @@ -237,11 +238,20 @@ def query_unified( usage_metrics: UsageMetrics | None = None if track_usage and raw_response is not None and hasattr(raw_response, "usage"): - usage_metrics = self._build_usage_metrics( - raw_response=raw_response, - provider=provider, - cost_calculator=cost_calculator, - ) + # When tools were used, accumulate usage from all API calls + if all_tool_responses is not None: + usage_metrics = self._accumulate_usage_metrics( + responses=all_tool_responses, + provider=provider, + cost_calculator=cost_calculator, + ) + else: + # Single API call without tools + usage_metrics = self._build_usage_metrics( + raw_response=raw_response, + provider=provider, + cost_calculator=cost_calculator, + ) return QueryResult( text=response_content, @@ -409,7 +419,9 @@ def query_with_schema_and_tracking( max_retries=max_retries, ) if result.model is None or result.usage is None: - raise RuntimeError("Expected model instance and usage metrics but one or both were not populated") + raise RuntimeError( + "Expected model instance and usage metrics but one or both were not populated" + ) return result.model, result.usage def _pydantic_model_to_schema(self, model_class: type[BaseModel]) -> dict[str, Any]: @@ -535,9 +547,15 @@ def _run_tool_loop( tools: list[dict[str, Any]], tool_handlers: dict[str, Callable[[dict[str, Any]], str | None]], max_turns: int, - ) -> tuple[str, Any]: - """Execute tool calls until completion.""" + ) -> tuple[str, Any, list[Any]]: + """Execute tool calls until completion. + + Returns: + A tuple of (final_content, final_response, all_responses) where + all_responses contains every API response from all turns for usage tracking. + """ working_messages = list(messages) + all_responses: list[Any] = [] for _turn in range(max_turns): response = completion( @@ -546,6 +564,7 @@ def _run_tool_loop( tools=tools, max_tokens=self.max_tokens, ) + all_responses.append(response) response_message = response.choices[0].message tool_calls = getattr(response_message, "tool_calls", None) @@ -598,7 +617,7 @@ def _run_tool_loop( ) continue - return str(response_message.content), response + return str(response_message.content), response, all_responses raise RuntimeError("Max tool-call turns reached without a final response.") @@ -654,6 +673,86 @@ def _build_usage_metrics( cost_source="estimated", ) + def _accumulate_usage_metrics( + self, + responses: list[Any], + provider: str, + cost_calculator: Optional["FallbackCostCalculator"] = None, + ) -> UsageMetrics: + """Accumulate usage metrics from multiple API responses. + + When tools are used, multiple API calls are made across iterations. + This method sums up the token usage from all calls to provide accurate + cumulative metrics. + + Args: + responses: List of LiteLLM response objects from all API calls + provider: The LLM provider name + cost_calculator: Optional cost calculator for estimating costs + + Returns: + Accumulated UsageMetrics with total token counts and costs + """ + total_input_tokens = 0 + total_output_tokens = 0 + total_cached_tokens = 0 + total_thinking_tokens = 0 + has_cached = False + has_thinking = False + + for response in responses: + if not hasattr(response, "usage"): + continue + + usage = response.usage + total_input_tokens += usage.prompt_tokens + total_output_tokens += usage.completion_tokens + + # Accumulate cached tokens if present + if ( + hasattr(usage, "prompt_tokens_details") + and usage.prompt_tokens_details + and hasattr(usage.prompt_tokens_details, "cached_tokens") + and usage.prompt_tokens_details.cached_tokens is not None + ): + total_cached_tokens += usage.prompt_tokens_details.cached_tokens + has_cached = True + + cached_tokens = total_cached_tokens if has_cached else None + thinking_tokens = total_thinking_tokens if has_thinking else None + + # Calculate cost based on accumulated tokens + estimated_cost_usd = None + if cost_calculator: + try: + temp_usage_metrics = UsageMetrics( + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics) + except Exception as e: + logging.warning( + f"Cost calculation failed for {provider}/{self.model}: {e}" + ) + + return UsageMetrics( + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + estimated_cost_usd=estimated_cost_usd, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + def _get_provider_from_model(self, model: str) -> str: """Determine provider from model name. diff --git a/tests/unit/test_agent_connection.py b/tests/unit/test_agent_connection.py index fe0d2c9..02adcbc 100644 --- a/tests/unit/test_agent_connection.py +++ b/tests/unit/test_agent_connection.py @@ -412,6 +412,78 @@ def test_litellm_agent_query_with_tools_missing_handler_raises( tool_handlers={}, ) + @pytest.mark.unit + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_query_unified_with_tools_accumulates_usage( + self, mock_completion: Any + ) -> None: + """Test that usage metrics accumulate across multiple tool call iterations.""" + tools = [ + { + "type": "function", + "function": { + "name": "get_data", + "description": "Get data", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ] + + # First API call - tool call requested + tool_call = Mock() + tool_call.id = "call_1" + tool_call.type = "function" + tool_call.function = Mock() + tool_call.function.name = "get_data" + tool_call.function.arguments = '{"query": "test"}' + + first_response = Mock() + first_response.choices = [Mock()] + first_response.choices[0].message.content = None + first_response.choices[0].message.tool_calls = [tool_call] + first_response.usage = Mock() + first_response.usage.prompt_tokens = 100 + first_response.usage.completion_tokens = 20 + first_response.usage.prompt_tokens_details = None + + # Second API call - final response + final_response = Mock() + final_response.choices = [Mock()] + final_response.choices[0].message.content = "Here is the result." + final_response.choices[0].message.tool_calls = [] + final_response.usage = Mock() + final_response.usage.prompt_tokens = 150 + final_response.usage.completion_tokens = 30 + final_response.usage.prompt_tokens_details = None + + mock_completion.side_effect = [first_response, final_response] + + def get_data(args: dict[str, Any]) -> str: + return "data result" + + agent = LiteLLMAgent(model="gpt-4", api_key="test-key") + result = agent.query_unified( + message="Get me data", + tools=tools, + tool_handlers={"get_data": get_data}, + track_usage=True, + ) + + # Verify response + assert result.text == "Here is the result." + assert result.usage is not None + + # Verify cumulative usage: 100 + 150 = 250 input, 20 + 30 = 50 output + assert result.usage.input_tokens == 250 + assert result.usage.output_tokens == 50 + assert result.usage.total_tokens == 300 + assert result.usage.provider == "openai" + assert result.usage.model == "gpt-4" + @pytest.mark.unit @patch("cellsem_llm_client.agents.agent_connection.completion") def test_query_unified_basic(self, mock_completion: Any) -> None: