diff --git a/.env.example b/.env.example index e487fba..15ac844 100644 --- a/.env.example +++ b/.env.example @@ -128,4 +128,16 @@ PLUGIN_PUBLISHER_ID="your-publisher-id" # MCP MCP_OAUTH_ISSUER_URL="https://your-mcp-oauth-issuer-url.com" -MCP_RESOURCE_SERVER_URL="https://your-mcp-resource-server-url.com" \ No newline at end of file +MCP_RESOURCE_SERVER_URL="https://your-mcp-resource-server-url.com" + +# Tracing +TRACE_TRACKER_ENABLED=false +TRACE_QUEUE_MAX_SIZE=10000 +TRACE_SLOW_REQUEST_MS=1000 +TRACE_AGENT_SLOW_OPERATION_MS=30000 +TRACE_HEARTBEAT_INTERVAL_SECONDS=30 +TRACE_HEALTH_INTERVAL_SECONDS=30 +TRACE_RESOURCE_INTERVAL_SECONDS=30 +TRACE_HEALTH_TIMEOUT_SECONDS=1 +TRACE_AGENT_LOOP_ITERATIONS=20 +TRACE_AGENT_TOOL_LOOP_ITERATIONS=20 \ No newline at end of file diff --git a/README.md b/README.md index d3e24f5..64d1e92 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

Discord - Version + Version Python License

diff --git a/pyproject.toml b/pyproject.toml index fec0eb0..5c908d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "brainapi2" -version = "2.11.9-dev" +version = "2.11.10-dev" description = "Version 2.x.x of the BrainAPI memory layer." authors = [ {name = "Christian",email = "alch.infoemail@gmail.com"} diff --git a/src/core/agents/core/invoke_loop.py b/src/core/agents/core/invoke_loop.py index 406ed2c..a27c96e 100644 --- a/src/core/agents/core/invoke_loop.py +++ b/src/core/agents/core/invoke_loop.py @@ -1,3 +1,6 @@ +import os + +from src.lib.tracing import TraceSeverity, tracer from src.utils.cleanup import strip_json from .parsing import ( @@ -59,7 +62,51 @@ def _log_token_usage(response) -> None: ) +def _trace_metadata(agent, config) -> dict: + metadata = (config or {}).get("metadata") or {} + return { + "agent": metadata.get("agent") or agent.__class__.__name__, + "brain_id": metadata.get("brain_id") or (config or {}).get("brain_id"), + "tags": (config or {}).get("tags") or [], + "tools_count": len(getattr(agent, "tools", []) or []), + "has_output_schema": agent.output_schema is not None, + } + + def run_invoke_loop(agent, messages, config): + trace_metadata = _trace_metadata(agent, config) + trace_tenant_id = trace_metadata.get("brain_id") + with tracer.span( + "agent.invoke_loop", + service="brainapi-agent", + operation="invoke_loop", + tenant_id=trace_tenant_id, + metadata=trace_metadata, + slow_operation_ms=float(os.getenv("TRACE_AGENT_SLOW_OPERATION_MS", "30000")), + ): + return _run_invoke_loop_impl( + agent, + messages, + config, + trace_metadata=trace_metadata, + trace_tenant_id=trace_tenant_id, + ) + + +def _run_invoke_loop_impl( + agent, + messages, + config, + *, + trace_metadata: dict, + trace_tenant_id: str | None, +): + agent_loop_threshold = int(os.getenv("TRACE_AGENT_LOOP_ITERATIONS", "20")) + tool_loop_threshold = int(os.getenv("TRACE_AGENT_TOOL_LOOP_ITERATIONS", "20")) + outer_loop_count = 0 + tool_loop_count = 0 + model_invoke_attempt_count = 0 + schema_retry_count = 0 model_responses = [] agent.messages = [] agent.messages.append( @@ -92,6 +139,16 @@ def run_invoke_loop(agent, messages, config): structured_response = None _schema_retry_count = 0 while True: + outer_loop_count += 1 + tracer.expensive_loop( + "agent.invoke_loop.outer_loop", + outer_loop_count, + service="brainapi-agent", + operation="invoke_loop", + tenant_id=trace_tenant_id, + metadata=trace_metadata, + threshold=agent_loop_threshold, + ) _did_retry_recovered_tool_call = False while n_message is None: _did_retry_unknown_finish = False @@ -120,9 +177,21 @@ def run_invoke_loop(agent, messages, config): else: _invoke_attempts = 3 for _invoke_attempt in range(_invoke_attempts): + model_invoke_attempt_count += 1 try: _n_message = agent.model.invoke(agent.messages, config) except Exception as _invoke_exc: + tracer.exception( + "agent.model.invoke.failed", + _invoke_exc, + service="brainapi-agent", + operation="model.invoke", + tenant_id=trace_tenant_id, + metadata={ + **trace_metadata, + "attempt": _invoke_attempt + 1, + }, + ) _exc_str = str(_invoke_exc).lower() if agent._tools_bound and ( "invalid message format" in _exc_str @@ -355,6 +424,19 @@ def run_invoke_loop(agent, messages, config): _did_retry_recovered_next_tool_call = False _last_called_tool_name = tool_name while True: + tool_loop_count += 1 + tracer.expensive_loop( + "agent.invoke_loop.tool_loop", + tool_loop_count, + service="brainapi-agent", + operation="tool_loop", + tenant_id=trace_tenant_id, + metadata={ + **trace_metadata, + "last_tool_name": _last_called_tool_name, + }, + threshold=tool_loop_threshold, + ) _next_attempts = 3 for _next_attempt in range(_next_attempts): if ( @@ -386,6 +468,18 @@ def run_invoke_loop(agent, messages, config): try: next_response = agent.model.invoke(agent.messages, config) except Exception as _next_exc: + tracer.exception( + "agent.model.next_invoke.failed", + _next_exc, + service="brainapi-agent", + operation="model.next_invoke", + tenant_id=trace_tenant_id, + metadata={ + **trace_metadata, + "attempt": _next_attempt + 1, + "last_tool_name": _last_called_tool_name, + }, + ) _next_exc_str = str(_next_exc).lower() if agent._tools_bound and ( "invalid message format" in _next_exc_str @@ -739,6 +833,20 @@ def run_invoke_loop(agent, messages, config): ) break except Exception as e: + schema_retry_count += 1 + tracer.error( + "agent.structured_output.parse_failed", + service="brainapi-agent", + operation="structured_output.parse", + tenant_id=trace_tenant_id, + severity=TraceSeverity.WARNING, + message=str(e), + metadata={ + **trace_metadata, + "schema_retry_count": schema_retry_count, + "error_type": type(e).__name__, + }, + ) if agent.debug: print("[DEBUG (agent_base)]: ", type(e).__name__, e) print( @@ -761,6 +869,16 @@ def run_invoke_loop(agent, messages, config): n_message = None continue + tracer.expensive_loop( + "agent.model.invoke_attempts", + model_invoke_attempt_count, + service="brainapi-agent", + operation="model.invoke", + tenant_id=trace_tenant_id, + metadata=trace_metadata, + threshold=int(os.getenv("TRACE_AGENT_MODEL_INVOKE_ATTEMPTS", "10")), + ) + try: from langchain_core.tracers.langchain import wait_for_all_tracers diff --git a/src/lib/tracing/__init__.py b/src/lib/tracing/__init__.py new file mode 100644 index 0000000..33ee078 --- /dev/null +++ b/src/lib/tracing/__init__.py @@ -0,0 +1,28 @@ +from src.lib.tracing.events import ( + TraceEvent, + TraceEventType, + TraceSeverity, +) +from src.lib.tracing.runtime import ( + HealthProbe, + RuntimeMonitor, + start_runtime_monitoring, + stop_runtime_monitoring, +) +from src.lib.tracing.subscribers import TraceSubscriber +from src.lib.tracing.tracker import LocalTraceQueue, TraceTracker, trace_tracker, tracer + +__all__ = [ + "HealthProbe", + "LocalTraceQueue", + "RuntimeMonitor", + "TraceEvent", + "TraceEventType", + "TraceSeverity", + "TraceSubscriber", + "TraceTracker", + "start_runtime_monitoring", + "stop_runtime_monitoring", + "trace_tracker", + "tracer", +] diff --git a/src/lib/tracing/events.py b/src/lib/tracing/events.py new file mode 100644 index 0000000..bd8ffca --- /dev/null +++ b/src/lib/tracing/events.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field +from enum import Enum +import time +from typing import Any + + +class TraceEventType(str, Enum): + ERROR = "error" + EXCEPTION = "exception" + DOWNTIME = "downtime" + EXPENSIVE_LOOP = "expensive_loop" + SLA_BREACH = "sla_breach" + LATENCY = "latency" + HEARTBEAT = "heartbeat" + HEALTH_CHECK = "health_check" + RESOURCE_SAMPLE = "resource_sample" + PROCESS = "process" + + +class TraceSeverity(str, Enum): + INFO = "info" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + + +@dataclass(frozen=True) +class TraceEvent: + event_type: TraceEventType + name: str + severity: TraceSeverity = TraceSeverity.INFO + service: str | None = None + operation: str | None = None + tenant_id: str | None = None + trace_id: str | None = None + duration_ms: float | None = None + threshold_ms: float | None = None + status_code: int | None = None + error_type: str | None = None + message: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + stack_trace: str | None = None + created_at: float = field(default_factory=time.time) + + def to_dict(self) -> dict[str, Any]: + return { + "event_type": self.event_type.value, + "name": self.name, + "severity": self.severity.value, + "service": self.service, + "operation": self.operation, + "tenant_id": self.tenant_id, + "trace_id": self.trace_id, + "duration_ms": self.duration_ms, + "threshold_ms": self.threshold_ms, + "status_code": self.status_code, + "error_type": self.error_type, + "message": self.message, + "metadata": dict(self.metadata), + "stack_trace": self.stack_trace, + "created_at": self.created_at, + } diff --git a/src/lib/tracing/middleware.py b/src/lib/tracing/middleware.py new file mode 100644 index 0000000..667fa2c --- /dev/null +++ b/src/lib/tracing/middleware.py @@ -0,0 +1,124 @@ +import os +import time +from uuid import uuid4 + +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from src.lib.tracing.events import TraceEventType, TraceSeverity +from src.lib.tracing.tracker import tracer + + +class TraceMiddleware: + def __init__( + self, + app: ASGIApp, + service_name: str, + slow_request_ms: float | None = None, + ): + self.app = app + self.service_name = service_name + self.slow_request_ms = ( + slow_request_ms + if slow_request_ms is not None + else float(os.getenv("TRACE_SLOW_REQUEST_MS", "1000")) + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + started_at = time.perf_counter() + status_code: int | None = None + tenant_id = self._tenant_id(scope) + trace_id = self._header(scope, "x-trace-id") or str(uuid4()) + tokens = tracer.set_context(trace_id=trace_id, tenant_id=tenant_id) + + async def send_wrapper(message: Message) -> None: + nonlocal status_code + if message["type"] == "http.response.start": + status_code = message.get("status") + await send(message) + + try: + await self.app(scope, receive, send_wrapper) + except Exception as exc: + duration_ms = (time.perf_counter() - started_at) * 1000 + tracer.exception( + name=self._operation(scope), + exception=exc, + service=self.service_name, + operation=self._operation(scope), + duration_ms=duration_ms, + status_code=status_code, + threshold_ms=self.slow_request_ms, + metadata=self._scope_metadata(scope, status_code), + ) + tracer.reset_context(tokens) + raise + + duration_ms = (time.perf_counter() - started_at) * 1000 + operation = self._operation(scope) + if status_code is not None and status_code >= 500: + tracer.error( + name=operation, + service=self.service_name, + operation=operation, + message=f"{self.service_name} returned HTTP {status_code}", + severity=TraceSeverity.ERROR, + duration_ms=duration_ms, + status_code=status_code, + threshold_ms=self.slow_request_ms, + metadata=self._scope_metadata(scope, status_code), + ) + + if self.slow_request_ms is not None and duration_ms >= self.slow_request_ms: + tracer.publish( + TraceEventType.SLA_BREACH, + name=operation, + severity=TraceSeverity.WARNING, + service=self.service_name, + operation=operation, + duration_ms=duration_ms, + threshold_ms=self.slow_request_ms, + status_code=status_code, + message=f"{operation} exceeded SLA threshold", + metadata=self._scope_metadata(scope, status_code), + ) + tracer.reset_context(tokens) + + def _operation(self, scope: Scope) -> str: + method = scope.get("method", scope["type"]) + return f"{method} {scope.get('path', '')}".strip() + + def _scope_metadata(self, scope: Scope, status_code: int | None) -> dict[str, str]: + metadata = { + "service": self.service_name, + "protocol": scope["type"], + "path": str(scope.get("path", "")), + } + if status_code is not None: + metadata["status_code"] = str(status_code) + return metadata + + def _tenant_id(self, scope: Scope) -> str | None: + return ( + self._header(scope, "x-brain-id") + or self._query_param(scope, "brain_id") + or self._header(scope, "x-tenant-id") + ) + + def _header(self, scope: Scope, name: str) -> str | None: + encoded_name = name.lower().encode() + for header_name, value in scope.get("headers", []): + if header_name.lower() == encoded_name: + return value.decode(errors="replace").rstrip() or None + return None + + def _query_param(self, scope: Scope, name: str) -> str | None: + raw_query = scope.get("query_string", b"").decode(errors="replace") + for item in raw_query.split("&"): + key, _, value = item.partition("=") + if key == name and value: + return value.rstrip() + return None diff --git a/src/lib/tracing/runtime.py b/src/lib/tracing/runtime.py new file mode 100644 index 0000000..893148a --- /dev/null +++ b/src/lib/tracing/runtime.py @@ -0,0 +1,336 @@ +import atexit +import asyncio +import os +import resource +import socket +import sys +import threading +import time +import traceback +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Callable + +from src.lib.tracing.events import TraceEventType, TraceSeverity +from src.lib.tracing.tracker import tracer + + +@dataclass(frozen=True) +class HealthProbe: + name: str + host: str + port: int + timeout_seconds: float = 1.0 + + +class RuntimeMonitor: + def __init__( + self, + *, + service_name: str, + heartbeat_interval_seconds: float | None = None, + health_interval_seconds: float | None = None, + resource_interval_seconds: float | None = None, + probes: list[HealthProbe] | None = None, + resource_sampler: Callable[[], dict] | None = None, + ): + self.service_name = service_name + self.heartbeat_interval_seconds = _interval( + heartbeat_interval_seconds, "TRACE_HEARTBEAT_INTERVAL_SECONDS", 30 + ) + self.health_interval_seconds = _interval( + health_interval_seconds, "TRACE_HEALTH_INTERVAL_SECONDS", 30 + ) + self.resource_interval_seconds = _interval( + resource_interval_seconds, "TRACE_RESOURCE_INTERVAL_SECONDS", 30 + ) + self.probes = probes or default_health_probes() + self.resource_sampler = resource_sampler or sample_process_resources + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._hooks_installed = False + + def start(self) -> None: + self.install_exception_hooks() + tracer.publish( + TraceEventType.PROCESS, + "process.started", + service=self.service_name, + operation="runtime.start", + metadata=_process_metadata(), + ) + if self._thread and self._thread.is_alive(): + return + self._thread = threading.Thread( + target=self._run, + name=f"{self.service_name}-runtime-monitor", + daemon=True, + ) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=2) + tracer.publish( + TraceEventType.PROCESS, + "process.stopped", + service=self.service_name, + operation="runtime.stop", + metadata=_process_metadata(), + ) + + def install_exception_hooks(self) -> None: + if self._hooks_installed: + return + previous_sys_hook = sys.excepthook + previous_thread_hook = getattr(threading, "excepthook", None) + + def sys_hook(exc_type, exc_value, exc_traceback): + self.record_unhandled_exception(exc_value, exc_traceback, "sys.excepthook") + previous_sys_hook(exc_type, exc_value, exc_traceback) + + def thread_hook(args): + self.record_unhandled_exception( + args.exc_value, + args.exc_traceback, + f"threading.excepthook:{getattr(args.thread, 'name', 'unknown')}", + ) + if previous_thread_hook: + previous_thread_hook(args) + + sys.excepthook = sys_hook + if previous_thread_hook: + threading.excepthook = thread_hook + self._hooks_installed = True + + def install_asyncio_exception_handler(self, loop: asyncio.AbstractEventLoop) -> None: + previous_handler = loop.get_exception_handler() + + def handler(event_loop, context): + exc = context.get("exception") + if exc: + self.record_unhandled_exception(exc, exc.__traceback__, "asyncio") + else: + tracer.error( + "runtime.asyncio.unhandled_error", + service=self.service_name, + operation="asyncio", + severity=TraceSeverity.ERROR, + message=context.get("message"), + metadata={k: str(v) for k, v in context.items() if k != "exception"}, + ) + if previous_handler: + previous_handler(event_loop, context) + else: + event_loop.default_exception_handler(context) + + loop.set_exception_handler(handler) + + def record_unhandled_exception( + self, + exc: BaseException, + exc_traceback, + operation: str, + ) -> None: + tracer.publish( + TraceEventType.EXCEPTION, + "runtime.unhandled_exception", + service=self.service_name, + operation=operation, + severity=TraceSeverity.CRITICAL, + message=str(exc), + exception=exc, + metadata={ + **_process_metadata(), + "traceback": "".join( + traceback.format_exception(type(exc), exc, exc_traceback) + ), + }, + ) + + def heartbeat(self) -> None: + tracer.publish( + TraceEventType.HEARTBEAT, + "runtime.heartbeat", + service=self.service_name, + operation="heartbeat", + metadata=_process_metadata(), + ) + + def sample_resources(self) -> None: + tracer.publish( + TraceEventType.RESOURCE_SAMPLE, + "runtime.resources", + service=self.service_name, + operation="resource_sampler", + metadata=self.resource_sampler(), + ) + + def check_health(self) -> None: + for probe in self.probes: + started_at = time.perf_counter() + ok = probe_tcp(probe) + duration_ms = (time.perf_counter() - started_at) * 1000 + metadata = { + "probe": probe.name, + "host": probe.host, + "port": probe.port, + "ok": ok, + } + if ok: + tracer.publish( + TraceEventType.HEALTH_CHECK, + f"health.{probe.name}", + service=self.service_name, + operation="health_check", + duration_ms=duration_ms, + metadata=metadata, + ) + else: + tracer.downtime( + f"health.{probe.name}.down", + duration_ms, + service=self.service_name, + operation="health_check", + metadata=metadata, + ) + + def _run(self) -> None: + last_heartbeat = last_health = last_resource = 0.0 + while not self._stop_event.is_set(): + now = time.monotonic() + if _due(now, last_heartbeat, self.heartbeat_interval_seconds): + self.heartbeat() + last_heartbeat = now + if _due(now, last_health, self.health_interval_seconds): + self.check_health() + last_health = now + if _due(now, last_resource, self.resource_interval_seconds): + self.sample_resources() + last_resource = now + self._stop_event.wait(1) + + +_monitors: dict[str, RuntimeMonitor] = {} + + +def start_runtime_monitoring(service_name: str, **kwargs) -> RuntimeMonitor: + monitor = _monitors.get(service_name) + if monitor is None: + monitor = RuntimeMonitor(service_name=service_name, **kwargs) + _monitors[service_name] = monitor + atexit.register(monitor.stop) + monitor.start() + return monitor + + +def stop_runtime_monitoring(service_name: str) -> None: + monitor = _monitors.get(service_name) + if monitor: + monitor.stop() + + +def runtime_tracing_lifespan(service_name: str, nested_lifespan=None): + @asynccontextmanager + async def lifespan(app): + start_runtime_monitoring(service_name) + if nested_lifespan is None: + try: + yield + finally: + stop_runtime_monitoring(service_name) + return + async with nested_lifespan(app): + try: + yield + finally: + stop_runtime_monitoring(service_name) + + return lifespan + + +def default_health_probes() -> list[HealthProbe]: + return [ + _probe_from_env("redis", "REDIS_HOST", "REDIS_PORT"), + _probe_from_env("mongo", "MONGO_HOST", "MONGO_PORT"), + _probe_from_env("neo4j", "NEO4J_HOST", "NEO4J_PORT"), + _probe_from_env("milvus", "MILVUS_HOST", "MILVUS_PORT"), + _probe_from_env("rabbitmq", "RABBITMQ_HOST", "RABBITMQ_PORT"), + ] + + +def _probe_from_env(name: str, host_var: str, port_var: str) -> HealthProbe: + defaults = { + "redis": ("localhost", 6379), + "mongo": ("localhost", 27017), + "neo4j": ("localhost", 7687), + "milvus": ("localhost", 19530), + "rabbitmq": ("localhost", 5672), + } + default_host, default_port = defaults[name] + return HealthProbe( + name=name, + host=os.getenv(host_var, default_host), + port=int(os.getenv(port_var, str(default_port))), + timeout_seconds=float(os.getenv("TRACE_HEALTH_TIMEOUT_SECONDS", "1")), + ) + + +def probe_tcp(probe: HealthProbe) -> bool: + try: + with socket.create_connection( + (probe.host, probe.port), timeout=probe.timeout_seconds + ): + return True + except OSError: + return False + + +def sample_process_resources() -> dict: + usage = resource.getrusage(resource.RUSAGE_SELF) + return { + **_process_metadata(), + "rss_kb": usage.ru_maxrss, + "user_cpu_seconds": usage.ru_utime, + "system_cpu_seconds": usage.ru_stime, + "threads": threading.active_count(), + "load_avg": os.getloadavg() if hasattr(os, "getloadavg") else None, + **_proc_status(), + } + + +def _proc_status() -> dict: + status_path = "/proc/self/status" + if not os.path.exists(status_path): + return {} + keys = { + "VmRSS": "vm_rss", + "VmSize": "vm_size", + "Threads": "proc_threads", + } + values = {} + try: + with open(status_path, encoding="utf-8") as status_file: + for line in status_file: + key, _, value = line.partition(":") + if key in keys: + values[keys[key]] = value.strip() + except OSError: + return {} + return values + + +def _process_metadata() -> dict: + return { + "pid": os.getpid(), + "process": os.path.basename(sys.argv[0] or "python"), + } + + +def _interval(value: float | None, env_name: str, default: float) -> float: + return value if value is not None else float(os.getenv(env_name, str(default))) + + +def _due(now: float, last: float, interval: float) -> bool: + return interval >= 0 and now - last >= interval diff --git a/src/lib/tracing/subscribers.py b/src/lib/tracing/subscribers.py new file mode 100644 index 0000000..6814264 --- /dev/null +++ b/src/lib/tracing/subscribers.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +from typing import Iterable + +from src.lib.tracing.events import TraceEvent + + +class TraceSubscriber(ABC): + @abstractmethod + def handle(self, event: TraceEvent) -> None: + raise NotImplementedError("handle method not implemented") + + def handle_many(self, events: Iterable[TraceEvent]) -> None: + for event in events: + self.handle(event) + diff --git a/src/lib/tracing/tracker.py b/src/lib/tracing/tracker.py new file mode 100644 index 0000000..77f6fc1 --- /dev/null +++ b/src/lib/tracing/tracker.py @@ -0,0 +1,376 @@ +import asyncio +import inspect +import logging +import os +import time +import traceback +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Iterator, Optional + +from src.lib.tracing.events import TraceEvent, TraceEventType, TraceSeverity +from src.lib.tracing.subscribers import TraceSubscriber + +logger = logging.getLogger(__name__) + +_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None) +_tenant_id_var: ContextVar[Optional[str]] = ContextVar("tenant_id", default=None) + + +class LocalTraceQueue: + def __init__(self, max_size: int = 10000): + self._queue: asyncio.Queue[TraceEvent] = asyncio.Queue(maxsize=max_size) + + @property + def size(self) -> int: + return self._queue.qsize() + + def put_nowait(self, event: TraceEvent) -> bool: + try: + self._queue.put_nowait(event) + return True + except asyncio.QueueFull: + return False + + async def get(self) -> TraceEvent: + return await self._queue.get() + + def task_done(self) -> None: + self._queue.task_done() + + def drain(self, limit: Optional[int] = None) -> list[TraceEvent]: + events: list[TraceEvent] = [] + while not self._queue.empty() and (limit is None or len(events) < limit): + events.append(self._queue.get_nowait()) + self._queue.task_done() + return events + + +class TraceTracker: + def __init__( + self, + *, + queue: Optional[LocalTraceQueue] = None, + max_queue_size: Optional[int] = None, + slow_operation_ms: Optional[float] = None, + downtime_ms: Optional[float] = None, + expensive_loop_iterations: Optional[int] = None, + enabled: Optional[bool] = None, + ): + self.queue = queue or LocalTraceQueue( + max_queue_size + if max_queue_size is not None + else int(os.getenv("TRACE_QUEUE_MAX_SIZE", "10000")) + ) + self.slow_operation_ms = ( + slow_operation_ms + if slow_operation_ms is not None + else float(os.getenv("TRACE_SLOW_OPERATION_MS", "1000")) + ) + self.downtime_ms = ( + downtime_ms + if downtime_ms is not None + else float(os.getenv("TRACE_DOWNTIME_MS", "5000")) + ) + self.expensive_loop_iterations = ( + expensive_loop_iterations + if expensive_loop_iterations is not None + else int(os.getenv("TRACE_EXPENSIVE_LOOP_ITERATIONS", "10000")) + ) + self.enabled = ( + enabled + if enabled is not None + else os.getenv("TRACE_TRACKER_ENABLED", "true") == "true" + ) + self._subscribers: list[TraceSubscriber] = [] + self._subscriber_tasks: list[asyncio.Task] = [] + + @property + def default_sla_ms(self) -> float: + return self.slow_operation_ms + + def set_context( + self, *, trace_id: Optional[str] = None, tenant_id: Optional[str] = None + ): + trace_token = _trace_id_var.set(trace_id) if trace_id is not None else None + tenant_token = _tenant_id_var.set(tenant_id) if tenant_id is not None else None + return trace_token, tenant_token + + def reset_context(self, tokens) -> None: + trace_token, tenant_token = tokens + if trace_token is not None: + _trace_id_var.reset(trace_token) + if tenant_token is not None: + _tenant_id_var.reset(tenant_token) + + def subscribe(self, subscriber: TraceSubscriber) -> None: + self._subscribers.append(subscriber) + + def clear_subscribers(self) -> None: + self._subscribers.clear() + + def publish( + self, + event_type: TraceEventType | str, + name: str, + *, + severity: TraceSeverity | str = TraceSeverity.INFO, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + service: Optional[str] = None, + operation: Optional[str] = None, + duration_ms: Optional[float] = None, + threshold_ms: Optional[float] = None, + status_code: Optional[int] = None, + message: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + exception: Optional[BaseException] = None, + ) -> Optional[TraceEvent]: + if not self.enabled: + return None + event_trace_id = trace_id or _trace_id_var.get() + event_message = message or (str(exception) if exception else None) + event = TraceEvent( + event_type=TraceEventType(event_type), + severity=TraceSeverity(severity), + name=name, + service=service, + operation=operation, + tenant_id=tenant_id or _tenant_id_var.get(), + trace_id=event_trace_id, + duration_ms=duration_ms, + threshold_ms=threshold_ms, + status_code=status_code, + error_type=type(exception).__name__ if exception else None, + message=event_message, + metadata=dict(metadata or {}), + stack_trace=( + "".join(traceback.format_exception(exception)) if exception else None + ), + ) + if not self.queue.put_nowait(event): + logger.warning("Trace queue is full; dropping event %s", event.name) + return None + return event + + def exception( + self, + name: str, + exception: BaseException, + *, + service: Optional[str] = None, + operation: Optional[str] = None, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + duration_ms: Optional[float] = None, + threshold_ms: Optional[float] = None, + status_code: Optional[int] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[TraceEvent]: + return self.publish( + TraceEventType.EXCEPTION, + name, + severity=TraceSeverity.ERROR, + service=service, + operation=operation, + tenant_id=tenant_id, + trace_id=trace_id, + duration_ms=duration_ms, + threshold_ms=threshold_ms, + status_code=status_code, + metadata=metadata, + exception=exception, + ) + + def error( + self, + name: str, + *, + service: Optional[str] = None, + operation: Optional[str] = None, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + severity: TraceSeverity | str = TraceSeverity.ERROR, + duration_ms: Optional[float] = None, + threshold_ms: Optional[float] = None, + status_code: Optional[int] = None, + message: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[TraceEvent]: + return self.publish( + TraceEventType.ERROR, + name, + severity=severity, + service=service, + operation=operation, + tenant_id=tenant_id, + trace_id=trace_id, + duration_ms=duration_ms, + threshold_ms=threshold_ms, + status_code=status_code, + message=message, + metadata=metadata, + ) + + def downtime( + self, + name: str, + duration_ms: float, + *, + service: Optional[str] = None, + operation: Optional[str] = None, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + ) -> Optional[TraceEvent]: + return self.publish( + TraceEventType.DOWNTIME, + name, + severity=TraceSeverity.CRITICAL, + service=service, + operation=operation, + tenant_id=tenant_id, + trace_id=trace_id, + duration_ms=duration_ms, + metadata=metadata, + ) + + def expensive_loop( + self, + name: str, + iterations: int, + *, + service: Optional[str] = None, + operation: Optional[str] = None, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + threshold: Optional[int] = None, + ) -> Optional[TraceEvent]: + limit = ( + threshold + if threshold is not None + else self.expensive_loop_iterations + ) + if iterations < limit: + return None + return self.publish( + TraceEventType.EXPENSIVE_LOOP, + name, + severity=TraceSeverity.WARNING, + service=service, + operation=operation, + tenant_id=tenant_id, + trace_id=trace_id, + metadata={**(metadata or {}), "iterations": iterations}, + ) + + @contextmanager + def span( + self, + name: str, + *, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + slow_operation_ms: Optional[float] = None, + service: Optional[str] = None, + operation: Optional[str] = None, + ) -> Iterator[None]: + tokens = self.set_context(trace_id=trace_id, tenant_id=tenant_id) + started_at = time.perf_counter() + try: + yield + except Exception as exc: + duration_ms = (time.perf_counter() - started_at) * 1000 + self.exception( + name, + exc, + service=service, + operation=operation, + metadata={**(metadata or {}), "duration_ms": duration_ms}, + ) + raise + finally: + duration_ms = (time.perf_counter() - started_at) * 1000 + threshold = ( + slow_operation_ms + if slow_operation_ms is not None + else self.slow_operation_ms + ) + if duration_ms >= threshold: + self.publish( + TraceEventType.SLA_BREACH, + name, + severity=TraceSeverity.WARNING, + service=service, + operation=operation, + duration_ms=duration_ms, + threshold_ms=threshold, + metadata=metadata, + ) + self.reset_context(tokens) + + def track_loop( + self, + name: str, + iterable, + *, + tenant_id: Optional[str] = None, + trace_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, + threshold: Optional[int] = None, + ): + count = 0 + for item in iterable: + count += 1 + yield item + if count >= (threshold or self.expensive_loop_iterations): + self.expensive_loop( + name, + count, + tenant_id=tenant_id, + trace_id=trace_id, + metadata=metadata, + threshold=threshold, + ) + + async def dispatch_once(self) -> TraceEvent: + event = await self.queue.get() + try: + for subscriber in list(self._subscribers): + result = subscriber.handle(event) + if inspect.isawaitable(result): + await result + finally: + self.queue.task_done() + return event + + async def subscribe_forever(self) -> None: + while True: + await self.dispatch_once() + + def start_subscribers(self) -> None: + if not self._subscribers or self._subscriber_tasks: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._subscriber_tasks = [ + loop.create_task(self.subscribe_forever()) for _ in self._subscribers + ] + + async def stop_subscribers(self) -> None: + for task in self._subscriber_tasks: + task.cancel() + for task in self._subscriber_tasks: + try: + await task + except asyncio.CancelledError: + pass + self._subscriber_tasks = [] + + +trace_tracker = TraceTracker() +tracer = trace_tracker diff --git a/src/lib/tracing/workers.py b/src/lib/tracing/workers.py new file mode 100644 index 0000000..95fe9ec --- /dev/null +++ b/src/lib/tracing/workers.py @@ -0,0 +1,127 @@ +import time +from typing import Any + +from celery import signals + +from src.lib.tracing.tracker import tracer + +_task_starts: dict[str, float] = {} +_task_context_tokens: dict[str, Any] = {} + + +def _task_id(task_id: str | None, task=None) -> str: + if task_id: + return task_id + request = getattr(task, "request", None) + return getattr(request, "id", None) or "unknown" + + +def _task_name(task=None, sender=None) -> str: + return getattr(task, "name", None) or getattr(sender, "name", None) or "unknown" + + +def _tenant_id(args=None, kwargs=None) -> str | None: + args = args or () + kwargs = kwargs or {} + candidates = list(args) + candidates.extend(kwargs.values()) + for candidate in candidates: + if isinstance(candidate, dict): + brain_id = candidate.get("brain_id") + if brain_id: + return str(brain_id) + brain_id = kwargs.get("brain_id") + return str(brain_id) if brain_id else None + + +def _trace_id(task_id: str, kwargs=None) -> str: + kwargs = kwargs or {} + trace_id = kwargs.get("trace_id") + return str(trace_id) if trace_id else task_id + + +def _metadata(task_id: str, task_name: str, args=None, kwargs=None) -> dict[str, Any]: + return { + "task_id": task_id, + "task_name": task_name, + "args_count": len(args or ()), + "kwargs_keys": sorted((kwargs or {}).keys()), + } + + +def install_celery_tracing(_app=None, service_name: str = "brainapi-worker") -> None: + signals.task_prerun.connect( + _trace_task_prerun, + weak=False, + dispatch_uid=f"{service_name}.trace.task_prerun", + ) + signals.task_postrun.connect( + _trace_task_postrun, + weak=False, + dispatch_uid=f"{service_name}.trace.task_postrun", + ) + signals.task_failure.connect( + _trace_task_failure, + weak=False, + dispatch_uid=f"{service_name}.trace.task_failure", + ) + + +def _trace_task_prerun(sender=None, task_id=None, task=None, args=None, kwargs=None, **_): + resolved_task_id = _task_id(task_id, task) + resolved_task_name = _task_name(task, sender) + _task_starts[resolved_task_id] = time.perf_counter() + _task_context_tokens[resolved_task_id] = tracer.set_context( + trace_id=_trace_id(resolved_task_id, kwargs), + tenant_id=_tenant_id(args, kwargs), + ) + tracer.publish( + "latency", + f"{resolved_task_name}.started", + service="brainapi-worker", + operation=resolved_task_name, + metadata=_metadata(resolved_task_id, resolved_task_name, args, kwargs), + ) + + +def _trace_task_postrun(sender=None, task_id=None, task=None, args=None, kwargs=None, retval=None, state=None, **_): + resolved_task_id = _task_id(task_id, task) + resolved_task_name = _task_name(task, sender) + started_at = _task_starts.pop(resolved_task_id, None) + duration_ms = (time.perf_counter() - started_at) * 1000 if started_at else None + metadata = { + **_metadata(resolved_task_id, resolved_task_name, args, kwargs), + "state": state, + "returned": retval is not None, + } + if duration_ms is not None and duration_ms >= tracer.default_sla_ms: + tracer.publish( + "sla_breach", + f"{resolved_task_name}.sla_breach", + service="brainapi-worker", + operation=resolved_task_name, + duration_ms=duration_ms, + threshold_ms=tracer.default_sla_ms, + metadata=metadata, + ) + token = _task_context_tokens.pop(resolved_task_id, None) + if token is not None: + tracer.reset_context(token) + + +def _trace_task_failure(sender=None, task_id=None, exception=None, args=None, kwargs=None, einfo=None, **_): + resolved_task_id = _task_id(task_id, sender) + resolved_task_name = _task_name(sender=sender) + started_at = _task_starts.get(resolved_task_id) + duration_ms = (time.perf_counter() - started_at) * 1000 if started_at else None + metadata = _metadata(resolved_task_id, resolved_task_name, args, kwargs) + if einfo is not None: + metadata["traceback"] = str(einfo) + tracer.exception( + f"{resolved_task_name}.failed", + exception or RuntimeError("Celery task failed"), + service="brainapi-worker", + operation=resolved_task_name, + duration_ms=duration_ms, + metadata=metadata, + ) diff --git a/src/services/api/app.py b/src/services/api/app.py index dbf6237..edfd38d 100644 --- a/src/services/api/app.py +++ b/src/services/api/app.py @@ -20,6 +20,8 @@ from src.services.api.routes.retrieve import retrieve_router from src.services.api.routes.system import system_router from src.services.api.routes.tasks import tasks_router +from src.lib.tracing.middleware import TraceMiddleware +from src.lib.tracing.runtime import start_runtime_monitoring, stop_runtime_monitoring logger = logging.getLogger("brainapi.plugins") @@ -31,6 +33,7 @@ async def lifespan(app: FastAPI): from src.core.plugins.context import PluginContext from src.core.plugins.loader import PluginLoader + start_runtime_monitoring("brainapi-api") ctx = PluginContext.from_app(app) loader = PluginLoader(plugins_dir=PLUGINS_DIR, context=ctx) results = loader.load_all() @@ -42,7 +45,10 @@ async def lifespan(app: FastAPI): for handler in handlers: await handler() if _is_coroutine(handler) else handler() - yield + try: + yield + finally: + stop_runtime_monitoring("brainapi-api") for event_name, handlers in ctx._event_handlers.items(): if event_name == "shutdown": @@ -118,6 +124,7 @@ def _log_plugin_banner(loader, results: dict[str, bool]): allow_methods=["*"], allow_headers=["*"], ) +app.add_middleware(TraceMiddleware, service_name="brainapi-api") app.include_router(ingest_router) app.include_router(retrieve_router) diff --git a/src/services/mcp/app.py b/src/services/mcp/app.py index 1f04f82..d1fe0fa 100644 --- a/src/services/mcp/app.py +++ b/src/services/mcp/app.py @@ -12,6 +12,10 @@ _project_root = Path(__file__).resolve().parent.parent.parent.parent dotenv.load_dotenv(_project_root / ".env") +from contextlib import asynccontextmanager + +from src.lib.tracing.middleware import TraceMiddleware +from src.lib.tracing.runtime import start_runtime_monitoring, stop_runtime_monitoring from src.services.mcp.main import auth_token_var, mcp, oauth_provider PLUGINS_DIR = Path(os.getenv("PLUGINS_DIR", str(_project_root / "plugins"))) @@ -80,6 +84,16 @@ def _log_plugin_banner(loader, results: dict[str, bool]): _mcp_app = mcp.streamable_http_app() +@asynccontextmanager +async def _lifespan(app): + start_runtime_monitoring("brainapi-mcp") + async with _mcp_app.router.lifespan_context(app): + try: + yield + finally: + stop_runtime_monitoring("brainapi-mcp") + + class AuthContextMiddleware: def __init__(self, app): self.app = app @@ -137,6 +151,9 @@ async def _lifespan(app): ] app = Starlette( routes=_custom_routes + list(_mcp_app.routes), - middleware=[Middleware(AuthContextMiddleware)], + middleware=[ + Middleware(TraceMiddleware, service_name="brainapi-mcp"), + Middleware(AuthContextMiddleware), + ], lifespan=_lifespan, ) diff --git a/src/services/mcp/main.py b/src/services/mcp/main.py index e694251..b37e66d 100644 --- a/src/services/mcp/main.py +++ b/src/services/mcp/main.py @@ -93,6 +93,7 @@ async def _mcp_oauth_consent(request: Request) -> HTMLResponse | RedirectRespons scope = q.get("scope") or " ".join(_oauth_scopes) resource = q.get("resource") or "" state = q.get("state") or "" + redirect_uri_provided = q.get("redirect_uri_provided_explicitly") or "1" if not (client_id and redirect_uri and code_challenge): return HTMLResponse("Missing OAuth parameters", status_code=400) client = await oauth_provider.get_client(client_id) @@ -111,6 +112,7 @@ async def _mcp_oauth_consent(request: Request) -> HTMLResponse | RedirectRespons + @@ -125,6 +127,9 @@ async def _mcp_oauth_consent(request: Request) -> HTMLResponse | RedirectRespons scope_str = str(form.get("scope") or "") resource = str(form.get("resource") or "") or None state = str(form.get("state") or "") or None + redirect_uri_provided_explicitly = str( + form.get("redirect_uri_provided_explicitly") or "1" + ) == "1" brainpat = str(form.get("brainpat") or "").strip() if not (client_id and redirect_uri and code_challenge and brainpat): return HTMLResponse("Missing fields", status_code=400) @@ -145,6 +150,7 @@ async def _mcp_oauth_consent(request: Request) -> HTMLResponse | RedirectRespons code = oauth_provider.issue_auth_code( client_id=client_id, redirect_uri=ru, + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, code_challenge=code_challenge, scopes=scopes, resource=resource, diff --git a/src/services/mcp/oauth_provider.py b/src/services/mcp/oauth_provider.py index 5a9da6d..6f31e95 100644 --- a/src/services/mcp/oauth_provider.py +++ b/src/services/mcp/oauth_provider.py @@ -142,6 +142,9 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat "client_id": client.client_id or "", "redirect_uri": str(params.redirect_uri), "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": "1" + if params.redirect_uri_provided_explicitly + else "0", "scope": " ".join(scopes), } if params.resource: @@ -283,6 +286,7 @@ def issue_auth_code( *, client_id: str, redirect_uri: AnyUrl, + redirect_uri_provided_explicitly: bool = True, code_challenge: str, scopes: list[str], resource: str | None, @@ -297,7 +301,7 @@ def issue_auth_code( client_id=client_id, code_challenge=code_challenge, redirect_uri=redirect_uri, - redirect_uri_provided_explicitly=True, + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, resource=resource, ) self._set_model(self._auth_code_key(code), ac, self._auth_code_ttl) diff --git a/src/workers/app.py b/src/workers/app.py index fdaf412..a013063 100644 --- a/src/workers/app.py +++ b/src/workers/app.py @@ -31,6 +31,8 @@ from kombu import Queue from src.config import config +from src.lib.tracing import start_runtime_monitoring +from src.lib.tracing.workers import install_celery_tracing os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1") @@ -98,3 +100,5 @@ "fanout_patterns": True, }, ) +install_celery_tracing(service_name="brainapi-worker") +start_runtime_monitoring(service_name="brainapi-worker") diff --git a/tests/test_mcp_oauth_redis_storage.py b/tests/test_mcp_oauth_redis_storage.py index 8ccd0a1..ae3a136 100644 --- a/tests/test_mcp_oauth_redis_storage.py +++ b/tests/test_mcp_oauth_redis_storage.py @@ -1,12 +1,14 @@ import asyncio import os import unittest +from urllib.parse import parse_qs, urlparse from pydantic import AnyUrl os.environ.setdefault("REDIS_HOST", "localhost") os.environ.setdefault("REDIS_PORT", "6379") +from mcp.server.auth.provider import AuthorizationParams from mcp.shared.auth import OAuthClientInformationFull from src.services.mcp.oauth_provider import BrainapiMcpOAuthProvider @@ -99,6 +101,44 @@ def test_authorization_code_exchange_survives_provider_recreation(self): self.assertIsNone(asyncio.run(recreated.load_authorization_code(loaded_client, code))) self.assertIsNotNone(asyncio.run(recreated.load_access_token(token.access_token))) + def test_authorization_code_preserves_implicit_redirect_uri(self): + redis_client = FakeRedis() + provider = make_provider(redis_client) + + code = provider.issue_auth_code( + client_id="client-1", + redirect_uri=AnyUrl("https://claude.ai/api/mcp/auth_callback"), + redirect_uri_provided_explicitly=False, + code_challenge="challenge", + scopes=["brainapi"], + resource="https://brainapi.example/mcp", + state="state", + brainpat="brain-pat", + ) + auth_code = asyncio.run(provider.load_authorization_code(None, code)) + + self.assertFalse(auth_code.redirect_uri_provided_explicitly) + + def test_authorize_carries_redirect_uri_explicitness_to_consent(self): + provider = make_provider(FakeRedis()) + client = OAuthClientInformationFull( + client_id="client-1", + redirect_uris=[AnyUrl("https://claude.ai/api/mcp/auth_callback")], + ) + params = AuthorizationParams( + state="state", + scopes=["brainapi"], + code_challenge="challenge", + redirect_uri=AnyUrl("https://claude.ai/api/mcp/auth_callback"), + redirect_uri_provided_explicitly=False, + resource="https://brainapi.example/mcp", + ) + + url = asyncio.run(provider.authorize(client, params)) + query = parse_qs(urlparse(url).query) + + self.assertEqual(query["redirect_uri_provided_explicitly"], ["0"]) + def test_refresh_exchange_survives_provider_recreation_and_revokes_old_access_token(self): redis_client = FakeRedis() provider = make_provider(redis_client) diff --git a/tests/test_tracing_runtime_monitor.py b/tests/test_tracing_runtime_monitor.py new file mode 100644 index 0000000..6934c2f --- /dev/null +++ b/tests/test_tracing_runtime_monitor.py @@ -0,0 +1,88 @@ +import asyncio +import unittest +from unittest.mock import patch + +from src.lib.tracing import LocalTraceQueue, TraceEventType, TraceTracker +from src.lib.tracing import runtime as tracing_runtime +from src.lib.tracing.runtime import HealthProbe, RuntimeMonitor + + +class RuntimeMonitorTests(unittest.TestCase): + def setUp(self): + self.original_tracer = tracing_runtime.tracer + self.tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + tracing_runtime.tracer = self.tracker + + def tearDown(self): + tracing_runtime.tracer = self.original_tracer + tracing_runtime._monitors.clear() + + def test_runtime_monitor_records_heartbeat_resources_and_shutdown(self): + monitor = RuntimeMonitor( + service_name="test-service", + probes=[], + resource_sampler=lambda: {"rss_kb": 12}, + ) + + monitor.heartbeat() + monitor.sample_resources() + monitor.stop() + + events = self.tracker.queue.drain() + self.assertEqual(events[0].event_type, TraceEventType.HEARTBEAT) + self.assertEqual(events[1].event_type, TraceEventType.RESOURCE_SAMPLE) + self.assertEqual(events[1].metadata["rss_kb"], 12) + self.assertEqual(events[2].event_type, TraceEventType.PROCESS) + self.assertEqual(events[2].name, "process.stopped") + + def test_health_check_records_success_and_downtime(self): + monitor = RuntimeMonitor( + service_name="test-service", + probes=[ + HealthProbe("up", "127.0.0.1", 1), + HealthProbe("down", "127.0.0.1", 2), + ], + ) + + with patch.object(tracing_runtime, "probe_tcp", side_effect=[True, False]): + monitor.check_health() + + events = self.tracker.queue.drain() + self.assertEqual(events[0].event_type, TraceEventType.HEALTH_CHECK) + self.assertEqual(events[0].name, "health.up") + self.assertEqual(events[1].event_type, TraceEventType.DOWNTIME) + self.assertEqual(events[1].name, "health.down.down") + + def test_unhandled_exception_hook_records_exception(self): + monitor = RuntimeMonitor(service_name="test-service", probes=[]) + exc = RuntimeError("boom") + + monitor.record_unhandled_exception(exc, exc.__traceback__, "unit-test") + + events = self.tracker.queue.drain() + self.assertEqual(len(events), 1) + self.assertEqual(events[0].event_type, TraceEventType.EXCEPTION) + self.assertEqual(events[0].severity.value, "critical") + self.assertEqual(events[0].operation, "unit-test") + + def test_asyncio_exception_handler_records_context(self): + async def run_test(): + loop = asyncio.get_running_loop() + previous_handler = loop.get_exception_handler() + loop.set_exception_handler(lambda _loop, _context: None) + monitor = RuntimeMonitor(service_name="test-service", probes=[]) + monitor.install_asyncio_exception_handler(loop) + try: + loop.call_exception_handler({"message": "async problem"}) + finally: + loop.set_exception_handler(previous_handler) + + asyncio.run(run_test()) + + events = self.tracker.queue.drain() + self.assertEqual(events[0].event_type, TraceEventType.ERROR) + self.assertEqual(events[0].name, "runtime.asyncio.unhandled_error") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tracing_sla_tracker.py b/tests/test_tracing_sla_tracker.py new file mode 100644 index 0000000..612ef5c --- /dev/null +++ b/tests/test_tracing_sla_tracker.py @@ -0,0 +1,167 @@ +import asyncio +from unittest.mock import patch +import unittest + +from starlette.applications import Starlette +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.middleware import Middleware +from starlette.routing import Route +from starlette.testclient import TestClient + +from src.lib.tracing import ( + LocalTraceQueue, + TraceEvent, + TraceEventType, + TraceSeverity, + TraceSubscriber, + TraceTracker, +) +from src.core.agents.core.invoke_loop import run_invoke_loop +from src.lib.tracing.middleware import TraceMiddleware +from src.lib.tracing import middleware as tracing_middleware +from src.lib.tracing import workers as tracing_workers + + +class MemorySubscriber(TraceSubscriber): + def __init__(self): + self.events = [] + + def handle(self, event: TraceEvent) -> None: + self.events.append(event) + + +class TraceTrackerTests(unittest.TestCase): + def test_publish_records_event_with_tenant_and_trace_context(self): + tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + tokens = tracker.set_context(trace_id="trace-1", tenant_id="tenant-1") + try: + event = tracker.error( + "operation.failed", + service="api", + operation="GET /demo", + message="failed", + ) + finally: + tracker.reset_context(tokens) + + self.assertIsNotNone(event) + drained = tracker.queue.drain() + self.assertEqual(len(drained), 1) + self.assertEqual(drained[0].event_type, TraceEventType.ERROR) + self.assertEqual(drained[0].severity, TraceSeverity.ERROR) + self.assertEqual(drained[0].tenant_id, "tenant-1") + self.assertEqual(drained[0].trace_id, "trace-1") + self.assertEqual(drained[0].service, "api") + self.assertEqual(drained[0].operation, "GET /demo") + + def test_span_records_exception_and_sla_breach(self): + tracker = TraceTracker( + queue=LocalTraceQueue(), + enabled=True, + slow_operation_ms=0, + ) + + with self.assertRaises(ValueError): + with tracker.span("failing-span", tenant_id="brain1"): + raise ValueError("boom") + + events = tracker.queue.drain() + self.assertEqual([event.event_type for event in events], [ + TraceEventType.EXCEPTION, + TraceEventType.SLA_BREACH, + ]) + self.assertEqual(events[0].error_type, "ValueError") + self.assertIn("boom", events[0].message) + self.assertEqual(events[0].tenant_id, "brain1") + self.assertIsNotNone(events[0].stack_trace) + + def test_track_loop_records_expensive_loop_when_threshold_is_met(self): + tracker = TraceTracker( + queue=LocalTraceQueue(), + enabled=True, + expensive_loop_iterations=3, + ) + + self.assertEqual(list(tracker.track_loop("loop", range(3))), [0, 1, 2]) + + events = tracker.queue.drain() + self.assertEqual(len(events), 1) + self.assertEqual(events[0].event_type, TraceEventType.EXPENSIVE_LOOP) + self.assertEqual(events[0].metadata["iterations"], 3) + + def test_dispatch_once_sends_event_to_subscriber(self): + async def run_test(): + tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + subscriber = MemorySubscriber() + tracker.subscribe(subscriber) + event = tracker.error("error") + + dispatched = await tracker.dispatch_once() + + self.assertIs(dispatched, event) + self.assertEqual(subscriber.events, [event]) + + asyncio.run(run_test()) + + +class TraceMiddlewareTests(unittest.TestCase): + def setUp(self): + self.original_tracer = tracing_middleware.tracer + self.tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + tracing_middleware.tracer = self.tracker + + def tearDown(self): + tracing_middleware.tracer = self.original_tracer + + def test_middleware_records_server_errors_and_tenant_id(self): + async def error_route(_request): + return JSONResponse({"detail": "bad"}, status_code=503) + + app = Starlette( + routes=[Route("/bad", error_route)], + middleware=[ + Middleware( + TraceMiddleware, + service_name="test-api", + slow_request_ms=100000, + ) + ], + ) + + response = TestClient(app).get("/bad", headers={"X-Brain-ID": "tenant-a"}) + + self.assertEqual(response.status_code, 503) + events = self.tracker.queue.drain() + self.assertEqual(len(events), 1) + self.assertEqual(events[0].event_type, TraceEventType.ERROR) + self.assertEqual(events[0].tenant_id, "tenant-a") + self.assertEqual(events[0].service, "test-api") + self.assertEqual(events[0].status_code, 503) + + def test_middleware_records_sla_breach(self): + async def ok_route(_request): + return PlainTextResponse("ok") + + app = Starlette( + routes=[Route("/ok", ok_route)], + middleware=[ + Middleware( + TraceMiddleware, + service_name="test-api", + slow_request_ms=0, + ) + ], + ) + + response = TestClient(app).get("/ok?brain_id=tenant-b") + + self.assertEqual(response.status_code, 200) + events = self.tracker.queue.drain() + self.assertEqual(len(events), 1) + self.assertEqual(events[0].event_type, TraceEventType.SLA_BREACH) + self.assertEqual(events[0].tenant_id, "tenant-b") + self.assertEqual(events[0].operation, "GET /ok") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tracing_workers_and_agents.py b/tests/test_tracing_workers_and_agents.py new file mode 100644 index 0000000..4913fa8 --- /dev/null +++ b/tests/test_tracing_workers_and_agents.py @@ -0,0 +1,140 @@ +import os +import unittest +from unittest.mock import patch + +from src.lib.tracing import LocalTraceQueue, TraceEventType, TraceTracker +from src.lib.tracing import workers as tracing_workers + + +class WorkerTracingTests(unittest.TestCase): + def setUp(self): + self.original_tracer = tracing_workers.tracer + self.tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + tracing_workers.tracer = self.tracker + + def tearDown(self): + tracing_workers.tracer = self.original_tracer + tracing_workers._task_starts.clear() + tracing_workers._task_context_tokens.clear() + + def test_worker_signals_record_start_sla_and_failure(self): + tracing_workers._trace_task_prerun( + sender=type("Task", (), {"name": "demo.task"})(), + task_id="task-1", + args=({"brain_id": "tenant-1"},), + kwargs={}, + ) + tracing_workers._task_starts["task-1"] -= 2 + tracing_workers._trace_task_failure( + sender=type("Task", (), {"name": "demo.task"})(), + task_id="task-1", + exception=RuntimeError("boom"), + args=({"brain_id": "tenant-1"},), + kwargs={}, + ) + tracing_workers._trace_task_postrun( + sender=type("Task", (), {"name": "demo.task"})(), + task_id="task-1", + args=({"brain_id": "tenant-1"},), + kwargs={}, + state="FAILURE", + ) + + events = self.tracker.queue.drain() + self.assertEqual(events[0].event_type, TraceEventType.LATENCY) + self.assertEqual(events[0].tenant_id, "tenant-1") + self.assertEqual(events[1].event_type, TraceEventType.EXCEPTION) + self.assertEqual(events[1].error_type, "RuntimeError") + self.assertEqual(events[2].event_type, TraceEventType.SLA_BREACH) + self.assertEqual(events[2].operation, "demo.task") + + +class AgentLoopTracingTests(unittest.TestCase): + def test_agent_loop_records_model_exception(self): + from src.core.agents.core import invoke_loop + + tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + + class FailingModel: + def invoke(self, *_args, **_kwargs): + raise RuntimeError("model down") + + class Agent: + class Tool: + name = "demo_tool" + description = "Demo tool" + args_schema = {} + + tools = [Tool()] + output_schema = None + model = FailingModel() + thinking = False + _tools_bound = False + system_prompt = "system" + debug = False + + def _model_requires_thought_signatures(self): + return False + + with patch.object(invoke_loop, "tracer", tracker): + with self.assertRaises(RuntimeError): + invoke_loop.run_invoke_loop( + Agent(), + [{"role": "user", "content": "hello"}], + {"metadata": {"brain_id": "tenant-2", "agent": "test-agent"}}, + ) + + events = tracker.queue.drain() + event_types = [event.event_type for event in events] + self.assertIn(TraceEventType.EXCEPTION, event_types) + self.assertTrue( + any(event.name == "agent.model.invoke.failed" for event in events) + ) + self.assertTrue(any(event.name == "agent.invoke_loop" for event in events)) + self.assertTrue(all( + event.tenant_id == "tenant-2" for event in events if event.tenant_id + )) + + def test_agent_loop_records_outer_loop_threshold(self): + from src.core.agents.core import invoke_loop + + tracker = TraceTracker(queue=LocalTraceQueue(), enabled=True) + + class SimpleModel: + def invoke(self, *_args, **_kwargs): + return {"content": "done"} + + class Agent: + tools = [] + output_schema = None + model = SimpleModel() + thinking = False + _tools_bound = False + system_prompt = "system" + debug = False + _get_effective_output_schema = None + + def _model_requires_thought_signatures(self): + return False + + with patch.object(invoke_loop, "tracer", tracker): + with patch.dict(os.environ, {"TRACE_AGENT_LOOP_ITERATIONS": "1"}): + result = invoke_loop.run_invoke_loop( + Agent(), + [{"role": "user", "content": "hello"}], + {"metadata": {"brain_id": "tenant-3", "agent": "test-agent"}}, + ) + + self.assertGreaterEqual(len(result["messages"]), 1) + events = tracker.queue.drain() + self.assertTrue( + any( + event.event_type == TraceEventType.EXPENSIVE_LOOP + and event.name == "agent.invoke_loop.outer_loop" + for event in events + ) + ) + + +if __name__ == "__main__": + unittest.main()