Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 109 additions & 10 deletions src/cellsem_llm_client/agents/agent_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ def query_unified(

raw_response: Any | None = None
response_content: str | None = None
all_tool_responses: list[Any] | None = None

if tools:
response_content, raw_response = self._run_tool_loop(
response_content, raw_response, all_tool_responses = self._run_tool_loop(
messages=messages,
tools=tools,
tool_handlers=tool_handlers or {},
Expand Down Expand Up @@ -237,11 +238,20 @@ def query_unified(

usage_metrics: UsageMetrics | None = None
if track_usage and raw_response is not None and hasattr(raw_response, "usage"):
usage_metrics = self._build_usage_metrics(
raw_response=raw_response,
provider=provider,
cost_calculator=cost_calculator,
)
# When tools were used, accumulate usage from all API calls
if all_tool_responses is not None:
usage_metrics = self._accumulate_usage_metrics(
responses=all_tool_responses,
provider=provider,
cost_calculator=cost_calculator,
)
else:
# Single API call without tools
usage_metrics = self._build_usage_metrics(
raw_response=raw_response,
provider=provider,
cost_calculator=cost_calculator,
)

return QueryResult(
text=response_content,
Expand Down Expand Up @@ -409,7 +419,9 @@ def query_with_schema_and_tracking(
max_retries=max_retries,
)
if result.model is None or result.usage is None:
raise RuntimeError("Expected model instance and usage metrics but one or both were not populated")
raise RuntimeError(
"Expected model instance and usage metrics but one or both were not populated"
)
return result.model, result.usage

def _pydantic_model_to_schema(self, model_class: type[BaseModel]) -> dict[str, Any]:
Expand Down Expand Up @@ -535,9 +547,15 @@ def _run_tool_loop(
tools: list[dict[str, Any]],
tool_handlers: dict[str, Callable[[dict[str, Any]], str | None]],
max_turns: int,
) -> tuple[str, Any]:
"""Execute tool calls until completion."""
) -> tuple[str, Any, list[Any]]:
"""Execute tool calls until completion.

Returns:
A tuple of (final_content, final_response, all_responses) where
all_responses contains every API response from all turns for usage tracking.
"""
working_messages = list(messages)
all_responses: list[Any] = []

for _turn in range(max_turns):
response = completion(
Expand All @@ -546,6 +564,7 @@ def _run_tool_loop(
tools=tools,
max_tokens=self.max_tokens,
)
all_responses.append(response)

response_message = response.choices[0].message
tool_calls = getattr(response_message, "tool_calls", None)
Expand Down Expand Up @@ -598,7 +617,7 @@ def _run_tool_loop(
)
continue

return str(response_message.content), response
return str(response_message.content), response, all_responses

raise RuntimeError("Max tool-call turns reached without a final response.")

Expand Down Expand Up @@ -654,6 +673,86 @@ def _build_usage_metrics(
cost_source="estimated",
)

def _accumulate_usage_metrics(
self,
responses: list[Any],
provider: str,
cost_calculator: Optional["FallbackCostCalculator"] = None,
) -> UsageMetrics:
"""Accumulate usage metrics from multiple API responses.

When tools are used, multiple API calls are made across iterations.
This method sums up the token usage from all calls to provide accurate
cumulative metrics.

Args:
responses: List of LiteLLM response objects from all API calls
provider: The LLM provider name
cost_calculator: Optional cost calculator for estimating costs

Returns:
Accumulated UsageMetrics with total token counts and costs
"""
total_input_tokens = 0
total_output_tokens = 0
total_cached_tokens = 0
total_thinking_tokens = 0
has_cached = False
has_thinking = False

for response in responses:
if not hasattr(response, "usage"):
continue

usage = response.usage
total_input_tokens += usage.prompt_tokens
total_output_tokens += usage.completion_tokens

# Accumulate cached tokens if present
if (
hasattr(usage, "prompt_tokens_details")
and usage.prompt_tokens_details
and hasattr(usage.prompt_tokens_details, "cached_tokens")
and usage.prompt_tokens_details.cached_tokens is not None
):
total_cached_tokens += usage.prompt_tokens_details.cached_tokens
has_cached = True

cached_tokens = total_cached_tokens if has_cached else None
thinking_tokens = total_thinking_tokens if has_thinking else None

# Calculate cost based on accumulated tokens
estimated_cost_usd = None
if cost_calculator:
try:
temp_usage_metrics = UsageMetrics(
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
cached_tokens=cached_tokens,
thinking_tokens=thinking_tokens,
provider=provider,
model=self.model,
timestamp=datetime.now(),
cost_source="estimated",
)
estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics)
except Exception as e:
logging.warning(
f"Cost calculation failed for {provider}/{self.model}: {e}"
)

return UsageMetrics(
input_tokens=total_input_tokens,
output_tokens=total_output_tokens,
cached_tokens=cached_tokens,
thinking_tokens=thinking_tokens,
estimated_cost_usd=estimated_cost_usd,
provider=provider,
model=self.model,
timestamp=datetime.now(),
cost_source="estimated",
)

def _get_provider_from_model(self, model: str) -> str:
"""Determine provider from model name.

Expand Down
72 changes: 72 additions & 0 deletions tests/unit/test_agent_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,78 @@ def test_litellm_agent_query_with_tools_missing_handler_raises(
tool_handlers={},
)

@pytest.mark.unit
@patch("cellsem_llm_client.agents.agent_connection.completion")
def test_query_unified_with_tools_accumulates_usage(
self, mock_completion: Any
) -> None:
"""Test that usage metrics accumulate across multiple tool call iterations."""
tools = [
{
"type": "function",
"function": {
"name": "get_data",
"description": "Get data",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
}
]

# First API call - tool call requested
tool_call = Mock()
tool_call.id = "call_1"
tool_call.type = "function"
tool_call.function = Mock()
tool_call.function.name = "get_data"
tool_call.function.arguments = '{"query": "test"}'

first_response = Mock()
first_response.choices = [Mock()]
first_response.choices[0].message.content = None
first_response.choices[0].message.tool_calls = [tool_call]
first_response.usage = Mock()
first_response.usage.prompt_tokens = 100
first_response.usage.completion_tokens = 20
first_response.usage.prompt_tokens_details = None

# Second API call - final response
final_response = Mock()
final_response.choices = [Mock()]
final_response.choices[0].message.content = "Here is the result."
final_response.choices[0].message.tool_calls = []
final_response.usage = Mock()
final_response.usage.prompt_tokens = 150
final_response.usage.completion_tokens = 30
final_response.usage.prompt_tokens_details = None

mock_completion.side_effect = [first_response, final_response]

def get_data(args: dict[str, Any]) -> str:
return "data result"

agent = LiteLLMAgent(model="gpt-4", api_key="test-key")
result = agent.query_unified(
message="Get me data",
tools=tools,
tool_handlers={"get_data": get_data},
track_usage=True,
)

# Verify response
assert result.text == "Here is the result."
assert result.usage is not None

# Verify cumulative usage: 100 + 150 = 250 input, 20 + 30 = 50 output
assert result.usage.input_tokens == 250
assert result.usage.output_tokens == 50
assert result.usage.total_tokens == 300
assert result.usage.provider == "openai"
assert result.usage.model == "gpt-4"

@pytest.mark.unit
@patch("cellsem_llm_client.agents.agent_connection.completion")
def test_query_unified_basic(self, mock_completion: Any) -> None:
Expand Down