diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 684ec39..868a397 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,7 +40,7 @@ jobs: run: | uv run ruff check src/ tests/ uv run ruff format --check src/ tests/ - uv run mypy src/ + uv run mypy src/ tests/ - name: Upload coverage reports if: matrix.python-version == '3.12' @@ -49,4 +49,3 @@ jobs: file: ./coverage.xml fail_ci_if_error: false verbose: true - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9e94c7b..094f563 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,11 @@ repos: hooks: - id: mypy additional_dependencies: [types-requests] - args: [--ignore-missing-imports] + args: + - --config-file=pyproject.toml + - src + - tests + pass_filenames: false - repo: local hooks: diff --git a/pyproject.toml b/pyproject.toml index 1678f98..0d8872e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,10 @@ description = "LLM client implemented using LiteLLM + Pydantic-AI for seamless m readme = "README.md" license = {text = "MIT"} authors = [ - {name = "Cellular Semantics", email = "info@cellularsemantics.org"} + {name = "Cellular Semantics", email = "cellsemantics@gmail.com"} ] classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", @@ -129,6 +129,7 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = "tests.*" disallow_untyped_decorators = false +disable_error_code = ["attr-defined"] [dependency-groups] dev = [ diff --git a/src/cellsem_llm_client/agents/agent_connection.py b/src/cellsem_llm_client/agents/agent_connection.py index 43abd34..fdc0e64 100644 --- a/src/cellsem_llm_client/agents/agent_connection.py +++ b/src/cellsem_llm_client/agents/agent_connection.py @@ -2,8 +2,10 @@ import json import logging +import warnings from abc import ABC, abstractmethod from collections.abc import Callable +from dataclasses import dataclass from datetime import datetime from typing import TYPE_CHECKING, Any, Optional @@ -18,6 +20,17 @@ ) from cellsem_llm_client.tracking.usage_metrics import UsageMetrics + +@dataclass +class QueryResult: + """Structured result returned by unified query interface.""" + + text: str | None + model: BaseModel | None = None + usage: UsageMetrics | None = None + raw_response: Any | None = None + + if TYPE_CHECKING: from cellsem_llm_client.tracking.cost_calculator import FallbackCostCalculator @@ -135,30 +148,159 @@ def __init__( self._schema_validator = SchemaValidator() self._adapter_factory = SchemaAdapterFactory() - def query(self, message: str, system_message: str | None = None) -> str: - """Send a query to the LLM using LiteLLM. + def query_unified( + self, + message: str, + system_message: str | None = None, + schema: dict[str, Any] | type[BaseModel] | str | None = None, + tools: list[dict[str, Any]] | None = None, + tool_handlers: dict[str, Callable[[dict[str, Any]], str | None]] | None = None, + max_turns: int = 5, + track_usage: bool = False, + cost_calculator: Optional["FallbackCostCalculator"] = None, + max_retries: int = 2, + ) -> QueryResult: + """Unified query interface with optional tools, schema enforcement, and tracking. + + This method consolidates the previous `query*` variants. Use feature flags/args + instead of separate methods: Args: - message: The user message to send - system_message: Optional system message to set context + message: User message. + system_message: Optional system prompt. + schema: JSON Schema dict, Pydantic model class, or schema name for + enforcement + validation. If provided with tools, validation runs + on the final assistant message after tool calls finish. + tools: LiteLLM tool definitions. Enables tool-call loop. + tool_handlers: Mapping of tool names to callables for execution. + max_turns: Max tool-call iterations before giving up. + track_usage: Whether to return usage metrics. + cost_calculator: Optional cost calculator for estimated cost. + max_retries: Validation retry limit when `schema` is provided. Returns: - The LLM's response as a string - """ - messages = [] + QueryResult containing final text, optional validated Pydantic model, + optional usage metrics, and the raw LiteLLM response. + Raises: + SchemaValidationException: If schema validation fails after retries. + ValueError: For missing tool handlers or argument parsing failures. + RuntimeError: If tool loop exceeds `max_turns`. + """ + messages: list[dict[str, Any]] = [] if system_message: messages.append({"role": "system", "content": system_message}) - messages.append({"role": "user", "content": message}) - response = completion( - model=self.model, - messages=messages, - max_tokens=self.max_tokens, + provider = self._get_provider_from_model(self.model) + pydantic_model: type[BaseModel] | None = None + schema_dict: dict[str, Any] | None = None + + if schema is not None: + pydantic_model = self._schema_manager.get_pydantic_model(schema) + schema_dict = self._pydantic_model_to_schema(pydantic_model) + + raw_response: Any | None = None + response_content: str | None = None + all_tool_responses: list[Any] | None = None + + if tools: + response_content, raw_response, all_tool_responses = self._run_tool_loop( + messages=messages, + tools=tools, + tool_handlers=tool_handlers or {}, + max_turns=max_turns, + ) + elif schema_dict is not None: + adapter = self._adapter_factory.get_adapter(provider, self.model) + from cellsem_llm_client.schema.adapters import AnthropicSchemaAdapter + + if isinstance(adapter, AnthropicSchemaAdapter): + raw_response = completion( + model=self.model, + messages=messages, + tools=[adapter._create_tool_definition(schema_dict)], + tool_choice={ + "type": "function", + "function": {"name": "structured_response"}, + }, + max_tokens=self.max_tokens, + ) + extracted = adapter._extract_tool_response(raw_response) + response_content = json.dumps(extracted) + else: + raw_response = adapter.apply_schema( + messages=messages, + schema_dict=schema_dict, + model=self.model, + max_tokens=self.max_tokens, + ) + response_content = str(raw_response.choices[0].message.content) + else: + raw_response = completion( + model=self.model, + messages=messages, + max_tokens=self.max_tokens, + ) + response_content = str(raw_response.choices[0].message.content) + + validated_model: BaseModel | None = None + if pydantic_model is not None and response_content is not None: + validation_result = self._schema_validator.validate_with_retry( + response_text=response_content, + target_model=pydantic_model, + max_retries=max_retries, + ) + if not validation_result.success: + raise SchemaValidationException( + f"Schema validation failed after {max_retries} retries: {validation_result.error}", + schema=str(schema), + response_text=response_content, + validation_errors=[str(validation_result.error)] + if validation_result.error + else [], + ) + validated_model = validation_result.model_instance + + usage_metrics: UsageMetrics | None = None + if track_usage and raw_response is not None and hasattr(raw_response, "usage"): + # When tools were used, accumulate usage from all API calls + if all_tool_responses is not None: + usage_metrics = self._accumulate_usage_metrics( + responses=all_tool_responses, + provider=provider, + cost_calculator=cost_calculator, + ) + else: + # Single API call without tools + usage_metrics = self._build_usage_metrics( + raw_response=raw_response, + provider=provider, + cost_calculator=cost_calculator, + ) + + return QueryResult( + text=response_content, + model=validated_model, + usage=usage_metrics, + raw_response=raw_response, ) - return str(response.choices[0].message.content) + def query(self, message: str, system_message: str | None = None) -> str: + """Send a query to the LLM using LiteLLM. + + Args: + message: The user message to send + system_message: Optional system message to set context + + Returns: + The LLM's response as a string + """ + result = self.query_unified( + message=message, + system_message=system_message, + ) + return result.text or "" def query_with_tools( self, @@ -186,79 +328,19 @@ def query_with_tools( RuntimeError: If the conversation does not terminate within ``max_turns`` iterations. """ - messages: list[dict[str, Any]] = [] - - if system_message: - messages.append({"role": "system", "content": system_message}) - - messages.append({"role": "user", "content": message}) - - for _turn in range(max_turns): - response = completion( - model=self.model, - messages=[*messages], - tools=tools, - max_tokens=self.max_tokens, - ) - - response_message = response.choices[0].message - tool_calls = getattr(response_message, "tool_calls", None) - - if tool_calls: - handler_map: dict[str, Callable[[dict[str, Any]], str | None]] = ( - tool_handlers or {} - ) - - assistant_message: dict[str, Any] = { - "role": "assistant", - "content": response_message.content, - "tool_calls": [], - } - messages.append(assistant_message) - - for tool_call in tool_calls: - function_call = getattr(tool_call, "function", None) - tool_name = getattr(function_call, "name", None) - tool_arguments = getattr(function_call, "arguments", {}) - - assistant_message["tool_calls"].append( - { - "id": getattr(tool_call, "id", ""), - "type": getattr(tool_call, "type", "function"), - "function": { - "name": tool_name, - "arguments": tool_arguments, - }, - } - ) - - if not tool_name or tool_name not in handler_map: - raise ValueError(f"No handler found for tool '{tool_name}'.") - - try: - parsed_args = ( - json.loads(tool_arguments) - if isinstance(tool_arguments, str) - else tool_arguments - ) - except Exception as exc: - raise ValueError( - f"Failed to parse arguments for tool '{tool_name}'." - ) from exc - - tool_result = handler_map[tool_name](parsed_args) - messages.append( - { - "role": "tool", - "tool_call_id": getattr(tool_call, "id", tool_name), - "content": tool_result if tool_result is not None else "", - } - ) - continue - - return str(response_message.content) - - raise RuntimeError("Max tool-call turns reached without a final response.") + warnings.warn( + "query_with_tools is deprecated; use query_unified with tools/tool_handlers.", + PendingDeprecationWarning, + stacklevel=2, + ) + result = self.query_unified( + message=message, + system_message=system_message, + tools=tools, + tool_handlers=tool_handlers, + max_turns=max_turns, + ) + return result.text or "" def query_with_tracking( self, @@ -276,78 +358,20 @@ def query_with_tracking( Returns: Tuple of (response, usage_metrics) """ - messages = [] - - if system_message: - messages.append({"role": "system", "content": system_message}) - - messages.append({"role": "user", "content": message}) - - response = completion( - model=self.model, - messages=messages, - max_tokens=self.max_tokens, + warnings.warn( + "query_with_tracking is deprecated; use query_unified with track_usage=True.", + PendingDeprecationWarning, + stacklevel=2, ) - - # Extract usage information from LiteLLM response - usage = response.usage - input_tokens = usage.prompt_tokens - output_tokens = usage.completion_tokens - - # Extract cached tokens (OpenAI specific) - cached_tokens = None - if ( - hasattr(usage, "prompt_tokens_details") - and usage.prompt_tokens_details - and hasattr(usage.prompt_tokens_details, "cached_tokens") - ): - cached_tokens = usage.prompt_tokens_details.cached_tokens - - # Extract thinking tokens (Anthropic specific) - # Note: LiteLLM may not expose thinking tokens directly yet - thinking_tokens = None - - # Determine provider from model name - provider = self._get_provider_from_model(self.model) - - # Calculate cost if calculator provided - estimated_cost_usd = None - if cost_calculator: - try: - # Create temporary metrics for cost calculation - temp_usage_metrics = UsageMetrics( - input_tokens=input_tokens, - output_tokens=output_tokens, - cached_tokens=cached_tokens, - thinking_tokens=thinking_tokens, - provider=provider, - model=self.model, - timestamp=datetime.now(), - cost_source="estimated", - ) - estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics) - except Exception as e: - # Log cost calculation failure but continue without cost estimation - logging.warning( - f"Cost calculation failed for {provider}/{self.model}: {e}" - ) - estimated_cost_usd = None - - # Create final usage metrics with cost - usage_metrics = UsageMetrics( - input_tokens=input_tokens, - output_tokens=output_tokens, - cached_tokens=cached_tokens, - thinking_tokens=thinking_tokens, - estimated_cost_usd=estimated_cost_usd, - provider=provider, - model=self.model, - timestamp=datetime.now(), - cost_source="estimated", + result = self.query_unified( + message=message, + system_message=system_message, + track_usage=True, + cost_calculator=cost_calculator, ) - - response_text = str(response.choices[0].message.content) - return response_text, usage_metrics + if result.usage is None: + raise RuntimeError("Expected usage metrics but none were populated") + return result.text or "", result.usage def query_with_schema( self, @@ -370,15 +394,20 @@ def query_with_schema( Raises: SchemaValidationException: If schema validation fails after max retries """ - # Call the tracking version and discard usage metrics - result, _ = self.query_with_schema_and_tracking( + warnings.warn( + "query_with_schema is deprecated; use query_unified with schema=...", + PendingDeprecationWarning, + stacklevel=2, + ) + result = self.query_unified( message=message, - schema=schema, system_message=system_message, - cost_calculator=None, + schema=schema, max_retries=max_retries, ) - return result + if result.model is None: + raise RuntimeError("Expected model instance but none was populated") + return result.model def query_with_schema_and_tracking( self, @@ -403,133 +432,24 @@ def query_with_schema_and_tracking( Raises: SchemaValidationException: If schema validation fails after max retries """ - # Get Pydantic model from schema input - pydantic_model = self._schema_manager.get_pydantic_model(schema) - - # Convert Pydantic model to JSON schema for adapter - schema_dict = self._pydantic_model_to_schema(pydantic_model) - - # Get appropriate adapter for this provider/model - provider = self._get_provider_from_model(self.model) - adapter = self._adapter_factory.get_adapter(provider, self.model) - - # Prepare messages - messages = [] - if system_message: - messages.append({"role": "system", "content": system_message}) - messages.append({"role": "user", "content": message}) - - # For Anthropic, we need to handle differently to preserve usage info - from cellsem_llm_client.schema.adapters import AnthropicSchemaAdapter - - if isinstance(adapter, AnthropicSchemaAdapter): - # Call completion directly to get both response and usage info - raw_response = completion( - model=self.model, - messages=messages, - tools=[adapter._create_tool_definition(schema_dict)], - tool_choice={ - "type": "function", - "function": {"name": "structured_response"}, - }, - max_tokens=self.max_tokens, - ) - - # Extract structured response using adapter - response = adapter._extract_tool_response(raw_response) - response_for_usage = raw_response - else: - # Apply schema enforcement using adapter - response = adapter.apply_schema( - messages=messages, - schema_dict=schema_dict, - model=self.model, - max_tokens=self.max_tokens, - ) - response_for_usage = response - - # Extract usage information from response - usage = response_for_usage.usage - input_tokens = usage.prompt_tokens - output_tokens = usage.completion_tokens - - # Extract cached tokens (OpenAI specific) - cached_tokens = None - if ( - hasattr(usage, "prompt_tokens_details") - and usage.prompt_tokens_details - and hasattr(usage.prompt_tokens_details, "cached_tokens") - ): - cached_tokens = usage.prompt_tokens_details.cached_tokens - - # Extract thinking tokens (future enhancement) - thinking_tokens = None - - # Calculate cost if calculator provided - estimated_cost_usd = None - if cost_calculator: - try: - temp_usage_metrics = UsageMetrics( - input_tokens=input_tokens, - output_tokens=output_tokens, - cached_tokens=cached_tokens, - thinking_tokens=thinking_tokens, - provider=provider, - model=self.model, - timestamp=datetime.now(), - cost_source="estimated", - ) - estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics) - except Exception as e: - # Log cost calculation failure but continue without cost estimation - logging.warning( - f"Cost calculation failed for {provider}/{self.model}: {e}" - ) - estimated_cost_usd = None - - # Create usage metrics - usage_metrics = UsageMetrics( - input_tokens=input_tokens, - output_tokens=output_tokens, - cached_tokens=cached_tokens, - thinking_tokens=thinking_tokens, - estimated_cost_usd=estimated_cost_usd, - provider=provider, - model=self.model, - timestamp=datetime.now(), - cost_source="estimated", + warnings.warn( + "query_with_schema_and_tracking is deprecated; use query_unified with schema=... and track_usage=True.", + PendingDeprecationWarning, + stacklevel=2, ) - - # Extract response content based on adapter type - if isinstance(adapter, AnthropicSchemaAdapter): - # Anthropic: response is already the extracted dict - response_content = json.dumps(response) - elif adapter.supports_native_schema(): - # For other native schema adapters (OpenAI), get content from message - response_content = str(response.choices[0].message.content) # type: ignore - else: - # For fallback adapters, content is in message - response_content = str(response.choices[0].message.content) # type: ignore - - # Validate response against schema with retry logic - validation_result = self._schema_validator.validate_with_retry( - response_text=response_content, - target_model=pydantic_model, + result = self.query_unified( + message=message, + schema=schema, + system_message=system_message, + track_usage=True, + cost_calculator=cost_calculator, max_retries=max_retries, ) - - if not validation_result.success: - raise SchemaValidationException( - f"Schema validation failed after {max_retries} retries: {validation_result.error}", - schema=str(schema), - response_text=response_content, - validation_errors=[str(validation_result.error)] - if validation_result.error - else [], + if result.model is None or result.usage is None: + raise RuntimeError( + "Expected model instance and usage metrics but one or both were not populated" ) - - assert validation_result.model_instance is not None - return validation_result.model_instance, usage_metrics + return result.model, result.usage def _pydantic_model_to_schema(self, model_class: type[BaseModel]) -> dict[str, Any]: """Convert Pydantic model to JSON schema dict. @@ -648,6 +568,218 @@ def _ensure_no_additional_properties(self, schema_dict: dict[str, Any]) -> None: if isinstance(item, dict): self._ensure_no_additional_properties(item) + def _run_tool_loop( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + tool_handlers: dict[str, Callable[[dict[str, Any]], str | None]], + max_turns: int, + ) -> tuple[str, Any, list[Any]]: + """Execute tool calls until completion. + + Returns: + A tuple of (final_content, final_response, all_responses) where + all_responses contains every API response from all turns for usage tracking. + """ + working_messages = list(messages) + all_responses: list[Any] = [] + + for _turn in range(max_turns): + response = completion( + model=self.model, + messages=[*working_messages], + tools=tools, + max_tokens=self.max_tokens, + ) + all_responses.append(response) + + response_message = response.choices[0].message + tool_calls = getattr(response_message, "tool_calls", None) + + if tool_calls: + assistant_message: dict[str, Any] = { + "role": "assistant", + "content": response_message.content, + "tool_calls": [], + } + working_messages.append(assistant_message) + + for tool_call in tool_calls: + function_call = getattr(tool_call, "function", None) + tool_name = getattr(function_call, "name", None) + tool_arguments = getattr(function_call, "arguments", {}) + + assistant_message["tool_calls"].append( + { + "id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", "function"), + "function": { + "name": tool_name, + "arguments": tool_arguments, + }, + } + ) + + if not tool_name or tool_name not in tool_handlers: + raise ValueError(f"No handler found for tool '{tool_name}'.") + + try: + parsed_args = ( + json.loads(tool_arguments) + if isinstance(tool_arguments, str) + else tool_arguments + ) + except Exception as exc: + raise ValueError( + f"Failed to parse arguments for tool '{tool_name}'." + ) from exc + + tool_result = tool_handlers[tool_name](parsed_args) + working_messages.append( + { + "role": "tool", + "tool_call_id": getattr(tool_call, "id", tool_name), + "content": tool_result if tool_result is not None else "", + } + ) + continue + + return str(response_message.content), response, all_responses + + raise RuntimeError("Max tool-call turns reached without a final response.") + + def _build_usage_metrics( + self, + raw_response: Any, + provider: str, + cost_calculator: Optional["FallbackCostCalculator"] = None, + ) -> UsageMetrics: + """Construct UsageMetrics from a LiteLLM response.""" + usage = raw_response.usage + input_tokens = usage.prompt_tokens + output_tokens = usage.completion_tokens + + cached_tokens = None + if ( + hasattr(usage, "prompt_tokens_details") + and usage.prompt_tokens_details + and hasattr(usage.prompt_tokens_details, "cached_tokens") + ): + cached_tokens = usage.prompt_tokens_details.cached_tokens + + thinking_tokens = None + + estimated_cost_usd = None + if cost_calculator: + try: + temp_usage_metrics = UsageMetrics( + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics) + except Exception as e: + logging.warning( + f"Cost calculation failed for {provider}/{self.model}: {e}" + ) + + return UsageMetrics( + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + estimated_cost_usd=estimated_cost_usd, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + + def _accumulate_usage_metrics( + self, + responses: list[Any], + provider: str, + cost_calculator: Optional["FallbackCostCalculator"] = None, + ) -> UsageMetrics: + """Accumulate usage metrics from multiple API responses. + + When tools are used, multiple API calls are made across iterations. + This method sums up the token usage from all calls to provide accurate + cumulative metrics. + + Args: + responses: List of LiteLLM response objects from all API calls + provider: The LLM provider name + cost_calculator: Optional cost calculator for estimating costs + + Returns: + Accumulated UsageMetrics with total token counts and costs + """ + total_input_tokens = 0 + total_output_tokens = 0 + total_cached_tokens = 0 + total_thinking_tokens = 0 + has_cached = False + has_thinking = False + + for response in responses: + if not hasattr(response, "usage"): + continue + + usage = response.usage + total_input_tokens += usage.prompt_tokens + total_output_tokens += usage.completion_tokens + + # Accumulate cached tokens if present + if ( + hasattr(usage, "prompt_tokens_details") + and usage.prompt_tokens_details + and hasattr(usage.prompt_tokens_details, "cached_tokens") + and usage.prompt_tokens_details.cached_tokens is not None + ): + total_cached_tokens += usage.prompt_tokens_details.cached_tokens + has_cached = True + + cached_tokens = total_cached_tokens if has_cached else None + thinking_tokens = total_thinking_tokens if has_thinking else None + + # Calculate cost based on accumulated tokens + estimated_cost_usd = None + if cost_calculator: + try: + temp_usage_metrics = UsageMetrics( + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + estimated_cost_usd = cost_calculator.calculate_cost(temp_usage_metrics) + except Exception as e: + logging.warning( + f"Cost calculation failed for {provider}/{self.model}: {e}" + ) + + return UsageMetrics( + input_tokens=total_input_tokens, + output_tokens=total_output_tokens, + cached_tokens=cached_tokens, + thinking_tokens=thinking_tokens, + estimated_cost_usd=estimated_cost_usd, + provider=provider, + model=self.model, + timestamp=datetime.now(), + cost_source="estimated", + ) + def _get_provider_from_model(self, model: str) -> str: """Determine provider from model name. diff --git a/tests/unit/test_agent_connection.py b/tests/unit/test_agent_connection.py index 7a7b33b..67cf56d 100644 --- a/tests/unit/test_agent_connection.py +++ b/tests/unit/test_agent_connection.py @@ -1,9 +1,12 @@ """Unit tests for agent connection classes.""" +import json +import warnings from typing import Any from unittest.mock import Mock, patch import pytest +from pydantic import BaseModel # Import will fail initially - that's expected for TDD from cellsem_llm_client.agents.agent_connection import ( @@ -405,10 +408,222 @@ def test_litellm_agent_query_with_tools_missing_handler_raises( with pytest.raises(ValueError, match="missing_tool"): agent.query_with_tools( message="Trigger a tool call", - tools=[], + tools=[{"type": "function", "function": {"name": "missing_tool"}}], tool_handlers={}, ) + @pytest.mark.unit + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_query_unified_tools_and_schema(self, mock_completion: Any) -> None: + """Tool loop should run and then validate final content against schema.""" + + class SampleModel(BaseModel): + answer: str + + tool_call = Mock() + tool_call.id = "call_1" + tool_call.type = "function" + tool_call.function = Mock() + tool_call.function.name = "fetch_data" + tool_call.function.arguments = "{}" + + first_response = Mock() + first_response.choices = [Mock()] + first_response.choices[0].message.content = None + first_response.choices[0].message.tool_calls = [tool_call] + first_response.usage = Mock() + first_response.usage.prompt_tokens = 1 + first_response.usage.completion_tokens = 1 + first_response.usage.prompt_tokens_details = None + + final_payload = {"answer": "done"} + final_response = Mock() + final_response.choices = [Mock()] + final_response.choices[0].message.content = json.dumps(final_payload) + final_response.choices[0].message.tool_calls = [] + final_response.usage = Mock() + final_response.usage.prompt_tokens = 1 + final_response.usage.completion_tokens = 1 + final_response.usage.prompt_tokens_details = None + + mock_completion.side_effect = [first_response, final_response] + + agent = LiteLLMAgent(model="gpt-3.5-turbo", api_key="test-key") + + executed = {} + + def fetch_data(_: dict[str, Any]) -> str: + executed["ran"] = True + return "ok" + + result = agent.query_unified( + message="Use the tool then answer in JSON", + tools=[ + { + "type": "function", + "function": { + "name": "fetch_data", + "parameters": {"type": "object"}, + "description": "dummy", + }, + } + ], + tool_handlers={"fetch_data": fetch_data}, + schema=SampleModel, + ) + + assert executed.get("ran") is True + assert result.model is not None + assert result.model.answer == "done" + + @pytest.mark.unit + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_query_unified_with_tools_accumulates_usage( + self, mock_completion: Any + ) -> None: + """Test that usage metrics accumulate across multiple tool call iterations.""" + tools = [ + { + "type": "function", + "function": { + "name": "get_data", + "description": "Get data", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ] + + # First API call - tool call requested + tool_call = Mock() + tool_call.id = "call_1" + tool_call.type = "function" + tool_call.function = Mock() + tool_call.function.name = "get_data" + tool_call.function.arguments = '{"query": "test"}' + + first_response = Mock() + first_response.choices = [Mock()] + first_response.choices[0].message.content = None + first_response.choices[0].message.tool_calls = [tool_call] + first_response.usage = Mock() + first_response.usage.prompt_tokens = 100 + first_response.usage.completion_tokens = 20 + first_response.usage.prompt_tokens_details = None + + # Second API call - final response + final_response = Mock() + final_response.choices = [Mock()] + final_response.choices[0].message.content = "Here is the result." + final_response.choices[0].message.tool_calls = [] + final_response.usage = Mock() + final_response.usage.prompt_tokens = 150 + final_response.usage.completion_tokens = 30 + final_response.usage.prompt_tokens_details = None + + mock_completion.side_effect = [first_response, final_response] + + def get_data(args: dict[str, Any]) -> str: + return "data result" + + agent = LiteLLMAgent(model="gpt-4", api_key="test-key") + result = agent.query_unified( + message="Get me data", + tools=tools, + tool_handlers={"get_data": get_data}, + track_usage=True, + ) + + # Verify response + assert result.text == "Here is the result." + assert result.usage is not None + + # Verify cumulative usage: 100 + 150 = 250 input, 20 + 30 = 50 output + assert result.usage.input_tokens == 250 + assert result.usage.output_tokens == 50 + assert result.usage.total_tokens == 300 + assert result.usage.provider == "openai" + assert result.usage.model == "gpt-4" + + @pytest.mark.unit + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_query_unified_basic(self, mock_completion: Any) -> None: + """Test unified query returns text.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Unified response" + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 1 + mock_response.usage.completion_tokens = 1 + mock_response.usage.prompt_tokens_details = None + mock_completion.return_value = mock_response + + agent = LiteLLMAgent(model="gpt-3.5-turbo", api_key="test-key") + result = agent.query_unified(message="Hello") + + assert result.text == "Unified response" + assert result.model is None + + @pytest.mark.unit + @patch("cellsem_llm_client.schema.adapters.litellm.completion") + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_query_unified_with_schema( + self, mock_agent_completion: Any, mock_adapter_completion: Any + ) -> None: + """Test unified query enforces schema.""" + + class SampleModel(BaseModel): + term: str + iri: str + + payload = {"term": "cell", "iri": "http://example.com"} + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = json.dumps(payload) + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 1 + mock_response.usage.completion_tokens = 1 + mock_response.usage.prompt_tokens_details = None + mock_agent_completion.return_value = mock_response + mock_adapter_completion.return_value = mock_response + + agent = LiteLLMAgent(model="gpt-4o", api_key="test-key") + result = agent.query_unified(message="Return term and iri", schema=SampleModel) + + assert result.model is not None + assert result.model.term == "cell" + + @pytest.mark.unit + @patch("cellsem_llm_client.schema.adapters.litellm.completion") + @patch("cellsem_llm_client.agents.agent_connection.completion") + def test_deprecated_wrappers_warn( + self, mock_agent_completion: Any, mock_adapter_completion: Any + ) -> None: + """Deprecated methods should emit PendingDeprecationWarning.""" + + class SampleModel(BaseModel): + reply: str + + payload = {"reply": "ok"} + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = json.dumps(payload) + mock_response.usage = Mock() + mock_response.usage.prompt_tokens = 1 + mock_response.usage.completion_tokens = 1 + mock_response.usage.prompt_tokens_details = None + mock_agent_completion.return_value = mock_response + mock_adapter_completion.return_value = mock_response + + agent = LiteLLMAgent(model="gpt-4o", api_key="test-key") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", PendingDeprecationWarning) + agent.query_with_schema("test", SampleModel) + assert any(issubclass(wi.category, PendingDeprecationWarning) for wi in w) + class TestOpenAIAgent: """Test the OpenAI-specific agent implementation."""