diff --git a/openhands-sdk/openhands/sdk/conversation/__init__.py b/openhands-sdk/openhands/sdk/conversation/__init__.py index d89d3040d..82a518483 100644 --- a/openhands-sdk/openhands/sdk/conversation/__init__.py +++ b/openhands-sdk/openhands/sdk/conversation/__init__.py @@ -8,6 +8,12 @@ from openhands.sdk.conversation.secret_registry import SecretRegistry from openhands.sdk.conversation.state import ConversationState from openhands.sdk.conversation.stuck_detector import StuckDetector +from openhands.sdk.conversation.token_display import ( + TokenDisplay, + TokenDisplayMode, + compute_token_display, + get_default_mode_from_env, +) from openhands.sdk.conversation.types import ConversationCallbackType from openhands.sdk.conversation.visualizer import ConversationVisualizer @@ -25,4 +31,9 @@ "RemoteConversation", "EventsListBase", "get_agent_final_response", + # Token display utilities (public SDK API) + "TokenDisplay", + "TokenDisplayMode", + "compute_token_display", + "get_default_mode_from_env", ] diff --git a/openhands-sdk/openhands/sdk/conversation/token_display.py b/openhands-sdk/openhands/sdk/conversation/token_display.py new file mode 100644 index 000000000..96a733a4a --- /dev/null +++ b/openhands-sdk/openhands/sdk/conversation/token_display.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from enum import Enum + +from openhands.sdk.conversation.conversation_stats import ConversationStats +from openhands.sdk.llm.utils.metrics import Metrics, TokenUsage + + +class TokenDisplayMode(str, Enum): + PER_CONTEXT = "per_context" # show metrics for the latest request only + ACCUMULATED = "accumulated" # show accumulated tokens across all requests + + @classmethod + def from_str(cls, value: str | None) -> TokenDisplayMode: + if not value: + return cls.PER_CONTEXT + v = value.strip().lower().replace("-", "_") + if v in {"per_context", "per_request", "latest", "current"}: + return cls.PER_CONTEXT + if v in {"accumulated", "total", "sum"}: + return cls.ACCUMULATED + # default to per-context to match current visual default and tests + return cls.PER_CONTEXT + + +@dataclass(frozen=True) +class TokenDisplay: + # Raw counts (not abbreviated) + input_tokens: int + output_tokens: int + cache_read_tokens: int + reasoning_tokens: int + context_window: int + # Rate [0.0, 1.0] or None if undefined + cache_hit_rate: float | None + # Total accumulated cost in USD + total_cost: float + # Optional delta of input tokens compared to previous request + since_last_input_tokens: int | None = None + + +def _get_combined_metrics(stats: ConversationStats | None) -> Metrics | None: + if not stats: + return None + try: + return stats.get_combined_metrics() + except Exception: + return None + + +def compute_token_display( + stats: ConversationStats | None, + mode: TokenDisplayMode = TokenDisplayMode.PER_CONTEXT, + include_since_last: bool = False, +) -> TokenDisplay | None: + """Compute token display values from conversation stats. + + Args: + stats: ConversationStats to read metrics from + mode: Whether to show per-context (latest request) or accumulated values + include_since_last: If True, include the delta of input tokens compared to + the previous request (when available) + + Returns: + TokenDisplay with raw numeric values, or None if metrics are unavailable + """ + combined = _get_combined_metrics(stats) + if not combined: + return None + + # No token usage recorded yet + if not combined.token_usages: + return None + + total_cost = combined.accumulated_cost or 0.0 + + if mode == TokenDisplayMode.ACCUMULATED: + usage: TokenUsage | None = combined.accumulated_token_usage + if usage is None: + return None + input_tokens = usage.prompt_tokens or 0 + output_tokens = usage.completion_tokens or 0 + cache_read = usage.cache_read_tokens or 0 + reasoning_tokens = usage.reasoning_tokens or 0 + context_window = usage.context_window or 0 + cache_hit_rate = (cache_read / input_tokens) if input_tokens > 0 else None + since_last: int | None = None + else: # PER_CONTEXT + usage = combined.token_usages[-1] + input_tokens = usage.prompt_tokens or 0 + output_tokens = usage.completion_tokens or 0 + cache_read = usage.cache_read_tokens or 0 + reasoning_tokens = usage.reasoning_tokens or 0 + context_window = usage.context_window or 0 + cache_hit_rate = (cache_read / input_tokens) if input_tokens > 0 else None + since_last = None + if include_since_last and len(combined.token_usages) >= 2: + prev = combined.token_usages[-2] + since_last = max(0, (usage.prompt_tokens or 0) - (prev.prompt_tokens or 0)) + + return TokenDisplay( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read, + reasoning_tokens=reasoning_tokens, + context_window=context_window, + cache_hit_rate=cache_hit_rate, + total_cost=total_cost, + since_last_input_tokens=since_last, + ) + + +def get_default_mode_from_env() -> TokenDisplayMode: + """Resolve default token display mode from env var. + + Env var: OH_TOKENS_VIEW_MODE + - "per_context" (default) + - "accumulated" + - also accepts aliases: per_request/latest/current and total/sum + """ + value = os.environ.get("OH_TOKENS_VIEW_MODE") + return TokenDisplayMode.from_str(value) diff --git a/openhands-sdk/openhands/sdk/conversation/visualizer.py b/openhands-sdk/openhands/sdk/conversation/visualizer.py index df12c129c..8150d8025 100644 --- a/openhands-sdk/openhands/sdk/conversation/visualizer.py +++ b/openhands-sdk/openhands/sdk/conversation/visualizer.py @@ -1,3 +1,4 @@ +import os import re from typing import TYPE_CHECKING @@ -5,6 +6,10 @@ from rich.panel import Panel from rich.text import Text +from openhands.sdk.conversation.token_display import ( + compute_token_display, + get_default_mode_from_env, +) from openhands.sdk.event import ( ActionEvent, AgentErrorEvent, @@ -269,19 +274,27 @@ def _create_event_panel(self, event: Event) -> Panel | None: ) def _format_metrics_subtitle(self) -> str | None: - """Format LLM metrics as a visually appealing subtitle string with icons, - colors, and k/m abbreviations using conversation stats.""" + """Format LLM metrics subtitle based on conversation stats. + + Uses TokenDisplay utility to compute values and supports env-configured + mode (per-context vs accumulated) and optional since-last delta. + """ + display_mode = get_default_mode_from_env() + include_since_last = os.environ.get( + "OH_TOKENS_VIEW_DELTA", "false" + ).lower() in {"1", "true", "yes"} + if not self._conversation_stats: return None - combined_metrics = self._conversation_stats.get_combined_metrics() - if not combined_metrics or not combined_metrics.accumulated_token_usage: + data = compute_token_display( + stats=self._conversation_stats, + mode=display_mode, + include_since_last=include_since_last, + ) + if not data: return None - usage = combined_metrics.accumulated_token_usage - cost = combined_metrics.accumulated_cost or 0.0 - - # helper: 1234 -> "1.2K", 1200000 -> "1.2M" def abbr(n: int | float) -> str: n = int(n or 0) if n >= 1_000_000_000: @@ -294,26 +307,25 @@ def abbr(n: int | float) -> str: return str(n) return s.replace(".0", "") - input_tokens = abbr(usage.prompt_tokens or 0) - output_tokens = abbr(usage.completion_tokens or 0) - - # Cache hit rate (prompt + cache) - prompt = usage.prompt_tokens or 0 - cache_read = usage.cache_read_tokens or 0 - cache_rate = f"{(cache_read / prompt * 100):.2f}%" if prompt > 0 else "N/A" - reasoning_tokens = usage.reasoning_tokens or 0 - - # Cost - cost_str = f"{cost:.4f}" if cost > 0 else "0.00" + cache_rate = ( + f"{(data.cache_hit_rate * 100):.2f}%" + if data.cache_hit_rate is not None + else "N/A" + ) + cost_str = f"{data.total_cost:.4f}" - # Build with fixed color scheme parts: list[str] = [] - parts.append(f"[cyan]↑ input {input_tokens}[/cyan]") + input_part = f"[cyan]↑ input {abbr(data.input_tokens)}" + if include_since_last and data.since_last_input_tokens is not None: + input_part += f" (+{abbr(data.since_last_input_tokens)})" + input_part += "[/cyan]" + parts.append(input_part) + parts.append(f"[magenta]cache hit {cache_rate}[/magenta]") - if reasoning_tokens > 0: - parts.append(f"[yellow] reasoning {abbr(reasoning_tokens)}[/yellow]") - parts.append(f"[blue]↓ output {output_tokens}[/blue]") - parts.append(f"[green]$ {cost_str}[/green]") + if data.reasoning_tokens > 0: + parts.append(f"[yellow] reasoning {abbr(data.reasoning_tokens)}[/yellow]") + parts.append(f"[blue]↓ output {abbr(data.output_tokens)}[/blue]") + parts.append(f"[green]$ {cost_str} (total)[/green]") return "Tokens: " + " • ".join(parts) diff --git a/tests/sdk/conversation/test_token_display.py b/tests/sdk/conversation/test_token_display.py new file mode 100644 index 000000000..6357242d7 --- /dev/null +++ b/tests/sdk/conversation/test_token_display.py @@ -0,0 +1,104 @@ +import pytest + +from openhands.sdk.conversation import ( + TokenDisplayMode, + compute_token_display, +) +from openhands.sdk.conversation.conversation_stats import ConversationStats +from openhands.sdk.conversation.visualizer import ConversationVisualizer +from openhands.sdk.llm.utils.metrics import Metrics + + +@pytest.fixture(autouse=True) +def _clear_env(monkeypatch): + # Ensure env vars do not leak between tests + monkeypatch.delenv("OH_TOKENS_VIEW_MODE", raising=False) + monkeypatch.delenv("OH_TOKENS_VIEW_DELTA", raising=False) + + +def _make_stats_with_two_requests(): + stats = ConversationStats() + m = Metrics(model_name="test-model") + # First call + m.add_cost(0.1) + m.add_token_usage( + prompt_tokens=100, + completion_tokens=25, + cache_read_tokens=10, + cache_write_tokens=0, + reasoning_tokens=5, + context_window=8000, + response_id="first", + ) + # Second call + m.add_cost(0.05) + m.add_token_usage( + prompt_tokens=220, + completion_tokens=80, + cache_read_tokens=44, + cache_write_tokens=0, + reasoning_tokens=0, + context_window=8000, + response_id="second", + ) + stats.usage_to_metrics["usage-1"] = m + return stats + + +def test_compute_token_display_per_context_with_delta(): + stats = _make_stats_with_two_requests() + + td = compute_token_display( + stats=stats, mode=TokenDisplayMode.PER_CONTEXT, include_since_last=True + ) + assert td is not None + + # Latest request values + assert td.input_tokens == 220 + assert td.output_tokens == 80 + assert td.reasoning_tokens == 0 + assert td.cache_read_tokens == 44 + assert td.context_window == 8000 + assert td.total_cost == pytest.approx(0.15) + assert td.cache_hit_rate == pytest.approx(44 / 220) + + # Delta vs previous + assert td.since_last_input_tokens == 120 # 220 - 100 + + +def test_compute_token_display_accumulated(): + stats = _make_stats_with_two_requests() + + td = compute_token_display(stats=stats, mode=TokenDisplayMode.ACCUMULATED) + assert td is not None + + # Accumulated values: sums of prompt/completion/cache_read; max context_window + assert td.input_tokens == 100 + 220 + assert td.output_tokens == 25 + 80 + assert td.cache_read_tokens == 10 + 44 + assert td.reasoning_tokens == 5 + 0 + assert td.context_window == 8000 + assert td.total_cost == pytest.approx(0.15) + assert td.cache_hit_rate == pytest.approx((10 + 44) / (100 + 220)) + + # No since-last in accumulated mode + assert td.since_last_input_tokens is None + + +def test_visualizer_env_vars_toggle_delta(monkeypatch): + stats = _make_stats_with_two_requests() + + # Force per-context and delta + monkeypatch.setenv("OH_TOKENS_VIEW_MODE", "per_context") + monkeypatch.setenv("OH_TOKENS_VIEW_DELTA", "true") + + viz = ConversationVisualizer(conversation_stats=stats) + subtitle = viz._format_metrics_subtitle() + assert subtitle is not None + assert "(+" in subtitle # shows since-last delta + + # Force accumulated mode: should hide delta even if env says true + monkeypatch.setenv("OH_TOKENS_VIEW_MODE", "accumulated") + subtitle2 = viz._format_metrics_subtitle() + assert subtitle2 is not None + assert "(+" not in subtitle2 diff --git a/tests/sdk/conversation/test_visualizer.py b/tests/sdk/conversation/test_visualizer.py index 7b479c2c0..5e1a1c9e0 100644 --- a/tests/sdk/conversation/test_visualizer.py +++ b/tests/sdk/conversation/test_visualizer.py @@ -337,7 +337,47 @@ def test_metrics_formatting(): assert "500" in subtitle # Output tokens assert "20.00%" in subtitle # Cache hit rate assert "200" in subtitle # Reasoning tokens - assert "0.0234" in subtitle # Cost + assert "$ 0.0234 (total)" in subtitle + + +def test_metrics_formatting_uses_latest_request(): + """Tokens should reflect the latest request while cost stays cumulative.""" + from openhands.sdk.conversation.conversation_stats import ConversationStats + from openhands.sdk.llm.utils.metrics import Metrics + + conversation_stats = ConversationStats() + metrics = Metrics(model_name="test-model") + metrics.add_cost(0.1) + metrics.add_token_usage( + prompt_tokens=120, + completion_tokens=40, + cache_read_tokens=12, + cache_write_tokens=0, + reasoning_tokens=5, + context_window=8000, + response_id="first", + ) + metrics.add_cost(0.05) + metrics.add_token_usage( + prompt_tokens=200, + completion_tokens=75, + cache_read_tokens=25, + cache_write_tokens=0, + reasoning_tokens=0, + context_window=8000, + response_id="second", + ) + conversation_stats.service_to_metrics["test_service"] = metrics + + visualizer = ConversationVisualizer(conversation_stats=conversation_stats) + + subtitle = visualizer._format_metrics_subtitle() + assert subtitle is not None + assert "input 200" in subtitle + assert "output 75" in subtitle + assert "cache hit 10.00%" not in subtitle # ensure using latest cache values + assert "cache hit 12.50%" in subtitle + assert "$ 0.1500 (total)" in subtitle def test_event_base_fallback_visualize():