diff --git a/app/cli/run.py b/app/cli/run.py index c7bf948..f03db3c 100644 --- a/app/cli/run.py +++ b/app/cli/run.py @@ -28,7 +28,7 @@ from app.runtime.agent.agent import Agent from app.runtime.config.settings import cfg -from app.runtime.state.guardrails_config import GuardrailsConfigStore +from app.runtime.state.guardrails import GuardrailsConfigStore from app.runtime.state.memory import get_memory from app.runtime.state.sandbox_config import SandboxConfigStore from app.runtime.state.session_store import SessionStore diff --git a/app/runtime/agent/__init__.py b/app/runtime/agent/__init__.py index 67a213d..069c153 100644 --- a/app/runtime/agent/__init__.py +++ b/app/runtime/agent/__init__.py @@ -1,3 +1,14 @@ -"""Core Copilot SDK integration -- agent, sessions, tools, and prompts.""" +"""Core Copilot SDK integration -- agent, sessions, tools, and prompts. -__all__ = ["Agent", "auto_approve", "run_one_shot"] +Public submodules (import directly): + +- ``agent.agent`` -- ``Agent``, ``MAX_START_RETRIES`` +- ``agent.aitl`` -- ``AitlReviewer`` +- ``agent.event_handler`` -- ``EventHandler`` +- ``agent.hitl`` -- ``HitlInterceptor`` +- ``agent.one_shot`` -- ``run_one_shot``, ``auto_approve`` +- ``agent.phone_verify`` -- ``PhoneVerifier`` +- ``agent.policy_bridge`` -- ``build_engine``, ``config_to_yaml``, ... +- ``agent.prompt`` -- ``build_system_prompt``, ``load_soul``, ``TEMPLATES_DIR`` +- ``agent.tools`` -- ``get_all_tools``, ``ALL_TOOLS``, tool functions +""" diff --git a/app/runtime/agent/agent.py b/app/runtime/agent/agent.py index a51f7e9..0b42cce 100644 --- a/app/runtime/agent/agent.py +++ b/app/runtime/agent/agent.py @@ -14,7 +14,7 @@ from ..config.settings import cfg from ..sandbox import SandboxExecutor, SandboxToolInterceptor from ..services.otel import invoke_agent_span, set_span_attribute -from ..state.guardrails_config import GuardrailsConfigStore +from ..state.guardrails import GuardrailsConfigStore from ..state.mcp_config import McpConfigStore from .event_handler import EventHandler from .hitl import HitlInterceptor @@ -352,26 +352,20 @@ async def list_models(self) -> list[dict]: logger.warning("Failed to list models: %s", exc) return [] - def _build_session_config(self) -> dict[str, Any]: - sandbox_active = self._interceptor and self._sandbox and self._sandbox.enabled - # Always register the HITL hook when an interceptor exists so that - # guardrails config changes (enable/disable) take effect without - # requiring a session restart. The hook itself checks hitl_enabled - # at call time via resolve_action() which returns "allow" when off. - hitl_available = self._hitl is not None - - logger.info( - "[agent.config] building session config: " - "sandbox_active=%s hitl_available=%s hitl_enabled=%s", - sandbox_active, hitl_available, - self._guardrails.hitl_enabled if self._guardrails else "(no store)", + def _build_hooks(self) -> dict[str, Any]: + """Compose pre/post-tool-use hooks from active interceptors.""" + sandbox_active = ( + self._interceptor and self._sandbox and self._sandbox.enabled ) + hitl_available = self._hitl is not None if sandbox_active and hitl_available: hitl = self._hitl sandbox = self._interceptor - async def chained_pre_tool_use(input_data: dict, invocation: Any) -> dict: + async def chained_pre_tool_use( + input_data: dict, invocation: Any, + ) -> dict: logger.info( "[agent.hook] chained_pre_tool_use called: tool=%s", input_data.get("toolName", "?"), @@ -380,7 +374,9 @@ async def chained_pre_tool_use(input_data: dict, invocation: Any) -> dict: if result.get("permissionDecision") != "allow": logger.info("[agent.hook] hitl denied, skipping sandbox") return result - logger.info("[agent.hook] hitl allowed, proceeding to sandbox") + logger.info( + "[agent.hook] hitl allowed, proceeding to sandbox", + ) return await sandbox.on_pre_tool_use(input_data, invocation) hooks: dict[str, Any] = { @@ -399,30 +395,66 @@ async def chained_pre_tool_use(input_data: dict, invocation: Any) -> dict: logger.info("[agent.config] hooks: sandbox only") else: hooks = {"on_pre_tool_use": auto_approve} - logger.info("[agent.config] hooks: auto_approve (no hitl, no sandbox)") + logger.info( + "[agent.config] hooks: auto_approve (no hitl, no sandbox)", + ) + + return hooks + + def _build_session_config(self) -> dict[str, Any]: + """Assemble the full session configuration for the Copilot SDK.""" + sandbox_active = ( + self._interceptor and self._sandbox and self._sandbox.enabled + ) + logger.info( + "[agent.config] building session config: " + "sandbox_active=%s hitl_available=%s hitl_enabled=%s", + sandbox_active, self._hitl is not None, + self._guardrails.hitl_enabled if self._guardrails else "(no store)", + ) session_cfg: dict[str, Any] = { "model": cfg.copilot_model, "streaming": True, "tools": get_all_tools(), - "system_message": {"mode": "replace", "content": build_system_prompt()}, - "hooks": hooks, - "skill_directories": [str(cfg.builtin_skills_dir), str(cfg.user_skills_dir)], + "system_message": { + "mode": "replace", + "content": build_system_prompt(), + }, + "hooks": self._build_hooks(), + "skill_directories": [ + str(cfg.builtin_skills_dir), + str(cfg.user_skills_dir), + ], } if sandbox_active: - session_cfg["excluded_tools"] = ["create", "view", "edit", "grep", "glob"] + session_cfg["excluded_tools"] = [ + "create", "view", "edit", "grep", "glob", + ] try: - session_cfg["mcp_servers"] = McpConfigStore().get_enabled_servers() + session_cfg["mcp_servers"] = ( + McpConfigStore().get_enabled_servers() + ) except Exception: - logger.warning("Failed to load MCP config, using defaults", exc_info=True) + logger.warning( + "Failed to load MCP config, using defaults", + exc_info=True, + ) session_cfg["mcp_servers"] = { "playwright": { "type": "local", "command": "npx", - "args": ["-y", "@playwright/mcp@latest", "--browser", "chromium", "--headless", "--isolated"], - "env": {"PLAYWRIGHT_CHROMIUM_ARGS": "--no-sandbox --disable-setuid-sandbox"}, + "args": [ + "-y", "@playwright/mcp@latest", + "--browser", "chromium", + "--headless", "--isolated", + ], + "env": { + "PLAYWRIGHT_CHROMIUM_ARGS": + "--no-sandbox --disable-setuid-sandbox", + }, "tools": ["*"], }, } diff --git a/app/runtime/agent/hitl.py b/app/runtime/agent/hitl.py index 6cf3bad..4691789 100644 --- a/app/runtime/agent/hitl.py +++ b/app/runtime/agent/hitl.py @@ -8,56 +8,87 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any -from ..state.guardrails_config import GuardrailsConfigStore -from ..util.async_helpers import run_sync +from ..state.guardrails import GuardrailsConfigStore +from .hitl_channels import ( + apply_aitl_review, + apply_filter_check, + ask_bot_approval, + ask_chat_approval, + ask_phone_approval, +) if TYPE_CHECKING: - from ..services.prompt_shield import PromptShieldService + from ..services.security.prompt_shield import PromptShieldService from ..state.tool_activity_store import ToolActivityStore from .aitl import AitlReviewer from .phone_verify import PhoneVerifier logger = logging.getLogger(__name__) -_APPROVAL_TIMEOUT = 300.0 - _ALWAYS_APPROVED_TOOLS: frozenset[str] = frozenset({"report_intent"}) +_ALLOW: dict[str, str] = {"permissionDecision": "allow"} +_DENY: dict[str, str] = {"permissionDecision": "deny"} + class HitlInterceptor: + """Human-in-the-loop tool approval interceptor. + + Per-turn state (emit, model, session context) is bound via + :meth:`bind_turn` and released via :meth:`unbind_turn`. Persistent + wiring (phone verifier, AITL reviewer, prompt shield) is set once + during application startup. + """ def __init__(self, guardrails: GuardrailsConfigStore) -> None: self._guardrails = guardrails + + # -- per-turn state (bound/unbound each agent turn) ---------------- self._emit: Callable[[str, dict[str, Any]], None] | None = None self._bot_reply_fn: Callable[[str], Awaitable[None]] | None = None self._execution_context: str = "" self._model: str = "" self._session_id: str = "" + self._tool_activity: ToolActivityStore | None = None + + # -- persistent state ---------------------------------------------- self._pending: dict[str, asyncio.Future[bool]] = {} self._phone_verifier: PhoneVerifier | None = None self._aitl_reviewer: AitlReviewer | None = None self._prompt_shield: PromptShieldService | None = None - self._tool_activity: ToolActivityStore | None = None self._resolved_strategies: dict[str, list[str]] = {} self._last_shield_result: dict[str, Any] | None = None - def set_emit(self, emit: Callable[[str, dict[str, Any]], None]) -> None: + # -- per-turn lifecycle ------------------------------------------------ + + def bind_turn( + self, + *, + emit: Callable[[str, dict[str, Any]], None] | None = None, + bot_reply_fn: Callable[[str], Awaitable[None]] | None = None, + execution_context: str = "", + model: str = "", + session_id: str = "", + tool_activity: ToolActivityStore | None = None, + ) -> None: + """Bind per-turn state before an agent send.""" self._emit = emit + self._bot_reply_fn = bot_reply_fn + self._execution_context = execution_context + self._model = model + self._session_id = session_id + self._tool_activity = tool_activity - def clear_emit(self) -> None: + def unbind_turn(self) -> None: + """Clear per-turn state after an agent send completes.""" self._emit = None - - def set_bot_reply_fn(self, fn: Callable[[str], Awaitable[None]]) -> None: - self._bot_reply_fn = fn - - def clear_bot_reply_fn(self) -> None: self._bot_reply_fn = None + self._execution_context = "" + self._model = "" + self._session_id = "" + self._tool_activity = None - def set_execution_context(self, context: str) -> None: - self._execution_context = context - - def set_model(self, model: str) -> None: - self._model = model + # -- persistent wiring ------------------------------------------------- def set_phone_verifier(self, verifier: PhoneVerifier) -> None: self._phone_verifier = verifier @@ -68,12 +99,6 @@ def set_aitl_reviewer(self, reviewer: AitlReviewer) -> None: def set_prompt_shield(self, shield: PromptShieldService) -> None: self._prompt_shield = shield - def set_tool_activity(self, store: ToolActivityStore) -> None: - self._tool_activity = store - - def set_session_id(self, session_id: str) -> None: - self._session_id = session_id - def pop_resolved_strategy(self, tool_name: str) -> str: queue = self._resolved_strategies.get(tool_name) if not queue: @@ -136,6 +161,7 @@ async def on_pre_tool_use(self, input_data: dict, invocation: Any) -> dict: return result async def _evaluate_tool(self, input_data: dict, tool_name: str) -> dict: + """Evaluate a tool invocation against the guardrails policy.""" self._last_shield_result = None call_id = input_data.get("toolCallId") or str(uuid.uuid4())[:8] @@ -151,10 +177,6 @@ async def _evaluate_tool(self, input_data: dict, tool_name: str) -> dict: self._model, mcp_server or "(none)", self._guardrails.hitl_enabled, ) - logger.info( - "[hitl.hook] input_data keys=%s", - list(input_data.keys()), - ) strategy = self._guardrails.resolve_action( tool_name, @@ -167,22 +189,20 @@ async def _evaluate_tool(self, input_data: dict, tool_name: str) -> dict: strategy, tool_name, ) + # Terminal strategies if strategy == "allow": logger.info("[hitl.hook] ALLOW tool=%s call_id=%s", tool_name, call_id) - return {"permissionDecision": "allow"} + return _ALLOW if strategy == "deny": - logger.info("[hitl.hook] DENY tool=%s call_id=%s", tool_name, call_id) - self._resolved_strategies.setdefault(tool_name, []).append("deny") - if self._emit: - self._emit("tool_denied", { - "call_id": call_id, - "tool": tool_name, - "reason": "Denied by guardrail rule", - }) - return {"permissionDecision": "deny"} - - if self._prompt_shield and self._prompt_shield.configured and strategy != "filter": + return self._make_deny(call_id, tool_name) + + # Pre-filter: run Prompt Shield before non-filter strategies + if ( + self._prompt_shield + and self._prompt_shield.configured + and strategy != "filter" + ): shield_result = await self._apply_filter(call_id, tool_name, args_str) if shield_result is not None: logger.info( @@ -192,53 +212,117 @@ async def _evaluate_tool(self, input_data: dict, tool_name: str) -> dict: ) return shield_result - if strategy == "aitl": - self._resolved_strategies.setdefault(tool_name, []).append("aitl") - if self._aitl_reviewer: - result = await self._apply_aitl(call_id, tool_name, args_str) - if result is not None: - return result - logger.warning( - "[hitl] AITL requested but unavailable, falling back to interactive: tool=%s", - tool_name, - ) + # Strategy-specific handler + result = await self._dispatch_strategy( + strategy, call_id, tool_name, args_str, + ) + if result is not None: + return result - if strategy == "filter": - self._resolved_strategies.setdefault(tool_name, []).append("filter") - if self._prompt_shield: - result = await self._apply_filter(call_id, tool_name, args_str) - if result is not None: - return result - logger.info( - "[hitl.hook] shield passed, ALLOW tool=%s call_id=%s", - tool_name, call_id, - ) - return {"permissionDecision": "allow"} - logger.warning( - "[hitl] no prompt shield available, allowing tool=%s (Content Safety not deployed)", - tool_name, - ) - self._last_shield_result = { - "result": "skipped", - "detail": "Content Safety not deployed", - "elapsed_ms": None, - } - return {"permissionDecision": "allow"} + # Fallback: interactive approval + return await self._route_interactive( + call_id, tool_name, args_str, mcp_server, + ) + def _make_deny(self, call_id: str, tool_name: str) -> dict: + """Build a deny response and emit an event.""" + logger.info("[hitl.hook] DENY tool=%s call_id=%s", tool_name, call_id) + self._resolved_strategies.setdefault(tool_name, []).append("deny") + if self._emit: + self._emit("tool_denied", { + "call_id": call_id, + "tool": tool_name, + "reason": "Denied by guardrail rule", + }) + return dict(_DENY) + + async def _dispatch_strategy( + self, + strategy: str, + call_id: str, + tool_name: str, + args_str: str, + ) -> dict | None: + """Delegate to a strategy-specific handler. + + Returns a decision dict, or ``None`` to fall through to + interactive approval. + """ + if strategy == "aitl": + return await self._handle_aitl(call_id, tool_name, args_str) + if strategy == "filter": + return await self._handle_filter(call_id, tool_name, args_str) if strategy == "pitl": - self._resolved_strategies.setdefault(tool_name, []).append("pitl") - if self._phone_verifier: - logger.info("[hitl.hook] PITL routing to phone: tool=%s", tool_name) - return await self._ask_phone(call_id, tool_name, args_str) - logger.warning( - "[hitl] PITL requested but phone verifier unavailable, " - "falling back to chat: tool=%s", tool_name, + return await self._handle_pitl(call_id, tool_name, args_str) + return None + + async def _handle_aitl( + self, call_id: str, tool_name: str, args_str: str, + ) -> dict | None: + """AI-in-the-loop review.""" + self._resolved_strategies.setdefault(tool_name, []).append("aitl") + if self._aitl_reviewer: + return await self._apply_aitl(call_id, tool_name, args_str) + logger.warning( + "[hitl] AITL requested but unavailable, " + "falling back to interactive: tool=%s", + tool_name, + ) + return None + + async def _handle_filter( + self, call_id: str, tool_name: str, args_str: str, + ) -> dict | None: + """Content-safety filter.""" + self._resolved_strategies.setdefault(tool_name, []).append("filter") + if self._prompt_shield: + result = await self._apply_filter(call_id, tool_name, args_str) + if result is not None: + return result + logger.info( + "[hitl.hook] shield passed, ALLOW tool=%s call_id=%s", + tool_name, call_id, ) + return dict(_ALLOW) + logger.warning( + "[hitl] no prompt shield available, allowing tool=%s " + "(Content Safety not deployed)", + tool_name, + ) + self._last_shield_result = { + "result": "skipped", + "detail": "Content Safety not deployed", + "elapsed_ms": None, + } + return dict(_ALLOW) + + async def _handle_pitl( + self, call_id: str, tool_name: str, args_str: str, + ) -> dict | None: + """Phone-in-the-loop verification.""" + self._resolved_strategies.setdefault(tool_name, []).append("pitl") + if self._phone_verifier: + logger.info("[hitl.hook] PITL routing to phone: tool=%s", tool_name) + return await self._ask_phone(call_id, tool_name, args_str) + logger.warning( + "[hitl] PITL requested but phone verifier unavailable, " + "falling back to chat: tool=%s", + tool_name, + ) + return None + async def _route_interactive( + self, + call_id: str, + tool_name: str, + args_str: str, + mcp_server: str, + ) -> dict: + """Route to the best available interactive approval channel.""" logger.info( - "[hitl.hook] interactive approval needed: tool=%s strategy=%s " + "[hitl.hook] interactive approval needed: tool=%s " "has_emit=%s has_bot_reply=%s has_phone=%s", - tool_name, strategy, + tool_name, self._emit is not None, self._bot_reply_fn is not None, self._phone_verifier is not None, @@ -251,26 +335,32 @@ async def _evaluate_tool(self, input_data: dict, tool_name: str) -> dict: ) if channel == "phone" and self._phone_verifier: - logger.info("[hitl.hook] routing to phone channel: tool=%s", tool_name) + logger.info( + "[hitl.hook] routing to phone channel: tool=%s", tool_name, + ) self._resolved_strategies.setdefault(tool_name, []).append("pitl") return await self._ask_phone(call_id, tool_name, args_str) if self._bot_reply_fn: - logger.info("[hitl.hook] routing to bot channel: tool=%s", tool_name) + logger.info( + "[hitl.hook] routing to bot channel: tool=%s", tool_name, + ) self._resolved_strategies.setdefault(tool_name, []).append("hitl") return await self._ask_bot_channel(call_id, tool_name, args_str) if self._emit: - logger.info("[hitl.hook] routing to web chat: tool=%s", tool_name) + logger.info( + "[hitl.hook] routing to web chat: tool=%s", tool_name, + ) self._resolved_strategies.setdefault(tool_name, []).append("hitl") return await self._ask_chat(call_id, tool_name, args_str) logger.error( - "[hitl.hook] NO APPROVAL CHANNEL available (no bot_reply_fn, " - "no emit) -- denying tool=%s call_id=%s to avoid silent hang", + "[hitl.hook] NO APPROVAL CHANNEL available -- " + "denying tool=%s call_id=%s to avoid silent hang", tool_name, call_id, ) - return {"permissionDecision": "deny"} + return dict(_DENY) def resolve_approval(self, call_id: str, approved: bool) -> bool: future = self._pending.get(call_id) @@ -299,211 +389,55 @@ async def _ask_chat(self, call_id: str, tool_name: str, args_str: str) -> dict: "denying tool=%s immediately", tool_name, ) return {"permissionDecision": "deny"} - - logger.info( - "[hitl.chat] sending approval_request via WebSocket: " - "tool=%s call_id=%s", - tool_name, call_id, - ) - self._emit("approval_request", { - "call_id": call_id, - "tool": tool_name, - "arguments": args_str, - }) - logger.info("[hitl.chat] approval_request emitted, waiting for response...") - - loop = asyncio.get_running_loop() - future: asyncio.Future[bool] = loop.create_future() - self._pending[call_id] = future - - try: - approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT) - except asyncio.TimeoutError: - logger.warning("[hitl] approval timed out: call_id=%s tool=%s", call_id, tool_name) - approved = False - finally: - self._pending.pop(call_id, None) - - decision = "allow" if approved else "deny" - logger.info( - "[hitl.chat] decision: tool=%s call_id=%s approved=%s decision=%s", - tool_name, call_id, approved, decision, + return await ask_chat_approval( + emit=self._emit, + pending=self._pending, + call_id=call_id, + tool_name=tool_name, + args_str=args_str, ) - if self._emit: - self._emit("approval_resolved", { - "call_id": call_id, - "tool": tool_name, - "approved": approved, - }) - return {"permissionDecision": decision} async def _ask_bot_channel( self, call_id: str, tool_name: str, args_str: str, ) -> dict: assert self._bot_reply_fn is not None - truncated = args_str if len(args_str) <= 200 else args_str[:197] + "..." - confirmation_msg = ( - f"The agent wants to use the tool **{tool_name}**.\n\n" - f"Arguments: `{truncated}`\n\n" - f"Reply **y** to approve or anything else to deny." - ) - logger.info( - "[hitl] bot-channel approval request: tool=%s call_id=%s", - tool_name, call_id, - ) - try: - await self._bot_reply_fn(confirmation_msg) - except Exception: - logger.exception("[hitl] failed to send bot approval message: call_id=%s", call_id) - return {"permissionDecision": "deny"} - - loop = asyncio.get_running_loop() - future: asyncio.Future[bool] = loop.create_future() - self._pending[call_id] = future - - try: - approved = await asyncio.wait_for(future, timeout=_APPROVAL_TIMEOUT) - except asyncio.TimeoutError: - logger.warning("[hitl] bot approval timed out: call_id=%s tool=%s", call_id, tool_name) - approved = False - finally: - self._pending.pop(call_id, None) - - decision = "allow" if approved else "deny" - logger.info( - "[hitl] bot-channel decision: tool=%s call_id=%s decision=%s", - tool_name, call_id, decision, - ) - - outcome_msg = ( - f"Tool **{tool_name}** {'approved' if approved else 'denied'}." + return await ask_bot_approval( + bot_reply_fn=self._bot_reply_fn, + pending=self._pending, + call_id=call_id, + tool_name=tool_name, + args_str=args_str, ) - try: - await self._bot_reply_fn(outcome_msg) - except Exception: - logger.exception("[hitl] failed to send bot outcome message: call_id=%s", call_id) - - return {"permissionDecision": decision} async def _ask_phone(self, call_id: str, tool_name: str, args_str: str) -> dict: assert self._phone_verifier is not None - logger.info("[hitl] phone verification: tool=%s call_id=%s", tool_name, call_id) - - if self._emit: - self._emit("phone_verification_started", { - "call_id": call_id, - "tool": tool_name, - "arguments": args_str, - }) - - try: - approved = await self._phone_verifier.request_verification( - call_id=call_id, - tool_name=tool_name, - tool_args=args_str, - ) - except Exception: - logger.exception("[hitl] phone verification failed: call_id=%s", call_id) - approved = False - - decision = "allow" if approved else "deny" - logger.info("[hitl] phone decision: tool=%s call_id=%s decision=%s", tool_name, call_id, decision) - - if self._emit: - self._emit("phone_verification_complete", { - "call_id": call_id, - "tool": tool_name, - "approved": approved, - }) - - return {"permissionDecision": decision} + return await ask_phone_approval( + phone_verifier=self._phone_verifier, + emit=self._emit, + call_id=call_id, + tool_name=tool_name, + args_str=args_str, + ) async def _apply_aitl(self, call_id: str, tool_name: str, args_str: str) -> dict | None: assert self._aitl_reviewer is not None - if self._emit: - self._emit("aitl_review_started", { - "call_id": call_id, - "tool": tool_name, - }) - try: - approved, reason = await self._aitl_reviewer.review( - tool_name=tool_name, - arguments=args_str, - ) - except Exception: - logger.exception("[hitl] AITL review error: call_id=%s", call_id) - return None - - if self._emit: - self._emit("aitl_review_complete", { - "call_id": call_id, - "tool": tool_name, - "approved": approved, - "reason": reason, - }) - - decision = "allow" if approved else "deny" - logger.info( - "[hitl] AITL decision: tool=%s call_id=%s decision=%s reason=%s", - tool_name, call_id, decision, reason, + return await apply_aitl_review( + aitl_reviewer=self._aitl_reviewer, + emit=self._emit, + call_id=call_id, + tool_name=tool_name, + args_str=args_str, ) - return {"permissionDecision": decision} async def _apply_filter(self, call_id: str, tool_name: str, args_str: str) -> dict | None: assert self._prompt_shield is not None - import time as _time - - t0 = _time.monotonic() - try: - result = await run_sync(self._prompt_shield.check, args_str) - except Exception: - elapsed_ms = (_time.monotonic() - t0) * 1000 - logger.exception("[hitl] Prompt Shield error: call_id=%s", call_id) - self._last_shield_result = { - "result": "error", - "detail": "Shield check raised an exception", - "elapsed_ms": round(elapsed_ms, 1), - } - if self._tool_activity: - self._tool_activity.update_shield_result( - call_id=call_id, shield_result="error", - shield_detail="Shield check raised an exception", - shield_elapsed_ms=round(elapsed_ms, 1), - ) - return None - - elapsed_ms = (_time.monotonic() - t0) * 1000 - shield_status = "attack" if result.attack_detected else "clean" - self._last_shield_result = { - "result": shield_status, - "detail": result.detail, - "elapsed_ms": round(elapsed_ms, 1), - } - - if self._tool_activity: - self._tool_activity.update_shield_result( - call_id=call_id, - shield_result=shield_status, - shield_detail=result.detail, - shield_elapsed_ms=round(elapsed_ms, 1), - ) - - if result.attack_detected: - logger.info( - "[hitl] Prompt Shield denied: tool=%s call_id=%s detail=%s elapsed=%.0fms", - tool_name, call_id, result.detail, elapsed_ms, - ) - if self._emit: - self._emit("tool_denied", { - "call_id": call_id, - "tool": tool_name, - "reason": "Blocked by content filter", - "shield_detail": result.detail, - }) - return {"permissionDecision": "deny"} - - logger.info( - "[hitl] Prompt Shield passed: tool=%s call_id=%s elapsed=%.0fms", - tool_name, call_id, elapsed_ms, + decision, shield_info = await apply_filter_check( + prompt_shield=self._prompt_shield, + tool_activity=self._tool_activity, + emit=self._emit, + call_id=call_id, + tool_name=tool_name, + args_str=args_str, ) - return None + self._last_shield_result = shield_info + return decision diff --git a/app/runtime/agent/hitl_channels.py b/app/runtime/agent/hitl_channels.py new file mode 100644 index 0000000..f5bc063 --- /dev/null +++ b/app/runtime/agent/hitl_channels.py @@ -0,0 +1,275 @@ +"""Approval-channel implementations for the HITL interceptor.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from ..util.async_helpers import run_sync + +if TYPE_CHECKING: + from ..services.security.prompt_shield import PromptShieldService + from ..state.tool_activity_store import ToolActivityStore + from .aitl import AitlReviewer + from .phone_verify import PhoneVerifier + +logger = logging.getLogger(__name__) + +_APPROVAL_TIMEOUT = 300.0 + + +async def ask_chat_approval( + *, + emit: Callable[[str, dict[str, Any]], None], + pending: dict[str, asyncio.Future[bool]], + call_id: str, + tool_name: str, + args_str: str, + timeout: float = _APPROVAL_TIMEOUT, +) -> dict[str, str]: + """Request approval via the WebSocket chat channel.""" + logger.info( + "[hitl.chat] sending approval_request via WebSocket: " + "tool=%s call_id=%s", + tool_name, call_id, + ) + emit("approval_request", { + "call_id": call_id, + "tool": tool_name, + "arguments": args_str, + }) + logger.info("[hitl.chat] approval_request emitted, waiting for response...") + + loop = asyncio.get_running_loop() + future: asyncio.Future[bool] = loop.create_future() + pending[call_id] = future + + try: + approved = await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + logger.warning("[hitl] approval timed out: call_id=%s tool=%s", call_id, tool_name) + approved = False + finally: + pending.pop(call_id, None) + + decision = "allow" if approved else "deny" + logger.info( + "[hitl.chat] decision: tool=%s call_id=%s approved=%s decision=%s", + tool_name, call_id, approved, decision, + ) + emit("approval_resolved", { + "call_id": call_id, + "tool": tool_name, + "approved": approved, + }) + return {"permissionDecision": decision} + + +async def ask_bot_approval( + *, + bot_reply_fn: Callable[[str], Awaitable[None]], + pending: dict[str, asyncio.Future[bool]], + call_id: str, + tool_name: str, + args_str: str, + timeout: float = _APPROVAL_TIMEOUT, +) -> dict[str, str]: + """Request approval via a messaging-bot reply channel.""" + truncated = args_str if len(args_str) <= 200 else args_str[:197] + "..." + confirmation_msg = ( + f"The agent wants to use the tool **{tool_name}**.\n\n" + f"Arguments: `{truncated}`\n\n" + f"Reply **y** to approve or anything else to deny." + ) + logger.info( + "[hitl] bot-channel approval request: tool=%s call_id=%s", + tool_name, call_id, + ) + try: + await bot_reply_fn(confirmation_msg) + except Exception: + logger.exception("[hitl] failed to send bot approval message: call_id=%s", call_id) + return {"permissionDecision": "deny"} + + loop = asyncio.get_running_loop() + future: asyncio.Future[bool] = loop.create_future() + pending[call_id] = future + + try: + approved = await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + logger.warning("[hitl] bot approval timed out: call_id=%s tool=%s", call_id, tool_name) + approved = False + finally: + pending.pop(call_id, None) + + decision = "allow" if approved else "deny" + logger.info( + "[hitl] bot-channel decision: tool=%s call_id=%s decision=%s", + tool_name, call_id, decision, + ) + + outcome_msg = ( + f"Tool **{tool_name}** {'approved' if approved else 'denied'}." + ) + try: + await bot_reply_fn(outcome_msg) + except Exception: + logger.exception("[hitl] failed to send bot outcome message: call_id=%s", call_id) + + return {"permissionDecision": decision} + + +async def ask_phone_approval( + *, + phone_verifier: PhoneVerifier, + emit: Callable[[str, dict[str, Any]], None] | None, + call_id: str, + tool_name: str, + args_str: str, +) -> dict[str, str]: + """Request approval via phone verification.""" + logger.info("[hitl] phone verification: tool=%s call_id=%s", tool_name, call_id) + + if emit: + emit("phone_verification_started", { + "call_id": call_id, + "tool": tool_name, + "arguments": args_str, + }) + + try: + approved = await phone_verifier.request_verification( + call_id=call_id, + tool_name=tool_name, + tool_args=args_str, + ) + except Exception: + logger.exception("[hitl] phone verification failed: call_id=%s", call_id) + approved = False + + decision = "allow" if approved else "deny" + logger.info("[hitl] phone decision: tool=%s call_id=%s decision=%s", tool_name, call_id, decision) + + if emit: + emit("phone_verification_complete", { + "call_id": call_id, + "tool": tool_name, + "approved": approved, + }) + + return {"permissionDecision": decision} + + +async def apply_aitl_review( + *, + aitl_reviewer: AitlReviewer, + emit: Callable[[str, dict[str, Any]], None] | None, + call_id: str, + tool_name: str, + args_str: str, +) -> dict[str, str] | None: + """Run an AI-in-the-loop review. Returns decision or ``None`` on error.""" + if emit: + emit("aitl_review_started", { + "call_id": call_id, + "tool": tool_name, + }) + try: + approved, reason = await aitl_reviewer.review( + tool_name=tool_name, + arguments=args_str, + ) + except Exception: + logger.exception("[hitl] AITL review error: call_id=%s", call_id) + return None + + if emit: + emit("aitl_review_complete", { + "call_id": call_id, + "tool": tool_name, + "approved": approved, + "reason": reason, + }) + + decision = "allow" if approved else "deny" + logger.info( + "[hitl] AITL decision: tool=%s call_id=%s decision=%s reason=%s", + tool_name, call_id, decision, reason, + ) + return {"permissionDecision": decision} + + +async def apply_filter_check( + *, + prompt_shield: PromptShieldService, + tool_activity: ToolActivityStore | None, + emit: Callable[[str, dict[str, Any]], None] | None, + call_id: str, + tool_name: str, + args_str: str, +) -> tuple[dict[str, str] | None, dict[str, Any]]: + """Run a Prompt Shield content-safety check. + + Returns ``(decision | None, shield_result_info)``. When ``decision`` + is ``None`` the content passed the filter and the caller should + continue with the next step. + """ + import time as _time + + t0 = _time.monotonic() + try: + result = await run_sync(prompt_shield.check, args_str) + except Exception: + elapsed_ms = (_time.monotonic() - t0) * 1000 + logger.exception("[hitl] Prompt Shield error: call_id=%s", call_id) + shield_info: dict[str, Any] = { + "result": "error", + "detail": "Shield check raised an exception", + "elapsed_ms": round(elapsed_ms, 1), + } + if tool_activity: + tool_activity.update_shield_result( + call_id=call_id, shield_result="error", + shield_detail="Shield check raised an exception", + shield_elapsed_ms=round(elapsed_ms, 1), + ) + return None, shield_info + + elapsed_ms = (_time.monotonic() - t0) * 1000 + shield_status = "attack" if result.attack_detected else "clean" + shield_info = { + "result": shield_status, + "detail": result.detail, + "elapsed_ms": round(elapsed_ms, 1), + } + + if tool_activity: + tool_activity.update_shield_result( + call_id=call_id, + shield_result=shield_status, + shield_detail=result.detail, + shield_elapsed_ms=round(elapsed_ms, 1), + ) + + if result.attack_detected: + logger.info( + "[hitl] Prompt Shield denied: tool=%s call_id=%s detail=%s elapsed=%.0fms", + tool_name, call_id, result.detail, elapsed_ms, + ) + if emit: + emit("tool_denied", { + "call_id": call_id, + "tool": tool_name, + "reason": "Blocked by content filter", + "shield_detail": result.detail, + }) + return {"permissionDecision": "deny"}, shield_info + + logger.info( + "[hitl] Prompt Shield passed: tool=%s call_id=%s elapsed=%.0fms", + tool_name, call_id, elapsed_ms, + ) + return None, shield_info diff --git a/app/runtime/agent/prompt.py b/app/runtime/agent/prompt.py index 9b2c146..6a1c2d4 100644 --- a/app/runtime/agent/prompt.py +++ b/app/runtime/agent/prompt.py @@ -6,11 +6,11 @@ from ..config.settings import cfg -_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" +TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" def _load_template(name: str) -> str: - return (_TEMPLATES_DIR / name).read_text() + return (TEMPLATES_DIR / name).read_text() def load_soul() -> str: diff --git a/app/runtime/agent/tools.py b/app/runtime/agent/tools.py deleted file mode 100644 index dd724f8..0000000 --- a/app/runtime/agent/tools.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Custom tools exposed to the Copilot agent.""" - -from __future__ import annotations - -import json -import logging -import threading -import urllib.error -import urllib.request - -from copilot import define_tool -from pydantic import BaseModel, Field - -from ..config.settings import cfg -from ..messaging.cards import CARD_TOOLS - -logger = logging.getLogger(__name__) - - -class ScheduleTaskParams(BaseModel): - description: str = Field(description="Human-readable description of the task") - prompt: str = Field(description="The prompt to send to the agent when this task fires") - cron: str | None = Field( - default=None, - description=( - "Cron expression for recurring tasks (minute hour day month weekday). " - "Minimum interval is every 1 hour. " - "Example: '0 9 * * *' for every day at 09:00 UTC." - ), - ) - run_at: str | None = Field( - default=None, - description="ISO datetime for one-shot tasks (e.g. '2026-02-07T14:00:00')", - ) - - -class CancelTaskParams(BaseModel): - task_id: str = Field(description="ID of the scheduled task to cancel") - - -class MakeCallParams(BaseModel): - prompt: str | None = Field( - default=None, - description="Optional custom prompt / instructions for the voice AI agent.", - ) - opening_message: str | None = Field( - default=None, - description="Optional opening message the AI should speak when the call connects.", - ) - - -class SearchMemoriesParams(BaseModel): - query: str = Field( - description="Natural language search query to find relevant memories.", - ) - top: int = Field(default=5, description="Maximum number of results to return (1-10).") - - -@define_tool( - description=( - "Schedule a future task. Provide either a cron expression for recurring " - "tasks (minimum every 1 hour) or a run_at datetime for one-shot tasks." - ) -) -def schedule_task(params: ScheduleTaskParams) -> dict: - from ..scheduler import get_scheduler - - scheduler = get_scheduler() - logger.info( - "[schedule_task] called: desc=%r, cron=%r, run_at=%r, prompt=%r", - params.description, params.cron, params.run_at, params.prompt[:80] if params.prompt else None, - ) - try: - task = scheduler.add( - description=params.description, - prompt=params.prompt, - cron=params.cron, - run_at=params.run_at, - ) - logger.info( - "[schedule_task] created task id=%s, run_at=%s, cron=%s, notify_cb=%s", - task.id, task.run_at, task.cron, - "SET" if scheduler._notify else "NOT SET", - ) - return {"id": task.id, "description": task.description, "status": "scheduled"} - except ValueError as exc: - logger.warning("[schedule_task] rejected: %s", exc) - return {"error": str(exc)} - - -@define_tool(description="Cancel a scheduled task by ID.") -def cancel_task(params: CancelTaskParams) -> str: - from ..scheduler import get_scheduler - - scheduler = get_scheduler() - return f"Task {params.task_id} cancelled." if scheduler.remove(params.task_id) else f"Task {params.task_id} not found." - - -@define_tool(description="List all scheduled tasks with their ID, description, schedule, and status.") -def list_scheduled_tasks() -> list[dict]: - from ..scheduler import get_scheduler - - return [ - { - "id": t.id, - "description": t.description, - "cron": t.cron, - "run_at": t.run_at, - "enabled": t.enabled, - "last_run": t.last_run, - } - for t in get_scheduler().list_tasks() - ] - - -@define_tool( - description=( - "Initiate an outbound voice call to the user. ALWAYS call this tool " - "when the user asks to be called -- the target phone number is managed " - "internally and you do not need to ask the user for it." - ) -) -def make_voice_call(params: MakeCallParams) -> dict: - target = cfg.voice_target_number - if not target: - return { - "status": "error", - "message": ( - "No target phone number configured yet. " - "Ask the user to run: /phone (e.g. /phone +14155551234)" - ), - } - url = f"http://127.0.0.1:{cfg.admin_port}/api/voice/call" - body: dict[str, str] = {"number": target} - if params.prompt: - body["prompt"] = params.prompt - if params.opening_message: - body["opening_message"] = params.opening_message - payload = json.dumps(body).encode("utf-8") - headers: dict[str, str] = {"Content-Type": "application/json"} - if cfg.admin_secret: - headers["Authorization"] = f"Bearer {cfg.admin_secret}" - req = urllib.request.Request(url, data=payload, headers=headers, method="POST") - - def _fire() -> None: - try: - with urllib.request.urlopen(req, timeout=30) as resp: - logger.info("Voice call API responded: %s", resp.read().decode()[:200]) - except Exception as exc: - logger.error("Voice call API request failed: %s", exc) - - threading.Thread(target=_fire, daemon=True).start() - return {"status": "ok", "message": "Call triggered"} - - -@define_tool( - description=( - "Search through indexed memories using Azure AI Search with vector " - "embeddings. Only works when Foundry IQ is enabled." - ) -) -def search_memories_tool(params: SearchMemoriesParams) -> dict: - from ..services.foundry_iq import search_memories - from ..state.foundry_iq_config import get_foundry_iq_config - - config = get_foundry_iq_config() - if not config.enabled or not config.is_configured: - return {"status": "skipped", "message": "Foundry IQ is not enabled."} - - try: - top = min(max(params.top, 1), 10) - data = search_memories(params.query, top, config) - if data.get("status") == "ok" and data.get("results"): - formatted = [ - { - "title": r.get("title", ""), - "content": r.get("content", ""), - "source_type": r.get("source_type", ""), - "date": r.get("date", ""), - } - for r in data["results"] - ] - return {"status": "ok", "results": formatted, "count": len(formatted)} - return {"status": "ok", "results": [], "count": 0, "message": "No matching memories found."} - except Exception as exc: - return {"status": "error", "message": f"Memory search failed: {exc}"} - - -ALL_TOOLS = [schedule_task, cancel_task, list_scheduled_tasks, make_voice_call] + CARD_TOOLS - - -def get_all_tools() -> list: - from ..state.foundry_iq_config import get_foundry_iq_config - - tools = list(ALL_TOOLS) - try: - fiq = get_foundry_iq_config() - if fiq.enabled and fiq.is_configured: - tools.append(search_memories_tool) - except Exception: - pass - return tools diff --git a/app/runtime/agent/tools/__init__.py b/app/runtime/agent/tools/__init__.py new file mode 100644 index 0000000..cfe850f --- /dev/null +++ b/app/runtime/agent/tools/__init__.py @@ -0,0 +1,42 @@ +"""Custom tools exposed to the Copilot agent.""" + +from .cards import CARD_TOOLS +from .memory import SearchMemoriesParams, search_memories_tool +from .scheduler import ( + CancelTaskParams, + ScheduleTaskParams, + cancel_task, + list_scheduled_tasks, + schedule_task, +) +from .voice import MakeCallParams, make_voice_call + +ALL_TOOLS = [schedule_task, cancel_task, list_scheduled_tasks, make_voice_call] + CARD_TOOLS + + +def get_all_tools() -> list: + from ...state.foundry_iq_config import get_foundry_iq_config + + tools = list(ALL_TOOLS) + try: + fiq = get_foundry_iq_config() + if fiq.enabled and fiq.is_configured: + tools.append(search_memories_tool) + except Exception: + pass + return tools + + +__all__ = [ + "ALL_TOOLS", + "CancelTaskParams", + "MakeCallParams", + "ScheduleTaskParams", + "SearchMemoriesParams", + "cancel_task", + "get_all_tools", + "list_scheduled_tasks", + "make_voice_call", + "schedule_task", + "search_memories_tool", +] diff --git a/app/runtime/agent/tools/cards.py b/app/runtime/agent/tools/cards.py new file mode 100644 index 0000000..bfbba5f --- /dev/null +++ b/app/runtime/agent/tools/cards.py @@ -0,0 +1,114 @@ +"""Card tool definitions for the Copilot agent. + +Wraps the card queue and attachment builders from the messaging layer +into ``@define_tool`` functions that the LLM can invoke. +""" + +from __future__ import annotations + +import json + +from copilot import define_tool +from pydantic import BaseModel, Field + +from ...messaging.cards import ( + _adaptive_card_attachment, + _default_queue, + _hero_card_attachment, + _thumbnail_card_attachment, +) + + +# -- parameter models ------------------------------------------------------ + + +class AdaptiveCardParams(BaseModel): + card_json: str = Field(description="The Adaptive Card payload as a JSON string.") + fallback_text: str = Field(default="", description="Plain-text fallback for unsupported clients.") + + +class HeroCardParams(BaseModel): + title: str = Field(default="", description="Card title") + subtitle: str = Field(default="", description="Card subtitle") + text: str = Field(default="", description="Card body text") + image_url: str | None = Field(default=None, description="URL of the card image") + buttons: str = Field(default="[]", description="JSON array of button objects.") + + +class ThumbnailCardParams(BaseModel): + title: str = Field(default="", description="Card title") + subtitle: str = Field(default="", description="Card subtitle") + text: str = Field(default="", description="Card body text") + image_url: str | None = Field(default=None, description="URL of the thumbnail image") + buttons: str = Field(default="[]", description="JSON array of button objects.") + + +class CardCarouselParams(BaseModel): + cards_json: str = Field(description="JSON array of card objects.") + + +# -- tool definitions ------------------------------------------------------ + + +@define_tool(description="Send an Adaptive Card to the user with rich layout support.") +def send_adaptive_card(params: AdaptiveCardParams) -> dict: + try: + card_data = json.loads(params.card_json) + except json.JSONDecodeError as exc: + return {"error": f"Invalid JSON: {exc}"} + if not isinstance(card_data, dict): + return {"error": "card_json must be a JSON object."} + _default_queue.enqueue(_adaptive_card_attachment(card_data)) + return {"status": "queued", "fallback_text": params.fallback_text, "elements": len(card_data.get("body", []))} + + +@define_tool(description="Send a Hero Card with large image, title, and action buttons.") +def send_hero_card(params: HeroCardParams) -> dict: + try: + buttons = json.loads(params.buttons) if params.buttons else [] + except json.JSONDecodeError: + buttons = [] + _default_queue.enqueue(_hero_card_attachment(title=params.title, subtitle=params.subtitle, text=params.text, image_url=params.image_url, buttons=buttons)) + return {"status": "queued", "title": params.title} + + +@define_tool(description="Send a Thumbnail Card with smaller image and compact layout.") +def send_thumbnail_card(params: ThumbnailCardParams) -> dict: + try: + buttons = json.loads(params.buttons) if params.buttons else [] + except json.JSONDecodeError: + buttons = [] + _default_queue.enqueue(_thumbnail_card_attachment(title=params.title, subtitle=params.subtitle, text=params.text, image_url=params.image_url, buttons=buttons)) + return {"status": "queued", "title": params.title} + + +@define_tool(description="Send multiple cards as a horizontal carousel.") +def send_card_carousel(params: CardCarouselParams) -> dict: + try: + cards = json.loads(params.cards_json) + except json.JSONDecodeError as exc: + return {"error": f"Invalid JSON: {exc}"} + if not isinstance(cards, list): + return {"error": "cards_json must be a JSON array."} + + count = 0 + for card in cards: + card_type = card.pop("type", "hero") + if card_type == "adaptive": + _default_queue.enqueue(_adaptive_card_attachment(card)) + elif card_type == "thumbnail": + buttons = card.get("buttons", []) + if isinstance(buttons, str): + buttons = json.loads(buttons) + _default_queue.enqueue(_thumbnail_card_attachment(title=card.get("title", ""), subtitle=card.get("subtitle", ""), text=card.get("text", ""), image_url=card.get("image_url"), buttons=buttons)) + else: + buttons = card.get("buttons", []) + if isinstance(buttons, str): + buttons = json.loads(buttons) + _default_queue.enqueue(_hero_card_attachment(title=card.get("title", ""), subtitle=card.get("subtitle", ""), text=card.get("text", ""), image_url=card.get("image_url"), buttons=buttons)) + count += 1 + + return {"status": "queued", "card_count": count} + + +CARD_TOOLS = [send_adaptive_card, send_hero_card, send_thumbnail_card, send_card_carousel] diff --git a/app/runtime/agent/tools/memory.py b/app/runtime/agent/tools/memory.py new file mode 100644 index 0000000..304fe8a --- /dev/null +++ b/app/runtime/agent/tools/memory.py @@ -0,0 +1,51 @@ +"""Memory search tool -- Foundry IQ vector search over indexed memories.""" + +from __future__ import annotations + +from copilot import define_tool +from pydantic import BaseModel, Field + + +class SearchMemoriesParams(BaseModel): + query: str = Field( + description="Natural language search query to find relevant memories.", + ) + top: int = Field(default=5, description="Maximum number of results to return (1-10).") + + +@define_tool( + description=( + "Search through indexed memories using Azure AI Search with vector " + "embeddings. Only works when Foundry IQ is enabled." + ) +) +def search_memories_tool(params: SearchMemoriesParams) -> dict: + from ...services.foundry_iq import search_memories + from ...state.foundry_iq_config import get_foundry_iq_config + + config = get_foundry_iq_config() + if not config.enabled or not config.is_configured: + return {"status": "skipped", "message": "Foundry IQ is not enabled."} + + try: + top = min(max(params.top, 1), 10) + data = search_memories(params.query, top, config) + if data.get("status") == "ok" and data.get("results"): + formatted = [ + { + "title": r.get("title", ""), + "content": r.get("content", ""), + "source_type": r.get("source_type", ""), + "date": r.get("date", ""), + } + for r in data["results"] + ] + return {"status": "ok", "results": formatted, "count": len(formatted)} + return { + "status": "ok", + "results": [], + "count": 0, + "message": "No matching memories found.", + } + except Exception as exc: + return {"status": "error", "message": f"Memory search failed: {exc}"} diff --git a/app/runtime/agent/tools/scheduler.py b/app/runtime/agent/tools/scheduler.py new file mode 100644 index 0000000..344a226 --- /dev/null +++ b/app/runtime/agent/tools/scheduler.py @@ -0,0 +1,94 @@ +"""Scheduler tools -- create, cancel, and list scheduled tasks.""" + +from __future__ import annotations + +import logging + +from copilot import define_tool +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class ScheduleTaskParams(BaseModel): + description: str = Field(description="Human-readable description of the task") + prompt: str = Field(description="The prompt to send to the agent when this task fires") + cron: str | None = Field( + default=None, + description=( + "Cron expression for recurring tasks (minute hour day month weekday). " + "Minimum interval is every 1 hour. " + "Example: '0 9 * * *' for every day at 09:00 UTC." + ), + ) + run_at: str | None = Field( + default=None, + description="ISO datetime for one-shot tasks (e.g. '2026-02-07T14:00:00')", + ) + + +class CancelTaskParams(BaseModel): + task_id: str = Field(description="ID of the scheduled task to cancel") + + +@define_tool( + description=( + "Schedule a future task. Provide either a cron expression for recurring " + "tasks (minimum every 1 hour) or a run_at datetime for one-shot tasks." + ) +) +def schedule_task(params: ScheduleTaskParams) -> dict: + from ...scheduler import get_scheduler + + scheduler = get_scheduler() + logger.info( + "[schedule_task] called: desc=%r, cron=%r, run_at=%r, prompt=%r", + params.description, params.cron, params.run_at, params.prompt[:80] if params.prompt else None, + ) + try: + task = scheduler.add( + description=params.description, + prompt=params.prompt, + cron=params.cron, + run_at=params.run_at, + ) + logger.info( + "[schedule_task] created task id=%s, run_at=%s, cron=%s, notify_cb=%s", + task.id, task.run_at, task.cron, + "SET" if scheduler._notify else "NOT SET", + ) + return {"id": task.id, "description": task.description, "status": "scheduled"} + except ValueError as exc: + logger.warning("[schedule_task] rejected: %s", exc) + return {"error": str(exc)} + + +@define_tool(description="Cancel a scheduled task by ID.") +def cancel_task(params: CancelTaskParams) -> str: + from ...scheduler import get_scheduler + + scheduler = get_scheduler() + return ( + f"Task {params.task_id} cancelled." + if scheduler.remove(params.task_id) + else f"Task {params.task_id} not found." + ) + + +@define_tool( + description="List all scheduled tasks with their ID, description, schedule, and status.", +) +def list_scheduled_tasks() -> list[dict]: + from ...scheduler import get_scheduler + + return [ + { + "id": t.id, + "description": t.description, + "cron": t.cron, + "run_at": t.run_at, + "enabled": t.enabled, + "last_run": t.last_run, + } + for t in get_scheduler().list_tasks() + ] diff --git a/app/runtime/agent/tools/voice.py b/app/runtime/agent/tools/voice.py new file mode 100644 index 0000000..db88143 --- /dev/null +++ b/app/runtime/agent/tools/voice.py @@ -0,0 +1,67 @@ +"""Voice call tool -- initiate outbound calls to the user.""" + +from __future__ import annotations + +import json +import logging +import threading +import urllib.error +import urllib.request + +from copilot import define_tool +from pydantic import BaseModel, Field + +from ...config.settings import cfg + +logger = logging.getLogger(__name__) + + +class MakeCallParams(BaseModel): + prompt: str | None = Field( + default=None, + description="Optional custom prompt / instructions for the voice AI agent.", + ) + opening_message: str | None = Field( + default=None, + description="Optional opening message the AI should speak when the call connects.", + ) + + +@define_tool( + description=( + "Initiate an outbound voice call to the user. ALWAYS call this tool " + "when the user asks to be called -- the target phone number is managed " + "internally and you do not need to ask the user for it." + ) +) +def make_voice_call(params: MakeCallParams) -> dict: + target = cfg.voice_target_number + if not target: + return { + "status": "error", + "message": ( + "No target phone number configured yet. " + "Ask the user to run: /phone (e.g. /phone +14155551234)" + ), + } + url = f"http://127.0.0.1:{cfg.admin_port}/api/voice/call" + body: dict[str, str] = {"number": target} + if params.prompt: + body["prompt"] = params.prompt + if params.opening_message: + body["opening_message"] = params.opening_message + payload = json.dumps(body).encode("utf-8") + headers: dict[str, str] = {"Content-Type": "application/json"} + if cfg.admin_secret: + headers["Authorization"] = f"Bearer {cfg.admin_secret}" + req = urllib.request.Request(url, data=payload, headers=headers, method="POST") + + def _fire() -> None: + try: + with urllib.request.urlopen(req, timeout=30) as resp: + logger.info("Voice call API responded: %s", resp.read().decode()[:200]) + except Exception as exc: + logger.error("Voice call API request failed: %s", exc) + + threading.Thread(target=_fire, daemon=True).start() + return {"status": "ok", "message": "Call triggered"} diff --git a/app/runtime/env_cli.py b/app/runtime/env_cli.py index 239ffd6..75dbacb 100644 --- a/app/runtime/env_cli.py +++ b/app/runtime/env_cli.py @@ -7,8 +7,8 @@ import sys from dataclasses import asdict -from .services.azure import AzureCLI -from .services.misconfig_checker import MisconfigChecker +from .services.cloud.azure import AzureCLI +from .services.security.misconfig_checker import MisconfigChecker from .services.resource_tracker import ResourceTracker from .state.deploy_state import DeployStateStore diff --git a/app/runtime/media/__init__.py b/app/runtime/media/__init__.py index 75d37a6..51505bb 100644 --- a/app/runtime/media/__init__.py +++ b/app/runtime/media/__init__.py @@ -1,10 +1,11 @@ """Media handling -- type classification, download, and outgoing extraction.""" from .classify import EXTENSION_TO_MIME, classify -from .incoming import build_media_prompt, download_attachment, extract_outgoing_attachments +from .incoming import build_media_prompt, download_attachment from .outgoing import ( MAX_OUTGOING_FILE_BYTES, collect_pending_outgoing, + extract_outgoing_attachments, move_attachments_to_error, read_error_details, ) diff --git a/app/runtime/media/incoming.py b/app/runtime/media/incoming.py index 9cdf929..19721f1 100644 --- a/app/runtime/media/incoming.py +++ b/app/runtime/media/incoming.py @@ -2,11 +2,9 @@ from __future__ import annotations -import base64 import logging import mimetypes import os -import re import uuid from pathlib import Path @@ -14,7 +12,7 @@ from ..config.settings import cfg from ..util.async_helpers import run_sync -from .classify import EXTENSION_TO_MIME, classify +from .classify import classify logger = logging.getLogger(__name__) @@ -70,42 +68,3 @@ def build_media_prompt(user_text: str, saved_files: list[dict]) -> str: ] block = "\n".join(descriptions) return f"{block}\n\n{user_text}" if user_text else block - - -_FILE_PATH_RE = re.compile( - r"(?:^|\s)(/[\w./-]+\.(?:" + "|".join(ext.lstrip(".") for ext in EXTENSION_TO_MIME) + r"))\b", - re.IGNORECASE, -) - - -def extract_outgoing_attachments(response: str) -> list[Attachment]: - matches = _FILE_PATH_RE.findall(response) - attachments: list[Attachment] = [] - seen: set[str] = set() - - for file_path in matches: - if file_path in seen: - continue - seen.add(file_path) - - p = Path(file_path) - if not p.is_file(): - continue - - content_type = EXTENSION_TO_MIME.get(p.suffix.lower()) - if not content_type: - continue - - try: - data = base64.b64encode(p.read_bytes()).decode("ascii") - attachments.append( - Attachment( - name=p.name, - content_type=content_type, - content_url=f"data:{content_type};base64,{data}", - ) - ) - except Exception: - logger.exception("Failed to read media file %s", file_path) - - return attachments diff --git a/app/runtime/media/outgoing.py b/app/runtime/media/outgoing.py index 324924e..97fb736 100644 --- a/app/runtime/media/outgoing.py +++ b/app/runtime/media/outgoing.py @@ -5,6 +5,7 @@ import base64 import logging import mimetypes +import re import shutil import uuid from pathlib import Path @@ -200,3 +201,45 @@ def read_error_details() -> list[dict]: except OSError: continue return details + + +# -- inline response attachment extraction ---------------------------------- + +_FILE_PATH_RE = re.compile( + r"(?:^|\s)(/[\w./-]+\.(?:" + "|".join(ext.lstrip(".") for ext in EXTENSION_TO_MIME) + r"))\b", + re.IGNORECASE, +) + + +def extract_outgoing_attachments(response: str) -> list[Attachment]: + """Scan LLM response text for file paths and return base64-encoded attachments.""" + matches = _FILE_PATH_RE.findall(response) + attachments: list[Attachment] = [] + seen: set[str] = set() + + for file_path in matches: + if file_path in seen: + continue + seen.add(file_path) + + p = Path(file_path) + if not p.is_file(): + continue + + content_type = EXTENSION_TO_MIME.get(p.suffix.lower()) + if not content_type: + continue + + try: + data = base64.b64encode(p.read_bytes()).decode("ascii") + attachments.append( + Attachment( + name=p.name, + content_type=content_type, + content_url=f"data:{content_type};base64,{data}", + ) + ) + except Exception: + logger.exception("Failed to read media file %s", file_path) + + return attachments diff --git a/app/runtime/messaging/__init__.py b/app/runtime/messaging/__init__.py index 260f623..0564649 100644 --- a/app/runtime/messaging/__init__.py +++ b/app/runtime/messaging/__init__.py @@ -1,11 +1,12 @@ """Channel messaging pipeline -- bot handler, commands, cards, and formatting.""" +from .cards import CardQueue, drain_pending_cards +from .formatting import markdown_to_telegram, strip_markdown +from .proactive import ConversationReferenceStore, send_proactive_message + __all__ = [ "CardQueue", - "CommandDispatcher", "ConversationReferenceStore", - "MessageProcessor", - "Bot", "drain_pending_cards", "markdown_to_telegram", "send_proactive_message", diff --git a/app/runtime/messaging/cards.py b/app/runtime/messaging/cards.py index 703d015..7dc5a89 100644 --- a/app/runtime/messaging/cards.py +++ b/app/runtime/messaging/cards.py @@ -6,7 +6,6 @@ from __future__ import annotations -import json import logging import threading from typing import Any @@ -19,8 +18,8 @@ HeroCard, ThumbnailCard, ) -from copilot import define_tool -from pydantic import BaseModel, Field + +from ..util.singletons import register_singleton logger = logging.getLogger(__name__) @@ -52,6 +51,14 @@ def drain_pending_cards() -> list[Attachment]: return _default_queue.drain() +def _reset_default_queue() -> None: + """Drain the global card queue (for test isolation).""" + _default_queue.drain() + + +register_singleton(_reset_default_queue) + + # -- attachment builders --------------------------------------------------- @@ -148,98 +155,3 @@ def _serialize_model(obj: Any) -> Any: def _to_camel(snake: str) -> str: parts = snake.split("_") return parts[0] + "".join(p.capitalize() for p in parts[1:]) - - -# -- parameter models ------------------------------------------------------ - - -class AdaptiveCardParams(BaseModel): - card_json: str = Field(description="The Adaptive Card payload as a JSON string.") - fallback_text: str = Field(default="", description="Plain-text fallback for unsupported clients.") - - -class HeroCardParams(BaseModel): - title: str = Field(default="", description="Card title") - subtitle: str = Field(default="", description="Card subtitle") - text: str = Field(default="", description="Card body text") - image_url: str | None = Field(default=None, description="URL of the card image") - buttons: str = Field(default="[]", description="JSON array of button objects.") - - -class ThumbnailCardParams(BaseModel): - title: str = Field(default="", description="Card title") - subtitle: str = Field(default="", description="Card subtitle") - text: str = Field(default="", description="Card body text") - image_url: str | None = Field(default=None, description="URL of the thumbnail image") - buttons: str = Field(default="[]", description="JSON array of button objects.") - - -class CardCarouselParams(BaseModel): - cards_json: str = Field(description="JSON array of card objects.") - - -# -- tool definitions ------------------------------------------------------ - - -@define_tool(description="Send an Adaptive Card to the user with rich layout support.") -def send_adaptive_card(params: AdaptiveCardParams) -> dict: - try: - card_data = json.loads(params.card_json) - except json.JSONDecodeError as exc: - return {"error": f"Invalid JSON: {exc}"} - if not isinstance(card_data, dict): - return {"error": "card_json must be a JSON object."} - _default_queue.enqueue(_adaptive_card_attachment(card_data)) - return {"status": "queued", "fallback_text": params.fallback_text, "elements": len(card_data.get("body", []))} - - -@define_tool(description="Send a Hero Card with large image, title, and action buttons.") -def send_hero_card(params: HeroCardParams) -> dict: - try: - buttons = json.loads(params.buttons) if params.buttons else [] - except json.JSONDecodeError: - buttons = [] - _default_queue.enqueue(_hero_card_attachment(title=params.title, subtitle=params.subtitle, text=params.text, image_url=params.image_url, buttons=buttons)) - return {"status": "queued", "title": params.title} - - -@define_tool(description="Send a Thumbnail Card with smaller image and compact layout.") -def send_thumbnail_card(params: ThumbnailCardParams) -> dict: - try: - buttons = json.loads(params.buttons) if params.buttons else [] - except json.JSONDecodeError: - buttons = [] - _default_queue.enqueue(_thumbnail_card_attachment(title=params.title, subtitle=params.subtitle, text=params.text, image_url=params.image_url, buttons=buttons)) - return {"status": "queued", "title": params.title} - - -@define_tool(description="Send multiple cards as a horizontal carousel.") -def send_card_carousel(params: CardCarouselParams) -> dict: - try: - cards = json.loads(params.cards_json) - except json.JSONDecodeError as exc: - return {"error": f"Invalid JSON: {exc}"} - if not isinstance(cards, list): - return {"error": "cards_json must be a JSON array."} - - count = 0 - for card in cards: - card_type = card.pop("type", "hero") - if card_type == "adaptive": - _default_queue.enqueue(_adaptive_card_attachment(card)) - elif card_type == "thumbnail": - buttons = card.get("buttons", []) - if isinstance(buttons, str): - buttons = json.loads(buttons) - _default_queue.enqueue(_thumbnail_card_attachment(title=card.get("title", ""), subtitle=card.get("subtitle", ""), text=card.get("text", ""), image_url=card.get("image_url"), buttons=buttons)) - else: - buttons = card.get("buttons", []) - if isinstance(buttons, str): - buttons = json.loads(buttons) - _default_queue.enqueue(_hero_card_attachment(title=card.get("title", ""), subtitle=card.get("subtitle", ""), text=card.get("text", ""), image_url=card.get("image_url"), buttons=buttons)) - count += 1 - - return {"status": "queued", "card_count": count} - - -CARD_TOOLS = [send_adaptive_card, send_hero_card, send_thumbnail_card, send_card_carousel] diff --git a/app/runtime/messaging/commands.py b/app/runtime/messaging/commands.py deleted file mode 100644 index b0c4efe..0000000 --- a/app/runtime/messaging/commands.py +++ /dev/null @@ -1,582 +0,0 @@ -"""Shared slash-command dispatcher. - -Centralises all slash-command logic so both the Bot Framework handler -and the WebSocket chat handler share a single implementation. -""" - -from __future__ import annotations - -import time -import uuid -from collections.abc import Awaitable, Callable -from dataclasses import dataclass -from typing import Any, Protocol - -from ..agent.agent import Agent -from ..config.settings import cfg -from ..registries.plugins import get_plugin_registry -from ..registries.skills import get_registry as get_skill_registry -from ..scheduler import get_scheduler -from ..state.infra_config import InfraConfigStore -from ..state.mcp_config import McpConfigStore -from ..state.profile import load_profile -from ..state.session_store import SessionStore - -BOOT_TIME = time.monotonic() - -ReplyFn = Callable[[str], Awaitable[None]] - - -class ChannelContext(Protocol): - @property - def conversation_refs_count(self) -> int: ... - - @property - def connected_channels(self) -> set[str]: ... - - @property - def conversation_refs(self) -> list[Any]: ... - - -@dataclass -class CommandContext: - text: str - reply: ReplyFn - channel: str - channel_ctx: ChannelContext | None = None - - -class CommandDispatcher: - _EXACT_COMMANDS: dict[str, str] = { - "/new": "_cmd_new", - "/status": "_cmd_status", - "/skills": "_cmd_skills", - "/session": "_cmd_session", - "/channels": "_cmd_channels", - "/clear": "_cmd_clear", - "/help": "_cmd_help", - "/plugins": "_cmd_plugins", - "/mcp": "_cmd_mcp", - "/schedules": "_cmd_schedules", - "/sessions": "_cmd_sessions", - "/profile": "_cmd_profile", - "/config": "_cmd_config", - "/preflight": "_cmd_preflight", - "/call": "_cmd_call", - "/models": "_cmd_models", - "/change": "_cmd_change", - } - - _PREFIX_COMMANDS: tuple[tuple[str, str], ...] = ( - ("/removeskill", "_cmd_removeskill"), - ("/addskill", "_cmd_addskill"), - ("/model", "_cmd_model"), - ("/plugin", "_cmd_plugin"), - ("/mcp", "_cmd_mcp"), - ("/schedule", "_cmd_schedule"), - ("/sessions", "_cmd_sessions_sub"), - ("/session", "_cmd_session_sub"), - ("/config", "_cmd_config"), - ("/phone", "_cmd_phone"), - ("/lockdown", "_cmd_lockdown"), - ) - - def __init__( - self, - agent: Agent, - session_store: SessionStore | None = None, - infra: InfraConfigStore | None = None, - ) -> None: - self._agent = agent - self._session_store = session_store - self._infra = infra - - @property - def infra(self) -> InfraConfigStore: - if self._infra is None: - self._infra = InfraConfigStore() - return self._infra - - async def try_handle( - self, - text: str, - reply: ReplyFn, - channel: str = "web", - *, - channel_ctx: ChannelContext | None = None, - ) -> bool: - lower = text.lower() - ctx = CommandContext(text=text, reply=reply, channel=channel, channel_ctx=channel_ctx) - - handler_name = self._EXACT_COMMANDS.get(lower) - if handler_name: - await getattr(self, handler_name)(ctx) - return True - - for prefix, handler_name in self._PREFIX_COMMANDS: - if lower.startswith(prefix): - await getattr(self, handler_name)(ctx) - return True - - return False - - async def _cmd_new(self, ctx: CommandContext) -> None: - await self._agent.new_session() - if self._session_store: - self._session_store.start_session(uuid.uuid4().hex[:12], model=cfg.copilot_model) - await ctx.reply("New session started.") - - async def _cmd_model(self, ctx: CommandContext) -> None: - parts = ctx.text.split(maxsplit=1) - if len(parts) < 2: - await ctx.reply(f"Current model: {cfg.copilot_model}\n\nUsage: /model ") - return - new_model = parts[1].strip() - old_model = cfg.copilot_model - cfg.write_env(COPILOT_MODEL=new_model) - await self._agent.new_session() - if self._session_store: - self._session_store.start_session(uuid.uuid4().hex[:12], model=new_model) - await ctx.reply(f"Model switched: {old_model} -> {new_model}\nNew session started.") - - async def _cmd_models(self, ctx: CommandContext) -> None: - models = await self._agent.list_models() - if not models: - await ctx.reply("No models available.") - return - current = cfg.copilot_model - lines = ["Available Models", ""] - for m in models: - marker = " *" if m["id"] == current else "" - cost = f" ({m['billing_multiplier']}x)" if m.get("billing_multiplier", 1.0) != 1.0 else "" - reasoning = f" [reasoning: {', '.join(m['reasoning_efforts'])}]" if m.get("reasoning_efforts") else "" - policy = m.get("policy", "enabled") - if policy != "enabled": - lines.append(f" {m['id']}{marker}{cost} ({policy})") - else: - lines.append(f" {m['id']}{marker}{cost}{reasoning}") - lines.append(f"\nCurrent: {current}\nUse /model to switch.") - await ctx.reply("\n".join(lines)) - - async def _cmd_status(self, ctx: CommandContext) -> None: - uptime_seconds = int(time.monotonic() - BOOT_TIME) - hours, remainder = divmod(uptime_seconds, 3600) - minutes, seconds = divmod(remainder, 60) - - sched = get_scheduler() - tasks = sched.list_tasks() - active_tasks = [t for t in tasks if t.enabled] - total_reqs = sum(self._agent.request_counts.values()) - - lines = [ - "System Status", - f" Model: {cfg.copilot_model}", - f" Uptime: {hours}h {minutes}m {seconds}s", - f" Total requests: {total_reqs}", - ] - for model, count in sorted(self._agent.request_counts.items()): - lines.append(f" {model}: {count}") - if ctx.channel_ctx is not None: - channels = ctx.channel_ctx.connected_channels - lines.append(f" Connected channels: {', '.join(sorted(channels)) or 'none'}") - lines.append(f" Conversation refs: {ctx.channel_ctx.conversation_refs_count}") - lines.append(f" Scheduled tasks: {len(active_tasks)} active / {len(tasks)} total") - lines.append(f" Data dir: {cfg.data_dir}") - await ctx.reply("\n".join(lines)) - - async def _cmd_skills(self, ctx: CommandContext) -> None: - skills: list[str] = [] - if cfg.user_skills_dir.is_dir(): - for d in sorted(cfg.user_skills_dir.iterdir()): - if d.is_dir() and (d / "SKILL.md").exists(): - skills.append(d.name) - lines = [f"Skills ({len(skills)}):"] + [f" - {name}" for name in skills] - if not skills: - lines.append(" (none)") - await ctx.reply("\n".join(lines)) - - async def _cmd_session(self, ctx: CommandContext) -> None: - lines = [ - "Session Info", - f" Active: {'yes' if self._agent.has_session else 'no'}", - f" Model: {cfg.copilot_model}", - " Playwright MCP: enabled", - ] - await ctx.reply("\n".join(lines)) - - async def _cmd_channels(self, ctx: CommandContext) -> None: - lines = ["Channel Configuration\n"] - tg = self.infra.channels.telegram - if tg.token: - masked = tg.token[:8] + "..." + tg.token[-4:] if len(tg.token) > 12 else "***" - lines.append(f"Telegram:\n Token: {masked}\n Whitelist: {tg.whitelist or '(none)'}") - else: - lines.append("Telegram: not configured") - lines.append(f"\nBot Framework:\n App ID: {cfg.bot_app_id[:8] + '...' if cfg.bot_app_id else 'not set'}") - lines.append(f" Tenant: {cfg.bot_app_tenant_id[:8] + '...' if cfg.bot_app_tenant_id else 'not set'}") - lines.append(f" Admin secret: {'set' if cfg.admin_secret else 'not set'}") - if ctx.channel_ctx is not None: - refs = ctx.channel_ctx.conversation_refs - lines.append(f"\nActive Conversations ({len(refs)}):") - for r in refs: - user_name = r.user.name if r.user else "?" - lines.append(f" - {r.channel_id}: {user_name}") - await ctx.reply("\n".join(lines)) - - async def _cmd_clear(self, ctx: CommandContext) -> None: - cleared = 0 - if cfg.memory_dir.is_dir(): - for f in cfg.memory_dir.rglob("*"): - if f.is_file(): - f.unlink() - cleared += 1 - await ctx.reply(f"Memory cleared ({cleared} files removed).") - - async def _cmd_addskill(self, ctx: CommandContext) -> None: - parts = ctx.text.split(maxsplit=1) - if len(parts) < 2: - reg = get_skill_registry() - try: - catalog = await reg.fetch_catalog() - available = [s for s in catalog if not s.installed] - if available: - lines = [f"Available skills ({len(available)}):"] - for s in available: - desc = f" - {s.description}" if s.description else "" - lines.append(f" {s.name}{desc} [{s.source}]") - lines.append("\nUsage: /addskill ") - else: - lines = ["All catalog skills already installed.", "Usage: /addskill "] - except Exception as exc: - lines = [f"Failed to fetch catalog: {exc}", "Usage: /addskill "] - await ctx.reply("\n".join(lines)) - return - name = parts[1].strip() - reg = get_skill_registry() - await ctx.reply(f"Installing skill '{name}'...") - ok = await reg.install(name) - await ctx.reply(f"Skill '{name}' installed." if ok else f"Failed to install skill '{name}'.") - - async def _cmd_removeskill(self, ctx: CommandContext) -> None: - parts = ctx.text.split(maxsplit=1) - if len(parts) < 2: - reg = get_skill_registry() - installed = reg.list_installed() - if installed: - lines = [f"Installed skills ({len(installed)}):"] + [f" {s.name}" for s in installed] - lines.append("\nUsage: /removeskill ") - else: - lines = ["No skills installed.", "Usage: /removeskill "] - await ctx.reply("\n".join(lines)) - return - name = parts[1].strip() - reg = get_skill_registry() - removed = reg.remove(name) - await ctx.reply(f"Skill '{name}' removed." if removed else f"Skill '{name}' not found.") - - async def _cmd_plugins(self, ctx: CommandContext) -> None: - reg = get_plugin_registry() - plugins = reg.list_plugins() - if not plugins: - await ctx.reply("No plugins found.") - return - lines = [f"Plugins ({len(plugins)}):"] - for p in plugins: - icon = "+" if p.get("enabled") else "-" - desc = f" - {p['description']}" if p.get("description") else "" - lines.append(f" [{icon}] {p['id']}{desc} ({p.get('skill_count', 0)} skills)") - lines.append("\nUsage: /plugin enable , /plugin disable ") - await ctx.reply("\n".join(lines)) - - async def _cmd_plugin(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - if len(parts) < 3: - await ctx.reply("Usage: /plugin enable or /plugin disable ") - return - action, plugin_id = parts[1].lower(), parts[2].strip() - reg = get_plugin_registry() - if action == "enable": - result = reg.enable_plugin(plugin_id) - await ctx.reply(f"Plugin '{plugin_id}' enabled." if result else f"Plugin '{plugin_id}' not found.") - elif action == "disable": - result = reg.disable_plugin(plugin_id) - await ctx.reply(f"Plugin '{plugin_id}' disabled." if result else f"Plugin '{plugin_id}' not found.") - else: - await ctx.reply(f"Unknown action '{action}'.") - - async def _cmd_mcp(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - store = McpConfigStore() - if len(parts) == 1: - servers = store.list_servers() - if not servers: - await ctx.reply("No MCP servers configured.") - return - lines = [f"MCP Servers ({len(servers)}):"] - for s in servers: - icon = "+" if s.get("enabled") else "-" - builtin = " [builtin]" if s.get("builtin") else "" - lines.append(f" [{icon}] {s['name']} ({s.get('type', '?')}){builtin}") - if s.get("description"): - lines.append(f" {s['description']}") - await ctx.reply("\n".join(lines)) - return - - action = parts[1].lower() - if action == "add": - if len(parts) < 4: - await ctx.reply("Usage: /mcp add ") - return - try: - store.add_server(parts[2], "http", url=parts[3]) - await ctx.reply(f"MCP server '{parts[2]}' added. Start a /new session to activate.") - except ValueError as exc: - await ctx.reply(f"Error: {exc}") - elif action == "remove": - if len(parts) < 3: - await ctx.reply("Usage: /mcp remove ") - return - try: - ok = store.remove_server(parts[2]) - await ctx.reply(f"MCP server '{parts[2]}' removed." if ok else f"MCP server '{parts[2]}' not found.") - except ValueError as exc: - await ctx.reply(f"Error: {exc}") - elif action in ("enable", "disable"): - if len(parts) < 3: - await ctx.reply(f"Usage: /mcp {action} ") - return - ok = store.set_enabled(parts[2], action == "enable") - await ctx.reply(f"MCP server '{parts[2]}' {action}d." if ok else f"MCP server '{parts[2]}' not found.") - else: - await ctx.reply(f"Unknown MCP action '{action}'.") - - async def _cmd_schedules(self, ctx: CommandContext) -> None: - sched = get_scheduler() - tasks = sched.list_tasks() - if not tasks: - await ctx.reply("No scheduled tasks.\n\nUsage: /schedule add ") - return - lines = [f"Scheduled Tasks ({len(tasks)}):"] - for t in tasks: - icon = "+" if t.enabled else "-" - schedule = t.cron or (f"once at {t.run_at}" if t.run_at else "?") - lines.append(f" [{icon}] {t.id} - {t.description}") - lines.append(f" Schedule: {schedule} | Last run: {t.last_run[:16] if t.last_run else 'never'}") - await ctx.reply("\n".join(lines)) - - async def _cmd_schedule(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - if len(parts) < 2: - await ctx.reply("Usage: /schedule add or /schedule remove ") - return - action = parts[1].lower() - sched = get_scheduler() - if action == "add": - if len(parts) < 8: - await ctx.reply("Usage: /schedule add ") - return - cron = " ".join(parts[2:7]) - prompt = " ".join(parts[7:]) - try: - task = sched.add(description=prompt[:60], prompt=prompt, cron=cron) - await ctx.reply(f"Scheduled task created:\n ID: {task.id}\n Cron: {cron}\n Prompt: {prompt}") - except ValueError as exc: - await ctx.reply(f"Error: {exc}") - elif action == "remove": - if len(parts) < 3: - await ctx.reply("Usage: /schedule remove ") - return - ok = sched.remove(parts[2]) - await ctx.reply(f"Task '{parts[2]}' removed." if ok else f"Task '{parts[2]}' not found.") - - async def _cmd_sessions(self, ctx: CommandContext) -> None: - if not self._session_store: - await ctx.reply("Session store not available.") - return - sessions = self._session_store.list_sessions() - if not sessions: - await ctx.reply("No recorded sessions.") - return - stats = self._session_store.get_session_stats() - lines = [f"Sessions ({stats['total_sessions']} total, {stats['total_messages']} messages)", ""] - for s in sessions[:10]: - started = s.get("started_at", "?")[:16] - preview = s.get("first_message", "")[:50] - lines.append(f" {s['id']} {started} {s.get('model', '?')} ({s.get('message_count', 0)} msgs)") - if len(sessions) > 10: - lines.append(f" ... and {len(sessions) - 10} more") - await ctx.reply("\n".join(lines)) - - async def _cmd_sessions_sub(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - if len(parts) >= 2 and parts[1].lower() == "clear": - if not self._session_store: - await ctx.reply("Session store not available.") - return - count = self._session_store.clear_all() - await ctx.reply(f"All sessions cleared ({count} deleted).") - else: - await self._cmd_sessions(ctx) - - async def _cmd_session_sub(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - if len(parts) >= 3 and parts[1].lower() == "delete": - if not self._session_store: - await ctx.reply("Session store not available.") - return - ok = self._session_store.delete_session(parts[2]) - await ctx.reply(f"Session '{parts[2]}' deleted." if ok else f"Session '{parts[2]}' not found.") - else: - await self._cmd_session(ctx) - - async def _cmd_profile(self, ctx: CommandContext) -> None: - profile = load_profile() - lines = [ - "Agent Profile", - f" Name: {profile.get('name') or '(not set)'}", - f" Location: {profile.get('location') or '(not set)'}", - f" Emotional state: {profile.get('emotional_state', 'neutral')}", - ] - prefs = profile.get("preferences", {}) - if prefs: - lines.append(" Preferences:") - for k, v in prefs.items(): - lines.append(f" {k}: {v}") - await ctx.reply("\n".join(lines)) - - async def _cmd_config(self, ctx: CommandContext) -> None: - parts = ctx.text.split(maxsplit=2) - if len(parts) == 1: - lines = [ - "Runtime Configuration", - f" Model: {cfg.copilot_model}", - f" Admin port: {cfg.admin_port}", - f" Bot port: {cfg.bot_port}", - f" Data dir: {cfg.data_dir}", - f" Admin secret: {'set' if cfg.admin_secret else 'not set'}", - "\nUsage: /config ", - ] - await ctx.reply("\n".join(lines)) - return - if len(parts) < 3: - await ctx.reply("Usage: /config ") - return - key = parts[1].upper() - allowed = {"COPILOT_MODEL", "ADMIN_PORT", "BOT_PORT", "VOICE_TARGET_NUMBER", "ACS_SOURCE_NUMBER"} - if key not in allowed: - await ctx.reply(f"Cannot set '{key}'. Allowed keys: {', '.join(sorted(allowed))}") - return - cfg.write_env(**{key: parts[2]}) - await ctx.reply(f"Config updated: {key} = {parts[2]}") - - async def _cmd_preflight(self, ctx: CommandContext) -> None: - import aiohttp as _aiohttp - - base = f"http://127.0.0.1:{cfg.admin_port}" - headers = {"Authorization": f"Bearer {cfg.admin_secret}"} if cfg.admin_secret else {} - try: - async with _aiohttp.ClientSession() as session: - async with session.get(f"{base}/api/setup/preflight", headers=headers, timeout=_aiohttp.ClientTimeout(total=30)) as resp: - if resp.status != 200: - await ctx.reply(f"Preflight check failed (HTTP {resp.status}).") - return - data = await resp.json() - except Exception as exc: - await ctx.reply(f"Cannot reach preflight endpoint: {exc}") - return - - checks = data.get("checks", []) - lines = [f"Preflight Checks ({data.get('status', '?').upper()})"] - for c in checks: - icon = "OK" if c.get("ok") else "!!" - lines.append(f" [{icon}] {c['check']}: {c.get('detail', '')}") - await ctx.reply("\n".join(lines)) - - async def _cmd_phone(self, ctx: CommandContext) -> None: - parts = ctx.text.split(maxsplit=1) - if len(parts) < 2: - await ctx.reply(f"Current target number: {cfg.voice_target_number or '(not set)'}\n\nUsage: /phone ") - return - number = parts[1].strip() - if not number.startswith("+"): - await ctx.reply("Phone number must start with + country code.") - return - cfg.write_env(VOICE_TARGET_NUMBER=number) - await ctx.reply(f"Voice target number set to {number}.") - - async def _cmd_call(self, ctx: CommandContext) -> None: - import aiohttp as _aiohttp - - target = cfg.voice_target_number - if not target: - await ctx.reply("No target number configured. Use /phone first.") - return - base = f"http://127.0.0.1:{cfg.admin_port}" - headers = {"Authorization": f"Bearer {cfg.admin_secret}"} if cfg.admin_secret else {} - try: - async with _aiohttp.ClientSession() as session: - async with session.post(f"{base}/api/voice/call", json={"target_number": target}, headers=headers, timeout=_aiohttp.ClientTimeout(total=30)) as resp: - data = await resp.json() - if resp.status == 200: - await ctx.reply(f"Calling {target}...") - else: - await ctx.reply(f"Call failed: {data.get('error', f'HTTP {resp.status}')}") - except Exception as exc: - await ctx.reply(f"Call failed: {exc}") - - async def _cmd_change(self, ctx: CommandContext) -> None: - if not self._session_store: - await ctx.reply("Session store not available.") - return - sessions = self._session_store.list_sessions() - if not sessions: - await ctx.reply("No sessions to switch to. Use /new to start one.") - return - lines = ["Recent Sessions:", ""] - for i, s in enumerate(sessions[:5], 1): - started = s.get("started_at", "?")[:16] - lines.append(f" {i}. {started} {s.get('model', '?')} ({s.get('message_count', 0)} msgs)") - lines.append(f" ID: {s['id']}") - await ctx.reply("\n".join(lines)) - - async def _cmd_lockdown(self, ctx: CommandContext) -> None: - parts = ctx.text.split() - if len(parts) < 2: - state = "ENABLED" if cfg.lockdown_mode else "disabled" - await ctx.reply(f"Lock Down Mode: {state}\n\nUsage: /lockdown on | /lockdown off") - return - action = parts[1].lower() - if action not in ("on", "off"): - await ctx.reply("Usage: /lockdown on | /lockdown off") - return - if action == "on": - if cfg.lockdown_mode: - await ctx.reply("Lock Down Mode is already enabled.") - return - cfg.write_env(LOCKDOWN_MODE="1", TUNNEL_RESTRICTED="1") - from ..services.azure import AzureCLI - az = AzureCLI() - az.ok("logout") - az.invalidate_cache("account", "show") - await ctx.reply("Lock Down Mode ENABLED\n\n - Azure CLI logged out\n - Admin panel disabled") - else: - if not cfg.lockdown_mode: - await ctx.reply("Lock Down Mode is already disabled.") - return - cfg.write_env(LOCKDOWN_MODE="", TUNNEL_RESTRICTED="") - await ctx.reply("Lock Down Mode DISABLED\n\n - Admin panel re-enabled") - - async def _cmd_help(self, ctx: CommandContext) -> None: - lines = [ - "Available Commands", - "", - " /new, /model , /models, /status, /session, /config", - " /skills, /addskill , /removeskill ", - " /plugins, /plugin enable|disable ", - " /mcp, /mcp add|remove|enable|disable ", - " /schedules, /schedule add|remove", - " /sessions, /session delete , /sessions clear", - " /change, /profile, /channels, /clear", - " /phone , /call, /preflight, /lockdown, /help", - ] - await ctx.reply("\n".join(lines)) diff --git a/app/runtime/messaging/commands/__init__.py b/app/runtime/messaging/commands/__init__.py new file mode 100644 index 0000000..209745d --- /dev/null +++ b/app/runtime/messaging/commands/__init__.py @@ -0,0 +1,24 @@ +"""Slash-command dispatcher and command implementations. + +Sub-modules group commands by domain: + +- ``agent`` -- skills, plugins, MCP, schedules +- ``session`` -- session lifecycle and model switching +- ``system`` -- status, infra, and connectivity commands +""" + +from ._dispatcher import ( + ChannelContext, + CommandContext, + CommandDispatcher, + ReplyFn, +) +from .system import BOOT_TIME + +__all__ = [ + "BOOT_TIME", + "ChannelContext", + "CommandContext", + "CommandDispatcher", + "ReplyFn", +] diff --git a/app/runtime/messaging/commands/_dispatcher.py b/app/runtime/messaging/commands/_dispatcher.py new file mode 100644 index 0000000..a8beee3 --- /dev/null +++ b/app/runtime/messaging/commands/_dispatcher.py @@ -0,0 +1,199 @@ +"""Shared slash-command dispatcher. + +Centralises all slash-command logic so both the Bot Framework handler +and the WebSocket chat handler share a single implementation. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any, Protocol + +from ...agent.agent import Agent +from ...state.infra_config import InfraConfigStore +from ...state.session_store import SessionStore + +from . import agent as _agent_cmds +from . import session as _session_cmds +from . import system as _system_cmds + +ReplyFn = Callable[[str], Awaitable[None]] + + +class ChannelContext(Protocol): + @property + def conversation_refs_count(self) -> int: ... + + @property + def connected_channels(self) -> set[str]: ... + + @property + def conversation_refs(self) -> list[Any]: ... + + +@dataclass +class CommandContext: + text: str + reply: ReplyFn + channel: str + channel_ctx: ChannelContext | None = None + + +class CommandDispatcher: + _EXACT_COMMANDS: dict[str, str] = { + "/new": "_cmd_new", + "/status": "_cmd_status", + "/skills": "_cmd_skills", + "/session": "_cmd_session", + "/channels": "_cmd_channels", + "/clear": "_cmd_clear", + "/help": "_cmd_help", + "/plugins": "_cmd_plugins", + "/mcp": "_cmd_mcp", + "/schedules": "_cmd_schedules", + "/sessions": "_cmd_sessions", + "/profile": "_cmd_profile", + "/config": "_cmd_config", + "/preflight": "_cmd_preflight", + "/call": "_cmd_call", + "/models": "_cmd_models", + "/change": "_cmd_change", + } + + _PREFIX_COMMANDS: tuple[tuple[str, str], ...] = ( + ("/removeskill", "_cmd_removeskill"), + ("/addskill", "_cmd_addskill"), + ("/model", "_cmd_model"), + ("/plugin", "_cmd_plugin"), + ("/mcp", "_cmd_mcp"), + ("/schedule", "_cmd_schedule"), + ("/sessions", "_cmd_sessions_sub"), + ("/session", "_cmd_session_sub"), + ("/config", "_cmd_config"), + ("/phone", "_cmd_phone"), + ("/lockdown", "_cmd_lockdown"), + ) + + def __init__( + self, + agent: Agent, + session_store: SessionStore | None = None, + infra: InfraConfigStore | None = None, + ) -> None: + self._agent = agent + self._session_store = session_store + self._infra = infra + + @property + def infra(self) -> InfraConfigStore: + if self._infra is None: + self._infra = InfraConfigStore() + return self._infra + + async def try_handle( + self, + text: str, + reply: ReplyFn, + channel: str = "web", + *, + channel_ctx: ChannelContext | None = None, + ) -> bool: + lower = text.lower() + ctx = CommandContext(text=text, reply=reply, channel=channel, channel_ctx=channel_ctx) + + handler_name = self._EXACT_COMMANDS.get(lower) + if handler_name: + await getattr(self, handler_name)(ctx) + return True + + for prefix, handler_name in self._PREFIX_COMMANDS: + if lower.startswith(prefix): + await getattr(self, handler_name)(ctx) + return True + + return False + + # -- Session & model commands (delegated to commands_session) ----------- + + async def _cmd_new(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_new(self, ctx) + + async def _cmd_model(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_model(self, ctx) + + async def _cmd_models(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_models(self, ctx) + + async def _cmd_session(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_session(self, ctx) + + async def _cmd_sessions(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_sessions(self, ctx) + + async def _cmd_sessions_sub(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_sessions_sub(self, ctx) + + async def _cmd_session_sub(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_session_sub(self, ctx) + + async def _cmd_change(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_change(self, ctx) + + async def _cmd_clear(self, ctx: CommandContext) -> None: + await _session_cmds.cmd_clear(self, ctx) + + # -- Agent commands (delegated to commands_agent) ---------------------- + + async def _cmd_skills(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_skills(self, ctx) + + async def _cmd_addskill(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_addskill(self, ctx) + + async def _cmd_removeskill(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_removeskill(self, ctx) + + async def _cmd_plugins(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_plugins(self, ctx) + + async def _cmd_plugin(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_plugin(self, ctx) + + async def _cmd_mcp(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_mcp(self, ctx) + + async def _cmd_schedules(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_schedules(self, ctx) + + async def _cmd_schedule(self, ctx: CommandContext) -> None: + await _agent_cmds.cmd_schedule(self, ctx) + + # -- System commands (delegated to commands_system) -------------------- + + async def _cmd_status(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_status(self, ctx) + + async def _cmd_channels(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_channels(self, ctx) + + async def _cmd_profile(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_profile(self, ctx) + + async def _cmd_config(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_config(self, ctx) + + async def _cmd_preflight(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_preflight(self, ctx) + + async def _cmd_phone(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_phone(self, ctx) + + async def _cmd_call(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_call(self, ctx) + + async def _cmd_lockdown(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_lockdown(self, ctx) + + async def _cmd_help(self, ctx: CommandContext) -> None: + await _system_cmds.cmd_help(self, ctx) diff --git a/app/runtime/messaging/commands/agent.py b/app/runtime/messaging/commands/agent.py new file mode 100644 index 0000000..dac03b1 --- /dev/null +++ b/app/runtime/messaging/commands/agent.py @@ -0,0 +1,191 @@ +"""Agent-related commands -- skills, plugins, MCP, schedules.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ...registries.plugins import get_plugin_registry +from ...registries.skills import get_registry as get_skill_registry +from ...scheduler import get_scheduler +from ...state.mcp_config import McpConfigStore + +if TYPE_CHECKING: + from ._dispatcher import CommandContext, CommandDispatcher + + +async def cmd_skills(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + from ...config.settings import cfg + + skills: list[str] = [] + if cfg.user_skills_dir.is_dir(): + for d in sorted(cfg.user_skills_dir.iterdir()): + if d.is_dir() and (d / "SKILL.md").exists(): + skills.append(d.name) + lines = [f"Skills ({len(skills)}):"] + [f" - {name}" for name in skills] + if not skills: + lines.append(" (none)") + await ctx.reply("\n".join(lines)) + + +async def cmd_addskill(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split(maxsplit=1) + if len(parts) < 2: + reg = get_skill_registry() + try: + catalog = await reg.fetch_catalog() + available = [s for s in catalog if not s.installed] + if available: + lines = [f"Available skills ({len(available)}):"] + for s in available: + desc = f" - {s.description}" if s.description else "" + lines.append(f" {s.name}{desc} [{s.source}]") + lines.append("\nUsage: /addskill ") + else: + lines = ["All catalog skills already installed.", "Usage: /addskill "] + except Exception as exc: + lines = [f"Failed to fetch catalog: {exc}", "Usage: /addskill "] + await ctx.reply("\n".join(lines)) + return + name = parts[1].strip() + reg = get_skill_registry() + await ctx.reply(f"Installing skill '{name}'...") + ok = await reg.install(name) + await ctx.reply(f"Skill '{name}' installed." if ok else f"Failed to install skill '{name}'.") + + +async def cmd_removeskill(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split(maxsplit=1) + if len(parts) < 2: + reg = get_skill_registry() + installed = reg.list_installed() + if installed: + lines = [f"Installed skills ({len(installed)}):"] + [f" {s.name}" for s in installed] + lines.append("\nUsage: /removeskill ") + else: + lines = ["No skills installed.", "Usage: /removeskill "] + await ctx.reply("\n".join(lines)) + return + name = parts[1].strip() + reg = get_skill_registry() + removed = reg.remove(name) + await ctx.reply(f"Skill '{name}' removed." if removed else f"Skill '{name}' not found.") + + +async def cmd_plugins(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + reg = get_plugin_registry() + plugins = reg.list_plugins() + if not plugins: + await ctx.reply("No plugins found.") + return + lines = [f"Plugins ({len(plugins)}):"] + for p in plugins: + icon = "+" if p.get("enabled") else "-" + desc = f" - {p['description']}" if p.get("description") else "" + lines.append(f" [{icon}] {p['id']}{desc} ({p.get('skill_count', 0)} skills)") + lines.append("\nUsage: /plugin enable , /plugin disable ") + await ctx.reply("\n".join(lines)) + + +async def cmd_plugin(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + if len(parts) < 3: + await ctx.reply("Usage: /plugin enable or /plugin disable ") + return + action, plugin_id = parts[1].lower(), parts[2].strip() + reg = get_plugin_registry() + if action == "enable": + result = reg.enable_plugin(plugin_id) + await ctx.reply(f"Plugin '{plugin_id}' enabled." if result else f"Plugin '{plugin_id}' not found.") + elif action == "disable": + result = reg.disable_plugin(plugin_id) + await ctx.reply(f"Plugin '{plugin_id}' disabled." if result else f"Plugin '{plugin_id}' not found.") + else: + await ctx.reply(f"Unknown action '{action}'.") + + +async def cmd_mcp(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + store = McpConfigStore() + if len(parts) == 1: + servers = store.list_servers() + if not servers: + await ctx.reply("No MCP servers configured.") + return + lines = [f"MCP Servers ({len(servers)}):"] + for s in servers: + icon = "+" if s.get("enabled") else "-" + builtin = " [builtin]" if s.get("builtin") else "" + lines.append(f" [{icon}] {s['name']} ({s.get('type', '?')}){builtin}") + if s.get("description"): + lines.append(f" {s['description']}") + await ctx.reply("\n".join(lines)) + return + + action = parts[1].lower() + if action == "add": + if len(parts) < 4: + await ctx.reply("Usage: /mcp add ") + return + try: + store.add_server(parts[2], "http", url=parts[3]) + await ctx.reply(f"MCP server '{parts[2]}' added. Start a /new session to activate.") + except ValueError as exc: + await ctx.reply(f"Error: {exc}") + elif action == "remove": + if len(parts) < 3: + await ctx.reply("Usage: /mcp remove ") + return + try: + ok = store.remove_server(parts[2]) + await ctx.reply(f"MCP server '{parts[2]}' removed." if ok else f"MCP server '{parts[2]}' not found.") + except ValueError as exc: + await ctx.reply(f"Error: {exc}") + elif action in ("enable", "disable"): + if len(parts) < 3: + await ctx.reply(f"Usage: /mcp {action} ") + return + ok = store.set_enabled(parts[2], action == "enable") + await ctx.reply(f"MCP server '{parts[2]}' {action}d." if ok else f"MCP server '{parts[2]}' not found.") + else: + await ctx.reply(f"Unknown MCP action '{action}'.") + + +async def cmd_schedules(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + sched = get_scheduler() + tasks = sched.list_tasks() + if not tasks: + await ctx.reply("No scheduled tasks.\n\nUsage: /schedule add ") + return + lines = [f"Scheduled Tasks ({len(tasks)}):"] + for t in tasks: + icon = "+" if t.enabled else "-" + schedule = t.cron or (f"once at {t.run_at}" if t.run_at else "?") + lines.append(f" [{icon}] {t.id} - {t.description}") + lines.append(f" Schedule: {schedule} | Last run: {t.last_run[:16] if t.last_run else 'never'}") + await ctx.reply("\n".join(lines)) + + +async def cmd_schedule(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + if len(parts) < 2: + await ctx.reply("Usage: /schedule add or /schedule remove ") + return + action = parts[1].lower() + sched = get_scheduler() + if action == "add": + if len(parts) < 8: + await ctx.reply("Usage: /schedule add ") + return + cron = " ".join(parts[2:7]) + prompt = " ".join(parts[7:]) + try: + task = sched.add(description=prompt[:60], prompt=prompt, cron=cron) + await ctx.reply(f"Scheduled task created:\n ID: {task.id}\n Cron: {cron}\n Prompt: {prompt}") + except ValueError as exc: + await ctx.reply(f"Error: {exc}") + elif action == "remove": + if len(parts) < 3: + await ctx.reply("Usage: /schedule remove ") + return + ok = sched.remove(parts[2]) + await ctx.reply(f"Task '{parts[2]}' removed." if ok else f"Task '{parts[2]}' not found.") diff --git a/app/runtime/messaging/commands/session.py b/app/runtime/messaging/commands/session.py new file mode 100644 index 0000000..c1f631d --- /dev/null +++ b/app/runtime/messaging/commands/session.py @@ -0,0 +1,130 @@ +"""Session and model management commands.""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING + +from ...config.settings import cfg + +if TYPE_CHECKING: + from ._dispatcher import CommandContext, CommandDispatcher + + +async def cmd_new(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + await dispatcher._agent.new_session() + if dispatcher._session_store: + dispatcher._session_store.start_session(uuid.uuid4().hex[:12], model=cfg.copilot_model) + await ctx.reply("New session started.") + + +async def cmd_model(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split(maxsplit=1) + if len(parts) < 2: + await ctx.reply(f"Current model: {cfg.copilot_model}\n\nUsage: /model ") + return + new_model = parts[1].strip() + old_model = cfg.copilot_model + cfg.write_env(COPILOT_MODEL=new_model) + await dispatcher._agent.new_session() + if dispatcher._session_store: + dispatcher._session_store.start_session(uuid.uuid4().hex[:12], model=new_model) + await ctx.reply(f"Model switched: {old_model} -> {new_model}\nNew session started.") + + +async def cmd_models(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + models = await dispatcher._agent.list_models() + if not models: + await ctx.reply("No models available.") + return + current = cfg.copilot_model + lines = ["Available Models", ""] + for m in models: + marker = " *" if m["id"] == current else "" + cost = f" ({m['billing_multiplier']}x)" if m.get("billing_multiplier", 1.0) != 1.0 else "" + reasoning = f" [reasoning: {', '.join(m['reasoning_efforts'])}]" if m.get("reasoning_efforts") else "" + policy = m.get("policy", "enabled") + if policy != "enabled": + lines.append(f" {m['id']}{marker}{cost} ({policy})") + else: + lines.append(f" {m['id']}{marker}{cost}{reasoning}") + lines.append(f"\nCurrent: {current}\nUse /model to switch.") + await ctx.reply("\n".join(lines)) + + +async def cmd_session(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + lines = [ + "Session Info", + f" Active: {'yes' if dispatcher._agent.has_session else 'no'}", + f" Model: {cfg.copilot_model}", + " Playwright MCP: enabled", + ] + await ctx.reply("\n".join(lines)) + + +async def cmd_sessions(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + if not dispatcher._session_store: + await ctx.reply("Session store not available.") + return + sessions = dispatcher._session_store.list_sessions() + if not sessions: + await ctx.reply("No recorded sessions.") + return + stats = dispatcher._session_store.get_session_stats() + lines = [f"Sessions ({stats['total_sessions']} total, {stats['total_messages']} messages)", ""] + for s in sessions[:10]: + started = s.get("started_at", "?")[:16] + lines.append(f" {s['id']} {started} {s.get('model', '?')} ({s.get('message_count', 0)} msgs)") + if len(sessions) > 10: + lines.append(f" ... and {len(sessions) - 10} more") + await ctx.reply("\n".join(lines)) + + +async def cmd_sessions_sub(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + if len(parts) >= 2 and parts[1].lower() == "clear": + if not dispatcher._session_store: + await ctx.reply("Session store not available.") + return + count = dispatcher._session_store.clear_all() + await ctx.reply(f"All sessions cleared ({count} deleted).") + else: + await cmd_sessions(dispatcher, ctx) + + +async def cmd_session_sub(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + if len(parts) >= 3 and parts[1].lower() == "delete": + if not dispatcher._session_store: + await ctx.reply("Session store not available.") + return + ok = dispatcher._session_store.delete_session(parts[2]) + await ctx.reply(f"Session '{parts[2]}' deleted." if ok else f"Session '{parts[2]}' not found.") + else: + await cmd_session(dispatcher, ctx) + + +async def cmd_change(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + if not dispatcher._session_store: + await ctx.reply("Session store not available.") + return + sessions = dispatcher._session_store.list_sessions() + if not sessions: + await ctx.reply("No sessions to switch to. Use /new to start one.") + return + lines = ["Recent Sessions:", ""] + for i, s in enumerate(sessions[:5], 1): + started = s.get("started_at", "?")[:16] + lines.append(f" {i}. {started} {s.get('model', '?')} ({s.get('message_count', 0)} msgs)") + lines.append(f" ID: {s['id']}") + await ctx.reply("\n".join(lines)) + + +async def cmd_clear(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + cleared = 0 + if cfg.memory_dir.is_dir(): + for f in cfg.memory_dir.rglob("*"): + if f.is_file(): + f.unlink() + cleared += 1 + await ctx.reply(f"Memory cleared ({cleared} files removed).") diff --git a/app/runtime/messaging/commands/system.py b/app/runtime/messaging/commands/system.py new file mode 100644 index 0000000..918b110 --- /dev/null +++ b/app/runtime/messaging/commands/system.py @@ -0,0 +1,206 @@ +"""System, status, and infrastructure commands.""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +from ...config.settings import cfg +from ...scheduler import get_scheduler +from ...state.profile import load_profile + +if TYPE_CHECKING: + from ._dispatcher import CommandContext, CommandDispatcher + +BOOT_TIME = time.monotonic() + + +async def cmd_status(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + uptime_seconds = int(time.monotonic() - BOOT_TIME) + hours, remainder = divmod(uptime_seconds, 3600) + minutes, seconds = divmod(remainder, 60) + + sched = get_scheduler() + tasks = sched.list_tasks() + active_tasks = [t for t in tasks if t.enabled] + total_reqs = sum(dispatcher._agent.request_counts.values()) + + lines = [ + "System Status", + f" Model: {cfg.copilot_model}", + f" Uptime: {hours}h {minutes}m {seconds}s", + f" Total requests: {total_reqs}", + ] + for model, count in sorted(dispatcher._agent.request_counts.items()): + lines.append(f" {model}: {count}") + if ctx.channel_ctx is not None: + channels = ctx.channel_ctx.connected_channels + lines.append(f" Connected channels: {', '.join(sorted(channels)) or 'none'}") + lines.append(f" Conversation refs: {ctx.channel_ctx.conversation_refs_count}") + lines.append(f" Scheduled tasks: {len(active_tasks)} active / {len(tasks)} total") + lines.append(f" Data dir: {cfg.data_dir}") + await ctx.reply("\n".join(lines)) + + +async def cmd_channels(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + lines = ["Channel Configuration\n"] + tg = dispatcher.infra.channels.telegram + if tg.token: + masked = tg.token[:8] + "..." + tg.token[-4:] if len(tg.token) > 12 else "***" + lines.append(f"Telegram:\n Token: {masked}\n Whitelist: {tg.whitelist or '(none)'}") + else: + lines.append("Telegram: not configured") + lines.append(f"\nBot Framework:\n App ID: {cfg.bot_app_id[:8] + '...' if cfg.bot_app_id else 'not set'}") + lines.append(f" Tenant: {cfg.bot_app_tenant_id[:8] + '...' if cfg.bot_app_tenant_id else 'not set'}") + lines.append(f" Admin secret: {'set' if cfg.admin_secret else 'not set'}") + if ctx.channel_ctx is not None: + refs = ctx.channel_ctx.conversation_refs + lines.append(f"\nActive Conversations ({len(refs)}):") + for r in refs: + user_name = r.user.name if r.user else "?" + lines.append(f" - {r.channel_id}: {user_name}") + await ctx.reply("\n".join(lines)) + + +async def cmd_profile(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + profile = load_profile() + lines = [ + "Agent Profile", + f" Name: {profile.get('name') or '(not set)'}", + f" Location: {profile.get('location') or '(not set)'}", + f" Emotional state: {profile.get('emotional_state', 'neutral')}", + ] + prefs = profile.get("preferences", {}) + if prefs: + lines.append(" Preferences:") + for k, v in prefs.items(): + lines.append(f" {k}: {v}") + await ctx.reply("\n".join(lines)) + + +async def cmd_config(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split(maxsplit=2) + if len(parts) == 1: + lines = [ + "Runtime Configuration", + f" Model: {cfg.copilot_model}", + f" Admin port: {cfg.admin_port}", + f" Bot port: {cfg.bot_port}", + f" Data dir: {cfg.data_dir}", + f" Admin secret: {'set' if cfg.admin_secret else 'not set'}", + "\nUsage: /config ", + ] + await ctx.reply("\n".join(lines)) + return + if len(parts) < 3: + await ctx.reply("Usage: /config ") + return + key = parts[1].upper() + allowed = {"COPILOT_MODEL", "ADMIN_PORT", "BOT_PORT", "VOICE_TARGET_NUMBER", "ACS_SOURCE_NUMBER"} + if key not in allowed: + await ctx.reply(f"Cannot set '{key}'. Allowed keys: {', '.join(sorted(allowed))}") + return + cfg.write_env(**{key: parts[2]}) + await ctx.reply(f"Config updated: {key} = {parts[2]}") + + +async def cmd_preflight(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + import aiohttp as _aiohttp + + base = f"http://127.0.0.1:{cfg.admin_port}" + headers = {"Authorization": f"Bearer {cfg.admin_secret}"} if cfg.admin_secret else {} + try: + async with _aiohttp.ClientSession() as session: + async with session.get(f"{base}/api/setup/preflight", headers=headers, timeout=_aiohttp.ClientTimeout(total=30)) as resp: + if resp.status != 200: + await ctx.reply(f"Preflight check failed (HTTP {resp.status}).") + return + data = await resp.json() + except Exception as exc: + await ctx.reply(f"Cannot reach preflight endpoint: {exc}") + return + + checks = data.get("checks", []) + lines = [f"Preflight Checks ({data.get('status', '?').upper()})"] + for c in checks: + icon = "OK" if c.get("ok") else "!!" + lines.append(f" [{icon}] {c['check']}: {c.get('detail', '')}") + await ctx.reply("\n".join(lines)) + + +async def cmd_phone(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split(maxsplit=1) + if len(parts) < 2: + await ctx.reply(f"Current target number: {cfg.voice_target_number or '(not set)'}\n\nUsage: /phone ") + return + number = parts[1].strip() + if not number.startswith("+"): + await ctx.reply("Phone number must start with + country code.") + return + cfg.write_env(VOICE_TARGET_NUMBER=number) + await ctx.reply(f"Voice target number set to {number}.") + + +async def cmd_call(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + import aiohttp as _aiohttp + + target = cfg.voice_target_number + if not target: + await ctx.reply("No target number configured. Use /phone first.") + return + base = f"http://127.0.0.1:{cfg.admin_port}" + headers = {"Authorization": f"Bearer {cfg.admin_secret}"} if cfg.admin_secret else {} + try: + async with _aiohttp.ClientSession() as session: + async with session.post(f"{base}/api/voice/call", json={"target_number": target}, headers=headers, timeout=_aiohttp.ClientTimeout(total=30)) as resp: + data = await resp.json() + if resp.status == 200: + await ctx.reply(f"Calling {target}...") + else: + await ctx.reply(f"Call failed: {data.get('error', f'HTTP {resp.status}')}") + except Exception as exc: + await ctx.reply(f"Call failed: {exc}") + + +async def cmd_lockdown(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + parts = ctx.text.split() + if len(parts) < 2: + state = "ENABLED" if cfg.lockdown_mode else "disabled" + await ctx.reply(f"Lock Down Mode: {state}\n\nUsage: /lockdown on | /lockdown off") + return + action = parts[1].lower() + if action not in ("on", "off"): + await ctx.reply("Usage: /lockdown on | /lockdown off") + return + if action == "on": + if cfg.lockdown_mode: + await ctx.reply("Lock Down Mode is already enabled.") + return + cfg.write_env(LOCKDOWN_MODE="1", TUNNEL_RESTRICTED="1") + from ...services.cloud.azure import AzureCLI + az = AzureCLI() + az.ok("logout") + az.invalidate_cache("account", "show") + await ctx.reply("Lock Down Mode ENABLED\n\n - Azure CLI logged out\n - Admin panel disabled") + else: + if not cfg.lockdown_mode: + await ctx.reply("Lock Down Mode is already disabled.") + return + cfg.write_env(LOCKDOWN_MODE="", TUNNEL_RESTRICTED="") + await ctx.reply("Lock Down Mode DISABLED\n\n - Admin panel re-enabled") + + +async def cmd_help(dispatcher: CommandDispatcher, ctx: CommandContext) -> None: + lines = [ + "Available Commands", + "", + " /new, /model , /models, /status, /session, /config", + " /skills, /addskill , /removeskill ", + " /plugins, /plugin enable|disable ", + " /mcp, /mcp add|remove|enable|disable ", + " /schedules, /schedule add|remove", + " /sessions, /session delete , /sessions clear", + " /change, /profile, /channels, /clear", + " /phone , /call, /preflight, /lockdown, /help", + ] + await ctx.reply("\n".join(lines)) diff --git a/app/runtime/messaging/message_processor.py b/app/runtime/messaging/message_processor.py index 13d9502..ab08992 100644 --- a/app/runtime/messaging/message_processor.py +++ b/app/runtime/messaging/message_processor.py @@ -62,14 +62,16 @@ async def process(self, ref: ConversationReference, prompt: str, channel: str) - async def bot_reply(text: str) -> None: await self._send_proactive_reply(ref, text, channel) - self._hitl.set_bot_reply_fn(bot_reply) - self._hitl.set_execution_context("bot_processor") - self._hitl.set_model(cfg.copilot_model) + self._hitl.bind_turn( + bot_reply_fn=bot_reply, + execution_context="bot_processor", + model=cfg.copilot_model, + ) try: response = await self._agent.send(prompt) finally: if self._hitl: - self._hitl.clear_bot_reply_fn() + self._hitl.unbind_turn() if response: self._memory.record("assistant", response) self.session_store.record("assistant", response) diff --git a/app/runtime/proactive_loop.py b/app/runtime/messaging/proactive_loop.py similarity index 96% rename from app/runtime/proactive_loop.py rename to app/runtime/messaging/proactive_loop.py index b3abe7c..d09579b 100644 --- a/app/runtime/proactive_loop.py +++ b/app/runtime/messaging/proactive_loop.py @@ -16,16 +16,16 @@ from pathlib import Path from typing import TYPE_CHECKING -from .config.settings import cfg -from .state.proactive import get_proactive_store -from .state.profile import log_interaction +from ..config.settings import cfg +from ..state.proactive import get_proactive_store +from ..state.profile import log_interaction if TYPE_CHECKING: - from .state.session_store import SessionStore + from ..state.session_store import SessionStore logger = logging.getLogger(__name__) -_TEMPLATES_DIR = Path(__file__).resolve().parent / "templates" +_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" # Minimum hours since last user activity before we proactively reach out. _MIN_USER_IDLE_HOURS = 1.0 @@ -94,9 +94,9 @@ def _gather_memory_context() -> str: def _gather_profile_context() -> str: """Read the user/agent profile JSON for LLM context.""" - from .state.profile import _profile_path + from ..state.profile import profile_path - path = _profile_path() + path = profile_path() if path.exists(): try: return path.read_text()[:1000] @@ -107,7 +107,7 @@ def _gather_profile_context() -> str: def _hours_since_last_session() -> float | None: """Return hours since the most recent session's last update.""" - from .state.session_store import SessionStore + from ..state.session_store import SessionStore store = SessionStore() sessions = store.list_sessions() @@ -137,7 +137,7 @@ async def _generate_proactive_message() -> str | None: Returns the message string or ``None`` if the LLM decided nothing is worth sending (``NO_FOLLOWUP``). """ - from .agent.one_shot import run_one_shot + from ..agent.one_shot import run_one_shot template = (_TEMPLATES_DIR / "proactive_generate_prompt.md").read_text() store = get_proactive_store() @@ -298,7 +298,7 @@ async def _deliver_message( session_store: "SessionStore | None", ) -> None: """Attempt to deliver a pending proactive message.""" - from .state.proactive import PendingMessage # noqa: F811 + from ..state.proactive import PendingMessage # noqa: F811 now = datetime.now(UTC) logger.info( diff --git a/app/runtime/realtime/__init__.py b/app/runtime/realtime/__init__.py index 25fe62e..8d8d152 100644 --- a/app/runtime/realtime/__init__.py +++ b/app/runtime/realtime/__init__.py @@ -1,5 +1,7 @@ """Realtime voice call module -- ACS + OpenAI Realtime API integration.""" +from __future__ import annotations + from .caller import AcsCaller from .middleware import RealtimeMiddleTier from .routes import RealtimeRoutes diff --git a/app/runtime/realtime/auth.py b/app/runtime/realtime/auth.py index 154104f..f5a903e 100644 --- a/app/runtime/realtime/auth.py +++ b/app/runtime/realtime/auth.py @@ -15,6 +15,8 @@ from aiohttp import web +from ..util.singletons import register_singleton + logger = logging.getLogger(__name__) _ACS_ISSUER = "https://acscallautomation.communication.azure.com" @@ -38,6 +40,16 @@ def get_learned_audience() -> str: return _learned_audience +def _reset_learned_audience() -> None: + """Clear the auto-learned audience (for test isolation).""" + global _learned_audience + with _audience_lock: + _learned_audience = "" + + +register_singleton(_reset_learned_audience) + + def validate_token_param(request: web.Request, expected_token: str) -> bool: return request.query.get("token", "") == expected_token diff --git a/app/runtime/realtime/caller.py b/app/runtime/realtime/caller.py index 552ea53..7356be3 100644 --- a/app/runtime/realtime/caller.py +++ b/app/runtime/realtime/caller.py @@ -58,16 +58,30 @@ def _ensure_client(self) -> Any: self._client = CallAutomationClient.from_connection_string(self.acs_connection_string) return self._client - async def initiate_call(self, target_number: str) -> None: + @staticmethod + def _build_media_config(ws_url: str) -> Any: + """Build the shared ``MediaStreamingOptions`` for ACS calls.""" from azure.communication.callautomation import ( AudioFormat, MediaStreamingAudioChannelType, MediaStreamingContentType, MediaStreamingOptions, - PhoneNumberIdentifier, StreamingTransportType, ) + return MediaStreamingOptions( + transport_url=ws_url, + transport_type=StreamingTransportType.WEBSOCKET, + content_type=MediaStreamingContentType.AUDIO, + audio_channel_type=MediaStreamingAudioChannelType.MIXED, + start_media_streaming=True, + enable_bidirectional=True, + audio_format=AudioFormat.PCM24_K_MONO, + ) + + async def initiate_call(self, target_number: str) -> None: + from azure.communication.callautomation import PhoneNumberIdentifier + callback_url = self.acs_callback_path ws_url = self.acs_media_streaming_websocket_path if not callback_url or not callback_url.startswith("https://"): @@ -78,16 +92,7 @@ async def initiate_call(self, target_number: str) -> None: client = self._ensure_client() target = PhoneNumberIdentifier(target_number) source = PhoneNumberIdentifier(self.source_number) - - media_config = MediaStreamingOptions( - transport_url=ws_url, - transport_type=StreamingTransportType.WEBSOCKET, - content_type=MediaStreamingContentType.AUDIO, - audio_channel_type=MediaStreamingAudioChannelType.MIXED, - start_media_streaming=True, - enable_bidirectional=True, - audio_format=AudioFormat.PCM24_K_MONO, - ) + media_config = self._build_media_config(ws_url) logger.info( "Initiating outbound call: target=%s, source=%s, callback=%s, ws=%s", @@ -108,24 +113,8 @@ async def initiate_call(self, target_number: str) -> None: raise async def answer_inbound_call(self, incoming_call_context: str) -> None: - from azure.communication.callautomation import ( - AudioFormat, - MediaStreamingAudioChannelType, - MediaStreamingContentType, - MediaStreamingOptions, - StreamingTransportType, - ) - client = self._ensure_client() - media_config = MediaStreamingOptions( - transport_url=self.acs_media_streaming_websocket_path, - transport_type=StreamingTransportType.WEBSOCKET, - content_type=MediaStreamingContentType.AUDIO, - audio_channel_type=MediaStreamingAudioChannelType.MIXED, - start_media_streaming=True, - enable_bidirectional=True, - audio_format=AudioFormat.PCM24_K_MONO, - ) + media_config = self._build_media_config(self.acs_media_streaming_websocket_path) logger.info("Answering inbound call") client.answer_call(incoming_call_context, self.acs_callback_path, media_streaming=media_config) logger.info("Inbound call answered") diff --git a/app/runtime/realtime/middleware.py b/app/runtime/realtime/middleware.py index f71c139..453e56f 100644 --- a/app/runtime/realtime/middleware.py +++ b/app/runtime/realtime/middleware.py @@ -5,7 +5,6 @@ import asyncio import json import logging -from pathlib import Path from typing import Any import aiohttp @@ -13,7 +12,7 @@ from azure.core.credentials import AzureKeyCredential from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from .prompt import REALTIME_SYSTEM_PROMPT +from .prompt import REALTIME_SYSTEM_PROMPT, TEMPLATES_DIR from .tools import ( ALL_REALTIME_TOOL_SCHEMAS, handle_check_agent_task, @@ -23,8 +22,6 @@ logger = logging.getLogger(__name__) -_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" - class RealtimeMiddleTier: """Proxies WebSocket traffic between a client and the OpenAI Realtime API.""" @@ -125,10 +122,10 @@ def _consume_pending(self) -> tuple[str, list[dict[str, Any]]]: else: parts: list[str] = [base] if prompt: - template = (_TEMPLATES_DIR / "realtime_call_instructions.md").read_text() + template = (TEMPLATES_DIR / "realtime_call_instructions.md").read_text() parts.append(template.format(prompt=prompt)) if opening_message: - template = (_TEMPLATES_DIR / "realtime_opening_message.md").read_text() + template = (TEMPLATES_DIR / "realtime_opening_message.md").read_text() parts.append(template.format(opening_message=opening_message)) effective_prompt = "\n\n".join(parts) if len(parts) > 1 else base @@ -280,12 +277,9 @@ async def _execute_tool(self, item: dict[str, Any], server_ws: ClientWebSocketRe logger.info("Realtime tool call: %s(%s)", name, args_str[:200]) - if name == "invoke_agent": - result = await handle_invoke_agent(args, self.agent) - elif name == "invoke_agent_async": - result = await handle_invoke_agent_async(args, self.agent) - elif name == "check_agent_task": - result = await handle_check_agent_task(args) + handler = _TOOL_DISPATCH.get(name) + if handler: + result = await handler(args, self.agent) else: result = f"Unknown tool: {name}" @@ -302,6 +296,21 @@ def _auth_headers(self) -> dict[str, str]: raise ValueError("No authentication configured for OpenAI Realtime") +# -- tool dispatch table --------------------------------------------------- + + +async def _dispatch_check_agent_task(args: dict[str, Any], agent: Any) -> str: + """Thin adapter so check_agent_task matches the ``(args, agent)`` signature.""" + return await handle_check_agent_task(args) + + +_TOOL_DISPATCH: dict[str, Any] = { + "invoke_agent": handle_invoke_agent, + "invoke_agent_async": handle_invoke_agent_async, + "check_agent_task": _dispatch_check_agent_task, +} + + class _ToolCall: __slots__ = ("call_id", "previous_id") diff --git a/app/runtime/realtime/prompt.py b/app/runtime/realtime/prompt.py index 578d941..e295271 100644 --- a/app/runtime/realtime/prompt.py +++ b/app/runtime/realtime/prompt.py @@ -1,7 +1,9 @@ -"""System prompt for the Realtime voice model.""" +"""System prompt and template directory for the Realtime voice model.""" + +from __future__ import annotations from pathlib import Path -_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" +TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" -REALTIME_SYSTEM_PROMPT: str = (_TEMPLATES_DIR / "realtime_prompt.md").read_text() +REALTIME_SYSTEM_PROMPT: str = (TEMPLATES_DIR / "realtime_prompt.md").read_text() diff --git a/app/runtime/realtime/tools.py b/app/runtime/realtime/tools.py index 80406f0..cc9e87c 100644 --- a/app/runtime/realtime/tools.py +++ b/app/runtime/realtime/tools.py @@ -13,6 +13,7 @@ from typing import Any from ..config.settings import cfg +from ..util.singletons import register_singleton logger = logging.getLogger(__name__) @@ -81,6 +82,9 @@ def _reset_task_store() -> None: _task_store = None +register_singleton(_reset_task_store) + + INVOKE_AGENT_SCHEMA = { "type": "function", "name": "invoke_agent", @@ -239,12 +243,11 @@ def _make_realtime_hook( guardrails policies are respected during voice-initiated tasks. """ from ..agent.hitl import HitlInterceptor - from ..state.guardrails_config import get_guardrails_config + from ..state.guardrails.config import get_guardrails_config store = get_guardrails_config() interceptor = HitlInterceptor(store) - interceptor.set_execution_context("realtime") - interceptor.set_model(_REALTIME_MODEL) + interceptor.bind_turn(execution_context="realtime", model=_REALTIME_MODEL) # Forward AITL / Prompt Shield / phone from the shared interceptor. shared_hitl = getattr(agent, "hitl_interceptor", None) diff --git a/app/runtime/registries/__init__.py b/app/runtime/registries/__init__.py index f79f857..5eed2e5 100644 --- a/app/runtime/registries/__init__.py +++ b/app/runtime/registries/__init__.py @@ -1,5 +1,10 @@ """Plugin and skill registries.""" +from __future__ import annotations + +from .plugins import PluginManifest, PluginRegistry, get_plugin_registry +from .skills import SkillInfo, SkillRegistry, get_registry + __all__ = [ "PluginManifest", "PluginRegistry", diff --git a/app/runtime/registries/catalog.py b/app/runtime/registries/catalog.py new file mode 100644 index 0000000..db95aaa --- /dev/null +++ b/app/runtime/registries/catalog.py @@ -0,0 +1,287 @@ +"""GitHub skill catalog -- fetch remote skill listings and install them.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import re +import shutil +from pathlib import Path +from typing import Any + +import aiohttp + +from ..config.settings import cfg + +logger = logging.getLogger(__name__) + +_CATALOG_SOURCES: list[dict[str, str]] = [ + { + "owner": "github", + "repo": "awesome-copilot", + "path": "skills", + "branch": "main", + "label": "GitHub Awesome Copilot", + "category": "github-awesome", + }, + { + "owner": "anthropics", + "repo": "skills", + "path": "skills", + "branch": "main", + "label": "Anthropic Skills", + "category": "anthropic", + }, +] + +_GITHUB_API = "https://api.github.com" +_GITHUB_RAW = "https://raw.githubusercontent.com" +_ORIGIN_FILE = ".origin" + + +def _github_headers() -> dict[str, str]: + """Build common headers for GitHub API requests.""" + headers: dict[str, str] = { + "Accept": "application/vnd.github.v3+json", + "User-Agent": "polyclaw-skill-registry", + } + token = cfg.github_token + if token: + headers["Authorization"] = f"token {token}" + return headers + + +async def fetch_catalog( + installed_names: set[str], + parse_frontmatter: Any, + curated_skills: set[str], +) -> tuple[list[Any], bool, int | None]: + """Fetch remote skill catalog from all configured GitHub sources. + + Returns ``(skills, rate_limited, rate_limit_reset)``. + """ + from .skills import SkillInfo + + rate_limited = False + rate_limit_reset: int | None = None + + headers = _github_headers() + all_skills: list[SkillInfo] = [] + + async with aiohttp.ClientSession(headers=headers) as session: + tasks = [ + _fetch_source(session, src, installed_names, parse_frontmatter, curated_skills) + for src in _CATALOG_SOURCES + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, res in enumerate(results): + if isinstance(res, list): + all_skills.extend(res) + elif isinstance(res, _RateLimited): + rate_limited = True + rate_limit_reset = res.reset_at + elif isinstance(res, Exception): + logger.error("Catalog source %s failed: %s", _CATALOG_SOURCES[i]["label"], res) + + try: + await _fetch_commit_counts(all_skills) + except Exception: + pass + + return all_skills, rate_limited, rate_limit_reset + + +class _RateLimited(Exception): + """Raised internally when GitHub returns a 403 rate-limit response.""" + + def __init__(self, reset_at: int | None = None) -> None: + self.reset_at = reset_at + + +async def _fetch_source( + session: aiohttp.ClientSession, + src: dict[str, str], + installed_names: set[str], + parse_frontmatter: Any, + curated_skills: set[str], +) -> list[Any]: + from .skills import SkillInfo + + url = ( + f"{_GITHUB_API}/repos/{src['owner']}/{src['repo']}" + f"/contents/{src['path']}?ref={src['branch']}" + ) + try: + async with session.get(url) as resp: + if resp.status != 200: + remaining = resp.headers.get("X-RateLimit-Remaining", "?") + if resp.status == 403 and remaining == "0": + reset_at: int | None = None + try: + reset_at = int(resp.headers.get("X-RateLimit-Reset", "0")) + except (ValueError, TypeError): + pass + raise _RateLimited(reset_at) + return [] + entries = await resp.json() + except _RateLimited: + raise + except Exception as exc: + logger.error("GitHub API request failed: %s", exc) + return [] + + if not isinstance(entries, list): + return [] + + sem = asyncio.Semaphore(20) + + async def _get_skill(name: str) -> SkillInfo | None: + async with sem: + raw_url = ( + f"{_GITHUB_RAW}/{src['owner']}/{src['repo']}" + f"/{src['branch']}/{src['path']}/{name}/SKILL.md" + ) + try: + async with session.get(raw_url) as r: + fm = parse_frontmatter(await r.text()) if r.status == 200 else {} + except Exception: + fm = {} + skill_name = fm.get("name", name) + return SkillInfo( + name=skill_name, + verb=fm.get("verb", name), + description=fm.get("description", ""), + source=src["label"], + category=src.get("category", ""), + repo_owner=src["owner"], + repo_name=src["repo"], + repo_path=f"{src['path']}/{name}", + repo_branch=src["branch"], + installed=skill_name in installed_names, + recommended=skill_name in curated_skills, + ) + + results = await asyncio.gather( + *[_get_skill(e["name"]) for e in entries if e.get("type") == "dir"], + return_exceptions=True, + ) + return [r for r in results if isinstance(r, SkillInfo)] + + +async def _fetch_commit_counts(skills: list[Any]) -> None: + headers = _github_headers() + sem = asyncio.Semaphore(10) + + async def _get_count(session: aiohttp.ClientSession, skill: Any) -> None: + if not skill.repo_owner: + return + async with sem: + url = ( + f"{_GITHUB_API}/repos/{skill.repo_owner}/{skill.repo_name}" + f"/commits?path={skill.repo_path}&sha={skill.repo_branch}&per_page=1" + ) + try: + async with session.get(url) as resp: + if resp.status != 200: + return + link = resp.headers.get("Link", "") + match = re.search(r'page=(\d+)>; rel="last"', link) + if match: + skill.edit_count = int(match.group(1)) + else: + data = await resp.json() + skill.edit_count = len(data) if isinstance(data, list) else 0 + except Exception: + pass + + async with aiohttp.ClientSession(headers=headers) as session: + await asyncio.gather( + *[_get_count(session, s) for s in skills], return_exceptions=True + ) + + +async def install_from_catalog( + skill: Any, + target_dir: Path, +) -> str | None: + """Download a skill from GitHub into *target_dir*. + + Returns ``None`` on success, or an error message string. + """ + headers = _github_headers() + + try: + async with aiohttp.ClientSession(headers=headers) as session: + await _download_dir( + session, + owner=skill.repo_owner, + repo=skill.repo_name, + path=skill.repo_path, + branch=skill.repo_branch, + target=target_dir, + ) + except Exception as exc: + if target_dir.exists(): + shutil.rmtree(target_dir) + return f"Download failed for skill {skill.name!r}: {exc}" + + origin_path = target_dir / _ORIGIN_FILE + origin_path.write_text( + json.dumps( + { + "origin": "marketplace", + "source": skill.source, + "category": skill.category, + "repo_owner": skill.repo_owner, + "repo_name": skill.repo_name, + "repo_path": skill.repo_path, + }, + indent=2, + ) + + "\n" + ) + return None + + +async def _download_dir( + session: aiohttp.ClientSession, + *, + owner: str, + repo: str, + path: str, + branch: str, + target: Path, +) -> None: + """Recursively download a directory from a GitHub repo.""" + url = f"{_GITHUB_API}/repos/{owner}/{repo}/contents/{path}?ref={branch}" + async with session.get(url) as resp: + if resp.status != 200: + body = await resp.text() + raise RuntimeError(f"GitHub API HTTP {resp.status} for {url}: {body[:500]}") + entries = await resp.json() + + if not isinstance(entries, list): + entries = [entries] + + for entry in entries: + if entry["type"] == "file": + raw_url = ( + entry.get("download_url") + or f"{_GITHUB_RAW}/{owner}/{repo}/{branch}/{entry['path']}" + ) + async with session.get(raw_url) as file_resp: + if file_resp.status == 200: + (target / entry["name"]).write_bytes(await file_resp.read()) + elif entry["type"] == "dir": + sub_dir = target / entry["name"] + sub_dir.mkdir(parents=True, exist_ok=True) + await _download_dir( + session, + owner=owner, + repo=repo, + path=entry["path"], + branch=branch, + target=sub_dir, + ) diff --git a/app/runtime/registries/skills.py b/app/runtime/registries/skills.py index 455d122..454cf13 100644 --- a/app/runtime/registries/skills.py +++ b/app/runtime/registries/skills.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import json import logging import re @@ -11,35 +10,11 @@ from pathlib import Path from typing import Any -import aiohttp - from ..config.settings import cfg from ..util.singletons import register_singleton logger = logging.getLogger(__name__) -_CATALOG_SOURCES: list[dict[str, str]] = [ - { - "owner": "github", - "repo": "awesome-copilot", - "path": "skills", - "branch": "main", - "label": "GitHub Awesome Copilot", - "category": "github-awesome", - }, - { - "owner": "anthropics", - "repo": "skills", - "path": "skills", - "branch": "main", - "label": "Anthropic Skills", - "category": "anthropic", - }, -] - -_GITHUB_API = "https://api.github.com" -_GITHUB_RAW = "https://raw.githubusercontent.com" -_CATALOG_CACHE_TTL = 300 _CURATED_SKILLS: set[str] = {"web-search", "summarize-url", "daily-briefing"} _ORIGIN_FILE = ".origin" @@ -200,11 +175,13 @@ def remove(self, name: str) -> bool: async def fetch_catalog(self, *, force: bool = False) -> list[SkillInfo]: import time + from .catalog import fetch_catalog as _fetch_catalog + now = time.monotonic() if ( not force and self._catalog_cache is not None - and (now - self._catalog_ts) < _CATALOG_CACHE_TTL + and (now - self._catalog_ts) < 300 ): return self._catalog_cache @@ -212,137 +189,19 @@ async def fetch_catalog(self, *, force: bool = False) -> list[SkillInfo]: self.rate_limited = False self.rate_limit_reset = None - headers: dict[str, str] = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "polyclaw-skill-registry", - } - token = cfg.github_token - if token: - headers["Authorization"] = f"token {token}" - - all_skills: list[SkillInfo] = [] - async with aiohttp.ClientSession(headers=headers) as session: - tasks = [self._fetch_source(session, src, installed_names) for src in _CATALOG_SOURCES] - results = await asyncio.gather(*tasks, return_exceptions=True) - - for i, res in enumerate(results): - if isinstance(res, list): - all_skills.extend(res) - elif isinstance(res, Exception): - logger.error("Catalog source %s failed: %s", _CATALOG_SOURCES[i]["label"], res) - - try: - await self._fetch_commit_counts(all_skills) - except Exception: - pass + all_skills, rate_limited, rate_limit_reset = await _fetch_catalog( + installed_names, _parse_frontmatter, _CURATED_SKILLS, + ) + self.rate_limited = rate_limited + self.rate_limit_reset = rate_limit_reset self._catalog_cache = all_skills self._catalog_ts = now return all_skills - async def _fetch_source( - self, - session: aiohttp.ClientSession, - src: dict[str, str], - installed_names: set[str], - ) -> list[SkillInfo]: - url = ( - f"{_GITHUB_API}/repos/{src['owner']}/{src['repo']}" - f"/contents/{src['path']}?ref={src['branch']}" - ) - try: - async with session.get(url) as resp: - if resp.status != 200: - remaining = resp.headers.get("X-RateLimit-Remaining", "?") - if resp.status == 403 and remaining == "0": - self.rate_limited = True - try: - self.rate_limit_reset = int( - resp.headers.get("X-RateLimit-Reset", "0") - ) - except (ValueError, TypeError): - pass - return [] - entries = await resp.json() - except Exception as exc: - logger.error("GitHub API request failed: %s", exc) - return [] - - if not isinstance(entries, list): - return [] - - sem = asyncio.Semaphore(20) - - async def _get_skill(name: str) -> SkillInfo | None: - async with sem: - raw_url = ( - f"{_GITHUB_RAW}/{src['owner']}/{src['repo']}" - f"/{src['branch']}/{src['path']}/{name}/SKILL.md" - ) - try: - async with session.get(raw_url) as r: - fm = _parse_frontmatter(await r.text()) if r.status == 200 else {} - except Exception: - fm = {} - skill_name = fm.get("name", name) - return SkillInfo( - name=skill_name, - verb=fm.get("verb", name), - description=fm.get("description", ""), - source=src["label"], - category=src.get("category", ""), - repo_owner=src["owner"], - repo_name=src["repo"], - repo_path=f"{src['path']}/{name}", - repo_branch=src["branch"], - installed=skill_name in installed_names, - recommended=skill_name in _CURATED_SKILLS, - ) - - results = await asyncio.gather( - *[_get_skill(e["name"]) for e in entries if e.get("type") == "dir"], - return_exceptions=True, - ) - return [r for r in results if isinstance(r, SkillInfo)] - - async def _fetch_commit_counts(self, skills: list[SkillInfo]) -> None: - headers: dict[str, str] = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "polyclaw-skill-registry", - } - token = cfg.github_token - if token: - headers["Authorization"] = f"token {token}" - sem = asyncio.Semaphore(10) - - async def _get_count(session: aiohttp.ClientSession, skill: SkillInfo) -> None: - if not skill.repo_owner: - return - async with sem: - url = ( - f"{_GITHUB_API}/repos/{skill.repo_owner}/{skill.repo_name}" - f"/commits?path={skill.repo_path}&sha={skill.repo_branch}&per_page=1" - ) - try: - async with session.get(url) as resp: - if resp.status != 200: - return - link = resp.headers.get("Link", "") - match = re.search(r'page=(\d+)>; rel="last"', link) - if match: - skill.edit_count = int(match.group(1)) - else: - data = await resp.json() - skill.edit_count = len(data) if isinstance(data, list) else 0 - except Exception: - pass - - async with aiohttp.ClientSession(headers=headers) as session: - await asyncio.gather( - *[_get_count(session, s) for s in skills], return_exceptions=True - ) - async def install(self, name: str) -> str | None: + from .catalog import install_from_catalog + catalog = await self.fetch_catalog() skill = next((s for s in catalog if s.name == name), None) if not skill: @@ -353,90 +212,14 @@ async def install(self, name: str) -> str | None: target_dir = cfg.user_skills_dir / name target_dir.mkdir(parents=True, exist_ok=True) - headers: dict[str, str] = { - "Accept": "application/vnd.github.v3+json", - "User-Agent": "polyclaw-skill-registry", - } - token = cfg.github_token - if token: - headers["Authorization"] = f"token {token}" - - try: - async with aiohttp.ClientSession(headers=headers) as session: - await self._download_dir( - session, - owner=skill.repo_owner, - repo=skill.repo_name, - path=skill.repo_path, - branch=skill.repo_branch, - target=target_dir, - ) - except Exception as exc: - if target_dir.exists(): - shutil.rmtree(target_dir) - return f"Download failed for skill {name!r}: {exc}" - - origin_path = target_dir / _ORIGIN_FILE - origin_path.write_text( - json.dumps( - { - "origin": "marketplace", - "source": skill.source, - "category": skill.category, - "repo_owner": skill.repo_owner, - "repo_name": skill.repo_name, - "repo_path": skill.repo_path, - }, - indent=2, - ) - + "\n" - ) + error = await install_from_catalog(skill, target_dir) + if error: + return error self._catalog_cache = None logger.info("Installed skill: %s -> %s", name, target_dir) return None - async def _download_dir( - self, - session: aiohttp.ClientSession, - *, - owner: str, - repo: str, - path: str, - branch: str, - target: Path, - ) -> None: - url = f"{_GITHUB_API}/repos/{owner}/{repo}/contents/{path}?ref={branch}" - async with session.get(url) as resp: - if resp.status != 200: - body = await resp.text() - raise RuntimeError(f"GitHub API HTTP {resp.status} for {url}: {body[:500]}") - entries = await resp.json() - - if not isinstance(entries, list): - entries = [entries] - - for entry in entries: - if entry["type"] == "file": - raw_url = ( - entry.get("download_url") - or f"{_GITHUB_RAW}/{owner}/{repo}/{branch}/{entry['path']}" - ) - async with session.get(raw_url) as file_resp: - if file_resp.status == 200: - (target / entry["name"]).write_bytes(await file_resp.read()) - elif entry["type"] == "dir": - sub_dir = target / entry["name"] - sub_dir.mkdir(parents=True, exist_ok=True) - await self._download_dir( - session, - owner=owner, - repo=repo, - path=entry["path"], - branch=branch, - target=sub_dir, - ) - _registry: SkillRegistry | None = None diff --git a/app/runtime/sandbox/__init__.py b/app/runtime/sandbox/__init__.py new file mode 100644 index 0000000..2ade41b --- /dev/null +++ b/app/runtime/sandbox/__init__.py @@ -0,0 +1,20 @@ +"""Agent sandbox executor -- runs agent commands in ACA Dynamic Sessions. + +.. warning:: This feature is experimental and may change or be removed in + future releases. +""" + +from __future__ import annotations + +from .executor import SandboxExecutor +from .helpers import _build_replay_command, _extract_command, _is_shell_tool, _parse_tool_args +from .interceptor import SandboxToolInterceptor + +__all__ = [ + "SandboxExecutor", + "SandboxToolInterceptor", + "_build_replay_command", + "_extract_command", + "_is_shell_tool", + "_parse_tool_args", +] diff --git a/app/runtime/sandbox.py b/app/runtime/sandbox/executor.py similarity index 67% rename from app/runtime/sandbox.py rename to app/runtime/sandbox/executor.py index 2adfbdf..b9be188 100644 --- a/app/runtime/sandbox.py +++ b/app/runtime/sandbox/executor.py @@ -1,8 +1,4 @@ -"""Agent sandbox executor -- runs agent commands in ACA Dynamic Sessions. - -.. warning:: This feature is experimental and may change or be removed in - future releases. -""" +"""Sandbox executor -- runs agent commands in ACA Dynamic Sessions.""" from __future__ import annotations @@ -12,7 +8,6 @@ import json import logging import os -import shlex import shutil import time import uuid @@ -22,8 +17,8 @@ import aiohttp -from .config.settings import cfg -from .state.sandbox_config import SandboxConfigStore +from ..config.settings import cfg +from ..state.sandbox_config import SandboxConfigStore logger = logging.getLogger(__name__) @@ -31,8 +26,6 @@ TOKEN_SCOPE = "https://dynamicsessions.io/.default" MAX_ZIP_SIZE = 100 * 1024 * 1024 -_SHELL_TOOL_PATTERNS = ("terminal", "shell", "bash", "command") -_SESSION_IDLE_TIMEOUT = 60 _UPLOAD_MAX_RETRIES = 3 _UPLOAD_BACKOFF_BASE = 1.0 @@ -87,15 +80,18 @@ async def execute( return self._result(False, f"Failed to create code archive: {exc}", start, session_id) if data_zip: - if not await self._upload_bytes(http, endpoint, session_id, "agent_data.zip", data_zip, headers): - return self._result(False, "Data upload failed", start, session_id) + err = await self._upload_bytes(http, endpoint, session_id, "agent_data.zip", data_zip, headers) + if err: + return self._result(False, f"Data upload failed: {err}", start, session_id) - if not await self._upload_bytes(http, endpoint, session_id, "polyclaw_code.zip", code_zip, headers): - return self._result(False, "Code upload failed", start, session_id) + err = await self._upload_bytes(http, endpoint, session_id, "polyclaw_code.zip", code_zip, headers) + if err: + return self._result(False, f"Code upload failed: {err}", start, session_id) bootstrap = self._build_bootstrap_script(command, has_data=data_zip is not None, env_vars=env_vars) - if not await self._upload_bytes(http, endpoint, session_id, "bootstrap.sh", bootstrap.encode(), headers): - return self._result(False, "Bootstrap upload failed", start, session_id) + err = await self._upload_bytes(http, endpoint, session_id, "bootstrap.sh", bootstrap.encode(), headers) + if err: + return self._result(False, f"Bootstrap upload failed: {err}", start, session_id) exec_result = await self._execute_in_session(http, endpoint, session_id, headers, timeout) if not exec_result["success"]: @@ -261,8 +257,14 @@ async def _get_token(self) -> str: async def _upload_bytes( self, http: aiohttp.ClientSession, endpoint: str, session_id: str, filename: str, data: bytes, headers: dict[str, str], - ) -> bool: + ) -> str: + """Upload bytes to the session. Returns empty string on success, error detail on failure.""" url = f"{endpoint}/files/upload?api-version={API_VERSION}&identifier={session_id}" + size_kb = len(data) / 1024 + logger.info( + "[sandbox.upload] file=%s size=%.1fKB session=%s", + filename, size_kb, session_id, + ) last_error = "" for attempt in range(_UPLOAD_MAX_RETRIES): form = aiohttp.FormData() @@ -275,7 +277,7 @@ async def _upload_bytes( timeout=aiohttp.ClientTimeout(total=120), ) as resp: if resp.status in (200, 201, 202): - return True + return "" body = await resp.text() last_error = f"HTTP {resp.status}: {body[:300]}" logger.warning( @@ -295,7 +297,7 @@ async def _upload_bytes( "Upload %s failed after %d attempts: %s", filename, _UPLOAD_MAX_RETRIES, last_error, ) - return False + return last_error async def _execute_in_session( self, http: aiohttp.ClientSession, endpoint: str, session_id: str, @@ -325,15 +327,10 @@ async def _execute_in_session( return {"success": False, "error": f"Execution failed: {resp.status} {text[:300]}"} result = await resp.json() props = result.get("properties", {}) - raw_stdout = props.get("stdout", "") - try: - output = json.loads(raw_stdout.strip()) - stdout, stderr, rc = output.get("stdout", ""), output.get("stderr", ""), output.get("rc", 0) - except (json.JSONDecodeError, AttributeError, TypeError): - stdout, stderr, rc = raw_stdout, props.get("stderr", ""), 1 - if rc != 0: - return {"success": False, "stdout": stdout, "stderr": stderr, "error": stderr or f"Exit code {rc}"} - return {"success": True, "stdout": stdout, "stderr": stderr} + return self._parse_exec_result( + props.get("stdout", ""), + fallback_stderr=props.get("stderr", ""), + ) except Exception as exc: logger.error("Sandbox exec exception: %s", exc, exc_info=True) return {"success": False, "error": str(exc)} @@ -353,20 +350,29 @@ async def provision_session(self, session_id: str) -> dict[str, Any]: headers = {"Authorization": f"Bearer {token}"} data_zip = self._create_data_zip() if self._store.sync_data else None + has_data = False if data_zip: - if not await self._upload_bytes(http, endpoint, session_id, "agent_data.zip", data_zip, headers): - return self._result(False, "Data upload failed", start, session_id) + err = await self._upload_bytes(http, endpoint, session_id, "agent_data.zip", data_zip, headers) + if err: + logger.warning( + "[sandbox.provision] Data upload failed (non-fatal), " + "continuing without data sync: %s", err, + ) + else: + has_data = True try: code_zip = self._create_code_zip() except Exception as exc: return self._result(False, f"Code archive failed: {exc}", start, session_id) - if not await self._upload_bytes(http, endpoint, session_id, "polyclaw_code.zip", code_zip, headers): - return self._result(False, "Code upload failed", start, session_id) + err = await self._upload_bytes(http, endpoint, session_id, "polyclaw_code.zip", code_zip, headers) + if err: + return self._result(False, f"Code upload failed: {err}", start, session_id) - setup = self._build_bootstrap_script("echo 'Session bootstrapped OK'", has_data=data_zip is not None) - if not await self._upload_bytes(http, endpoint, session_id, "bootstrap.sh", setup.encode(), headers): - return self._result(False, "Bootstrap upload failed", start, session_id) + setup = self._build_bootstrap_script("echo 'Session bootstrapped OK'", has_data=has_data) + err = await self._upload_bytes(http, endpoint, session_id, "bootstrap.sh", setup.encode(), headers) + if err: + return self._result(False, f"Bootstrap upload failed: {err}", start, session_id) exec_result = await self._execute_in_session(http, endpoint, session_id, headers, timeout=120) if not exec_result["success"]: @@ -413,19 +419,36 @@ async def _execute_code( text = await resp.text() return {"success": False, "error": f"HTTP {resp.status}: {text[:300]}"} result = await resp.json() - raw_stdout = result.get("properties", {}).get("stdout", "") - try: - output = json.loads(raw_stdout.strip()) - stdout, stderr, rc = output.get("stdout", ""), output.get("stderr", ""), output.get("rc", 0) - except (json.JSONDecodeError, AttributeError, TypeError): - stdout, stderr, rc = raw_stdout, result.get("properties", {}).get("stderr", ""), 1 - if rc != 0: - return {"success": False, "stdout": stdout, "stderr": stderr, "error": stderr or f"Exit code {rc}"} - return {"success": True, "stdout": stdout, "stderr": stderr} + props = result.get("properties", {}) + return self._parse_exec_result( + props.get("stdout", ""), + fallback_stderr=props.get("stderr", ""), + ) except Exception as exc: logger.error("Session exec exception: %s", exc, exc_info=True) return {"success": False, "error": str(exc)} + @staticmethod + def _parse_exec_result( + raw_stdout: str, fallback_stderr: str = "", + ) -> dict[str, Any]: + """Parse JSON-wrapped subprocess output into a result dict.""" + try: + output = json.loads(raw_stdout.strip()) + stdout = output.get("stdout", "") + stderr = output.get("stderr", "") + rc = output.get("rc", 0) + except (json.JSONDecodeError, AttributeError, TypeError): + stdout, stderr, rc = raw_stdout, fallback_stderr, 1 + if rc != 0: + return { + "success": False, + "stdout": stdout, + "stderr": stderr, + "error": stderr or f"Exit code {rc}", + } + return {"success": True, "stdout": stdout, "stderr": stderr} + async def destroy_session(self, session_id: str) -> None: if self._store.sync_data: try: @@ -457,164 +480,3 @@ def _timing(self, start: float, session_id: str) -> dict[str, Any]: def _result(self, success: bool, error: str, start: float, session_id: str) -> dict[str, Any]: return {"success": success, "error": error, **self._timing(start, session_id)} - - -class SandboxToolInterceptor: - def __init__(self, executor: SandboxExecutor) -> None: - self._executor = executor - self._session_id: str | None = None - self._session_ready: bool = False - self._provisioning: bool = False - self._last_activity: float = 0 - self._idle_task: asyncio.Task | None = None - self._pending_result: dict[str, Any] | None = None - - @property - def session_id(self) -> str | None: - return self._session_id - - async def _ensure_session(self) -> str: - self._last_activity = time.time() - if self._session_id and self._session_ready: - return self._session_id - - self._session_id = str(uuid.uuid4()) - self._session_ready = False - self._provisioning = True - - try: - result = await self._executor.provision_session(self._session_id) - if not result["success"]: - self._session_id = None - raise RuntimeError(f"Sandbox session provision failed: {result.get('error')}") - self._session_ready = True - finally: - self._provisioning = False - - self._start_idle_timer() - return self._session_id - - async def _teardown_session(self) -> None: - if not self._session_id: - return - sid = self._session_id - self._session_id = None - self._session_ready = False - if self._idle_task and not self._idle_task.done(): - self._idle_task.cancel() - self._idle_task = None - try: - await self._executor.destroy_session(sid) - except Exception as exc: - logger.warning("Session teardown error: %s", exc) - - def _start_idle_timer(self) -> None: - if self._idle_task and not self._idle_task.done(): - self._idle_task.cancel() - self._idle_task = asyncio.ensure_future(self._idle_reaper()) - - async def _idle_reaper(self) -> None: - try: - while True: - await asyncio.sleep(10) - if not self._session_id: - return - if time.time() - self._last_activity >= _SESSION_IDLE_TIMEOUT: - await self._teardown_session() - return - except asyncio.CancelledError: - pass - - def touch(self) -> None: - self._last_activity = time.time() - - async def on_pre_tool_use(self, input_data: dict, ctx: dict) -> dict | None: - tool_name = input_data.get("toolName", "") - if not self._executor.enabled: - return {"permissionDecision": "allow"} - if not _is_shell_tool(tool_name): - return {"permissionDecision": "allow"} - - tool_args = _parse_tool_args(input_data.get("toolArgs")) - command = _extract_command(tool_args) - if not command: - return {"permissionDecision": "allow"} - - try: - session_id = await self._ensure_session() - result = await self._executor.run_in_session(session_id, command, timeout=120) - self._last_activity = time.time() - except Exception as exc: - logger.error("Sandbox interceptor failed: %s", exc, exc_info=True) - result = {"success": False, "stdout": "", "stderr": str(exc)} - - self._pending_result = result - replay = _build_replay_command( - result.get("stdout", ""), result.get("stderr", ""), result.get("success", False) - ) - noop_args = dict(tool_args) - noop_args["command"] = replay - if "input" in noop_args: - noop_args["input"] = replay - return {"permissionDecision": "allow", "modifiedArgs": noop_args} - - async def on_post_tool_use(self, input_data: dict, ctx: dict) -> dict | None: - if self._pending_result is None: - return None - - result = self._pending_result - self._pending_result = None - - parts: list[str] = [] - if result.get("stdout"): - parts.append(result["stdout"]) - if result.get("stderr"): - parts.append(f"STDERR:\n{result['stderr']}") - output = "\n".join(parts) if parts else "(no output)" - if not result.get("success"): - output = f"Command failed in sandbox.\n{output}" - return {"modifiedResult": output} - - -def _parse_tool_args(raw: Any) -> dict: - if isinstance(raw, dict): - return raw - if isinstance(raw, str): - try: - parsed = json.loads(raw) - if isinstance(parsed, dict): - return parsed - except (json.JSONDecodeError, TypeError): - pass - return {} - - -def _extract_command(args: Any) -> str: - if isinstance(args, str): - try: - parsed = json.loads(args) - if isinstance(parsed, dict): - args = parsed - else: - return args - except (json.JSONDecodeError, TypeError): - return args - if isinstance(args, dict): - return args.get("command", "") or args.get("cmd", "") or args.get("input", "") or args.get("script", "") - return "" - - -def _is_shell_tool(name: str) -> bool: - lower = name.lower() - return any(p in lower for p in _SHELL_TOOL_PATTERNS) - - -def _build_replay_command(stdout: str, stderr: str, success: bool) -> str: - parts: list[str] = [] - if stdout: - parts.append(f"printf %s {shlex.quote(stdout)}") - if stderr: - parts.append(f"printf %s {shlex.quote(stderr)} >&2") - if not success: - parts.append("exit 1") - return " ; ".join(parts) if parts else "true" diff --git a/app/runtime/sandbox/helpers.py b/app/runtime/sandbox/helpers.py new file mode 100644 index 0000000..c606072 --- /dev/null +++ b/app/runtime/sandbox/helpers.py @@ -0,0 +1,53 @@ +"""Sandbox helper utilities for tool argument parsing and command replay.""" + +from __future__ import annotations + +import json +import shlex +from typing import Any + +_SHELL_TOOL_PATTERNS = ("terminal", "shell", "bash", "command") + + +def _parse_tool_args(raw: Any) -> dict: + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, TypeError): + pass + return {} + + +def _extract_command(args: Any) -> str: + if isinstance(args, str): + try: + parsed = json.loads(args) + if isinstance(parsed, dict): + args = parsed + else: + return args + except (json.JSONDecodeError, TypeError): + return args + if isinstance(args, dict): + return args.get("command", "") or args.get("cmd", "") or args.get("input", "") or args.get("script", "") + return "" + + +def _is_shell_tool(name: str) -> bool: + lower = name.lower() + return any(p in lower for p in _SHELL_TOOL_PATTERNS) + + +def _build_replay_command(stdout: str, stderr: str, success: bool) -> str: + parts: list[str] = [] + if stdout: + parts.append(f"printf %s {shlex.quote(stdout)}") + if stderr: + parts.append(f"printf %s {shlex.quote(stderr)} >&2") + if not success: + parts.append("exit 1") + return " ; ".join(parts) if parts else "true" diff --git a/app/runtime/sandbox/interceptor.py b/app/runtime/sandbox/interceptor.py new file mode 100644 index 0000000..ca6b155 --- /dev/null +++ b/app/runtime/sandbox/interceptor.py @@ -0,0 +1,133 @@ +"""Sandbox tool interceptor -- intercepts shell tool calls for sandbox execution.""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import Any + +from .executor import SandboxExecutor +from .helpers import _build_replay_command, _extract_command, _is_shell_tool, _parse_tool_args + +logger = logging.getLogger(__name__) + +_SESSION_IDLE_TIMEOUT = 60 + + +class SandboxToolInterceptor: + def __init__(self, executor: SandboxExecutor) -> None: + self._executor = executor + self._session_id: str | None = None + self._session_ready: bool = False + self._provisioning: bool = False + self._last_activity: float = 0 + self._idle_task: asyncio.Task | None = None + self._pending_result: dict[str, Any] | None = None + + @property + def session_id(self) -> str | None: + return self._session_id + + async def _ensure_session(self) -> str: + self._last_activity = time.time() + if self._session_id and self._session_ready: + return self._session_id + + self._session_id = str(uuid.uuid4()) + self._session_ready = False + self._provisioning = True + + try: + result = await self._executor.provision_session(self._session_id) + if not result["success"]: + self._session_id = None + raise RuntimeError(f"Sandbox session provision failed: {result.get('error')}") + self._session_ready = True + finally: + self._provisioning = False + + self._start_idle_timer() + return self._session_id + + async def _teardown_session(self) -> None: + if not self._session_id: + return + sid = self._session_id + self._session_id = None + self._session_ready = False + if self._idle_task and not self._idle_task.done(): + self._idle_task.cancel() + self._idle_task = None + try: + await self._executor.destroy_session(sid) + except Exception as exc: + logger.warning("Session teardown error: %s", exc) + + def _start_idle_timer(self) -> None: + if self._idle_task and not self._idle_task.done(): + self._idle_task.cancel() + self._idle_task = asyncio.ensure_future(self._idle_reaper()) + + async def _idle_reaper(self) -> None: + try: + while True: + await asyncio.sleep(10) + if not self._session_id: + return + if time.time() - self._last_activity >= _SESSION_IDLE_TIMEOUT: + await self._teardown_session() + return + except asyncio.CancelledError: + pass + + def touch(self) -> None: + self._last_activity = time.time() + + async def on_pre_tool_use(self, input_data: dict, ctx: dict) -> dict | None: + tool_name = input_data.get("toolName", "") + if not self._executor.enabled: + return {"permissionDecision": "allow"} + if not _is_shell_tool(tool_name): + return {"permissionDecision": "allow"} + + tool_args = _parse_tool_args(input_data.get("toolArgs")) + command = _extract_command(tool_args) + if not command: + return {"permissionDecision": "allow"} + + try: + session_id = await self._ensure_session() + result = await self._executor.run_in_session(session_id, command, timeout=120) + self._last_activity = time.time() + except Exception as exc: + logger.error("Sandbox interceptor failed: %s", exc, exc_info=True) + result = {"success": False, "stdout": "", "stderr": str(exc)} + + self._pending_result = result + replay = _build_replay_command( + result.get("stdout", ""), result.get("stderr", ""), result.get("success", False) + ) + noop_args = dict(tool_args) + noop_args["command"] = replay + if "input" in noop_args: + noop_args["input"] = replay + return {"permissionDecision": "allow", "modifiedArgs": noop_args} + + async def on_post_tool_use(self, input_data: dict, ctx: dict) -> dict | None: + if self._pending_result is None: + return None + + result = self._pending_result + self._pending_result = None + + parts: list[str] = [] + if result.get("stdout"): + parts.append(result["stdout"]) + if result.get("stderr"): + parts.append(f"STDERR:\n{result['stderr']}") + output = "\n".join(parts) if parts else "(no output)" + if not result.get("success"): + output = f"Command failed in sandbox.\n{output}" + return {"modifiedResult": output} diff --git a/app/runtime/scheduler/__init__.py b/app/runtime/scheduler/__init__.py new file mode 100644 index 0000000..45c6f33 --- /dev/null +++ b/app/runtime/scheduler/__init__.py @@ -0,0 +1,25 @@ +"""Scheduler -- persistent task scheduling that spawns Copilot SDK sessions.""" + +from .engine import ( + MIN_INTERVAL_SECONDS, + SCHEDULED_MODEL, + ScheduledTask, + Scheduler, + _cron_matches, + _validate_cron, + get_scheduler, + scheduler_loop, + set_scheduler, +) + +__all__ = [ + "MIN_INTERVAL_SECONDS", + "SCHEDULED_MODEL", + "ScheduledTask", + "Scheduler", + "_cron_matches", + "_validate_cron", + "get_scheduler", + "scheduler_loop", + "set_scheduler", +] diff --git a/app/runtime/scheduler.py b/app/runtime/scheduler/engine.py similarity index 93% rename from app/runtime/scheduler.py rename to app/runtime/scheduler/engine.py index 72af951..897f65a 100644 --- a/app/runtime/scheduler.py +++ b/app/runtime/scheduler/engine.py @@ -1,4 +1,4 @@ -"""Scheduler -- persistent task scheduling that spawns Copilot SDK sessions.""" +"""Core scheduler engine -- persistent task scheduling that spawns Copilot SDK sessions.""" from __future__ import annotations @@ -14,16 +14,16 @@ from croniter import croniter -from .agent import one_shot as one_shot_mod -from .config.settings import cfg -from .util.singletons import register_singleton +from ..agent import one_shot as one_shot_mod +from ..config.settings import cfg +from ..util.singletons import register_singleton logger = logging.getLogger(__name__) SCHEDULED_MODEL = "gpt-4.1" MIN_INTERVAL_SECONDS = 3600 -_TEMPLATES_DIR = Path(__file__).resolve().parent / "templates" +_TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" @dataclass @@ -214,7 +214,8 @@ def check_due(self) -> list[ScheduledTask]: gap = (now - last_dt).total_seconds() if gap < MIN_INTERVAL_SECONDS: logger.debug( - "[scheduler] task %s (%s) -- cron matches but too soon (%.0fs < %ds)", + "[scheduler] task %s (%s) -- cron matches but too soon" + " (%.0fs < %ds)", task.id, task.description, gap, MIN_INTERVAL_SECONDS, ) continue @@ -238,7 +239,7 @@ def check_due(self) -> list[ScheduledTask]: return due async def run_due_tasks(self) -> None: - from .state.profile import log_interaction + from ..state.profile import log_interaction for task in self.check_due(): logger.info( @@ -260,7 +261,7 @@ async def run_due_tasks(self) -> None: self._active_interceptor = None async def _spawn_session(self, task: ScheduledTask) -> str | None: - from .agent.tools import get_all_tools + from ..agent.tools import get_all_tools template = (_TEMPLATES_DIR / "scheduler_prompt.md").read_text() system_message = template.format( @@ -287,17 +288,16 @@ def _make_background_hook(self, model: str) -> Callable[..., Any]: PITL works if a ``PhoneVerifier`` is configured on the shared interceptor. AITL and Prompt Shields are also forwarded. """ - from .agent.hitl import HitlInterceptor - from .state.guardrails_config import get_guardrails_config + from ..agent.hitl import HitlInterceptor + from ..state.guardrails.config import get_guardrails_config store = get_guardrails_config() interceptor = HitlInterceptor(store) - interceptor.set_execution_context("scheduler") - interceptor.set_model(model) - - # Bind notification channel so HITL can interact with the user. - if self._notify: - interceptor.set_bot_reply_fn(self._notify) + interceptor.bind_turn( + execution_context="scheduler", + model=model, + bot_reply_fn=self._notify if self._notify else None, + ) # Forward AITL / Prompt Shield / phone from the shared interceptor. if self._hitl_interceptor: @@ -324,7 +324,10 @@ async def _send_notification(self, task: ScheduledTask, result: str | None) -> N await self._notify(msg) logger.info("[scheduler] notification sent for task %s", task.id) except Exception as exc: - logger.error("[scheduler] notification send failed for task %s: %s", task.id, exc, exc_info=True) + logger.error( + "[scheduler] notification send failed for task %s: %s", + task.id, exc, exc_info=True, + ) _scheduler: Scheduler | None = None diff --git a/app/runtime/server/__init__.py b/app/runtime/server/__init__.py index 697117e..707cfcf 100644 --- a/app/runtime/server/__init__.py +++ b/app/runtime/server/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations -from .app import AppFactory, create_adapter, create_app, main +from .app import AppFactory, create_app, main +from .wiring import create_adapter __all__ = ["AppFactory", "create_adapter", "create_app", "main"] diff --git a/app/runtime/server/app.py b/app/runtime/server/app.py index 6567742..28e52e2 100644 --- a/app/runtime/server/app.py +++ b/app/runtime/server/app.py @@ -3,283 +3,122 @@ from __future__ import annotations import asyncio -import hmac import logging -import mimetypes import os import secrets -import time from collections.abc import Awaitable, Callable -from pathlib import Path +from typing import TYPE_CHECKING, Any from aiohttp import web -from aiohttp.abc import AbstractAccessLogger from .. import __version__ from ..config.settings import ServerMode, cfg -from ..media import EXTENSION_TO_MIME -logger = logging.getLogger(__name__) - -_FRONTEND_DIR = Path(__file__).resolve().parent.parent.parent / "frontend" / "dist" -_QUIET_PATHS = frozenset({"/api/setup/status", "/health"}) - - -class QuietAccessLogger(AbstractAccessLogger): - """Demotes polling-endpoint and noisy log entries to DEBUG.""" - - def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: - status = response.status - if request.path in _QUIET_PATHS or status == 401 or status in (502, 503): - level = logging.DEBUG - else: - level = logging.INFO - self.logger.log( - level, - "%s %s %s %s %.3fs", - request.remote, - request.method, - request.path, - status, - time, - ) - - -def create_adapter() -> object: - from botbuilder.core import BotFrameworkAdapter, BotFrameworkAdapterSettings, TurnContext - from botbuilder.schema import Activity, ActivityTypes - - settings = BotFrameworkAdapterSettings( - app_id=cfg.bot_app_id or None, - app_password=cfg.bot_app_password or None, - channel_auth_tenant=cfg.bot_app_tenant_id or None, - ) - adapter = BotFrameworkAdapter(settings) - - async def on_error(context: TurnContext, error: Exception) -> None: - logger.error("Bot turn error: %s", error, exc_info=True) - try: - activity = Activity(type=ActivityTypes.message, text="An error occurred.") - if (context.activity.channel_id or "").lower() == "telegram": - activity.text_format = "plain" - await context.send_activity(activity) - except Exception: - pass - - adapter.on_turn_error = on_error - return adapter - - -_PUBLIC_PREFIXES = ("/health", "/api/messages", "/acs", "/realtime-acs", "/api/voice/acs-callback", "/api/voice/media-streaming") -_PUBLIC_EXACT = ("/api/auth/check",) - -_TUNNEL_ALLOWED_PREFIXES = ( - "/health", - "/api/messages", - "/acs", - "/realtime-acs", - "/api/voice/acs-callback", - "/api/voice/media-streaming", +if TYPE_CHECKING: + from ..agent.agent import Agent + from ..messaging.bot import Bot + from ..messaging.proactive import ConversationReferenceStore + from ..sandbox import SandboxExecutor + from ..scheduler import Scheduler + from ..services.cloud.azure import AzureCLI + from ..services.cloud.github import GitHubAuth + from ..services.deployment.aca_deployer import AcaDeployer + from ..services.deployment.deployer import BotDeployer + from ..services.deployment.provisioner import Provisioner + from ..services.tunnel import CloudflareTunnel + from ..state.deploy_state import DeployStateStore + from ..state.foundry_iq_config import FoundryIQConfigStore + from ..state.guardrails import GuardrailsConfigStore + from ..state.infra_config import InfraConfigStore + from ..state.mcp_config import McpConfigStore + from ..state.monitoring_config import MonitoringConfigStore + from ..state.proactive import ProactiveStore + from ..state.sandbox_config import SandboxConfigStore + from ..state.session_store import SessionStore + from .bot_endpoint import BotEndpoint +from . import lifecycle +from .app_routes import register_admin_routes, register_runtime_routes +from .app_static import ( + FRONTEND_DIR, + make_file_handler, + serve_index, + serve_media, + serve_spa_or_404, ) - -_LOCKDOWN_ALLOWED_PREFIXES = ( - "/health", - "/api/messages", - "/acs", - "/realtime-acs", - "/api/voice/acs-callback", - "/api/voice/media-streaming", - "/api/setup/lockdown", +from .middleware import ( + auth_middleware, + lockdown_middleware, + tunnel_restriction_middleware, ) +from .wiring import create_adapter, create_voice_handler, init_core, init_services -_CF_HEADERS = ("cf-connecting-ip", "cf-ray", "cf-ipcountry") - - -@web.middleware -async def lockdown_middleware(request: web.Request, handler): # type: ignore[type-arg] - if not cfg.lockdown_mode: - return await handler(request) - if any(request.path.startswith(p) for p in _LOCKDOWN_ALLOWED_PREFIXES): - return await handler(request) - return web.json_response( - { - "status": "locked", - "message": ( - "Lock Down Mode is active. The admin panel is disabled. " - "Use /lockdown off via the bot to restore access." - ), - }, - status=403, - ) - - -@web.middleware -async def tunnel_restriction_middleware(request: web.Request, handler): # type: ignore[type-arg] - if not cfg.tunnel_restricted: - return await handler(request) - is_tunnel = any(request.headers.get(h) for h in _CF_HEADERS) - if not is_tunnel: - return await handler(request) - if any(request.path.startswith(p) for p in _TUNNEL_ALLOWED_PREFIXES): - return await handler(request) - return web.json_response({"status": "forbidden"}, status=403) - - -@web.middleware -async def auth_middleware(request: web.Request, handler): # type: ignore[type-arg] - secret = cfg.admin_secret - if not secret: - return await handler(request) - - path = request.path - - # Only protect /api/* endpoints (except public ones); frontend assets are public - if not path.startswith("/api/"): - return await handler(request) - - if path in _PUBLIC_EXACT or any(path.startswith(p) for p in _PUBLIC_PREFIXES): - return await handler(request) - - auth = request.headers.get("Authorization", "") - expected = f"Bearer {secret}" - if hmac.compare_digest(auth, expected): - return await handler(request) - - token_param = request.query.get("token", "") - if token_param and hmac.compare_digest(token_param, secret): - return await handler(request) - - secret_param = request.query.get("secret", "") - if secret_param and hmac.compare_digest(secret_param, secret): - return await handler(request) - - return web.json_response( - {"status": "unauthorized", "message": "Invalid or missing admin secret"}, - status=401, - ) - - -def _append_token(url: str, token: str) -> str: - sep = "&" if "?" in url else "?" - return f"{url}{sep}token={token}" +logger = logging.getLogger(__name__) async def create_app() -> web.Application: + """Public entry point -- build and return the ``aiohttp`` application.""" factory = AppFactory() return await factory.build() -def _create_voice_handler(agent: object, tunnel: object | None = None) -> object | None: - cfg.reload() - if not (cfg.acs_connection_string and cfg.acs_source_number and cfg.azure_openai_endpoint): - logger.info("Voice call not configured (ACS/AOAI settings missing)") - return None - - from azure.core.credentials import AzureKeyCredential as _AKC - - from ..realtime import AcsCaller, RealtimeMiddleTier, RealtimeRoutes - - def _resolve_acs_urls() -> tuple[str, str]: - token = cfg.acs_callback_token - cb_path = cfg.acs_callback_path - ws_path = cfg.acs_media_streaming_websocket_path - - logger.debug("_resolve_acs_urls: cb_path=%r, ws_path=%r, token=%s", cb_path, ws_path, "set" if token else "empty") - - # If both paths are already absolute URLs, use them directly - cb_is_absolute = cb_path.startswith("https://") - ws_is_absolute = ws_path.startswith("wss://") - if cb_is_absolute and ws_is_absolute: - resolved = _append_token(cb_path, token), _append_token(ws_path, token) - logger.info("ACS URLs (absolute): callback=%s, ws=%s", resolved[0], resolved[1]) - return resolved - - # Otherwise, resolve relative paths against the tunnel URL - tunnel_url = (getattr(tunnel, 'url', None) or "").rstrip("/") - if tunnel_url: - cb = cb_path if cb_is_absolute else f"{tunnel_url}{cb_path or '/api/voice/acs-callback'}" - ws = ws_path if ws_is_absolute else ( - tunnel_url.replace("https://", "wss://").replace("http://", "ws://") - + (ws_path or "/api/voice/media-streaming") - ) - resolved = _append_token(cb, token), _append_token(ws, token) - logger.info("ACS URLs (tunnel): callback=%s, ws=%s", resolved[0], resolved[1]) - return resolved - logger.warning("ACS URLs fallback to localhost -- calls will fail") - return ( - cb_path or f"http://localhost:{cfg.admin_port}/api/voice/acs-callback", - ws_path or f"ws://localhost:{cfg.admin_port}/api/voice/media-streaming", - ) - - caller = AcsCaller( - source_number=cfg.acs_source_number, - acs_connection_string=cfg.acs_connection_string, - resolve_urls=_resolve_acs_urls, - resolve_source_number=lambda: cfg.acs_source_number, - ) - - realtime_credential: _AKC | object - if cfg.azure_openai_api_key: - realtime_credential = _AKC(cfg.azure_openai_api_key) - else: - from azure.identity import DefaultAzureCredential as _DAC - - realtime_credential = _DAC() - - rt_middleware = RealtimeMiddleTier( - endpoint=cfg.azure_openai_endpoint, - deployment=cfg.azure_openai_realtime_deployment, - credential=realtime_credential, - agent=agent, - ) - handler = RealtimeRoutes( - caller, - rt_middleware, - callback_token=cfg.acs_callback_token, - acs_resource_id=cfg.acs_resource_id, - ) - logger.info("Voice call (ACS + Realtime) enabled: source=%s", cfg.acs_source_number) - return handler - - -_SCHEDULE_INTERVALS = {"hourly": 3600, "daily": 86400} - - class AppFactory: + """Assembles the aiohttp application with routes, middleware, and lifecycle hooks. + + All dependency references are declared in ``__init__`` so the full shape + of the object is visible in one place. + """ + + def __init__(self) -> None: + self._mode: ServerMode = cfg.server_mode + + # Core components (populated by _init_core) + self._agent: Agent | None = None + self._adapter: Any = None # BotFrameworkAdapter (external) + self._conv_store: ConversationReferenceStore | None = None + self._session_store: SessionStore | None = None + self._bot: Bot | None = None + self._bot_ep: BotEndpoint | None = None + + # State stores (populated by _init_services) + self._deploy_store: DeployStateStore | None = None + self._infra_store: InfraConfigStore | None = None + self._mcp_store: McpConfigStore | None = None + self._sandbox_store: SandboxConfigStore | None = None + self._foundry_iq_store: FoundryIQConfigStore | None = None + self._guardrails_store: GuardrailsConfigStore | None = None + self._monitoring_store: MonitoringConfigStore | None = None + + # External services (populated by _init_services) + self._tunnel: CloudflareTunnel | None = None + self._az: AzureCLI | None = None + self._gh: GitHubAuth | None = None + self._deployer: BotDeployer | None = None + self._provisioner: Provisioner | None = None + self._aca_deployer: AcaDeployer | None = None + + # Runtime-only services (populated by _init_services) + self._scheduler: Scheduler | None = None + self._proactive_store: ProactiveStore | None = None + self._sandbox_executor: SandboxExecutor | None = None + + # Voice handler (populated by _init_voice) + self._voice_routes: Any = None + + # -- Public API -------------------------------------------------------- async def build(self) -> web.Application: + """Wire everything together and return the application.""" self._mode = cfg.server_mode cfg.ensure_dirs() self._ensure_admin_secret() await self._init_core() self._init_services() - - if self._bot and self._agent and self._agent.hitl_interceptor: - self._bot._hitl = self._agent.hitl_interceptor - self._bot._processor._hitl = self._agent.hitl_interceptor - - if self._scheduler and self._agent and self._agent.hitl_interceptor: - self._scheduler.set_hitl_interceptor(self._agent.hitl_interceptor) - if self._bot and self._scheduler: - self._bot._scheduler = self._scheduler - + self._cross_wire() self._init_voice() middlewares = [lockdown_middleware, tunnel_restriction_middleware, auth_middleware] - - # Admin-only mode: proxy unmatched /api/* requests to runtime - proxy_mw = None - if self._is_admin and not self._is_runtime: - from .runtime_proxy import create_runtime_proxy_middleware - - if os.getenv("POLYCLAW_USE_MI"): - aca_fqdn = cfg.env.read("ACA_RUNTIME_FQDN") - if aca_fqdn: - aca_url = f"https://{aca_fqdn}" - os.environ["RUNTIME_URL"] = aca_url - logger.info("[startup] Restored RUNTIME_URL=%s from ACA deployment", aca_url) - - proxy_mw = create_runtime_proxy_middleware() + proxy_mw = self._maybe_create_proxy() + if proxy_mw is not None: middlewares.append(proxy_mw) app = web.Application(middlewares=middlewares) @@ -294,6 +133,8 @@ async def build(self) -> web.Application: return app + # -- Properties -------------------------------------------------------- + @property def _is_admin(self) -> bool: return self._mode in (ServerMode.admin, ServerMode.combined) @@ -302,6 +143,8 @@ def _is_admin(self) -> bool: def _is_runtime(self) -> bool: return self._mode in (ServerMode.runtime, ServerMode.combined) + # -- Initialisation (delegates to wiring module) ----------------------- + @staticmethod def _ensure_admin_secret() -> None: if cfg.admin_secret: @@ -316,120 +159,69 @@ def _ensure_admin_secret() -> None: logger.info("Generated ADMIN_SECRET (persisted to .env)") async def _init_core(self) -> None: - self._agent = None - self._adapter = None - self._conv_store = None - self._session_store = None - self._bot = None - self._bot_ep = None - - if self._is_runtime: - from ..agent.agent import Agent - from ..messaging.bot import Bot - from ..messaging.proactive import ConversationReferenceStore - from ..state.session_store import SessionStore - from .bot_endpoint import BotEndpoint - - logger.info("[init_core] creating Agent ...") - self._agent = Agent() - logger.info("[init_core] starting Agent (Copilot CLI) ...") - await self._agent.start() - logger.info("[init_core] Agent started successfully") - - self._adapter = create_adapter() - self._conv_store = ConversationReferenceStore() - self._session_store = SessionStore() - - hitl = self._agent.hitl_interceptor if self._agent else None - self._bot = Bot(self._agent, self._conv_store, hitl=hitl) - self._bot.session_store = self._session_store - self._bot.adapter = self._adapter - self._bot_ep = BotEndpoint(self._adapter, self._bot) - logger.info("[init_core] core initialization complete") - - if self._is_admin and not self._is_runtime: - from ..state.session_store import SessionStore - - self._session_store = SessionStore() - logger.info("[init_core] admin-only initialization complete") + core = await init_core(self._mode) + self._agent = core["agent"] + self._adapter = core["adapter"] + self._conv_store = core["conv_store"] + self._session_store = core["session_store"] + self._bot = core["bot"] + self._bot_ep = core["bot_ep"] def _init_services(self) -> None: - from ..state.deploy_state import DeployStateStore - from ..state.foundry_iq_config import FoundryIQConfigStore - from ..state.guardrails_config import GuardrailsConfigStore - from ..state.infra_config import InfraConfigStore - from ..state.mcp_config import McpConfigStore - from ..state.monitoring_config import MonitoringConfigStore - from ..state.sandbox_config import SandboxConfigStore - - self._tunnel = None - if self._is_runtime: - from ..services.tunnel import CloudflareTunnel - - self._tunnel = CloudflareTunnel() - self._deploy_store = DeployStateStore() - self._infra_store = InfraConfigStore() - self._mcp_store = McpConfigStore() - self._sandbox_store = SandboxConfigStore() - self._foundry_iq_store = FoundryIQConfigStore() - self._guardrails_store = GuardrailsConfigStore() - self._monitoring_store = MonitoringConfigStore() - - # Admin-side services: Azure CLI, GitHub auth, deployer, provisioner - self._az = None - self._gh = None - self._deployer = None - self._provisioner = None - self._aca_deployer = None - if self._is_admin: - from ..services.aca_deployer import AcaDeployer - from ..services.azure import AzureCLI - from ..services.deployer import BotDeployer - from ..services.github import GitHubAuth - from ..services.provisioner import Provisioner - - self._az = AzureCLI() - self._gh = GitHubAuth() - self._deployer = BotDeployer(self._az, self._deploy_store) - self._provisioner = Provisioner( - self._az, self._deployer, - self._infra_store, self._deploy_store, - tunnel=self._tunnel, - ) - self._aca_deployer = AcaDeployer(self._az, self._deploy_store) - elif self._is_runtime: - from ..services.azure import AzureCLI - from ..services.deployer import BotDeployer - from ..services.provisioner import Provisioner - - self._az = AzureCLI() - self._deployer = BotDeployer(self._az, self._deploy_store) - self._provisioner = Provisioner( - self._az, self._deployer, - self._infra_store, self._deploy_store, - tunnel=self._tunnel, - ) - - # Runtime-side services: scheduler, sandbox, proactive - self._scheduler = None - self._proactive_store = None - self._sandbox_executor = None - if self._is_runtime: - from ..sandbox import SandboxExecutor - from ..scheduler import get_scheduler - from ..state.proactive import get_proactive_store - - self._scheduler = get_scheduler() - self._proactive_store = get_proactive_store() - self._sandbox_executor = SandboxExecutor(self._sandbox_store) - if self._agent: + svc = init_services(self._mode) + self._tunnel = svc["tunnel"] + self._deploy_store = svc["deploy_store"] + self._infra_store = svc["infra_store"] + self._mcp_store = svc["mcp_store"] + self._sandbox_store = svc["sandbox_store"] + self._foundry_iq_store = svc["foundry_iq_store"] + self._guardrails_store = svc["guardrails_store"] + self._monitoring_store = svc["monitoring_store"] + self._az = svc["az"] + self._gh = svc["gh"] + self._deployer = svc["deployer"] + self._provisioner = svc["provisioner"] + self._aca_deployer = svc["aca_deployer"] + self._scheduler = svc["scheduler"] + self._proactive_store = svc["proactive_store"] + self._sandbox_executor = svc["sandbox_executor"] + + # Wire sandbox and guardrails into agent + if self._is_runtime and self._agent: + if self._sandbox_executor: self._agent.set_sandbox(self._sandbox_executor) - self._agent.set_guardrails(self._guardrails_store) + self._agent.set_guardrails(self._guardrails_store) + + def _cross_wire(self) -> None: + """Wire cross-cutting references that span core and services.""" + if self._bot and self._agent and self._agent.hitl_interceptor: + self._bot._hitl = self._agent.hitl_interceptor + self._bot._processor._hitl = self._agent.hitl_interceptor + + if self._scheduler and self._agent and self._agent.hitl_interceptor: + self._scheduler.set_hitl_interceptor(self._agent.hitl_interceptor) + if self._bot and self._scheduler: + self._bot._scheduler = self._scheduler def _init_voice(self) -> None: self._voice_routes = None if self._is_runtime: - self._voice_routes = _create_voice_handler(self._agent, self._tunnel) + self._voice_routes = create_voice_handler(self._agent, self._tunnel) + + def _maybe_create_proxy(self) -> object | None: + """Create the runtime proxy middleware for admin-only mode.""" + if not (self._is_admin and not self._is_runtime): + return None + from .runtime_proxy import create_runtime_proxy_middleware + + if os.getenv("POLYCLAW_USE_MI"): + aca_fqdn = cfg.env.read("ACA_RUNTIME_FQDN") + if aca_fqdn: + aca_url = f"https://{aca_fqdn}" + os.environ["RUNTIME_URL"] = aca_url + logger.info("[startup] Restored RUNTIME_URL=%s from ACA deployment", aca_url) + + return create_runtime_proxy_middleware() def _rebuild_adapter(self) -> object: cfg.reload() @@ -463,7 +255,7 @@ async def auth_check(req: web.Request) -> web.Response: self._register_runtime_routes(app) # Shared routes (both modes) - router.add_get("/api/media/{filename:.+}", _serve_media) + router.add_get("/api/media/{filename:.+}", serve_media) router.add_get("/health", self._health_handler()) # Frontend SPA -- served by admin in split mode, or by combined @@ -472,124 +264,32 @@ async def auth_check(req: web.Request) -> web.Response: def _register_admin_routes(self, router: web.UrlDispatcher) -> None: """Routes available only in ``admin`` or ``combined`` mode.""" - from .setup import SetupRoutes - from .setup_voice import VoiceSetupRoutes - from .workspace import WorkspaceHandler - from .routes.content_safety_routes import ContentSafetyRoutes - from .routes.env_routes import EnvironmentRoutes - from .routes.foundry_iq_routes import FoundryIQRoutes - from .routes.network_routes import NetworkRoutes - from .routes.monitoring_routes import MonitoringRoutes - from .routes.sandbox_routes import SandboxRoutes - - SetupRoutes( - self._az, self._gh, self._tunnel, self._deployer, - self._rebuild_adapter, self._infra_store, - self._provisioner, self._deploy_store, - self._aca_deployer, - ).register(router) - - VoiceSetupRoutes(self._az, self._infra_store).register(router) - WorkspaceHandler().register(router) - EnvironmentRoutes(self._deploy_store, self._az).register(router) - SandboxRoutes( - self._sandbox_store, self._sandbox_executor, self._az, self._deploy_store, - ).register(router) - FoundryIQRoutes(self._foundry_iq_store, self._az, self._deploy_store).register(router) - NetworkRoutes(self._tunnel, self._az, self._sandbox_store, self._foundry_iq_store).register(router) - MonitoringRoutes( - self._monitoring_store, self._az, self._deploy_store, - ).register(router) - ContentSafetyRoutes(self._az, self._guardrails_store).register(router) - - from .routes.identity_routes import IdentityRoutes - IdentityRoutes(self._az, self._guardrails_store).register(router) - - if self._az: - from .routes.security_preflight_routes import SecurityPreflightRoutes - from ..services.security_preflight import SecurityPreflightChecker - - SecurityPreflightRoutes(SecurityPreflightChecker(self._az)).register(router) + register_admin_routes( + router, + az=self._az, gh=self._gh, tunnel=self._tunnel, + deployer=self._deployer, rebuild_adapter=self._rebuild_adapter, + infra_store=self._infra_store, provisioner=self._provisioner, + deploy_store=self._deploy_store, aca_deployer=self._aca_deployer, + sandbox_store=self._sandbox_store, + sandbox_executor=self._sandbox_executor, + foundry_iq_store=self._foundry_iq_store, + monitoring_store=self._monitoring_store, + guardrails_store=self._guardrails_store, + ) def _register_runtime_routes(self, app: web.Application) -> None: """Routes available only in ``runtime`` or ``combined`` mode.""" - from ..agent.aitl import AitlReviewer - from ..agent.phone_verify import PhoneVerifier - from ..registries.plugins import get_plugin_registry - from ..registries.skills import get_registry as get_skill_registry - from ..services.prompt_shield import PromptShieldService - from ..state.plugin_config import PluginConfigStore - from .chat import ChatHandler - from .routes.guardrails_routes import GuardrailsRoutes - from .routes.mcp_routes import McpRoutes - from .routes.plugin_routes import PluginRoutes - from .routes.proactive_routes import ProactiveRoutes - from .routes.profile_routes import ProfileRoutes - from .routes.scheduler_routes import SchedulerRoutes - from .routes.session_routes import SessionRoutes - from .routes.skill_routes import SkillRoutes - from .routes.tool_activity_routes import ToolActivityRoutes - - router = app.router - - router.add_post("/api/internal/reload", self._handle_reload) - - from .routes.network_routes import NetworkRoutes as _NR - _nr_instance = _NR(self._tunnel) - router.add_get("/api/network/endpoints", _nr_instance._endpoints) - - hitl = self._agent.hitl_interceptor if self._agent else None - - # Wire phone verifier into HITL interceptor - if hitl: - phone_verifier = PhoneVerifier(app) - hitl.set_phone_verifier(phone_verifier) - app["_phone_verifier"] = phone_verifier - - # Wire AITL reviewer - gcfg = self._guardrails_store.config - aitl_reviewer = AitlReviewer( - model=gcfg.aitl_model, - spotlighting=gcfg.aitl_spotlighting, - ) - hitl.set_aitl_reviewer(aitl_reviewer) - - prompt_shield = PromptShieldService( - endpoint=gcfg.content_safety_endpoint, - mode=gcfg.filter_mode, - ) - hitl.set_prompt_shield(prompt_shield) - - ChatHandler( - self._agent, - session_store=self._session_store, - sandbox_interceptor=self._sandbox_executor, - hitl_interceptor=hitl, - ).register(router) - - self._bot_ep.register(router) - self._register_voice_dynamic(app) - - SchedulerRoutes(self._scheduler).register(router) - SessionRoutes(self._session_store).register(router) - SkillRoutes(get_skill_registry()).register(router) - McpRoutes(self._mcp_store).register(router) - PluginRoutes(get_plugin_registry(), PluginConfigStore()).register(router) - ProfileRoutes().register(router) - GuardrailsRoutes( - self._guardrails_store, self._mcp_store, - skills_registry=get_skill_registry(), - ).register(router) - - from ..state.tool_activity_store import get_tool_activity_store - ToolActivityRoutes(get_tool_activity_store(), self._session_store).register(router) - - ProactiveRoutes( - self._proactive_store, - adapter=self._adapter, - conv_store=self._conv_store, - app_id=cfg.bot_app_id, - ).register(router) + register_runtime_routes( + app, + agent=self._agent, session_store=self._session_store, + sandbox_executor=self._sandbox_executor, + mcp_store=self._mcp_store, guardrails_store=self._guardrails_store, + scheduler=self._scheduler, proactive_store=self._proactive_store, + adapter=self._adapter, conv_store=self._conv_store, + bot_ep=self._bot_ep, tunnel=self._tunnel, + voice_routes=self._voice_routes, + handle_reload=self._handle_reload, + ) def _health_handler(self) -> Callable: """Return a health handler that includes mode and tunnel info.""" @@ -604,75 +304,19 @@ async def handler(_req: web.Request) -> web.Response: return handler def _register_frontend(self, router: web.UrlDispatcher) -> None: - fe = _FRONTEND_DIR + fe = FRONTEND_DIR if not fe.exists(): return - router.add_get("/", _serve_index) + router.add_get("/", serve_index) if (fe / "assets").is_dir(): router.add_static("/assets/", path=str(fe / "assets"), name="fe_assets") for fname in ("favicon.ico", "logo.png", "headertext.png"): fpath = fe / fname if fpath.exists(): - router.add_get(f"/{fname}", _make_file_handler(fpath)) - router.add_get("/{tail:[^/].*}", _serve_spa_or_404) - - def _register_voice_dynamic(self, app: web.Application) -> None: - app["_voice_handler"] = self._voice_routes - agent = self._agent - - def reinit_voice() -> None: - handler = _create_voice_handler(agent, self._tunnel) - app["_voice_handler"] = handler - app["voice_configured"] = handler is not None - - app["_reinit_voice"] = reinit_voice - - def _not_configured() -> web.Response: - return web.json_response( - { - "status": "error", - "message": ( - "Voice calling is not configured. Deploy ACS + " - "Azure OpenAI resources in the Voice Call section first." - ), - }, - status=400, - ) - - async def voice_call(req: web.Request) -> web.Response: - h = req.app["_voice_handler"] - return _not_configured() if h is None else await h._api_call(req) - - async def voice_status(req: web.Request) -> web.Response: - h = req.app["_voice_handler"] - return _not_configured() if h is None else await h._api_status(req) - - async def acs_callback(req: web.Request) -> web.Response: - h = req.app["_voice_handler"] - logger.info("ACS callback hit: method=%s path=%s handler=%s", req.method, req.path, "configured" if h else "NONE") - return _not_configured() if h is None else await h._acs_callback(req) + router.add_get(f"/{fname}", make_file_handler(fpath)) + router.add_get("/{tail:[^/].*}", serve_spa_or_404) - async def acs_incoming(req: web.Request) -> web.Response: - h = req.app["_voice_handler"] - logger.info("ACS incoming hit: method=%s path=%s handler=%s", req.method, req.path, "configured" if h else "NONE") - return _not_configured() if h is None else await h._acs_incoming(req) - - async def ws_handler_acs(req: web.Request) -> web.WebSocketResponse: - h = req.app["_voice_handler"] - logger.info("ACS media-streaming WS hit: method=%s path=%s handler=%s", req.method, req.path, "configured" if h else "NONE") - return _not_configured() if h is None else await h._ws_handler_acs(req) # type: ignore[return-value] - - router = app.router - router.add_post("/api/voice/call", voice_call) - router.add_get("/api/voice/status", voice_status) - # Legacy routes (kept for backwards compat) - router.add_post("/acs", acs_callback) - router.add_post("/acs/incoming", acs_incoming) - router.add_get("/realtime-acs", ws_handler_acs) - # Routes matching cfg.acs_callback_path / cfg.acs_media_streaming_websocket_path - router.add_post("/api/voice/acs-callback", acs_callback) - router.add_post("/api/voice/acs-callback/incoming", acs_incoming) - router.add_get("/api/voice/media-streaming", ws_handler_acs) + # -- Lifecycle (delegates to lifecycle module) -------------------------- def _make_notify(self) -> Callable[[str], Awaitable[bool]]: from ..messaging.proactive import send_proactive_message @@ -700,214 +344,41 @@ async def notify(message: str) -> None: async def _on_startup(self, app: web.Application) -> None: if self._is_runtime: - await self._on_startup_runtime(app) - if self._is_admin: - await self._on_startup_admin(app) - - async def _on_startup_runtime(self, app: web.Application) -> None: - """Start background tasks and bot infrastructure for the runtime.""" - from ..proactive_loop import proactive_delivery_loop - from ..scheduler import scheduler_loop - from ..services.otel import configure_otel - - # Bootstrap OTel if monitoring is configured - mon = self._monitoring_store - if mon.is_configured: - configure_otel( - mon.connection_string, - sampling_ratio=mon.config.sampling_ratio, - enable_live_metrics=mon.config.enable_live_metrics, - ) - - self._rebuild_adapter() - - app["scheduler_task"] = asyncio.create_task(scheduler_loop()) - app["proactive_task"] = asyncio.create_task( - proactive_delivery_loop(self._make_notify(), session_store=self._session_store), - ) - app["foundry_iq_task"] = asyncio.create_task( - _foundry_iq_index_loop(self._foundry_iq_store), - ) - - logger.info( - "[startup.runtime] mode=%s lockdown=%s bot_configured=%s " - "telegram_configured=%s tunnel=%s provisioner=%s az=%s", - self._mode.value, cfg.lockdown_mode, - self._infra_store.bot_configured if self._infra_store else "", - self._infra_store.telegram_configured if self._infra_store else "", - self._tunnel is not None, - self._provisioner is not None, - self._az is not None, - ) - - if cfg.lockdown_mode: - logger.info("Lock Down Mode active -- skipping infrastructure provisioning") - return - - bot_endpoint = os.environ.get("BOT_ENDPOINT", "") - - if self._mode != ServerMode.combined: - github_token = cfg.github_token - if not github_token: - logger.warning( - "[startup.runtime] Setup incomplete -- missing GITHUB_TOKEN. " - "Complete the setup wizard in the admin container, " - "then recreate the agent container.", - ) - return - - needs_bot = ( - self._infra_store.bot_configured - and self._infra_store.telegram_configured - ) - - if self._mode == ServerMode.combined: - if self._infra_store.bot_configured and self._provisioner: - from ..util.async_helpers import run_sync - - logger.info("Startup: provisioning infrastructure from config ...") - steps = await run_sync(self._provisioner.provision) - self._rebuild_adapter() - for s in steps: - logger.info( - " provision: %s = %s (%s)", - s.get("step"), s.get("status"), s.get("detail", ""), - ) - if needs_bot and self._tunnel: - await self._start_tunnel_and_create_bot() - - elif bot_endpoint: - cfg.reload() - self._rebuild_adapter() - if needs_bot: - logger.info("Static bot endpoint: %s", bot_endpoint) - await self._recreate_bot(endpoint_override=bot_endpoint) - else: - logger.info("No messaging channels configured -- skipping bot service") - - else: - if needs_bot and self._tunnel: - from ..services.deployer import BotDeployer - - bot_app_id = BotDeployer._env("BOT_APP_ID") - if not bot_app_id: - logger.warning( - "Telegram configured but BOT_APP_ID missing -- " - "run Infrastructure Deploy in the admin wizard first" - ) - else: - await self._start_tunnel_and_create_bot() - else: - reasons = [] - if not self._infra_store.bot_configured: - reasons.append("bot not configured") - if not self._infra_store.telegram_configured: - reasons.append("no channels configured") - if not self._tunnel: - reasons.append("no tunnel") - logger.info( - "Skipping bot service: %s", - ", ".join(reasons) or "no reason", - ) - - async def _on_startup_admin(self, app: web.Application) -> None: - """Admin startup: reconcile stale deployments and RBAC.""" - if self._az: - from ..services.resource_tracker import ResourceTracker - from ..util.async_helpers import run_sync - - app["reconcile_task"] = asyncio.create_task(self._reconcile_deployments()) - app["cs_rbac_task"] = asyncio.create_task( - self._ensure_content_safety_rbac(), + await lifecycle.on_startup_runtime( + app, + mode=self._mode, + adapter=self._adapter, + bot=self._bot, + bot_ep=self._bot_ep, + conv_store=self._conv_store, + agent=self._agent, + tunnel=self._tunnel, + infra_store=self._infra_store, + provisioner=self._provisioner, + az=self._az, + monitoring_store=self._monitoring_store, + session_store=self._session_store, + foundry_iq_store=self._foundry_iq_store, + scheduler=self._scheduler, + rebuild_adapter=self._rebuild_adapter, + make_notify=self._make_notify, ) - - async def _ensure_content_safety_rbac(self) -> None: - from .routes.content_safety_routes import ContentSafetyRoutes - - try: - routes = ContentSafetyRoutes( + if self._is_admin: + await lifecycle.on_startup_admin( + app, az=self._az, + deploy_store=self._deploy_store, guardrails_store=self._guardrails_store, ) - steps = await routes.ensure_rbac() - for s in steps: - logger.info( - "[startup.cs_rbac] %s = %s (%s)", - s.get("step"), s.get("status"), s.get("detail", ""), - ) - except Exception: - logger.warning( - "[startup.cs_rbac] Content Safety RBAC check failed", - exc_info=True, - ) - - async def _recreate_bot(self, *, endpoint_override: str | None = None) -> None: - from ..util.async_helpers import run_sync - logger.info( - "[recreate_bot] provisioner=%s az=%s bot_configured=%s endpoint_override=%s", - self._provisioner is not None, - self._az is not None, - self._infra_store.bot_configured if self._infra_store else "?", - endpoint_override, + async def _on_cleanup(self, app: web.Application) -> None: + await lifecycle.on_cleanup( + app, + mode=self._mode, + infra_store=self._infra_store, + provisioner=self._provisioner, + agent=self._agent, ) - if not (self._provisioner and self._az and self._infra_store.bot_configured): - logger.warning( - "[recreate_bot] precondition failed -- provisioner=%s az=%s bot_configured=%s", - self._provisioner is not None, - self._az is not None, - self._infra_store.bot_configured if self._infra_store else "?", - ) - return - - tunnel_url = endpoint_override or getattr(self._tunnel, "url", None) - if not tunnel_url: - logger.warning("Bot recreate: no endpoint URL available -- skipping") - return - - endpoint = tunnel_url - logger.info("Bot recreate: endpoint %s", endpoint) - try: - steps = await run_sync(self._provisioner.recreate_endpoint, endpoint) - self._rebuild_adapter() - for s in steps: - logger.info( - " recreate: %s = %s (%s)", - s.get("step"), s.get("status"), s.get("detail", ""), - ) - except Exception as exc: - logger.warning("Bot recreate: error -- %s", exc, exc_info=True) - - async def _start_tunnel_and_create_bot(self) -> None: - from ..util.async_helpers import run_sync - - logger.info("Starting tunnel for bot service endpoint ...") - tunnel_url = self._tunnel.url - if not tunnel_url and not self._tunnel.is_active: - max_retries = 5 - for attempt in range(1, max_retries + 1): - result = await run_sync(self._tunnel.start, cfg.admin_port) - if result: - logger.info("Tunnel started at %s", result.value) - break - if attempt < max_retries: - logger.warning( - "Tunnel failed (attempt %d/%d): %s -- retrying in %ds ...", - attempt, max_retries, - result.message if result else "unknown", - 2 * attempt, - ) - await asyncio.sleep(2 * attempt) - else: - logger.error( - "Tunnel failed after %d attempts: %s", - max_retries, - result.message if result else "unknown", - ) - return - - self._rebuild_adapter() - await self._recreate_bot() async def _handle_reload(self, request: web.Request) -> web.Response: logger.info("[reload] triggered by admin -- re-reading configuration") @@ -915,20 +386,20 @@ async def _handle_reload(self, request: web.Request) -> web.Response: # 1. Re-read .env from shared volume cfg.reload() - # 2. Reload infra config (bot & channel settings from infra.json) + # 2. Reload infra config if self._infra_store: self._infra_store._load() - # 3. Reload agent auth (GITHUB_TOKEN may have changed) + # 3. Reload agent auth auth_result: dict = {} if self._agent: auth_result = await self._agent.reload_auth() logger.info("[reload] agent auth: %s", auth_result.get("status")) - # 4. Rebuild Bot Framework adapter (BOT_APP_ID/PASSWORD may have changed) + # 4. Rebuild Bot Framework adapter self._rebuild_adapter() - # 5. Reinitialise voice handler (ACS settings may have changed) + # 5. Reinitialise voice handler reinit_voice = request.app.get("_reinit_voice") if reinit_voice: reinit_voice() @@ -940,37 +411,9 @@ async def _handle_reload(self, request: web.Request) -> web.Response: and self._infra_store.telegram_configured ) if needs_bot: - bot_endpoint = os.environ.get("BOT_ENDPOINT", "") - tunnel_active = getattr(self._tunnel, "is_active", False) if self._tunnel else False - - if bot_endpoint: - # Static endpoint (ACA or other) -- no tunnel needed. - async def _deferred_static_bot() -> None: - await self._recreate_bot(endpoint_override=bot_endpoint) - - request.app["reload_bot_task"] = asyncio.create_task( - _deferred_static_bot() - ) - bot_task_started = True - elif self._tunnel and not tunnel_active: - from ..services.deployer import BotDeployer - - bot_app_id = BotDeployer._env("BOT_APP_ID") - if bot_app_id: - async def _deferred_docker_bot() -> None: - await self._start_tunnel_and_create_bot() - - request.app["reload_bot_task"] = asyncio.create_task( - _deferred_docker_bot() - ) - bot_task_started = True - elif self._tunnel and tunnel_active: - async def _deferred_recreate() -> None: - await self._recreate_bot() - - request.app["reload_bot_task"] = asyncio.create_task( - _deferred_recreate() - ) + coro = self._pick_bot_reload_coro() + if coro is not None: + request.app["reload_bot_task"] = asyncio.create_task(coro) bot_task_started = True logger.info( @@ -985,119 +428,55 @@ async def _deferred_recreate() -> None: "bot_task_started": bot_task_started, }) - async def _reconcile_deployments(self) -> None: - from ..services.resource_tracker import ResourceTracker - from ..util.async_helpers import run_sync - - try: - tracker = ResourceTracker(self._az, self._deploy_store) - cleaned = await run_sync(tracker.reconcile) - if cleaned: - logger.info( - "Startup reconcile: removed %d stale deployment(s): %s", - len(cleaned), ", ".join(c["deploy_id"] for c in cleaned), - ) - except Exception as exc: - logger.warning("Startup reconcile failed (non-fatal): %s", exc) - - async def _on_cleanup(self, _app: web.Application) -> None: - for key in ("scheduler_task", "proactive_task", "foundry_iq_task", "reconcile_task"): - task = _app.get(key) - if task and not task.done(): - task.cancel() - - if self._mode == ServerMode.combined: - if cfg.lockdown_mode: - logger.info("Lock Down Mode active -- skipping shutdown decommission") - elif self._infra_store.bot_configured and (cfg.env.read("BOT_NAME") or cfg.env.read("BOT_APP_ID")) and self._provisioner: - from ..util.async_helpers import run_sync - - logger.info("Shutdown: decommissioning infrastructure ...") - steps = await run_sync(self._provisioner.decommission) - for s in steps: - logger.info( - " decommission: %s = %s (%s)", - s.get("step"), s.get("status"), s.get("detail", ""), - ) - - if self._agent: - await self._agent.stop() - - -async def _foundry_iq_index_loop(store: object) -> None: - from ..services.foundry_iq import index_memories - from ..state.foundry_iq_config import FoundryIQConfigStore - from ..util.async_helpers import run_sync - - assert isinstance(store, FoundryIQConfigStore) - await asyncio.sleep(60) - while True: - try: - store._load() - schedule = store.config.index_schedule - if store.enabled and store.is_configured and schedule in _SCHEDULE_INTERVALS: - logger.info("Foundry IQ: running scheduled indexing (%s)...", schedule) - result = await run_sync(index_memories, store) - logger.info("Foundry IQ indexing: %s (indexed=%s)", result.get("status"), result.get("indexed", 0)) - interval = _SCHEDULE_INTERVALS.get(schedule, 86400) - except asyncio.CancelledError: - return - except Exception as exc: - logger.error("Foundry IQ index loop error: %s", exc, exc_info=True) - interval = 3600 - try: - await asyncio.sleep(interval) - except asyncio.CancelledError: - return + def _pick_bot_reload_coro(self) -> Any: + """Return the appropriate bot-reload coroutine, or ``None``.""" + bot_endpoint = os.environ.get("BOT_ENDPOINT", "") + tunnel_active = ( + getattr(self._tunnel, "is_active", False) + if self._tunnel + else False + ) + if bot_endpoint: + return lifecycle.recreate_bot( + provisioner=self._provisioner, az=self._az, + infra_store=self._infra_store, tunnel=self._tunnel, + rebuild_adapter=self._rebuild_adapter, + endpoint_override=bot_endpoint, + ) -async def _serve_media(req: web.Request) -> web.Response: - filename = req.match_info["filename"] - if ".." in filename or filename.startswith("/"): - return web.Response(status=403, text="Forbidden") - file_path = cfg.media_outgoing_sent_dir / filename - if not file_path.is_file(): - return web.Response(status=404, text="Not found") - content_type = ( - EXTENSION_TO_MIME.get(file_path.suffix.lower()) - or mimetypes.guess_type(file_path.name)[0] - or "application/octet-stream" - ) - return web.FileResponse(file_path, headers={"Content-Type": content_type}) + if self._tunnel and not tunnel_active: + from ..services.deployment.deployer import BotDeployer as _BD + if _BD._env("BOT_APP_ID"): + return lifecycle.start_tunnel_and_create_bot( + tunnel=self._tunnel, + provisioner=self._provisioner, + az=self._az, + infra_store=self._infra_store, + rebuild_adapter=self._rebuild_adapter, + ) -def _make_file_handler(fpath: Path): - async def handler(_req: web.Request) -> web.Response: - ct = mimetypes.guess_type(fpath.name)[0] or "application/octet-stream" - return web.FileResponse(fpath, headers={"Content-Type": ct}) - return handler + if self._tunnel and tunnel_active: + return lifecycle.recreate_bot( + provisioner=self._provisioner, az=self._az, + infra_store=self._infra_store, tunnel=self._tunnel, + rebuild_adapter=self._rebuild_adapter, + ) + return None -async def _serve_index(req: web.Request) -> web.Response: - index = _FRONTEND_DIR / "index.html" - if not index.exists(): - return web.Response(status=404, text="Not found") - html = index.read_text() - return web.Response( - text=html, - content_type="text/html", - headers={"Cache-Control": "no-cache, no-store, must-revalidate"}, - ) -async def _serve_spa_or_404(req: web.Request) -> web.Response: - if req.path.startswith("/api/"): - raise web.HTTPNotFound( - text='{"status":"error","message":"Unknown endpoint: ' - f'{req.method} {req.path}"' + '}', - content_type="application/json", - ) - return await _serve_index(req) +# -- CLI entry point ------------------------------------------------------- def main() -> None: + """Launch the server from the command line.""" import argparse + from .middleware import QuietAccessLogger + parser = argparse.ArgumentParser(description="Polyclaw server") parser.add_argument( "--admin-only", @@ -1116,10 +495,8 @@ def main() -> None: # Set the mode via env var so Settings.reload() picks it up. if args.admin_only: - import os os.environ["POLYCLAW_SERVER_MODE"] = "admin" elif args.runtime_only: - import os os.environ["POLYCLAW_SERVER_MODE"] = "runtime" logging.basicConfig( @@ -1138,7 +515,11 @@ def main() -> None: cfg.write_env(ADMIN_SECRET=secrets.token_urlsafe(24)) logger.info("Generated ADMIN_SECRET (persisted to .env)") - display_secret = cfg.admin_secret if cfg.admin_secret and not cfg.admin_secret.startswith("@kv:") else "" + display_secret = ( + cfg.admin_secret + if cfg.admin_secret and not cfg.admin_secret.startswith("@kv:") + else "" + ) if mode == ServerMode.runtime: admin_url = f"http://localhost:{port}" logger.info("Runtime endpoint: %s", admin_url) diff --git a/app/runtime/server/app_routes.py b/app/runtime/server/app_routes.py new file mode 100644 index 0000000..a5c30ee --- /dev/null +++ b/app/runtime/server/app_routes.py @@ -0,0 +1,230 @@ +"""Route registration helpers for AppFactory.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from ..config.settings import cfg +from .app_static import voice_handler +from .wiring import create_voice_handler + +if TYPE_CHECKING: + from ..services.cloud.azure import AzureCLI + from ..services.cloud.github import GitHubAuth + from ..services.deployment.aca_deployer import AcaDeployer + from ..services.deployment.deployer import BotDeployer + from ..services.deployment.provisioner import Provisioner + from ..services.tunnel import CloudflareTunnel + from ..state.deploy_state import DeployStateStore + from ..state.foundry_iq_config import FoundryIQConfigStore + from ..state.guardrails import GuardrailsConfigStore + from ..state.infra_config import InfraConfigStore + from ..state.mcp_config import McpConfigStore + from ..state.monitoring_config import MonitoringConfigStore + from ..state.proactive import ProactiveStore + from ..state.sandbox_config import SandboxConfigStore + from ..sandbox import SandboxExecutor + from ..messaging.proactive import ConversationReferenceStore + from ..state.session_store import SessionStore + from ..scheduler import Scheduler + from .bot_endpoint import BotEndpoint + + +def register_admin_routes( + router: web.UrlDispatcher, + *, + az: AzureCLI | None, + gh: GitHubAuth | None, + tunnel: CloudflareTunnel | None, + deployer: BotDeployer | None, + rebuild_adapter: Any, + infra_store: InfraConfigStore | None, + provisioner: Provisioner | None, + deploy_store: DeployStateStore | None, + aca_deployer: AcaDeployer | None, + sandbox_store: SandboxConfigStore | None, + sandbox_executor: SandboxExecutor | None, + foundry_iq_store: FoundryIQConfigStore | None, + monitoring_store: MonitoringConfigStore | None, + guardrails_store: GuardrailsConfigStore | None, +) -> None: + """Register routes available only in ``admin`` or ``combined`` mode.""" + from .setup import SetupRoutes, VoiceSetupRoutes + from .workspace import WorkspaceHandler + from .routes.content_safety_routes import ContentSafetyRoutes + from .routes.env_routes import EnvironmentRoutes + from .routes.foundry_iq_routes import FoundryIQRoutes + from .routes.network_routes import NetworkRoutes + from .routes.monitoring_routes import MonitoringRoutes + from .routes.sandbox_routes import SandboxRoutes + + SetupRoutes( + az, gh, tunnel, deployer, + rebuild_adapter, infra_store, + provisioner, deploy_store, + aca_deployer, + ).register(router) + + VoiceSetupRoutes(az, infra_store).register(router) + WorkspaceHandler().register(router) + EnvironmentRoutes(deploy_store, az).register(router) + SandboxRoutes(sandbox_store, sandbox_executor, az, deploy_store).register(router) + FoundryIQRoutes(foundry_iq_store, az, deploy_store).register(router) + NetworkRoutes(tunnel, az, sandbox_store, foundry_iq_store).register(router) + MonitoringRoutes(monitoring_store, az, deploy_store).register(router) + ContentSafetyRoutes(az, guardrails_store).register(router) + + from .routes.identity_routes import IdentityRoutes + + IdentityRoutes(az, guardrails_store).register(router) + + if az: + from .routes.security_preflight_routes import SecurityPreflightRoutes + from ..services.security.security_preflight import SecurityPreflightChecker + + SecurityPreflightRoutes(SecurityPreflightChecker(az)).register(router) + + +def register_runtime_routes( + app: web.Application, + *, + agent: Any, + session_store: SessionStore | None, + sandbox_executor: SandboxExecutor | None, + mcp_store: McpConfigStore | None, + guardrails_store: GuardrailsConfigStore | None, + scheduler: Scheduler | None, + proactive_store: ProactiveStore | None, + adapter: Any, + conv_store: ConversationReferenceStore | None, + bot_ep: BotEndpoint | None, + tunnel: CloudflareTunnel | None, + voice_routes: Any, + handle_reload: Any, +) -> None: + """Register routes available only in ``runtime`` or ``combined`` mode.""" + from ..registries.plugins import get_plugin_registry + from ..registries.skills import get_registry as get_skill_registry + from ..state.plugin_config import PluginConfigStore + from .chat import ChatHandler + from .routes.guardrails_routes import GuardrailsRoutes + from .routes.mcp_routes import McpRoutes + from .routes.plugin_routes import PluginRoutes + from .routes.proactive_routes import ProactiveRoutes + from .routes.profile_routes import ProfileRoutes + from .routes.scheduler_routes import SchedulerRoutes + from .routes.session_routes import SessionRoutes + from .routes.skill_routes import SkillRoutes + from .routes.tool_activity_routes import ToolActivityRoutes + + router = app.router + + router.add_post("/api/internal/reload", handle_reload) + + from .routes.network_routes import NetworkRoutes as _NR + + _nr_instance = _NR(tunnel) + router.add_get("/api/network/endpoints", _nr_instance._endpoints) + + hitl = agent.hitl_interceptor if agent else None + if hitl: + wire_hitl_services(app, hitl, guardrails_store) + + ChatHandler( + agent, + session_store=session_store, + sandbox_interceptor=sandbox_executor, + hitl_interceptor=hitl, + ).register(router) + + bot_ep.register(router) + register_voice_dynamic(app, voice_routes=voice_routes, agent=agent, tunnel=tunnel) + + SchedulerRoutes(scheduler).register(router) + SessionRoutes(session_store).register(router) + SkillRoutes(get_skill_registry()).register(router) + McpRoutes(mcp_store).register(router) + PluginRoutes(get_plugin_registry(), PluginConfigStore()).register(router) + ProfileRoutes().register(router) + GuardrailsRoutes( + guardrails_store, mcp_store, + skills_registry=get_skill_registry(), + ).register(router) + + from ..state.tool_activity_store import get_tool_activity_store + + ToolActivityRoutes(get_tool_activity_store(), session_store).register(router) + + ProactiveRoutes( + proactive_store, + adapter=adapter, + conv_store=conv_store, + app_id=cfg.bot_app_id, + ).register(router) + + +def wire_hitl_services( + app: web.Application, hitl: Any, guardrails_store: Any, +) -> None: + """Wire phone verifier, AITL reviewer, and prompt shield into HITL.""" + from ..agent.aitl import AitlReviewer + from ..agent.phone_verify import PhoneVerifier + from ..services.security.prompt_shield import PromptShieldService + + phone_verifier = PhoneVerifier(app) + hitl.set_phone_verifier(phone_verifier) + app["_phone_verifier"] = phone_verifier + + gcfg = guardrails_store.config + hitl.set_aitl_reviewer( + AitlReviewer(model=gcfg.aitl_model, spotlighting=gcfg.aitl_spotlighting), + ) + hitl.set_prompt_shield( + PromptShieldService( + endpoint=gcfg.content_safety_endpoint, mode=gcfg.filter_mode, + ), + ) + + +def register_voice_dynamic( + app: web.Application, + *, + voice_routes: Any, + agent: Any, + tunnel: Any, +) -> None: + """Register dynamic voice routes that delegate to the current handler.""" + app["_voice_handler"] = voice_routes + + def reinit_voice() -> None: + handler = create_voice_handler(agent, tunnel) + app["_voice_handler"] = handler + app["voice_configured"] = handler is not None + + app["_reinit_voice"] = reinit_voice + + router = app.router + router.add_post("/api/voice/call", voice_handler("_api_call")) + router.add_get("/api/voice/status", voice_handler("_api_status")) + # Legacy routes (kept for backwards compat) + router.add_post("/acs", voice_handler("_acs_callback", log_label="ACS callback")) + router.add_post("/acs/incoming", voice_handler("_acs_incoming", log_label="ACS incoming")) + router.add_get( + "/realtime-acs", + voice_handler("_ws_handler_acs", log_label="ACS media-streaming WS"), + ) + # Routes matching cfg.acs_callback_path / cfg.acs_media_streaming_websocket_path + router.add_post( + "/api/voice/acs-callback", + voice_handler("_acs_callback", log_label="ACS callback"), + ) + router.add_post( + "/api/voice/acs-callback/incoming", + voice_handler("_acs_incoming", log_label="ACS incoming"), + ) + router.add_get( + "/api/voice/media-streaming", + voice_handler("_ws_handler_acs", log_label="ACS media-streaming WS"), + ) diff --git a/app/runtime/server/app_static.py b/app/runtime/server/app_static.py new file mode 100644 index 0000000..70c2449 --- /dev/null +++ b/app/runtime/server/app_static.py @@ -0,0 +1,100 @@ +"""Static / SPA file handlers and voice route delegation.""" + +from __future__ import annotations + +import logging +import mimetypes +from collections.abc import Awaitable, Callable +from pathlib import Path + +from aiohttp import web + +from ..config.settings import cfg +from ..media import EXTENSION_TO_MIME + +logger = logging.getLogger(__name__) + +FRONTEND_DIR = Path(__file__).resolve().parent.parent.parent / "frontend" / "dist" + +# -- Voice dynamic route handler factory ----------------------------------- + +_VOICE_NOT_CONFIGURED = { + "status": "error", + "message": ( + "Voice calling is not configured. Deploy ACS + " + "Azure OpenAI resources in the Voice Call section first." + ), +} + + +def voice_handler( + method_name: str, *, log_label: str = "", +) -> Callable[[web.Request], Awaitable[web.Response]]: + """Create a dynamic voice route handler that delegates to the app's handler.""" + + async def handler(req: web.Request) -> web.Response: + h = req.app.get("_voice_handler") + if log_label: + logger.info( + "%s hit: method=%s path=%s handler=%s", + log_label, req.method, req.path, + "configured" if h else "NONE", + ) + if h is None: + return web.json_response(_VOICE_NOT_CONFIGURED, status=400) + return await getattr(h, method_name)(req) + + return handler + + +# -- Static file handlers ------------------------------------------------- + + +async def serve_media(req: web.Request) -> web.Response: + """Serve an uploaded media file from the outgoing directory.""" + filename = req.match_info["filename"] + if ".." in filename or filename.startswith("/"): + return web.Response(status=403, text="Forbidden") + file_path = cfg.media_outgoing_sent_dir / filename + if not file_path.is_file(): + return web.Response(status=404, text="Not found") + content_type = ( + EXTENSION_TO_MIME.get(file_path.suffix.lower()) + or mimetypes.guess_type(file_path.name)[0] + or "application/octet-stream" + ) + return web.FileResponse(file_path, headers={"Content-Type": content_type}) + + +def make_file_handler(fpath: Path) -> Callable: + """Return a handler that serves a single static file.""" + + async def handler(_req: web.Request) -> web.Response: + ct = mimetypes.guess_type(fpath.name)[0] or "application/octet-stream" + return web.FileResponse(fpath, headers={"Content-Type": ct}) + + return handler + + +async def serve_index(req: web.Request) -> web.Response: + """Serve the frontend index.html with no-cache headers.""" + index = FRONTEND_DIR / "index.html" + if not index.exists(): + return web.Response(status=404, text="Not found") + html = index.read_text() + return web.Response( + text=html, + content_type="text/html", + headers={"Cache-Control": "no-cache, no-store, must-revalidate"}, + ) + + +async def serve_spa_or_404(req: web.Request) -> web.Response: + """Serve the SPA for non-API paths, 404 for unknown /api/ paths.""" + if req.path.startswith("/api/"): + raise web.HTTPNotFound( + text='{"status":"error","message":"Unknown endpoint: ' + f'{req.method} {req.path}"' + '}', + content_type="application/json", + ) + return await serve_index(req) diff --git a/app/runtime/server/chat.py b/app/runtime/server/chat.py index f5f9782..db482b5 100644 --- a/app/runtime/server/chat.py +++ b/app/runtime/server/chat.py @@ -109,22 +109,45 @@ async def get_suggestions(self, _req: web.Request) -> web.Response: async def _dispatch(self, ws: web.WebSocketResponse, data: dict) -> None: action = data.get("action", "") logger.info("[chat.dispatch] action=%s keys=%s", action, list(data.keys())) - if action == "new_session": - await self._agent.new_session() - session_id = str(uuid.uuid4()) - logger.info("[chat.dispatch] new session created: %s", session_id) - self._sessions.start_session(session_id, model=cfg.copilot_model) - await ws.send_json({"type": "session_created", "session_id": session_id}) - elif action == "resume_session": - await self._resume_session(ws, data.get("session_id", "")) - elif action == "send": - await self._send_prompt(ws, data) - elif action == "approve_tool": - await self._handle_tool_approval(ws, data) + + handler = self._ACTION_DISPATCH.get(action) + if handler is not None: + await handler(self, ws, data) else: logger.warning("[chat.dispatch] unknown action: %s", action) await ws.send_json({"type": "error", "content": f"Unknown action: {action}"}) + async def _handle_new_session( + self, ws: web.WebSocketResponse, _data: dict + ) -> None: + await self._agent.new_session() + session_id = str(uuid.uuid4()) + logger.info("[chat.dispatch] new session created: %s", session_id) + self._sessions.start_session(session_id, model=cfg.copilot_model) + await ws.send_json({"type": "session_created", "session_id": session_id}) + + async def _dispatch_resume( + self, ws: web.WebSocketResponse, data: dict + ) -> None: + await self._resume_session(ws, data.get("session_id", "")) + + async def _dispatch_send( + self, ws: web.WebSocketResponse, data: dict + ) -> None: + await self._send_prompt(ws, data) + + async def _dispatch_approve( + self, ws: web.WebSocketResponse, data: dict + ) -> None: + await self._handle_tool_approval(ws, data) + + _ACTION_DISPATCH: dict[str, Any] = { + "new_session": _handle_new_session, + "resume_session": _dispatch_resume, + "send": _dispatch_send, + "approve_tool": _dispatch_approve, + } + async def _send_prompt(self, ws: web.WebSocketResponse, data: dict) -> None: text = (data.get("text") or data.get("message") or "").strip() if not text: @@ -132,16 +155,12 @@ async def _send_prompt(self, ws: web.WebSocketResponse, data: dict) -> None: return session_id = data.get("session_id", "") - logger.info("[chat.send_prompt] text=%r session=%s", text[:80], session_id or "(none)") + logger.info( + "[chat.send_prompt] text=%r session=%s", + text[:80], session_id or "(none)", + ) - # Ensure session store tracks a session -- auto-create one if none is - # active so that messages are always persisted to disk. - if session_id and self._sessions.current_session_id != session_id: - self._sessions.start_session(session_id) - elif not self._sessions.current_session_id: - auto_id = str(uuid.uuid4()) - logger.info("[chat.send_prompt] no active session, auto-creating %s", auto_id) - self._sessions.start_session(auto_id, model=cfg.copilot_model) + self._ensure_active_session(session_id) # Slash command dispatch if text.startswith("/"): @@ -155,6 +174,62 @@ async def _send_prompt(self, ws: web.WebSocketResponse, data: dict) -> None: memory.record("user", text) chunks: list[str] = [] + on_delta, on_event = self._make_event_callbacks(ws, chunks) + self._bind_hitl(ws) + + logger.info("[chat.send_prompt] calling agent.send() ...") + with agent_span( + "chat.agent_turn", + attributes={ + "chat.prompt_length": len(text), + "chat.session_id": session_id or "", + }, + ): + try: + response = await self._agent.send( + text, + on_delta=lambda d: asyncio.ensure_future(on_delta(d)), + on_event=lambda t, d: asyncio.ensure_future( + on_event({"type": t, **d}), + ), + ) + except Exception: + logger.exception("[chat.send_prompt] agent.send() raised") + record_event("agent_error") + await ws.send_json({ + "type": "error", + "content": "Agent error -- check server logs", + }) + return + finally: + self._unbind_hitl() + full_text = "".join(chunks) or response or "" + set_span_attribute("chat.response_length", len(full_text)) + set_span_attribute("chat.chunk_count", len(chunks)) + + await self._finalize_response(ws, full_text, chunks, memory) + + # -- _send_prompt helpers ---------------------------------------------- + + def _ensure_active_session(self, session_id: str) -> None: + """Ensure the session store is tracking an active session.""" + if session_id and self._sessions.current_session_id != session_id: + self._sessions.start_session(session_id) + elif not self._sessions.current_session_id: + auto_id = str(uuid.uuid4()) + logger.info( + "[chat.send_prompt] no active session, auto-creating %s", + auto_id, + ) + self._sessions.start_session(auto_id, model=cfg.copilot_model) + + def _make_event_callbacks( + self, ws: web.WebSocketResponse, chunks: list[str], + ) -> tuple[ + Any, # on_delta coroutine + Any, # on_event coroutine + ]: + """Build the delta and event callback coroutines for agent.send.""" async def on_delta(delta: str) -> None: chunks.append(delta) @@ -163,7 +238,9 @@ async def on_delta(delta: str) -> None: async def on_event(event: dict[str, Any]) -> None: event_type = event.pop("type", "") if event_type == "sandbox_exec" and self._sandbox: - result = await self._sandbox.intercept({"type": event_type, **event}) + result = await self._sandbox.intercept( + {"type": event_type, **event}, + ) if result: await ws.send_json({"type": "sandbox_result", **result}) # Record tool activity for audit @@ -173,7 +250,9 @@ async def on_event(event: dict[str, Any]) -> None: tool_name = event.get("tool", "unknown") interaction_type = "" if self._hitl: - interaction_type = self._hitl.pop_resolved_strategy(tool_name) + interaction_type = ( + self._hitl.pop_resolved_strategy(tool_name) + ) self._tool_activity.record_start( session_id=self._sessions.current_session_id, tool=tool_name, @@ -188,60 +267,68 @@ async def on_event(event: dict[str, Any]) -> None: call_id=event.get("call_id", ""), result=event.get("result", ""), ) - await ws.send_json({"type": "event", "event": event_type, **event}) + await ws.send_json({ + "type": "event", "event": event_type, **event, + }) - # Bind the HITL emitter so approval requests reach the WebSocket - if self._hitl: - def hitl_emit(etype: str, payload: dict[str, Any]) -> None: - logger.info( - "[chat.hitl_emit] sending event=%s payload_keys=%s", - etype, list(payload.keys()), - ) - asyncio.ensure_future(ws.send_json({"type": "event", "event": etype, **payload})) - self._hitl.set_emit(hitl_emit) - self._hitl.set_execution_context("interactive") - self._hitl.set_model(cfg.copilot_model) - self._hitl.set_tool_activity(self._tool_activity) - self._hitl.set_session_id(self._sessions.current_session_id) + return on_delta, on_event + + def _bind_hitl(self, ws: web.WebSocketResponse) -> None: + """Bind the HITL emitter so approval requests reach the WebSocket.""" + if not self._hitl: + logger.info("[chat.send_prompt] no HITL interceptor available") + return + + def hitl_emit(etype: str, payload: dict[str, Any]) -> None: logger.info( - "[chat.send_prompt] HITL emitter bound: model=%s", - cfg.copilot_model, + "[chat.hitl_emit] sending event=%s payload_keys=%s", + etype, list(payload.keys()), + ) + asyncio.ensure_future( + ws.send_json({"type": "event", "event": etype, **payload}), ) - else: - logger.info("[chat.send_prompt] no HITL interceptor available") - logger.info("[chat.send_prompt] calling agent.send() ...") - with agent_span( - "chat.agent_turn", - attributes={"chat.prompt_length": len(text), "chat.session_id": session_id or ""}, - ): - try: - response = await self._agent.send( - text, - on_delta=lambda d: asyncio.ensure_future(on_delta(d)), - on_event=lambda t, d: asyncio.ensure_future(on_event({"type": t, **d})), - ) - except Exception: - logger.exception("[chat.send_prompt] agent.send() raised") - record_event("agent_error") - await ws.send_json({"type": "error", "content": "Agent error -- check server logs"}) - return - finally: - if self._hitl: - self._hitl.clear_emit() - full_text = "".join(chunks) or response or "" - set_span_attribute("chat.response_length", len(full_text)) - set_span_attribute("chat.chunk_count", len(chunks)) - logger.info("[chat.send_prompt] response complete, len=%d, chunks=%d", len(full_text), len(chunks)) + self._hitl.bind_turn( + emit=hitl_emit, + execution_context="interactive", + model=cfg.copilot_model, + tool_activity=self._tool_activity, + session_id=self._sessions.current_session_id, + ) + logger.info( + "[chat.send_prompt] HITL emitter bound: model=%s", + cfg.copilot_model, + ) + + def _unbind_hitl(self) -> None: + """Clear the HITL emitter after a turn completes.""" + if self._hitl: + self._hitl.unbind_turn() + + async def _finalize_response( + self, + ws: web.WebSocketResponse, + full_text: str, + chunks: list[str], + memory: Any, + ) -> None: + """Log, persist, and send the final response artifacts.""" + logger.info( + "[chat.send_prompt] response complete, len=%d, chunks=%d", + len(full_text), len(chunks), + ) if not full_text: - logger.warning("[chat.send_prompt] empty response -- model may have timed out") + logger.warning( + "[chat.send_prompt] empty response -- " + "model may have timed out", + ) await ws.send_json({ "type": "error", "content": ( "The model did not respond. " - "This can happen when the model is overloaded or the session " - "is stale. Please try again." + "This can happen when the model is overloaded or " + "the session is stale. Please try again." ), }) @@ -254,7 +341,10 @@ def hitl_emit(etype: str, payload: dict[str, Any]) -> None: if outgoing: await ws.send_json({"type": "media", "files": outgoing}) if cards: - await ws.send_json({"type": "cards", "cards": [attachment_to_dict(c) for c in cards]}) + await ws.send_json({ + "type": "cards", + "cards": [attachment_to_dict(c) for c in cards], + }) await ws.send_json({"type": "done"}) async def _try_command( diff --git a/app/runtime/server/lifecycle.py b/app/runtime/server/lifecycle.py new file mode 100644 index 0000000..cff0e94 --- /dev/null +++ b/app/runtime/server/lifecycle.py @@ -0,0 +1,371 @@ +"""Application lifecycle -- startup and cleanup hooks.""" + +from __future__ import annotations + +import asyncio +import logging +import os +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from ..config.settings import ServerMode, cfg +from .wiring import create_adapter, create_voice_handler + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +_SCHEDULE_INTERVALS = {"hourly": 3600, "daily": 86400} + + +async def on_startup_runtime( + app: web.Application, + *, + mode: ServerMode, + adapter: object, + bot: object | None, + bot_ep: object | None, + conv_store: object | None, + agent: object | None, + tunnel: object | None, + infra_store: object, + provisioner: object | None, + az: object | None, + monitoring_store: object, + session_store: object | None, + foundry_iq_store: object, + scheduler: object | None, + rebuild_adapter: Callable, + make_notify: Callable[[], Callable[[str], Awaitable[bool]]], +) -> None: + """Start background tasks and bot infrastructure for the runtime.""" + from ..messaging.proactive_loop import proactive_delivery_loop + from ..scheduler import scheduler_loop + from ..services.otel import configure_otel + + # Bootstrap OTel if monitoring is configured + mon = monitoring_store + if mon.is_configured: + configure_otel( + mon.connection_string, + sampling_ratio=mon.config.sampling_ratio, + enable_live_metrics=mon.config.enable_live_metrics, + ) + + rebuild_adapter() + + app["scheduler_task"] = asyncio.create_task(scheduler_loop()) + app["proactive_task"] = asyncio.create_task( + proactive_delivery_loop(make_notify(), session_store=session_store), + ) + app["foundry_iq_task"] = asyncio.create_task( + _foundry_iq_index_loop(foundry_iq_store), + ) + + logger.info( + "[startup.runtime] mode=%s lockdown=%s bot_configured=%s " + "telegram_configured=%s tunnel=%s provisioner=%s az=%s", + mode.value, cfg.lockdown_mode, + infra_store.bot_configured if infra_store else "", + infra_store.telegram_configured if infra_store else "", + tunnel is not None, + provisioner is not None, + az is not None, + ) + + if cfg.lockdown_mode: + logger.info("Lock Down Mode active -- skipping infrastructure provisioning") + return + + bot_endpoint = os.environ.get("BOT_ENDPOINT", "") + + if mode != ServerMode.combined: + github_token = cfg.github_token + if not github_token: + logger.warning( + "[startup.runtime] Setup incomplete -- missing GITHUB_TOKEN. " + "Complete the setup wizard in the admin container, " + "then recreate the agent container.", + ) + return + + needs_bot = ( + infra_store.bot_configured + and infra_store.telegram_configured + ) + + if mode == ServerMode.combined: + if infra_store.bot_configured and provisioner: + from ..util.async_helpers import run_sync + + logger.info("Startup: provisioning infrastructure from config ...") + steps = await run_sync(provisioner.provision) + rebuild_adapter() + for s in steps: + logger.info( + " provision: %s = %s (%s)", + s.get("step"), s.get("status"), s.get("detail", ""), + ) + if needs_bot and tunnel: + await start_tunnel_and_create_bot( + tunnel=tunnel, provisioner=provisioner, az=az, + infra_store=infra_store, rebuild_adapter=rebuild_adapter, + ) + + elif bot_endpoint: + cfg.reload() + rebuild_adapter() + if needs_bot: + logger.info("Static bot endpoint: %s", bot_endpoint) + await recreate_bot( + provisioner=provisioner, az=az, infra_store=infra_store, + tunnel=tunnel, rebuild_adapter=rebuild_adapter, + endpoint_override=bot_endpoint, + ) + else: + logger.info("No messaging channels configured -- skipping bot service") + + else: + if needs_bot and tunnel: + from ..services.deployment.deployer import BotDeployer + + bot_app_id = BotDeployer._env("BOT_APP_ID") + if not bot_app_id: + logger.warning( + "Telegram configured but BOT_APP_ID missing -- " + "run Infrastructure Deploy in the admin wizard first" + ) + else: + await start_tunnel_and_create_bot( + tunnel=tunnel, provisioner=provisioner, az=az, + infra_store=infra_store, rebuild_adapter=rebuild_adapter, + ) + else: + reasons = [] + if not infra_store.bot_configured: + reasons.append("bot not configured") + if not infra_store.telegram_configured: + reasons.append("no channels configured") + if not tunnel: + reasons.append("no tunnel") + logger.info( + "Skipping bot service: %s", + ", ".join(reasons) or "no reason", + ) + + +async def on_startup_admin( + app: web.Application, + *, + az: object | None, + deploy_store: object, + guardrails_store: object, +) -> None: + """Admin startup: reconcile stale deployments and RBAC.""" + if az: + app["reconcile_task"] = asyncio.create_task( + _reconcile_deployments(az, deploy_store), + ) + app["cs_rbac_task"] = asyncio.create_task( + _ensure_content_safety_rbac(az, guardrails_store), + ) + + +async def on_cleanup( + app: web.Application, + *, + mode: ServerMode, + infra_store: object, + provisioner: object | None, + agent: object | None, +) -> None: + """Cancel background tasks and decommission infrastructure on shutdown.""" + for key in ("scheduler_task", "proactive_task", "foundry_iq_task", "reconcile_task"): + task = app.get(key) + if task and not task.done(): + task.cancel() + + if mode == ServerMode.combined: + if cfg.lockdown_mode: + logger.info("Lock Down Mode active -- skipping shutdown decommission") + elif ( + infra_store.bot_configured + and (cfg.env.read("BOT_NAME") or cfg.env.read("BOT_APP_ID")) + and provisioner + ): + from ..util.async_helpers import run_sync + + logger.info("Shutdown: decommissioning infrastructure ...") + steps = await run_sync(provisioner.decommission) + for s in steps: + logger.info( + " decommission: %s = %s (%s)", + s.get("step"), s.get("status"), s.get("detail", ""), + ) + + if agent: + await agent.stop() + + +# -- Bot infrastructure helpers ------------------------------------------- + +async def recreate_bot( + *, + provisioner: object | None, + az: object | None, + infra_store: object, + tunnel: object | None, + rebuild_adapter: Callable, + endpoint_override: str | None = None, +) -> None: + """Recreate the bot service endpoint.""" + from ..util.async_helpers import run_sync + + logger.info( + "[recreate_bot] provisioner=%s az=%s bot_configured=%s endpoint_override=%s", + provisioner is not None, + az is not None, + infra_store.bot_configured if infra_store else "?", + endpoint_override, + ) + if not (provisioner and az and infra_store.bot_configured): + logger.warning( + "[recreate_bot] precondition failed -- provisioner=%s az=%s bot_configured=%s", + provisioner is not None, + az is not None, + infra_store.bot_configured if infra_store else "?", + ) + return + + tunnel_url = endpoint_override or getattr(tunnel, "url", None) + if not tunnel_url: + logger.warning("Bot recreate: no endpoint URL available -- skipping") + return + + endpoint = tunnel_url + logger.info("Bot recreate: endpoint %s", endpoint) + try: + steps = await run_sync(provisioner.recreate_endpoint, endpoint) + rebuild_adapter() + for s in steps: + logger.info( + " recreate: %s = %s (%s)", + s.get("step"), s.get("status"), s.get("detail", ""), + ) + except Exception as exc: + logger.warning("Bot recreate: error -- %s", exc, exc_info=True) + + +async def start_tunnel_and_create_bot( + *, + tunnel: object, + provisioner: object | None, + az: object | None, + infra_store: object, + rebuild_adapter: Callable, +) -> None: + """Start the Cloudflare tunnel and recreate the bot service.""" + from ..util.async_helpers import run_sync + + logger.info("Starting tunnel for bot service endpoint ...") + tunnel_url = tunnel.url + if not tunnel_url and not tunnel.is_active: + max_retries = 5 + for attempt in range(1, max_retries + 1): + result = await run_sync(tunnel.start, cfg.admin_port) + if result: + logger.info("Tunnel started at %s", result.value) + break + if attempt < max_retries: + logger.warning( + "Tunnel failed (attempt %d/%d): %s -- retrying in %ds ...", + attempt, max_retries, + result.message if result else "unknown", + 2 * attempt, + ) + await asyncio.sleep(2 * attempt) + else: + logger.error( + "Tunnel failed after %d attempts: %s", + max_retries, + result.message if result else "unknown", + ) + return + + rebuild_adapter() + await recreate_bot( + provisioner=provisioner, az=az, infra_store=infra_store, + tunnel=tunnel, rebuild_adapter=rebuild_adapter, + ) + + +# -- Background loops ----------------------------------------------------- + +async def _foundry_iq_index_loop(store: object) -> None: + from ..services.foundry_iq import index_memories + from ..state.foundry_iq_config import FoundryIQConfigStore + from ..util.async_helpers import run_sync + + assert isinstance(store, FoundryIQConfigStore) + await asyncio.sleep(60) + while True: + try: + store._load() + schedule = store.config.index_schedule + if store.enabled and store.is_configured and schedule in _SCHEDULE_INTERVALS: + logger.info("Foundry IQ: running scheduled indexing (%s)...", schedule) + result = await run_sync(index_memories, store) + logger.info( + "Foundry IQ indexing: %s (indexed=%s)", + result.get("status"), result.get("indexed", 0), + ) + interval = _SCHEDULE_INTERVALS.get(schedule, 86400) + except asyncio.CancelledError: + return + except Exception as exc: + logger.error("Foundry IQ index loop error: %s", exc, exc_info=True) + interval = 3600 + try: + await asyncio.sleep(interval) + except asyncio.CancelledError: + return + + +async def _reconcile_deployments(az: object, deploy_store: object) -> None: + from ..services.resource_tracker import ResourceTracker + from ..util.async_helpers import run_sync + + try: + tracker = ResourceTracker(az, deploy_store) + cleaned = await run_sync(tracker.reconcile) + if cleaned: + logger.info( + "Startup reconcile: removed %d stale deployment(s): %s", + len(cleaned), ", ".join(c["deploy_id"] for c in cleaned), + ) + except Exception as exc: + logger.warning("Startup reconcile failed (non-fatal): %s", exc) + + +async def _ensure_content_safety_rbac(az: object, guardrails_store: object) -> None: + from .routes.content_safety_routes import ContentSafetyRoutes + + try: + routes = ContentSafetyRoutes( + az=az, + guardrails_store=guardrails_store, + ) + steps = await routes.ensure_rbac() + for s in steps: + logger.info( + "[startup.cs_rbac] %s = %s (%s)", + s.get("step"), s.get("status"), s.get("detail", ""), + ) + except Exception: + logger.warning( + "[startup.cs_rbac] Content Safety RBAC check failed", + exc_info=True, + ) diff --git a/app/runtime/server/middleware.py b/app/runtime/server/middleware.py new file mode 100644 index 0000000..a3e32e7 --- /dev/null +++ b/app/runtime/server/middleware.py @@ -0,0 +1,120 @@ +"""HTTP middleware -- auth, lockdown, and tunnel restrictions.""" + +from __future__ import annotations + +import hmac +import logging + +from aiohttp import web +from aiohttp.abc import AbstractAccessLogger + +from ..config.settings import cfg + +logger = logging.getLogger(__name__) + +_QUIET_PATHS = frozenset({"/api/setup/status", "/health"}) + + +class QuietAccessLogger(AbstractAccessLogger): + """Demotes polling-endpoint and noisy log entries to DEBUG.""" + + def log(self, request: web.BaseRequest, response: web.StreamResponse, time: float) -> None: + status = response.status + if request.path in _QUIET_PATHS or status == 401 or status in (502, 503): + level = logging.DEBUG + else: + level = logging.INFO + self.logger.log( + level, + "%s %s %s %s %.3fs", + request.remote, + request.method, + request.path, + status, + time, + ) + + +_PUBLIC_PREFIXES = ( + "/health", + "/api/messages", + "/acs", + "/realtime-acs", + "/api/voice/acs-callback", + "/api/voice/media-streaming", +) +_PUBLIC_EXACT = ("/api/auth/check",) + +# Tunnel restrictions and lockdown share the same base set as public prefixes; +# lockdown adds one extra path. +_TUNNEL_ALLOWED_PREFIXES = _PUBLIC_PREFIXES +_LOCKDOWN_ALLOWED_PREFIXES = _PUBLIC_PREFIXES + ("/api/setup/lockdown",) + +_CF_HEADERS = ("cf-connecting-ip", "cf-ray", "cf-ipcountry") + + +@web.middleware +async def lockdown_middleware(request: web.Request, handler): # type: ignore[type-arg] + """Block all admin panel routes when lockdown mode is active.""" + if not cfg.lockdown_mode: + return await handler(request) + if any(request.path.startswith(p) for p in _LOCKDOWN_ALLOWED_PREFIXES): + return await handler(request) + return web.json_response( + { + "status": "locked", + "message": ( + "Lock Down Mode is active. The admin panel is disabled. " + "Use /lockdown off via the bot to restore access." + ), + }, + status=403, + ) + + +@web.middleware +async def tunnel_restriction_middleware(request: web.Request, handler): # type: ignore[type-arg] + """Restrict Cloudflare-tunnelled requests to bot-only endpoints.""" + if not cfg.tunnel_restricted: + return await handler(request) + is_tunnel = any(request.headers.get(h) for h in _CF_HEADERS) + if not is_tunnel: + return await handler(request) + if any(request.path.startswith(p) for p in _TUNNEL_ALLOWED_PREFIXES): + return await handler(request) + return web.json_response({"status": "forbidden"}, status=403) + + +@web.middleware +async def auth_middleware(request: web.Request, handler): # type: ignore[type-arg] + """Require Bearer token on ``/api/*`` endpoints (except public ones).""" + secret = cfg.admin_secret + if not secret: + return await handler(request) + + path = request.path + + # Only protect /api/* endpoints (except public ones); frontend assets are public + if not path.startswith("/api/"): + return await handler(request) + + if path in _PUBLIC_EXACT or any(path.startswith(p) for p in _PUBLIC_PREFIXES): + return await handler(request) + + auth = request.headers.get("Authorization", "") + expected = f"Bearer {secret}" + if hmac.compare_digest(auth, expected): + return await handler(request) + + token_param = request.query.get("token", "") + if token_param and hmac.compare_digest(token_param, secret): + return await handler(request) + + secret_param = request.query.get("secret", "") + if secret_param and hmac.compare_digest(secret_param, secret): + return await handler(request) + + return web.json_response( + {"status": "unauthorized", "message": "Invalid or missing admin secret"}, + status=401, + ) diff --git a/app/runtime/server/routes/__init__.py b/app/runtime/server/routes/__init__.py index 16c55ac..ad887cd 100644 --- a/app/runtime/server/routes/__init__.py +++ b/app/runtime/server/routes/__init__.py @@ -2,26 +2,40 @@ from __future__ import annotations +from .content_safety_routes import ContentSafetyRoutes from .env_routes import EnvironmentRoutes from .foundry_iq_routes import FoundryIQRoutes +from .guardrails_routes import GuardrailsRoutes +from .identity_routes import IdentityRoutes from .mcp_routes import McpRoutes +from .monitoring_routes import MonitoringRoutes +from .network_routes import NetworkRoutes from .plugin_routes import PluginRoutes from .proactive_routes import ProactiveRoutes from .profile_routes import ProfileRoutes from .sandbox_routes import SandboxRoutes from .scheduler_routes import SchedulerRoutes +from .security_preflight_routes import SecurityPreflightRoutes from .session_routes import SessionRoutes from .skill_routes import SkillRoutes +from .tool_activity_routes import ToolActivityRoutes __all__ = [ + "ContentSafetyRoutes", "EnvironmentRoutes", "FoundryIQRoutes", + "GuardrailsRoutes", + "IdentityRoutes", "McpRoutes", + "MonitoringRoutes", + "NetworkRoutes", "PluginRoutes", "ProactiveRoutes", "ProfileRoutes", "SandboxRoutes", "SchedulerRoutes", + "SecurityPreflightRoutes", "SessionRoutes", "SkillRoutes", + "ToolActivityRoutes", ] diff --git a/app/runtime/server/routes/_helpers.py b/app/runtime/server/routes/_helpers.py new file mode 100644 index 0000000..dbe9eda --- /dev/null +++ b/app/runtime/server/routes/_helpers.py @@ -0,0 +1,24 @@ +"""Shared helpers for route handlers.""" + +from __future__ import annotations + +from typing import Any + +from aiohttp import web + + +def no_az() -> web.Response: + """Return a standard error when Azure CLI is unavailable.""" + return web.json_response( + {"status": "error", "message": "Azure CLI not available"}, status=500 + ) + + +def fail_response(steps: list[dict[str, Any]]) -> web.Response: + """Return a standard provisioning-failure response with step details.""" + failed = [s for s in steps if s.get("status") == "failed"] + msg = failed[0].get("detail", "Unknown error") if failed else "Unknown error" + return web.json_response( + {"status": "error", "steps": steps, "message": f"Provisioning failed: {msg}"}, + status=500, + ) diff --git a/app/runtime/server/routes/content_safety_routes.py b/app/runtime/server/routes/content_safety_routes.py index 2a00de5..eb5a2d4 100644 --- a/app/runtime/server/routes/content_safety_routes.py +++ b/app/runtime/server/routes/content_safety_routes.py @@ -9,9 +9,9 @@ from aiohttp import web from ...config.settings import cfg -from ...services.azure import AzureCLI -from ...services.prompt_shield import PromptShieldService -from ...state.guardrails_config import GuardrailsConfigStore +from ...services.cloud.azure import AzureCLI +from ...services.security.prompt_shield import PromptShieldService +from ...state.guardrails import GuardrailsConfigStore from ...util.async_helpers import run_sync logger = logging.getLogger(__name__) diff --git a/app/runtime/server/routes/env_routes.py b/app/runtime/server/routes/env_routes.py index 33ec200..c109155 100644 --- a/app/runtime/server/routes/env_routes.py +++ b/app/runtime/server/routes/env_routes.py @@ -6,11 +6,12 @@ from aiohttp import web -from ...services.azure import AzureCLI -from ...services.misconfig_checker import MisconfigChecker +from ...services.cloud.azure import AzureCLI +from ...services.security.misconfig_checker import MisconfigChecker from ...services.resource_tracker import ResourceTracker from ...state.deploy_state import DeployStateStore from ...util.async_helpers import run_sync +from ._helpers import no_az as _no_az logger = logging.getLogger(__name__) @@ -118,8 +119,3 @@ async def _misconfig_check(self, req: web.Request) -> web.Response: result = await run_sync(checker.check_all, resource_groups) return web.json_response(MisconfigChecker.to_dict(result)) - -def _no_az() -> web.Response: - return web.json_response( - {"status": "error", "message": "Azure CLI not available"}, status=500 - ) diff --git a/app/runtime/server/routes/foundry_iq_routes.py b/app/runtime/server/routes/foundry_iq_routes.py index 23ef487..def92be 100644 --- a/app/runtime/server/routes/foundry_iq_routes.py +++ b/app/runtime/server/routes/foundry_iq_routes.py @@ -8,7 +8,7 @@ from aiohttp import web -from ...services.azure import AzureCLI +from ...services.cloud.azure import AzureCLI from ...services.foundry_iq import ( delete_index, ensure_index, @@ -21,6 +21,7 @@ from ...state.deploy_state import DeployStateStore from ...state.foundry_iq_config import FoundryIQConfigStore from ...util.async_helpers import run_sync +from ._helpers import fail_response as _fail_response, no_az as _no_az logger = logging.getLogger(__name__) @@ -436,15 +437,3 @@ async def _deploy_model( }) return deployment_name - -def _no_az() -> web.Response: - return web.json_response( - {"status": "error", "message": "Azure CLI not available"}, status=500 - ) - - -def _fail_response(steps: list[dict[str, Any]]) -> web.Response: - return web.json_response( - {"status": "error", "message": "Provisioning failed", "steps": steps}, - status=500, - ) diff --git a/app/runtime/server/routes/guardrails_routes.py b/app/runtime/server/routes/guardrails_routes.py index a137041..50a1ce1 100644 --- a/app/runtime/server/routes/guardrails_routes.py +++ b/app/runtime/server/routes/guardrails_routes.py @@ -3,31 +3,24 @@ from __future__ import annotations import logging +from collections.abc import Callable from typing import Any from aiohttp import web -from ...agent.tools import get_all_tools from ...registries.skills import SkillRegistry -from ...state.guardrails_config import ( +from ...state.guardrails import ( GuardrailsConfigStore, list_model_tiers, list_presets, ) from ...state.mcp_config import McpConfigStore - -logger = logging.getLogger(__name__) - -_BUILTIN_SDK_TOOLS: list[dict[str, str]] = [ - {"name": "create", "source": "sdk", "description": "Create a new file"}, - {"name": "edit", "source": "sdk", "description": "Edit an existing file"}, - {"name": "view", "source": "sdk", "description": "View file contents"}, - {"name": "grep", "source": "sdk", "description": "Search file contents"}, - {"name": "glob", "source": "sdk", "description": "Find files by pattern"}, - {"name": "run", "source": "sdk", "description": "Run a shell command"}, - {"name": "bash", "source": "sdk", "description": "Run a bash command"}, - {"name": "report_intent", "source": "sdk", "description": "Log agent intent (always auto-approved)"}, -] +from .guardrails_routes_meta import ( + collect_tools, + get_template_handler, + list_contexts_handler, + list_templates_handler, +) class GuardrailsRoutes: @@ -57,81 +50,87 @@ def register(self, router: web.UrlDispatcher) -> None: router.add_post("/api/guardrails/model-columns", self._add_model_column) router.add_delete("/api/guardrails/model-columns/{model}", self._remove_model_column) router.add_put("/api/guardrails/model-policies/{model}/{ctx}/{tool_id}", self._set_model_policy) - router.add_get("/api/guardrails/contexts", self._list_contexts) + router.add_get("/api/guardrails/contexts", list_contexts_handler) router.add_get("/api/guardrails/presets", self._list_presets) router.add_post("/api/guardrails/presets/{preset_id}", self._apply_preset) router.add_post("/api/guardrails/set-all", self._set_all) router.add_post("/api/guardrails/model-defaults", self._apply_model_defaults) router.add_get("/api/guardrails/model-tiers", self._list_model_tiers) - router.add_get("/api/guardrails/templates", self._list_templates) - router.add_get("/api/guardrails/templates/{name}", self._get_template) + router.add_get("/api/guardrails/templates", list_templates_handler) + router.add_get("/api/guardrails/templates/{name}", get_template_handler) router.add_get("/api/guardrails/background-agents", self._list_background_agents) router.add_get("/api/guardrails/policy-yaml", self._get_policy_yaml) router.add_put("/api/guardrails/policy-yaml", self._put_policy_yaml) + @staticmethod + def _apply_validated_field( + data: dict[str, Any], + key: str, + setter: Callable[[Any], None], + ) -> web.Response | None: + """Apply a data field via *setter*, returning a 400 on ValueError.""" + if key not in data: + return None + try: + setter(data[key]) + except ValueError as exc: + return web.json_response( + {"status": "error", "message": str(exc)}, status=400, + ) + return None + async def _get_config(self, _req: web.Request) -> web.Response: return web.json_response({"status": "ok", **self._store.to_dict()}) async def _update_config(self, req: web.Request) -> web.Response: data = await req.json() - # Accept both frontend ('enabled') and backend ('hitl_enabled') field names. + + # Boolean fields (no validation needed) if "enabled" in data: self._store.set_hitl_enabled(bool(data["enabled"])) if "hitl_enabled" in data: self._store.set_hitl_enabled(bool(data["hitl_enabled"])) - # Accept both 'default_strategy' (frontend) and 'default_action' (backend). - if "default_strategy" in data: - try: - self._store.set_default_action(data["default_strategy"]) - except ValueError as exc: - return web.json_response( - {"status": "error", "message": str(exc)}, status=400 - ) - if "default_action" in data: - try: - self._store.set_default_action(data["default_action"]) - except ValueError as exc: - return web.json_response( - {"status": "error", "message": str(exc)}, status=400 - ) - # Accept both 'hitl_channel' (frontend) and 'default_channel' (backend). - if "hitl_channel" in data: - try: - self._store.set_default_channel(data["hitl_channel"]) - except ValueError as exc: - return web.json_response( - {"status": "error", "message": str(exc)}, status=400 - ) - if "default_channel" in data: - try: - self._store.set_default_channel(data["default_channel"]) - except ValueError as exc: - return web.json_response( - {"status": "error", "message": str(exc)}, status=400 - ) + + # Validated fields -- accept both frontend and backend key names + for key in ("default_strategy", "default_action"): + err = self._apply_validated_field( + data, key, self._store.set_default_action, + ) + if err: + return err + for key in ("hitl_channel", "default_channel"): + err = self._apply_validated_field( + data, key, self._store.set_default_channel, + ) + if err: + return err + err = self._apply_validated_field( + data, "filter_mode", self._store.set_filter_mode, + ) + if err: + return err + + # Simple fields (no validation) if "phone_number" in data: self._store.set_phone_number(data["phone_number"]) if "aitl_model" in data: self._store.set_aitl_model(data["aitl_model"]) if "aitl_spotlighting" in data: self._store.set_aitl_spotlighting(bool(data["aitl_spotlighting"])) - if "filter_mode" in data: - try: - self._store.set_filter_mode(data["filter_mode"]) - except ValueError as exc: - return web.json_response( - {"status": "error", "message": str(exc)}, status=400 - ) if "content_safety_endpoint" in data: self._store.set_content_safety_endpoint(data["content_safety_endpoint"]) + + # Context defaults (batch update) if "context_defaults" in data: for ctx, strategy in data["context_defaults"].items(): try: self._store.set_context_default(ctx, strategy) except ValueError as exc: return web.json_response( - {"status": "error", "message": str(exc)}, status=400 + {"status": "error", "message": str(exc)}, + status=400, ) + # Single context_default update (used by Background Agents tab) if "context_default" in data: cd = data["context_default"] @@ -143,10 +142,12 @@ async def _update_config(self, req: web.Request) -> web.Response: self._store.set_context_default(ctx, strategy) except ValueError as exc: return web.json_response( - {"status": "error", "message": str(exc)}, status=400 + {"status": "error", "message": str(exc)}, + status=400, ) else: self._store.remove_context_default(ctx) + return web.json_response({"status": "ok", **self._store.to_dict()}) async def _list_rules(self, _req: web.Request) -> web.Response: @@ -250,7 +251,7 @@ async def _delete_rule(self, req: web.Request) -> web.Response: async def _list_tools(self, _req: web.Request) -> web.Response: """Return all tools and MCP servers available to the agent.""" - tools = self._collect_tools() + tools = collect_tools() mcps = self._collect_mcps() return web.json_response({ "status": "ok", @@ -261,7 +262,7 @@ async def _list_tools(self, _req: web.Request) -> web.Response: async def _list_inventory(self, _req: web.Request) -> web.Response: """Return a unified tool inventory for the policy matrix UI.""" inventory: list[dict[str, Any]] = [] - for t in self._collect_tools(): + for t in collect_tools(): inventory.append({ "id": t["name"], "name": t["name"], @@ -339,22 +340,6 @@ async def _set_model_policy(self, req: web.Request) -> web.Response: ) return web.json_response({"status": "ok"}) - def _collect_tools(self) -> list[dict[str, Any]]: - """Gather custom tools defined via @define_tool + built-in SDK tools.""" - result: list[dict[str, Any]] = [] - for t in get_all_tools(): - name = getattr(t, "name", "") or getattr(t, "__name__", "unknown") - desc = getattr(t, "description", "") or "" - # Avoid using the class-level __doc__ which is the Tool repr - if not desc and hasattr(t, "__doc__") and t.__doc__: - first_line = t.__doc__.strip().split("\n")[0] - if not first_line.startswith("Tool("): - desc = first_line - result.append({"name": name, "source": "custom", "description": desc}) - for entry in _BUILTIN_SDK_TOOLS: - result.append(dict(entry)) - return result - def _collect_mcps(self) -> list[dict[str, Any]]: """Gather configured MCP servers.""" return [ @@ -380,30 +365,6 @@ def _collect_skills(self) -> list[dict[str, Any]]: for s in self._skills.list_installed() ] - async def _list_contexts(self, _req: web.Request) -> web.Response: - """Return available execution contexts, HITL channels, and strategies.""" - return web.json_response({ - "status": "ok", - "contexts": [ - {"id": "interactive", "label": "Interactive", "description": "User is chatting via the web UI or TUI"}, - {"id": "background", "label": "Background", "description": "Scheduled tasks and proactive loop"}, - {"id": "voice", "label": "Voice", "description": "Realtime voice call sessions"}, - {"id": "api", "label": "API", "description": "External API-triggered executions"}, - ], - "channels": [ - {"id": "chat", "label": "Chat", "description": "In-session WebSocket approval prompt"}, - {"id": "phone", "label": "Phone Call", "description": "Outbound phone call verification via ACS"}, - ], - "strategies": [ - {"id": "allow", "label": "Allow", "description": "Pass through without review", "color": "var(--ok)"}, - {"id": "deny", "label": "Deny", "description": "Block immediately", "color": "var(--err)"}, - {"id": "hitl", "label": "HITL", "description": "Human-in-the-loop approval via chat", "color": "var(--blue)"}, - {"id": "pitl", "label": "PITL (Experimental)", "description": "Phone-in-the-loop approval via outbound phone call (experimental)", "color": "var(--cyan, #22d3ee)"}, - {"id": "aitl", "label": "AITL", "description": "AI-in-the-loop: background reviewer agent decides", "color": "var(--gold)"}, - {"id": "filter", "label": "Filter", "description": "Content Safety Prompt Shields injection detection", "color": "var(--purple, #a78bfa)"}, - ], - }) - async def _list_presets(self, _req: web.Request) -> web.Response: """Return available preset definitions with model-tier metadata.""" return web.json_response({ @@ -466,66 +427,9 @@ async def _list_model_tiers(self, _req: web.Request) -> web.Response: "models": list_model_tiers(), }) - async def _list_templates(self, _req: web.Request) -> web.Response: - """Return the list of prompt template names.""" - from pathlib import Path as _Path - - from ...agent.prompt import _TEMPLATES_DIR - - templates: list[dict[str, str]] = [] - if _TEMPLATES_DIR.is_dir(): - for f in sorted(_TEMPLATES_DIR.iterdir()): - if f.suffix == ".md": - templates.append({ - "name": f.name, - "size": str(f.stat().st_size), - }) - # Also include SOUL.md if it exists - from ...config.settings import cfg - - if cfg.soul_path.exists(): - templates.insert(0, { - "name": "SOUL.md", - "size": str(cfg.soul_path.stat().st_size), - }) - return web.json_response({"status": "ok", "templates": templates}) - - async def _get_template(self, req: web.Request) -> web.Response: - """Fetch the content of a single prompt template.""" - name = req.match_info["name"] - if ".." in name or "/" in name: - return web.json_response( - {"status": "error", "message": "invalid name"}, status=400 - ) - # Check SOUL.md first - if name == "SOUL.md": - from ...config.settings import cfg - - if cfg.soul_path.exists(): - return web.json_response({ - "status": "ok", - "name": name, - "content": cfg.soul_path.read_text(), - }) - return web.json_response( - {"status": "error", "message": "not found"}, status=404 - ) - from ...agent.prompt import _TEMPLATES_DIR - - path = _TEMPLATES_DIR / name - if not path.exists() or not path.suffix == ".md": - return web.json_response( - {"status": "error", "message": "not found"}, status=404 - ) - return web.json_response({ - "status": "ok", - "name": name, - "content": path.read_text(), - }) - async def _list_background_agents(self, _req: web.Request) -> web.Response: """Return metadata for all background agents with current policy.""" - from ...state.guardrails_config import list_background_agents + from ...state.guardrails import list_background_agents agents = list_background_agents() config = self._store.config diff --git a/app/runtime/server/routes/guardrails_routes_meta.py b/app/runtime/server/routes/guardrails_routes_meta.py new file mode 100644 index 0000000..15096ec --- /dev/null +++ b/app/runtime/server/routes/guardrails_routes_meta.py @@ -0,0 +1,135 @@ +"""Guardrails metadata handlers -- static context, template, and tool data.""" + +from __future__ import annotations + +from typing import Any + +from aiohttp import web + +BUILTIN_SDK_TOOLS: list[dict[str, str]] = [ + {"name": "create", "source": "sdk", "description": "Create a new file"}, + {"name": "edit", "source": "sdk", "description": "Edit an existing file"}, + {"name": "view", "source": "sdk", "description": "View file contents"}, + {"name": "grep", "source": "sdk", "description": "Search file contents"}, + {"name": "glob", "source": "sdk", "description": "Find files by pattern"}, + {"name": "run", "source": "sdk", "description": "Run a shell command"}, + {"name": "bash", "source": "sdk", "description": "Run a bash command"}, + {"name": "report_intent", "source": "sdk", + "description": "Log agent intent (always auto-approved)"}, +] + + +async def list_contexts_handler(_req: web.Request) -> web.Response: + """Return available execution contexts, HITL channels, and strategies.""" + return web.json_response({ + "status": "ok", + "contexts": [ + {"id": "interactive", "label": "Interactive", + "description": "User is chatting via the web UI or TUI"}, + {"id": "background", "label": "Background", + "description": "Scheduled tasks and proactive loop"}, + {"id": "voice", "label": "Voice", + "description": "Realtime voice call sessions"}, + {"id": "api", "label": "API", + "description": "External API-triggered executions"}, + ], + "channels": [ + {"id": "chat", "label": "Chat", + "description": "In-session WebSocket approval prompt"}, + {"id": "phone", "label": "Phone Call", + "description": "Outbound phone call verification via ACS"}, + ], + "strategies": [ + {"id": "allow", "label": "Allow", + "description": "Pass through without review", "color": "var(--ok)"}, + {"id": "deny", "label": "Deny", + "description": "Block immediately", "color": "var(--err)"}, + {"id": "hitl", "label": "HITL", + "description": "Human-in-the-loop approval via chat", + "color": "var(--blue)"}, + {"id": "pitl", "label": "PITL (Experimental)", + "description": "Phone-in-the-loop approval via outbound phone call" + " (experimental)", + "color": "var(--cyan, #22d3ee)"}, + {"id": "aitl", "label": "AITL", + "description": "AI-in-the-loop: background reviewer agent decides", + "color": "var(--gold)"}, + {"id": "filter", "label": "Filter", + "description": "Content Safety Prompt Shields injection detection", + "color": "var(--purple, #a78bfa)"}, + ], + }) + + +async def list_templates_handler(_req: web.Request) -> web.Response: + """Return the list of prompt template names.""" + from pathlib import Path as _Path + + from ...agent.prompt import TEMPLATES_DIR + from ...config.settings import cfg + + templates: list[dict[str, str]] = [] + if TEMPLATES_DIR.is_dir(): + for f in sorted(TEMPLATES_DIR.iterdir()): + if f.suffix == ".md": + templates.append({ + "name": f.name, + "size": str(f.stat().st_size), + }) + if cfg.soul_path.exists(): + templates.insert(0, { + "name": "SOUL.md", + "size": str(cfg.soul_path.stat().st_size), + }) + return web.json_response({"status": "ok", "templates": templates}) + + +async def get_template_handler(req: web.Request) -> web.Response: + """Fetch the content of a single prompt template.""" + name = req.match_info["name"] + if ".." in name or "/" in name: + return web.json_response( + {"status": "error", "message": "invalid name"}, status=400, + ) + if name == "SOUL.md": + from ...config.settings import cfg + + if cfg.soul_path.exists(): + return web.json_response({ + "status": "ok", + "name": name, + "content": cfg.soul_path.read_text(), + }) + return web.json_response( + {"status": "error", "message": "not found"}, status=404, + ) + from ...agent.prompt import TEMPLATES_DIR + + path = TEMPLATES_DIR / name + if not path.exists() or not path.suffix == ".md": + return web.json_response( + {"status": "error", "message": "not found"}, status=404, + ) + return web.json_response({ + "status": "ok", + "name": name, + "content": path.read_text(), + }) + + +def collect_tools() -> list[dict[str, Any]]: + """Gather custom tools defined via ``@define_tool`` plus built-in SDK tools.""" + from ...agent.tools import get_all_tools + + result: list[dict[str, Any]] = [] + for t in get_all_tools(): + name = getattr(t, "name", "") or getattr(t, "__name__", "unknown") + desc = getattr(t, "description", "") or "" + if not desc and hasattr(t, "__doc__") and t.__doc__: + first_line = t.__doc__.strip().split("\n")[0] + if not first_line.startswith("Tool("): + desc = first_line + result.append({"name": name, "source": "custom", "description": desc}) + for entry in BUILTIN_SDK_TOOLS: + result.append(dict(entry)) + return result diff --git a/app/runtime/server/routes/identity_routes.py b/app/runtime/server/routes/identity_routes.py index 4b715fb..53dd9e1 100644 --- a/app/runtime/server/routes/identity_routes.py +++ b/app/runtime/server/routes/identity_routes.py @@ -9,8 +9,8 @@ from aiohttp import web from ...config.settings import cfg -from ...services.azure import AzureCLI -from ...state.guardrails_config import GuardrailsConfigStore +from ...services.cloud.azure import AzureCLI +from ...state.guardrails import GuardrailsConfigStore from ...state.sandbox_config import SandboxConfigStore from ...util.async_helpers import run_sync @@ -151,17 +151,40 @@ async def _roles(self, _req: web.Request) -> web.Response: "condition": a.get("condition", ""), }) - # Check which required roles are present + # Resolve expected session pool scope for scope-aware checking. + session_pool_scope = self._resolve_session_pool_scope() + + # Check which required roles are present. For the Session + # Executor role we also verify that the assignment scope covers + # the configured session pool -- an assignment on a different + # resource / RG still results in 403. assigned_names = {a.get("roleDefinitionName", "") for a in assignments} checks: list[dict[str, Any]] = [] for req in _REQUIRED_ROLES: - present = req["role"] in assigned_names - checks.append({ - "feature": req["feature"], - "role": req["role"], - "present": present, - "data_action": req.get("data_action", ""), - }) + role_name = req["role"] + if role_name == "Azure ContainerApps Session Executor": + present, detail = self._check_session_executor_scope( + assignments, session_pool_scope, + ) + check: dict[str, Any] = { + "feature": req["feature"], + "role": role_name, + "present": present, + "data_action": req.get("data_action", ""), + } + if detail: + check["detail"] = detail + if session_pool_scope: + check["expected_scope"] = session_pool_scope + checks.append(check) + else: + present = role_name in assigned_names + checks.append({ + "feature": req["feature"], + "role": role_name, + "present": present, + "data_action": req.get("data_action", ""), + }) return web.json_response({ "status": "ok", @@ -250,6 +273,70 @@ async def _fix_roles(self, req: web.Request) -> web.Response: # Helpers # ------------------------------------------------------------------ + def _resolve_session_pool_scope(self) -> str: + """Return the expected ARM scope for the configured session pool, or ``""``.""" + store = self._sandbox_store or SandboxConfigStore() + pool_id = store.pool_id + if pool_id: + return pool_id + endpoint = store.session_pool_endpoint + if endpoint: + # The management-plane endpoint embeds the resource path, e.g. + # https://.dynamicsessions.io/subscriptions/.../sessionPools/ + # Extract the ARM resource id from it. + for prefix in ( + "https://", "http://", + ): + if endpoint.lower().startswith(prefix): + endpoint = endpoint[len(prefix):] + break + parts = endpoint.split("/") + try: + sub_idx = parts.index("subscriptions") + return "/" + "/".join(parts[sub_idx:]) + except ValueError: + pass + return "" + + @staticmethod + def _check_session_executor_scope( + assignments: list[Any], + expected_scope: str, + ) -> tuple[bool, str]: + """Check whether Session Executor is assigned on the right scope. + + Returns ``(present, detail)`` where *detail* explains mismatches. + """ + role_name = "Azure ContainerApps Session Executor" + matching: list[str] = [] + for a in assignments: + if not isinstance(a, dict): + continue + if a.get("roleDefinitionName", "") != role_name: + continue + scope = a.get("scope", "") + matching.append(scope) + + if not matching: + return False, "Role not assigned to this identity" + + if not expected_scope: + # No session pool configured; can't verify scope. + return True, "Role present (session pool scope not configured -- cannot verify)" + + normalised = expected_scope.lower().rstrip("/") + for scope in matching: + if scope.lower().rstrip("/") == normalised: + return True, "" + + # Role exists but on wrong scope + scopes_str = ", ".join(matching) + return False, ( + f"Role assigned on wrong scope. " + f"Expected: {expected_scope} -- " + f"Found: {scopes_str}" + ) + async def _fix_session_pool_role( self, principal_id: str, diff --git a/app/runtime/server/routes/mcp_routes.py b/app/runtime/server/routes/mcp_routes.py index 12a2e0e..46e9d94 100644 --- a/app/runtime/server/routes/mcp_routes.py +++ b/app/runtime/server/routes/mcp_routes.py @@ -167,7 +167,3 @@ async def _registry(self, req: web.Request) -> web.Response: "servers": servers, "source": "github.com/mcp", }) - - -def _error(message: str, status: int = 500) -> web.Response: - return web.json_response({"status": "error", "message": message}, status=status) diff --git a/app/runtime/server/routes/monitoring_routes.py b/app/runtime/server/routes/monitoring_routes.py index 1a645e8..eec581e 100644 --- a/app/runtime/server/routes/monitoring_routes.py +++ b/app/runtime/server/routes/monitoring_routes.py @@ -8,11 +8,12 @@ from aiohttp import web -from ...services.azure import AzureCLI +from ...services.cloud.azure import AzureCLI from ...services.otel import configure_otel, get_status, is_active, shutdown_otel from ...state.deploy_state import DeployStateStore from ...state.monitoring_config import MonitoringConfigStore from ...util.async_helpers import run_sync +from ._helpers import fail_response as _fail_response, no_az as _no_az logger = logging.getLogger(__name__) @@ -436,17 +437,3 @@ async def _create_app_insights( }) return cs - -def _no_az() -> web.Response: - return web.json_response( - {"status": "error", "message": "Azure CLI not available"}, status=500 - ) - - -def _fail_response(steps: list[dict[str, Any]]) -> web.Response: - failed = [s for s in steps if s.get("status") == "failed"] - msg = failed[0].get("detail", "Unknown error") if failed else "Unknown error" - return web.json_response( - {"status": "error", "steps": steps, "message": f"Provisioning failed: {msg}"}, - status=500, - ) diff --git a/app/runtime/server/routes/network_audit.py b/app/runtime/server/routes/network_audit.py new file mode 100644 index 0000000..d598bee --- /dev/null +++ b/app/runtime/server/routes/network_audit.py @@ -0,0 +1,299 @@ +"""Azure resource network audit helpers for the network-info API.""" + +from __future__ import annotations + +from typing import Any + +from ...services.cloud.azure import AzureCLI +from ...state.foundry_iq_config import FoundryIQConfigStore +from ...state.sandbox_config import SandboxConfigStore + +# Maps lowercased Azure resource type prefixes to audit functions. +_RESOURCE_AUDITORS: dict[str, Any] = {} # populated after function definitions + + +def collect_resource_groups( + cfg: Any, + sandbox_store: SandboxConfigStore | None, + foundry_iq_store: FoundryIQConfigStore | None, +) -> list[str]: + """Gather all known resource groups from config stores.""" + rgs: set[str] = set() + + bot_rg = cfg.env.read("RESOURCE_GROUP") or "" + if bot_rg: + rgs.add(bot_rg) + + if sandbox_store: + sb = sandbox_store.config + if sb.resource_group: + rgs.add(sb.resource_group) + + if foundry_iq_store: + fiq = foundry_iq_store.config + if fiq.resource_group: + rgs.add(fiq.resource_group) + + deploy_rg = cfg.env.read("DEPLOY_RESOURCE_GROUP") or "" + if deploy_rg: + rgs.add(deploy_rg) + + voice_rg = cfg.env.read("VOICE_RESOURCE_GROUP") or "" + if voice_rg: + rgs.add(voice_rg) + + return list(rgs) + + +def audit_resource( + az: AzureCLI, rg: str, name: str, rtype: str, +) -> dict[str, Any] | None: + """Return a network audit dict for a single Azure resource.""" + rtype_lower = rtype.lower() + for prefix, auditor in _RESOURCE_AUDITORS.items(): + if prefix in rtype_lower: + return auditor(az=az, rg=rg, name=name) + return None + + +# ------------------------------------------------------------------ +# Per-resource audit functions +# ------------------------------------------------------------------ + + +def _audit_storage( + az: AzureCLI, rg: str, name: str, +) -> dict[str, Any] | None: + info = az.json("storage", "account", "show", "--name", name, "--resource-group", rg) + if not isinstance(info, dict): + return None + props = info.get("properties") or info + net_rules = props.get("networkRuleSet") or props.get("networkAcls") or {} + default_action = net_rules.get("defaultAction") or "Allow" + ip_rules = net_rules.get("ipRules") or [] + vnet_rules = net_rules.get("virtualNetworkRules") or [] + allowed_ips = [r.get("value", r.get("ipAddressOrRange", "")) for r in ip_rules] + allowed_vnets = [r.get("id", "") for r in vnet_rules] + public_blob = props.get("allowBlobPublicAccess", True) + https_only = info.get("enableHttpsTrafficOnly", props.get("supportsHttpsTrafficOnly", True)) + min_tls = props.get("minimumTlsVersion", "TLS1_0") + private_eps = _get_private_endpoints(props) + + return { + "name": name, + "resource_group": rg, + "type": "Storage Account", + "icon": "storage", + "public_access": default_action == "Allow", + "default_action": default_action, + "allowed_ips": allowed_ips, + "allowed_vnets": allowed_vnets, + "private_endpoints": private_eps, + "https_only": https_only, + "min_tls_version": min_tls, + "extra": { + "public_blob_access": public_blob, + }, + } + + +def _audit_keyvault( + az: AzureCLI, rg: str, name: str, +) -> dict[str, Any] | None: + info = az.json("keyvault", "show", "--name", name, "--resource-group", rg) + if not isinstance(info, dict): + return None + props = info.get("properties") or info + net_acls = props.get("networkAcls") or {} + default_action = net_acls.get("defaultAction") or "Allow" + ip_rules = net_acls.get("ipRules") or [] + vnet_rules = net_acls.get("virtualNetworkRules") or [] + allowed_ips = [r.get("value", "") for r in ip_rules] + allowed_vnets = [r.get("id", "") for r in vnet_rules] + public_access = props.get("publicNetworkAccess", "Enabled") + private_eps = _get_private_endpoints(props) + rbac = props.get("enableRbacAuthorization", False) + soft_delete = props.get("enableSoftDelete", False) + purge_protect = props.get("enablePurgeProtection", False) + + return { + "name": name, + "resource_group": rg, + "type": "Key Vault", + "icon": "keyvault", + "public_access": public_access != "Disabled" and default_action == "Allow", + "default_action": default_action, + "allowed_ips": allowed_ips, + "allowed_vnets": allowed_vnets, + "private_endpoints": private_eps, + "extra": { + "public_network_access": public_access, + "rbac_authorization": rbac, + "soft_delete": soft_delete, + "purge_protection": purge_protect, + }, + } + + +def _audit_cognitive( + az: AzureCLI, rg: str, name: str, +) -> dict[str, Any] | None: + """Audit Azure OpenAI / Cognitive Services accounts.""" + info = az.json( + "cognitiveservices", "account", "show", + "--name", name, "--resource-group", rg, + ) + if not isinstance(info, dict): + return None + props = info.get("properties") or info + net_acls = props.get("networkAcls") or {} + default_action = net_acls.get("defaultAction") or "Allow" + ip_rules = net_acls.get("ipRules") or [] + vnet_rules = net_acls.get("virtualNetworkRules") or [] + allowed_ips = [r.get("value", "") for r in ip_rules] + allowed_vnets = [r.get("id", "") for r in vnet_rules] + public_access = props.get("publicNetworkAccess", "Enabled") + private_eps = _get_private_endpoints(props) + kind = info.get("kind", "CognitiveServices") + endpoint = ( + props.get("endpoint") + or (props.get("endpoints") or {}).get("OpenAI Language Model Instance API", "") + ) + + label = "Azure OpenAI" if kind.lower() == "openai" else f"Cognitive Services ({kind})" + + return { + "name": name, + "resource_group": rg, + "type": label, + "icon": "ai", + "public_access": public_access != "Disabled" and default_action == "Allow", + "default_action": default_action, + "allowed_ips": allowed_ips, + "allowed_vnets": allowed_vnets, + "private_endpoints": private_eps, + "extra": { + "public_network_access": public_access, + "kind": kind, + "endpoint": endpoint, + }, + } + + +def _audit_search( + az: AzureCLI, rg: str, name: str, +) -> dict[str, Any] | None: + """Audit Azure AI Search service.""" + info = az.json( + "search", "service", "show", + "--name", name, "--resource-group", rg, + ) + if not isinstance(info, dict): + return None + props = info.get("properties") or info + public_access = props.get("publicNetworkAccess", "enabled") + ip_rules = (props.get("networkRuleSet") or {}).get("ipRules") or [] + allowed_ips = [r.get("value", "") for r in ip_rules] + private_eps = _get_private_endpoints(props) + + return { + "name": name, + "resource_group": rg, + "type": "Azure AI Search", + "icon": "search", + "public_access": public_access.lower() != "disabled", + "default_action": "Allow" if public_access.lower() != "disabled" else "Deny", + "allowed_ips": allowed_ips, + "allowed_vnets": [], + "private_endpoints": private_eps, + "extra": { + "public_network_access": public_access, + "sku": info.get("sku", {}).get("name", ""), + }, + } + + +def _audit_acr( + az: AzureCLI, rg: str, name: str, +) -> dict[str, Any] | None: + info = az.json("acr", "show", "--name", name, "--resource-group", rg) + if not isinstance(info, dict): + return None + public_access = info.get("publicNetworkAccess", "Enabled") + net_rules = info.get("networkRuleSet") or {} + default_action = net_rules.get("defaultAction") or "Allow" + ip_rules = net_rules.get("ipRules") or [] + allowed_ips = [r.get("value", "") for r in ip_rules] + admin_enabled = info.get("adminUserEnabled", False) + + return { + "name": name, + "resource_group": rg, + "type": "Container Registry", + "icon": "acr", + "public_access": public_access == "Enabled", + "default_action": default_action, + "allowed_ips": allowed_ips, + "allowed_vnets": [], + "private_endpoints": [], + "extra": { + "admin_user_enabled": admin_enabled, + "sku": info.get("sku", {}).get("name", ""), + }, + } + + +def _audit_session_pool(rg: str, name: str, **_kw: Any) -> dict[str, Any]: + """Audit Azure Container Apps session pool.""" + return { + "name": name, + "resource_group": rg, + "type": "Session Pool", + "icon": "sandbox", + "public_access": True, + "default_action": "Allow", + "allowed_ips": [], + "allowed_vnets": [], + "private_endpoints": [], + "extra": {}, + } + + +def _audit_acs(rg: str, name: str, **_kw: Any) -> dict[str, Any]: + """Audit Azure Communication Services.""" + return { + "name": name, + "resource_group": rg, + "type": "Communication Services", + "icon": "communication", + "public_access": True, + "default_action": "Allow", + "allowed_ips": [], + "allowed_vnets": [], + "private_endpoints": [], + "extra": {}, + } + + +def _get_private_endpoints(props: dict[str, Any]) -> list[str]: + """Extract private endpoint names from a resource's properties.""" + pe_conns = props.get("privateEndpointConnections", []) + results: list[str] = [] + for pec in pe_conns: + pe = pec.get("privateEndpoint", {}) + pe_id = pe.get("id", "") + if pe_id: + results.append(pe_id.rsplit("/", 1)[-1]) + return results + + +# Populate the dispatch table now that all audit functions are defined. +_RESOURCE_AUDITORS.update({ + "microsoft.storage/storageaccounts": _audit_storage, + "microsoft.keyvault/vaults": _audit_keyvault, + "microsoft.cognitiveservices/accounts": _audit_cognitive, + "microsoft.search/searchservices": _audit_search, + "microsoft.containerregistry/registries": _audit_acr, + "microsoft.app/sessionpools": _audit_session_pool, + "microsoft.communication/communicationservices": _audit_acs, +}) diff --git a/app/runtime/server/routes/network_routes.py b/app/runtime/server/routes/network_routes.py index 8cea8df..6cfa3cf 100644 --- a/app/runtime/server/routes/network_routes.py +++ b/app/runtime/server/routes/network_routes.py @@ -10,9 +10,11 @@ from aiohttp import ClientSession, ClientTimeout, web from ...config.settings import cfg -from ...services.azure import AzureCLI +from ...services.cloud.azure import AzureCLI from ...state.foundry_iq_config import FoundryIQConfigStore from ...state.sandbox_config import SandboxConfigStore +from .network_audit import audit_resource, collect_resource_groups +from .network_topology import build_components, build_containers logger = logging.getLogger(__name__) @@ -164,10 +166,10 @@ async def _info(self, req: web.Request) -> web.Response: tunnel_info = await resolve_tunnel_info(self._tunnel, self._az) # Build component info (what network connections are configured) - components = self._build_components(deploy_mode, tunnel_info) + components = build_components(deploy_mode, self._tunnel, tunnel_info) # Build container topology for dual-container deployments - containers = self._build_containers(deploy_mode, server_mode, admin_port) + containers = build_containers(deploy_mode, server_mode, admin_port) return web.json_response({ "deploy_mode": deploy_mode, @@ -455,157 +457,6 @@ def _collect_endpoints(self, app: web.Application) -> list[dict[str, Any]]: results.sort(key=lambda e: (e["category"], e["path"], e["method"])) return results - def _build_containers( - self, - deploy_mode: str, - server_mode: str, - admin_port: int, - ) -> list[dict[str, Any]]: - """Build container topology for the network diagram. - - Only states facts that can be read from the current environment - or configuration. Identity and volume claims are intentionally - omitted -- those are verified by the probe endpoint. - """ - if deploy_mode == "docker": - runtime_port = int(os.getenv("RUNTIME_PORT", "8080")) - runtime_url = os.getenv("RUNTIME_URL", "http://runtime:8080") - # Parse actual port from RUNTIME_URL if set - if ":" in runtime_url.rsplit("/", 1)[-1]: - try: - runtime_port = int(runtime_url.rsplit(":", 1)[-1].rstrip("/")) - except ValueError: - pass - return [ - { - "role": "admin", - "label": "Admin Container", - "port": admin_port, - "host": "127.0.0.1", - "exposure": "localhost-only", - }, - { - "role": "runtime", - "label": "Agent Container", - "port": runtime_port, - "host": "runtime", - "exposure": "tunnel (Cloudflare)", - }, - ] - if deploy_mode == "aca": - aca_name = os.getenv("ACA_ENV_NAME", "polyclaw") - runtime_port = int(os.getenv("RUNTIME_PORT", "8080")) - return [ - { - "role": "admin", - "label": "Admin Container", - "port": admin_port, - "host": "internal", - "exposure": "internal-only", - }, - { - "role": "runtime", - "label": "Agent Container", - "port": runtime_port, - "host": aca_name, - "exposure": "ACA ingress", - }, - ] - # local / combined -- single process - return [ - { - "role": "combined", - "label": "Polyclaw Server", - "port": admin_port, - "host": "localhost", - "exposure": "localhost", - }, - ] - - def _build_components( - self, deploy_mode: str, tunnel_info: dict[str, Any] | None = None, - ) -> list[dict[str, Any]]: - """Build the list of network-connected components.""" - components: list[dict[str, Any]] = [] - - # Azure OpenAI / Foundry - aoai_endpoint = cfg.azure_openai_endpoint - if aoai_endpoint: - components.append({ - "name": "Azure OpenAI", - "type": "ai", - "endpoint": aoai_endpoint, - "deployment": cfg.azure_openai_realtime_deployment, - "status": "configured", - }) - - # GitHub Copilot (model backend) - if cfg.github_token: - components.append({ - "name": "GitHub Copilot", - "type": "ai", - "endpoint": "https://api.githubcopilot.com", - "model": cfg.copilot_model, - "status": "configured", - }) - - # ACS (Communication Services) - if cfg.acs_connection_string: - components.append({ - "name": "Azure Communication Services", - "type": "communication", - "status": "configured", - "source_number": cfg.acs_source_number or None, - }) - - # Cloudflare Tunnel -- use pre-resolved tunnel_info when available - if tunnel_info is not None: - components.append({ - "name": "Cloudflare Tunnel", - "type": "tunnel", - "status": "active" if tunnel_info["active"] else "inactive", - "url": tunnel_info["url"], - "restricted": tunnel_info["restricted"], - }) - else: - components.append({ - "name": "Cloudflare Tunnel", - "type": "tunnel", - "status": "active" if getattr(self._tunnel, "is_active", False) else "inactive", - "url": getattr(self._tunnel, "url", None), - "restricted": cfg.tunnel_restricted, - }) - - # Azure Bot Service - if cfg.bot_app_id: - components.append({ - "name": "Azure Bot Service", - "type": "bot", - "status": "configured", - "app_id": cfg.bot_app_id[:12] + "..." if cfg.bot_app_id else None, - }) - - # Foundry IQ / AI Search (check env for search endpoint) - search_endpoint = cfg.env.read("SEARCH_ENDPOINT") or "" - if search_endpoint: - components.append({ - "name": "Azure AI Search", - "type": "search", - "endpoint": search_endpoint, - "status": "configured", - }) - - # Storage / Data directory - components.append({ - "name": "Local Data Store", - "type": "storage", - "path": str(cfg.data_dir), - "status": "active", - "deploy_mode": deploy_mode, - }) - - return components - # ------------------------------------------------------------------ # Resource network audit # ------------------------------------------------------------------ @@ -619,7 +470,9 @@ async def _resource_audit(self, req: web.Request) -> web.Response: if not self._az: return web.json_response({"resources": [], "error": "Azure CLI not available"}) - resource_groups = self._collect_resource_groups() + resource_groups = collect_resource_groups( + cfg, self._sandbox_store, self._foundry_iq_store, + ) if not resource_groups: return web.json_response({"resources": []}) @@ -631,266 +484,8 @@ async def _resource_audit(self, req: web.Request) -> web.Response: for r in raw: rtype = (r.get("type") or "").lower() rname = r.get("name", "") - audit = self._audit_resource(rg, rname, rtype) + audit = audit_resource(self._az, rg, rname, rtype) if audit: resources.append(audit) return web.json_response({"resources": resources}) - - def _collect_resource_groups(self) -> list[str]: - """Gather all known resource groups from config stores.""" - rgs: set[str] = set() - - # Main bot / infra resource group - bot_rg = cfg.env.read("RESOURCE_GROUP") or "" - if bot_rg: - rgs.add(bot_rg) - - # Sandbox - if self._sandbox_store: - sb = self._sandbox_store.config - if sb.resource_group: - rgs.add(sb.resource_group) - - # Foundry IQ - if self._foundry_iq_store: - fiq = self._foundry_iq_store.config - if fiq.resource_group: - rgs.add(fiq.resource_group) - - # Deploy state resource group - deploy_rg = cfg.env.read("DEPLOY_RESOURCE_GROUP") or "" - if deploy_rg: - rgs.add(deploy_rg) - - # Voice resource group - voice_rg = cfg.env.read("VOICE_RESOURCE_GROUP") or "" - if voice_rg: - rgs.add(voice_rg) - - return list(rgs) - - def _audit_resource(self, rg: str, name: str, rtype: str) -> dict[str, Any] | None: - """Return a network audit dict for a single Azure resource.""" - if "microsoft.storage/storageaccounts" in rtype: - return self._audit_storage(rg, name) - if "microsoft.keyvault/vaults" in rtype: - return self._audit_keyvault(rg, name) - if "microsoft.cognitiveservices/accounts" in rtype: - return self._audit_cognitive(rg, name) - if "microsoft.search/searchservices" in rtype: - return self._audit_search(rg, name) - if "microsoft.containerregistry/registries" in rtype: - return self._audit_acr(rg, name) - if "microsoft.app/sessionpools" in rtype: - return self._audit_session_pool(rg, name) - if "microsoft.communication/communicationservices" in rtype: - return self._audit_acs(rg, name) - return None - - def _audit_storage(self, rg: str, name: str) -> dict[str, Any] | None: - info = self._az.json("storage", "account", "show", "--name", name, "--resource-group", rg) - if not isinstance(info, dict): - return None - props = info.get("properties") or info - net_rules = props.get("networkRuleSet") or props.get("networkAcls") or {} - default_action = (net_rules.get("defaultAction") or "Allow") - ip_rules = net_rules.get("ipRules") or [] - vnet_rules = net_rules.get("virtualNetworkRules") or [] - allowed_ips = [r.get("value", r.get("ipAddressOrRange", "")) for r in ip_rules] - allowed_vnets = [r.get("id", "") for r in vnet_rules] - public_blob = props.get("allowBlobPublicAccess", True) - https_only = info.get("enableHttpsTrafficOnly", props.get("supportsHttpsTrafficOnly", True)) - min_tls = props.get("minimumTlsVersion", "TLS1_0") - private_eps = self._get_private_endpoints(props) - - return { - "name": name, - "resource_group": rg, - "type": "Storage Account", - "icon": "storage", - "public_access": default_action == "Allow", - "default_action": default_action, - "allowed_ips": allowed_ips, - "allowed_vnets": allowed_vnets, - "private_endpoints": private_eps, - "https_only": https_only, - "min_tls_version": min_tls, - "extra": { - "public_blob_access": public_blob, - }, - } - - def _audit_keyvault(self, rg: str, name: str) -> dict[str, Any] | None: - info = self._az.json("keyvault", "show", "--name", name, "--resource-group", rg) - if not isinstance(info, dict): - return None - props = info.get("properties") or info - net_acls = props.get("networkAcls") or {} - default_action = (net_acls.get("defaultAction") or "Allow") - ip_rules = net_acls.get("ipRules") or [] - vnet_rules = net_acls.get("virtualNetworkRules") or [] - allowed_ips = [r.get("value", "") for r in ip_rules] - allowed_vnets = [r.get("id", "") for r in vnet_rules] - public_access = props.get("publicNetworkAccess", "Enabled") - private_eps = self._get_private_endpoints(props) - rbac = props.get("enableRbacAuthorization", False) - soft_delete = props.get("enableSoftDelete", False) - purge_protect = props.get("enablePurgeProtection", False) - - return { - "name": name, - "resource_group": rg, - "type": "Key Vault", - "icon": "keyvault", - "public_access": public_access != "Disabled" and default_action == "Allow", - "default_action": default_action, - "allowed_ips": allowed_ips, - "allowed_vnets": allowed_vnets, - "private_endpoints": private_eps, - "extra": { - "public_network_access": public_access, - "rbac_authorization": rbac, - "soft_delete": soft_delete, - "purge_protection": purge_protect, - }, - } - - def _audit_cognitive(self, rg: str, name: str) -> dict[str, Any] | None: - """Audit Azure OpenAI / Cognitive Services accounts.""" - info = self._az.json( - "cognitiveservices", "account", "show", - "--name", name, "--resource-group", rg, - ) - if not isinstance(info, dict): - return None - props = info.get("properties") or info - net_acls = props.get("networkAcls") or {} - default_action = (net_acls.get("defaultAction") or "Allow") - ip_rules = net_acls.get("ipRules") or [] - vnet_rules = net_acls.get("virtualNetworkRules") or [] - allowed_ips = [r.get("value", "") for r in ip_rules] - allowed_vnets = [r.get("id", "") for r in vnet_rules] - public_access = props.get("publicNetworkAccess", "Enabled") - private_eps = self._get_private_endpoints(props) - kind = info.get("kind", "CognitiveServices") - endpoint = props.get("endpoint") or (props.get("endpoints") or {}).get("OpenAI Language Model Instance API", "") - - label = "Azure OpenAI" if kind.lower() == "openai" else f"Cognitive Services ({kind})" - - return { - "name": name, - "resource_group": rg, - "type": label, - "icon": "ai", - "public_access": public_access != "Disabled" and default_action == "Allow", - "default_action": default_action, - "allowed_ips": allowed_ips, - "allowed_vnets": allowed_vnets, - "private_endpoints": private_eps, - "extra": { - "public_network_access": public_access, - "kind": kind, - "endpoint": endpoint, - }, - } - - def _audit_search(self, rg: str, name: str) -> dict[str, Any] | None: - """Audit Azure AI Search service.""" - info = self._az.json( - "search", "service", "show", - "--name", name, "--resource-group", rg, - ) - if not isinstance(info, dict): - return None - props = info.get("properties") or info - public_access = props.get("publicNetworkAccess", "enabled") - ip_rules = (props.get("networkRuleSet") or {}).get("ipRules") or [] - allowed_ips = [r.get("value", "") for r in ip_rules] - private_eps = self._get_private_endpoints(props) - - return { - "name": name, - "resource_group": rg, - "type": "Azure AI Search", - "icon": "search", - "public_access": public_access.lower() != "disabled", - "default_action": "Allow" if public_access.lower() != "disabled" else "Deny", - "allowed_ips": allowed_ips, - "allowed_vnets": [], - "private_endpoints": private_eps, - "extra": { - "public_network_access": public_access, - "sku": info.get("sku", {}).get("name", ""), - }, - } - - def _audit_acr(self, rg: str, name: str) -> dict[str, Any] | None: - info = self._az.json("acr", "show", "--name", name, "--resource-group", rg) - if not isinstance(info, dict): - return None - public_access = info.get("publicNetworkAccess", "Enabled") - net_rules = info.get("networkRuleSet") or {} - default_action = (net_rules.get("defaultAction") or "Allow") - ip_rules = net_rules.get("ipRules") or [] - allowed_ips = [r.get("value", "") for r in ip_rules] - admin_enabled = info.get("adminUserEnabled", False) - - return { - "name": name, - "resource_group": rg, - "type": "Container Registry", - "icon": "acr", - "public_access": public_access == "Enabled", - "default_action": default_action, - "allowed_ips": allowed_ips, - "allowed_vnets": [], - "private_endpoints": [], - "extra": { - "admin_user_enabled": admin_enabled, - "sku": info.get("sku", {}).get("name", ""), - }, - } - - def _audit_session_pool(self, rg: str, name: str) -> dict[str, Any] | None: - """Audit Azure Container Apps session pool.""" - return { - "name": name, - "resource_group": rg, - "type": "Session Pool", - "icon": "sandbox", - "public_access": True, - "default_action": "Allow", - "allowed_ips": [], - "allowed_vnets": [], - "private_endpoints": [], - "extra": {}, - } - - def _audit_acs(self, rg: str, name: str) -> dict[str, Any] | None: - """Audit Azure Communication Services.""" - return { - "name": name, - "resource_group": rg, - "type": "Communication Services", - "icon": "communication", - "public_access": True, - "default_action": "Allow", - "allowed_ips": [], - "allowed_vnets": [], - "private_endpoints": [], - "extra": {}, - } - - @staticmethod - def _get_private_endpoints(props: dict[str, Any]) -> list[str]: - """Extract private endpoint names from a resource's properties.""" - pe_conns = props.get("privateEndpointConnections", []) - results: list[str] = [] - for pec in pe_conns: - pe = pec.get("privateEndpoint", {}) - pe_id = pe.get("id", "") - if pe_id: - # Extract just the endpoint name from the full resource ID - results.append(pe_id.rsplit("/", 1)[-1]) - return results diff --git a/app/runtime/server/routes/network_topology.py b/app/runtime/server/routes/network_topology.py new file mode 100644 index 0000000..1c2157e --- /dev/null +++ b/app/runtime/server/routes/network_topology.py @@ -0,0 +1,162 @@ +"""Network topology builders for the network-info API.""" + +from __future__ import annotations + +import os +from typing import Any + +from ...config.settings import cfg + + +def build_containers( + deploy_mode: str, + server_mode: str, + admin_port: int, +) -> list[dict[str, Any]]: + """Build container topology for the network diagram. + + Only states facts that can be read from the current environment + or configuration. Identity and volume claims are intentionally + omitted -- those are verified by the probe endpoint. + """ + if deploy_mode == "docker": + runtime_port = int(os.getenv("RUNTIME_PORT", "8080")) + runtime_url = os.getenv("RUNTIME_URL", "http://runtime:8080") + # Parse actual port from RUNTIME_URL if set + if ":" in runtime_url.rsplit("/", 1)[-1]: + try: + runtime_port = int(runtime_url.rsplit(":", 1)[-1].rstrip("/")) + except ValueError: + pass + return [ + { + "role": "admin", + "label": "Admin Container", + "port": admin_port, + "host": "127.0.0.1", + "exposure": "localhost-only", + }, + { + "role": "runtime", + "label": "Agent Container", + "port": runtime_port, + "host": "runtime", + "exposure": "tunnel (Cloudflare)", + }, + ] + if deploy_mode == "aca": + aca_name = os.getenv("ACA_ENV_NAME", "polyclaw") + runtime_port = int(os.getenv("RUNTIME_PORT", "8080")) + return [ + { + "role": "admin", + "label": "Admin Container", + "port": admin_port, + "host": "internal", + "exposure": "internal-only", + }, + { + "role": "runtime", + "label": "Agent Container", + "port": runtime_port, + "host": aca_name, + "exposure": "ACA ingress", + }, + ] + # local / combined -- single process + return [ + { + "role": "combined", + "label": "Polyclaw Server", + "port": admin_port, + "host": "localhost", + "exposure": "localhost", + }, + ] + + +def build_components( + deploy_mode: str, + tunnel: object | None, + tunnel_info: dict[str, Any] | None = None, +) -> list[dict[str, Any]]: + """Build the list of network-connected components.""" + components: list[dict[str, Any]] = [] + + # Azure OpenAI / Foundry + aoai_endpoint = cfg.azure_openai_endpoint + if aoai_endpoint: + components.append({ + "name": "Azure OpenAI", + "type": "ai", + "endpoint": aoai_endpoint, + "deployment": cfg.azure_openai_realtime_deployment, + "status": "configured", + }) + + # GitHub Copilot (model backend) + if cfg.github_token: + components.append({ + "name": "GitHub Copilot", + "type": "ai", + "endpoint": "https://api.githubcopilot.com", + "model": cfg.copilot_model, + "status": "configured", + }) + + # ACS (Communication Services) + if cfg.acs_connection_string: + components.append({ + "name": "Azure Communication Services", + "type": "communication", + "status": "configured", + "source_number": cfg.acs_source_number or None, + }) + + # Cloudflare Tunnel -- use pre-resolved tunnel_info when available + if tunnel_info is not None: + components.append({ + "name": "Cloudflare Tunnel", + "type": "tunnel", + "status": "active" if tunnel_info["active"] else "inactive", + "url": tunnel_info["url"], + "restricted": tunnel_info["restricted"], + }) + else: + components.append({ + "name": "Cloudflare Tunnel", + "type": "tunnel", + "status": "active" if getattr(tunnel, "is_active", False) else "inactive", + "url": getattr(tunnel, "url", None), + "restricted": cfg.tunnel_restricted, + }) + + # Azure Bot Service + if cfg.bot_app_id: + components.append({ + "name": "Azure Bot Service", + "type": "bot", + "status": "configured", + "app_id": cfg.bot_app_id[:12] + "..." if cfg.bot_app_id else None, + }) + + # Foundry IQ / AI Search (check env for search endpoint) + search_endpoint = cfg.env.read("SEARCH_ENDPOINT") or "" + if search_endpoint: + components.append({ + "name": "Azure AI Search", + "type": "search", + "endpoint": search_endpoint, + "status": "configured", + }) + + # Storage / Data directory + components.append({ + "name": "Local Data Store", + "type": "storage", + "path": str(cfg.data_dir), + "status": "active", + "deploy_mode": deploy_mode, + }) + + return components diff --git a/app/runtime/server/routes/sandbox_routes.py b/app/runtime/server/routes/sandbox_routes.py index 3282005..65dc998 100644 --- a/app/runtime/server/routes/sandbox_routes.py +++ b/app/runtime/server/routes/sandbox_routes.py @@ -9,10 +9,11 @@ from aiohttp import web from ...sandbox import SandboxExecutor -from ...services.azure import AzureCLI +from ...services.cloud.azure import AzureCLI from ...state.deploy_state import DeployStateStore from ...state.sandbox_config import BLACKLIST, DEFAULT_WHITELIST, SandboxConfigStore from ...util.async_helpers import run_sync +from ._helpers import fail_response as _fail_response, no_az as _no_az logger = logging.getLogger(__name__) @@ -314,17 +315,3 @@ async def _create_pool( return endpoint, pool_id - -def _no_az() -> web.Response: - return web.json_response( - {"status": "error", "message": "Azure CLI not available"}, status=500 - ) - - -def _fail_response(steps: list[dict[str, Any]]) -> web.Response: - failed = [s for s in steps if s.get("status") == "failed"] - msg = failed[0].get("detail", "Unknown error") if failed else "Unknown error" - return web.json_response( - {"status": "error", "steps": steps, "message": f"Provisioning failed: {msg}"}, - status=500, - ) diff --git a/app/runtime/server/routes/security_preflight_routes.py b/app/runtime/server/routes/security_preflight_routes.py index 1b675cf..8c3c753 100644 --- a/app/runtime/server/routes/security_preflight_routes.py +++ b/app/runtime/server/routes/security_preflight_routes.py @@ -6,7 +6,7 @@ from aiohttp import web -from ...services.security_preflight import SecurityPreflightChecker +from ...services.security.security_preflight import SecurityPreflightChecker from ...util.async_helpers import run_sync logger = logging.getLogger(__name__) diff --git a/app/runtime/server/runtime_proxy.py b/app/runtime/server/runtime_proxy.py index 11ef3d2..518b412 100644 --- a/app/runtime/server/runtime_proxy.py +++ b/app/runtime/server/runtime_proxy.py @@ -65,8 +65,11 @@ async def _proxy_http( body=response_body, headers=resp_headers, ) + except (aiohttp.ClientConnectorError, aiohttp.ClientOSError, OSError): + logger.debug("[proxy.http] runtime unreachable: %s", target_url) + raise web.HTTPBadGateway(text="Runtime container unreachable") except Exception: - logger.warning("[proxy.http] runtime unreachable: %s", target_url, exc_info=True) + logger.warning("[proxy.http] runtime proxy error: %s", target_url, exc_info=True) raise web.HTTPBadGateway(text="Runtime container unreachable") diff --git a/app/runtime/server/setup/__init__.py b/app/runtime/server/setup/__init__.py new file mode 100644 index 0000000..f10a8b5 --- /dev/null +++ b/app/runtime/server/setup/__init__.py @@ -0,0 +1,19 @@ +"""Setup wizard -- Azure, deployment, voice, prerequisites, and preflight.""" + +from __future__ import annotations + +from ._routes import SetupRoutes +from .azure import AzureSetupRoutes +from .deploy import DeploymentRoutes +from .preflight import PreflightRoutes +from .prerequisites import PrerequisitesRoutes +from .voice import VoiceSetupRoutes + +__all__ = [ + "AzureSetupRoutes", + "DeploymentRoutes", + "PreflightRoutes", + "PrerequisitesRoutes", + "SetupRoutes", + "VoiceSetupRoutes", +] diff --git a/app/runtime/server/setup/_helpers.py b/app/runtime/server/setup/_helpers.py new file mode 100644 index 0000000..6a0fe99 --- /dev/null +++ b/app/runtime/server/setup/_helpers.py @@ -0,0 +1,15 @@ +"""Shared helpers for setup route handlers.""" + +from __future__ import annotations + +from aiohttp import web + + +def ok_response(message: str) -> web.Response: + """Return a standard success response.""" + return web.json_response({"status": "ok", "message": message}) + + +def error_response(message: str, status: int = 500) -> web.Response: + """Return a standard error response.""" + return web.json_response({"status": "error", "message": message}, status=status) diff --git a/app/runtime/server/setup.py b/app/runtime/server/setup/_routes.py similarity index 50% rename from app/runtime/server/setup.py rename to app/runtime/server/setup/_routes.py index 98c05ad..8565db3 100644 --- a/app/runtime/server/setup.py +++ b/app/runtime/server/setup/_routes.py @@ -9,20 +9,22 @@ import aiohttp as _aiohttp from aiohttp import web -from ..config.settings import SECRET_ENV_KEYS, ServerMode, cfg -from ..services.aca_deployer import AcaDeployer, AcaDeployRequest -from ..services.azure import AzureCLI -from ..services.deployer import BotDeployer -from ..services.github import GitHubAuth -from ..services.provisioner import Provisioner -from ..services.runtime_identity import RuntimeIdentityProvisioner -from ..state.deploy_state import DeployStateStore -from ..state.infra_config import InfraConfigStore -from ..util.async_helpers import run_sync -from .setup_preflight import PreflightRoutes -from .setup_prerequisites import PrerequisitesRoutes -from .setup_voice import VoiceSetupRoutes -from .smoke_test import SmokeTestRunner +from ...config.settings import SECRET_ENV_KEYS, ServerMode, cfg +from ...services.cloud.azure import AzureCLI +from ...services.cloud.github import GitHubAuth +from ...services.deployment.aca_deployer import AcaDeployer +from ...services.deployment.deployer import BotDeployer +from ...services.deployment.provisioner import Provisioner +from ...state.deploy_state import DeployStateStore +from ...state.infra_config import InfraConfigStore +from ...util.async_helpers import run_sync +from .azure import AzureSetupRoutes +from ._helpers import error_response as _error, ok_response as _ok +from .deploy import DeploymentRoutes +from .preflight import PreflightRoutes +from .prerequisites import PrerequisitesRoutes +from .voice import VoiceSetupRoutes +from ..smoke_test import SmokeTestRunner logger = logging.getLogger(__name__) @@ -51,20 +53,24 @@ def __init__( self._provisioner = provisioner self._deploy_store = deploy_store self._aca_deployer = aca_deployer + self._azure_routes = AzureSetupRoutes(az) self._voice_routes = VoiceSetupRoutes(az, infra_store) self._prerequisites_routes = PrerequisitesRoutes(az, infra_store, deploy_store) self._preflight_routes = PreflightRoutes(tunnel, infra_store, az=az) - self._runtime_identity = RuntimeIdentityProvisioner(az) + self._deployment_routes = DeploymentRoutes( + az=az, + provisioner=provisioner, + rebuild_adapter=rebuild_adapter, + restart_runtime=self._restart_runtime, + infra_store=infra_store, + deploy_store=deploy_store, + aca_deployer=aca_deployer, + ) def register(self, router: web.UrlDispatcher) -> None: r = router r.add_get("/api/setup/status", self.status) - r.add_post("/api/setup/azure/login", self.azure_login) - r.add_get("/api/setup/azure/check", self.azure_check) - r.add_post("/api/setup/azure/logout", self.azure_logout) - r.add_get("/api/setup/azure/subscriptions", self.list_subscriptions) - r.add_post("/api/setup/azure/subscription", self.set_subscription) - r.add_get("/api/setup/azure/resource-groups", self.list_resource_groups) + self._azure_routes.register(r) r.add_get("/api/setup/copilot/status", self.copilot_status) r.add_post("/api/setup/copilot/login", self.copilot_login) r.add_post("/api/setup/copilot/token", self.copilot_set_token) @@ -78,28 +84,17 @@ def register(self, router: web.UrlDispatcher) -> None: r.add_post("/api/setup/channels/telegram/config", self.save_telegram_config) r.add_post("/api/setup/channels/telegram/remove", self.remove_telegram_config) r.add_post("/api/setup/configuration/save", self.save_configuration) - r.add_get("/api/setup/infra/status", self.infra_status) - r.add_post("/api/setup/infra/deploy", self.infra_deploy) - r.add_post("/api/setup/infra/decommission", self.infra_decommission) self._prerequisites_routes.register(r) self._voice_routes.register(r) r.add_get("/api/setup/config", self.get_config) r.add_post("/api/setup/config", self.save_config) self._preflight_routes.register(r) - r.add_get("/api/setup/lockdown", self.lockdown_status) - r.add_post("/api/setup/lockdown", self.lockdown_toggle) - r.add_get("/api/setup/runtime-identity", self.runtime_identity_status) - r.add_post("/api/setup/runtime-identity/provision", self.runtime_identity_provision) - r.add_post("/api/setup/runtime-identity/revoke", self.runtime_identity_revoke) - r.add_get("/api/setup/aca/status", self.aca_status) - r.add_post("/api/setup/aca/deploy", self.aca_deploy) - r.add_post("/api/setup/aca/destroy", self.aca_destroy) - r.add_post("/api/setup/container/restart", self.container_restart) + self._deployment_routes.register(r) # -- Status -- async def status(self, _req: web.Request) -> web.Response: - from .tunnel_status import resolve_tunnel_info + from ..tunnel_status import resolve_tunnel_info account = self._az.account_info() copilot = self._gh.status() @@ -126,69 +121,10 @@ async def status(self, _req: web.Request) -> web.Response: "data_dir": str(cfg.data_dir), }) - # -- Azure -- - - async def azure_login(self, _req: web.Request) -> web.Response: - account = self._az.account_info() - if account: - return web.json_response({ - "status": "already_logged_in", - "user": account.get("user", {}).get("name"), - "subscription": account.get("name"), - }) - info = self._az.login_device_code() - return web.json_response({"status": "device_code_pending", **info}) - - async def azure_check(self, _req: web.Request) -> web.Response: - account = self._az.account_info() - if account: - return web.json_response({ - "status": "logged_in", - "user": account.get("user", {}).get("name"), - "subscription": account.get("name"), - }) - return web.json_response({"status": "pending"}) - - async def azure_logout(self, _req: web.Request) -> web.Response: - ok, msg = self._az.ok("logout") - self._az.invalidate_cache("account", "show") - return _ok(msg) if ok else _error(msg) - - async def list_subscriptions(self, _req: web.Request) -> web.Response: - subs = self._az.json("account", "list") or [] - return web.json_response([ - { - "id": s.get("id", ""), - "name": s.get("name", ""), - "is_default": s.get("isDefault", False), - "state": s.get("state", ""), - } - for s in (subs if isinstance(subs, list) else []) - ]) - - async def set_subscription(self, req: web.Request) -> web.Response: - body = await req.json() - sub_id = body.get("subscription_id", "").strip() - if not sub_id: - return _error("subscription_id is required", 400) - ok, msg = self._az.ok("account", "set", "--subscription", sub_id) - self._az.invalidate_cache("account", "show") - return _ok(f"Subscription set to {sub_id}") if ok else _error(f"Failed: {msg}") - - async def list_resource_groups(self, _req: web.Request) -> web.Response: - groups = self._az.json("group", "list") or [] - return web.json_response([ - {"name": g["name"], "location": g["location"]} - for g in (groups if isinstance(groups, list) else []) - ]) - # -- Copilot -- async def copilot_status(self, _req: web.Request) -> web.Response: info = self._gh.status() - # Auto-persist: if gh CLI is authenticated but no GITHUB_TOKEN in - # .env yet, extract the token and write it so the runtime container - # picks it up from the shared volume. if info.get("authenticated") and not cfg.github_token: token = self._gh.extract_token() if token: @@ -213,14 +149,7 @@ async def copilot_set_token(self, req: web.Request) -> web.Response: return _ok("GitHub token saved") async def _restart_runtime(self) -> None: - """Signal the runtime container to reload configuration. - - In two-container mode the admin container calls the runtime's - ``/api/internal/reload`` endpoint so it picks up new settings - from the shared volume without a full container restart. - - In combined mode this is a no-op (changes are already in-process). - """ + """Signal the runtime container to reload configuration.""" runtime_url = os.getenv("RUNTIME_URL", "") if not runtime_url or cfg.server_mode == ServerMode.combined: return @@ -295,11 +224,6 @@ async def toggle_tunnel_restriction(self, req: web.Request) -> web.Response: state = "enabled" if restricted else "disabled" logger.info("Tunnel restriction %s", state) - # Detect whether a container redeploy is needed for the change to - # take effect (ACA / Docker deployments where the runtime container - # reads env vars at startup). - import os - deploy_mode = "local" if os.getenv("POLYCLAW_USE_MI"): deploy_mode = "aca" @@ -426,37 +350,6 @@ async def save_configuration(self, req: web.Request) -> web.Response: "message": "Configuration saved securely", }) - # -- Infrastructure -- - - async def infra_status(self, _req: web.Request) -> web.Response: - result = await run_sync(self._provisioner.status) - return web.json_response(result) - - async def infra_deploy(self, _req: web.Request) -> web.Response: - decomm_steps = await run_sync(self._provisioner.decommission) - prov_steps = await run_sync(self._provisioner.provision) - self._rebuild() - - all_steps = decomm_steps + prov_steps - prov_failed = any(s.get("status") == "failed" for s in prov_steps) - if not prov_failed: - await self._restart_runtime() - return web.json_response({ - "status": "error" if prov_failed else "ok", - "message": "Deploy completed with errors" if prov_failed else "Deployed", - "steps": all_steps, - }, status=500 if prov_failed else 200) - - async def infra_decommission(self, _req: web.Request) -> web.Response: - steps = await run_sync(self._provisioner.decommission) - self._rebuild() - failed = any(s.get("status") == "failed" for s in steps) - return web.json_response({ - "status": "error" if failed else "ok", - "message": "Errors during decommission" if failed else "Decommissioned", - "steps": steps, - }, status=500 if failed else 200) - # -- Runtime config -- async def get_config(self, _req: web.Request) -> web.Response: @@ -483,157 +376,3 @@ async def save_config(self, req: web.Request) -> web.Response: return _error(f"Disallowed config keys: {', '.join(sorted(invalid))}", 400) cfg.write_env(**body) return _ok("Config saved") - - # -- Lock Down Mode -- - - async def lockdown_status(self, _req: web.Request) -> web.Response: - return web.json_response({ - "lockdown_mode": cfg.lockdown_mode, - "tunnel_restricted": cfg.tunnel_restricted, - }) - - async def lockdown_toggle(self, req: web.Request) -> web.Response: - body = await req.json() - enabled = bool(body.get("enabled", False)) - - if enabled: - if cfg.lockdown_mode: - return _ok("Already enabled") - cfg.write_env(LOCKDOWN_MODE="1", TUNNEL_RESTRICTED="1") - try: - self._az.ok("logout") - self._az.invalidate_cache("account", "show") - except Exception: - pass - return web.json_response({ - "status": "ok", "lockdown_mode": True, - "message": "Lock Down Mode enabled.", - }) - else: - if not cfg.lockdown_mode: - return _ok("Already disabled") - cfg.write_env(LOCKDOWN_MODE="", TUNNEL_RESTRICTED="") - return web.json_response({ - "status": "ok", "lockdown_mode": False, - "message": "Lock Down Mode disabled.", - }) - - # -- Runtime Identity -- - - async def runtime_identity_status(self, _req: web.Request) -> web.Response: - return web.json_response(self._runtime_identity.status()) - - async def runtime_identity_provision(self, req: web.Request) -> web.Response: - body = await req.json() - rg = body.get("resource_group") or cfg.env.read("BOT_RESOURCE_GROUP") - if not rg: - return _error("resource_group is required (or set BOT_RESOURCE_GROUP)", 400) - result = await run_sync(self._runtime_identity.provision, rg) - if result.get("ok"): - await self._restart_runtime() - status_code = 200 if result.get("ok") else 500 - return web.json_response(result, status=status_code) - - async def runtime_identity_revoke(self, _req: web.Request) -> web.Response: - result = await run_sync(self._runtime_identity.revoke) - return web.json_response(result) - - # -- ACA Deployment -- - - async def aca_status(self, _req: web.Request) -> web.Response: - if not self._aca_deployer: - return _error("ACA deployer not available", 500) - return web.json_response(self._aca_deployer.status()) - - async def aca_deploy(self, req: web.Request) -> web.Response: - if not self._aca_deployer: - return _error("ACA deployer not available", 500) - body = await req.json() - aca_req = AcaDeployRequest( - resource_group=body.get("resource_group", self._store.bot.resource_group), - location=body.get("location", self._store.bot.location), - bot_display_name=body.get("display_name", self._store.bot.display_name), - bot_handle=body.get("bot_handle", self._store.bot.bot_handle), - admin_port=int(body.get("admin_port", 9090)), - runtime_port=int(body.get("runtime_port", 8080)), - image_tag=body.get("image_tag", "latest"), - acr_name=body.get("acr_name", ""), - env_name=body.get("env_name", ""), - ) - result = await run_sync(self._aca_deployer.deploy, aca_req) - status_code = 200 if result.ok else 500 - return web.json_response({ - "status": "ok" if result.ok else "error", - "message": "ACA deployment complete" if result.ok else result.error, - "steps": result.steps, - "runtime_fqdn": result.runtime_fqdn, - "deploy_id": result.deploy_id, - }, status=status_code) - - async def aca_destroy(self, req: web.Request) -> web.Response: - if not self._aca_deployer: - return _error("ACA deployer not available", 500) - body = await req.json() if req.can_read_body else {} - deploy_id = body.get("deploy_id") - result = await run_sync(self._aca_deployer.destroy, deploy_id) - return web.json_response({ - "status": "ok" if result.ok else "error", - "steps": result.steps, - }) - - async def container_restart(self, _req: web.Request) -> web.Response: - """Restart the agent container (Docker or ACA) to pick up config changes.""" - import subprocess - - deploy_mode = "local" - if os.getenv("POLYCLAW_USE_MI"): - deploy_mode = "aca" - elif os.getenv("POLYCLAW_CONTAINER") == "1": - deploy_mode = "docker" - - if deploy_mode == "aca": - if not self._aca_deployer: - return _error("ACA deployer not available", 500) - result = await run_sync(self._aca_deployer.restart) - status_code = 200 if result["ok"] else 500 - return web.json_response({ - "status": "ok" if result["ok"] else "error", - "message": "ACA containers restarted" if result["ok"] else "Some containers failed to restart", - "deploy_mode": "aca", - "results": result["results"], - }, status=status_code) - - if deploy_mode == "docker": - try: - proc = subprocess.run( - ["docker", "restart", "polyclaw-runtime"], - capture_output=True, text=True, timeout=60, - ) - ok = proc.returncode == 0 - return web.json_response({ - "status": "ok" if ok else "error", - "message": "Docker runtime container restarted" if ok else proc.stderr.strip(), - "deploy_mode": "docker", - }, status=200 if ok else 500) - except Exception as exc: - logger.warning( - "[setup.container_restart] docker restart failed: %s", - exc, exc_info=True, - ) - return _error(f"Docker restart failed: {exc}") - - # Local / combined mode -- reload config in-process - await self._restart_runtime() - return web.json_response({ - "status": "ok", - "message": "Configuration reloaded", - "deploy_mode": "local", - }) - - -def _ok(message: str) -> web.Response: - return web.json_response({"status": "ok", "message": message}) - - -def _error(message: str, status: int = 500) -> web.Response: - return web.json_response({"status": "error", "message": message}, status=status) diff --git a/app/runtime/server/setup/azure.py b/app/runtime/server/setup/azure.py new file mode 100644 index 0000000..12a1e6a --- /dev/null +++ b/app/runtime/server/setup/azure.py @@ -0,0 +1,81 @@ +"""Azure authentication and subscription routes -- /api/setup/azure/*.""" + +from __future__ import annotations + +import logging + +from aiohttp import web + +from ...services.cloud.azure import AzureCLI +from ._helpers import error_response as _error, ok_response as _ok + +logger = logging.getLogger(__name__) + + +class AzureSetupRoutes: + """Handles Azure CLI login, logout, subscription listing.""" + + def __init__(self, az: AzureCLI) -> None: + self._az = az + + def register(self, router: web.UrlDispatcher) -> None: + router.add_post("/api/setup/azure/login", self.azure_login) + router.add_get("/api/setup/azure/check", self.azure_check) + router.add_post("/api/setup/azure/logout", self.azure_logout) + router.add_get("/api/setup/azure/subscriptions", self.list_subscriptions) + router.add_post("/api/setup/azure/subscription", self.set_subscription) + router.add_get("/api/setup/azure/resource-groups", self.list_resource_groups) + + async def azure_login(self, _req: web.Request) -> web.Response: + account = self._az.account_info() + if account: + return web.json_response({ + "status": "already_logged_in", + "user": account.get("user", {}).get("name"), + "subscription": account.get("name"), + }) + info = self._az.login_device_code() + return web.json_response({"status": "device_code_pending", **info}) + + async def azure_check(self, _req: web.Request) -> web.Response: + account = self._az.account_info() + if account: + return web.json_response({ + "status": "logged_in", + "user": account.get("user", {}).get("name"), + "subscription": account.get("name"), + }) + return web.json_response({"status": "pending"}) + + async def azure_logout(self, _req: web.Request) -> web.Response: + ok, msg = self._az.ok("logout") + self._az.invalidate_cache("account", "show") + return _ok(msg) if ok else _error(msg) + + async def list_subscriptions(self, _req: web.Request) -> web.Response: + subs = self._az.json("account", "list") or [] + return web.json_response([ + { + "id": s.get("id", ""), + "name": s.get("name", ""), + "is_default": s.get("isDefault", False), + "state": s.get("state", ""), + } + for s in (subs if isinstance(subs, list) else []) + ]) + + async def set_subscription(self, req: web.Request) -> web.Response: + body = await req.json() + sub_id = body.get("subscription_id", "").strip() + if not sub_id: + return _error("subscription_id is required", 400) + ok, msg = self._az.ok("account", "set", "--subscription", sub_id) + self._az.invalidate_cache("account", "show") + return _ok(f"Subscription set to {sub_id}") if ok else _error(f"Failed: {msg}") + + async def list_resource_groups(self, _req: web.Request) -> web.Response: + groups = self._az.json("group", "list") or [] + return web.json_response([ + {"name": g["name"], "location": g["location"]} + for g in (groups if isinstance(groups, list) else []) + ]) diff --git a/app/runtime/server/setup/deploy.py b/app/runtime/server/setup/deploy.py new file mode 100644 index 0000000..9b7ccb5 --- /dev/null +++ b/app/runtime/server/setup/deploy.py @@ -0,0 +1,241 @@ +"""Deployment and infrastructure routes -- /api/setup/infra/*, /api/setup/aca/*.""" + +from __future__ import annotations + +import logging +import os +import subprocess +from collections.abc import Callable, Coroutine +from typing import Any + +from aiohttp import web + +from ...config.settings import cfg +from ...services.cloud.azure import AzureCLI +from ...services.cloud.runtime_identity import RuntimeIdentityProvisioner +from ...services.deployment.aca_deployer import AcaDeployer, AcaDeployRequest +from ...services.deployment.provisioner import Provisioner +from ...state.deploy_state import DeployStateStore +from ...state.infra_config import InfraConfigStore +from ...util.async_helpers import run_sync +from ._helpers import error_response as _error, ok_response as _ok + +logger = logging.getLogger(__name__) + + +class DeploymentRoutes: + """Handles infrastructure provisioning, ACA deployment, runtime identity, lockdown.""" + + def __init__( + self, + az: AzureCLI, + provisioner: Provisioner, + rebuild_adapter: Callable, + restart_runtime: Callable[[], Coroutine[Any, Any, None]], + infra_store: InfraConfigStore, + deploy_store: DeployStateStore | None = None, + aca_deployer: AcaDeployer | None = None, + ) -> None: + self._az = az + self._provisioner = provisioner + self._rebuild = rebuild_adapter + self._restart_runtime = restart_runtime + self._store = infra_store + self._deploy_store = deploy_store + self._aca_deployer = aca_deployer + self._runtime_identity = RuntimeIdentityProvisioner(az) + + def register(self, router: web.UrlDispatcher) -> None: + router.add_get("/api/setup/infra/status", self.infra_status) + router.add_post("/api/setup/infra/deploy", self.infra_deploy) + router.add_post("/api/setup/infra/decommission", self.infra_decommission) + router.add_get("/api/setup/lockdown", self.lockdown_status) + router.add_post("/api/setup/lockdown", self.lockdown_toggle) + router.add_get("/api/setup/runtime-identity", self.runtime_identity_status) + router.add_post("/api/setup/runtime-identity/provision", self.runtime_identity_provision) + router.add_post("/api/setup/runtime-identity/revoke", self.runtime_identity_revoke) + router.add_get("/api/setup/aca/status", self.aca_status) + router.add_post("/api/setup/aca/deploy", self.aca_deploy) + router.add_post("/api/setup/aca/destroy", self.aca_destroy) + router.add_post("/api/setup/container/restart", self.container_restart) + + # -- Infrastructure -- + + async def infra_status(self, _req: web.Request) -> web.Response: + result = await run_sync(self._provisioner.status) + return web.json_response(result) + + async def infra_deploy(self, _req: web.Request) -> web.Response: + decomm_steps = await run_sync(self._provisioner.decommission) + prov_steps = await run_sync(self._provisioner.provision) + self._rebuild() + + all_steps = decomm_steps + prov_steps + prov_failed = any(s.get("status") == "failed" for s in prov_steps) + if not prov_failed: + await self._restart_runtime() + return web.json_response({ + "status": "error" if prov_failed else "ok", + "message": "Deploy completed with errors" if prov_failed else "Deployed", + "steps": all_steps, + }, status=500 if prov_failed else 200) + + async def infra_decommission(self, _req: web.Request) -> web.Response: + steps = await run_sync(self._provisioner.decommission) + self._rebuild() + failed = any(s.get("status") == "failed" for s in steps) + return web.json_response({ + "status": "error" if failed else "ok", + "message": "Errors during decommission" if failed else "Decommissioned", + "steps": steps, + }, status=500 if failed else 200) + + # -- Lock Down Mode -- + + async def lockdown_status(self, _req: web.Request) -> web.Response: + return web.json_response({ + "lockdown_mode": cfg.lockdown_mode, + "tunnel_restricted": cfg.tunnel_restricted, + }) + + async def lockdown_toggle(self, req: web.Request) -> web.Response: + body = await req.json() + enabled = bool(body.get("enabled", False)) + + if enabled: + if cfg.lockdown_mode: + return _ok("Already enabled") + cfg.write_env(LOCKDOWN_MODE="1", TUNNEL_RESTRICTED="1") + try: + self._az.ok("logout") + self._az.invalidate_cache("account", "show") + except Exception: + pass + return web.json_response({ + "status": "ok", "lockdown_mode": True, + "message": "Lock Down Mode enabled.", + }) + else: + if not cfg.lockdown_mode: + return _ok("Already disabled") + cfg.write_env(LOCKDOWN_MODE="", TUNNEL_RESTRICTED="") + return web.json_response({ + "status": "ok", "lockdown_mode": False, + "message": "Lock Down Mode disabled.", + }) + + # -- Runtime Identity -- + + async def runtime_identity_status(self, _req: web.Request) -> web.Response: + return web.json_response(self._runtime_identity.status()) + + async def runtime_identity_provision(self, req: web.Request) -> web.Response: + body = await req.json() + rg = body.get("resource_group") or cfg.env.read("BOT_RESOURCE_GROUP") + if not rg: + return _error("resource_group is required (or set BOT_RESOURCE_GROUP)", 400) + result = await run_sync(self._runtime_identity.provision, rg) + if result.get("ok"): + await self._restart_runtime() + status_code = 200 if result.get("ok") else 500 + return web.json_response(result, status=status_code) + + async def runtime_identity_revoke(self, _req: web.Request) -> web.Response: + result = await run_sync(self._runtime_identity.revoke) + return web.json_response(result) + + # -- ACA Deployment -- + + async def aca_status(self, _req: web.Request) -> web.Response: + if not self._aca_deployer: + return _error("ACA deployer not available", 500) + return web.json_response(self._aca_deployer.status()) + + async def aca_deploy(self, req: web.Request) -> web.Response: + if not self._aca_deployer: + return _error("ACA deployer not available", 500) + body = await req.json() + aca_req = AcaDeployRequest( + resource_group=body.get("resource_group", self._store.bot.resource_group), + location=body.get("location", self._store.bot.location), + bot_display_name=body.get("display_name", self._store.bot.display_name), + bot_handle=body.get("bot_handle", self._store.bot.bot_handle), + admin_port=int(body.get("admin_port", 9090)), + runtime_port=int(body.get("runtime_port", 8080)), + image_tag=body.get("image_tag", "latest"), + acr_name=body.get("acr_name", ""), + env_name=body.get("env_name", ""), + ) + result = await run_sync(self._aca_deployer.deploy, aca_req) + status_code = 200 if result.ok else 500 + return web.json_response({ + "status": "ok" if result.ok else "error", + "message": "ACA deployment complete" if result.ok else result.error, + "steps": result.steps, + "runtime_fqdn": result.runtime_fqdn, + "deploy_id": result.deploy_id, + }, status=status_code) + + async def aca_destroy(self, req: web.Request) -> web.Response: + if not self._aca_deployer: + return _error("ACA deployer not available", 500) + body = await req.json() if req.can_read_body else {} + deploy_id = body.get("deploy_id") + result = await run_sync(self._aca_deployer.destroy, deploy_id) + return web.json_response({ + "status": "ok" if result.ok else "error", + "steps": result.steps, + }) + + async def container_restart(self, _req: web.Request) -> web.Response: + """Restart the agent container (Docker or ACA) to pick up config changes.""" + deploy_mode = "local" + if os.getenv("POLYCLAW_USE_MI"): + deploy_mode = "aca" + elif os.getenv("POLYCLAW_CONTAINER") == "1": + deploy_mode = "docker" + + if deploy_mode == "aca": + if not self._aca_deployer: + return _error("ACA deployer not available", 500) + result = await run_sync(self._aca_deployer.restart) + status_code = 200 if result["ok"] else 500 + return web.json_response({ + "status": "ok" if result["ok"] else "error", + "message": ( + "ACA containers restarted" if result["ok"] + else "Some containers failed to restart" + ), + "deploy_mode": "aca", + "results": result["results"], + }, status=status_code) + + if deploy_mode == "docker": + try: + proc = subprocess.run( + ["docker", "restart", "polyclaw-runtime"], + capture_output=True, text=True, timeout=60, + ) + ok = proc.returncode == 0 + return web.json_response({ + "status": "ok" if ok else "error", + "message": ( + "Docker runtime container restarted" if ok + else proc.stderr.strip() + ), + "deploy_mode": "docker", + }, status=200 if ok else 500) + except Exception as exc: + logger.warning( + "[setup.container_restart] docker restart failed: %s", + exc, exc_info=True, + ) + return _error(f"Docker restart failed: {exc}") + + # Local / combined mode -- reload config in-process + await self._restart_runtime() + return web.json_response({ + "status": "ok", + "message": "Configuration reloaded", + "deploy_mode": "local", + }) diff --git a/app/runtime/server/setup_preflight.py b/app/runtime/server/setup/preflight.py similarity index 98% rename from app/runtime/server/setup_preflight.py rename to app/runtime/server/setup/preflight.py index d1489c5..fcd9830 100644 --- a/app/runtime/server/setup_preflight.py +++ b/app/runtime/server/setup/preflight.py @@ -8,9 +8,9 @@ import aiohttp from aiohttp import web -from ..config.settings import cfg -from ..services.azure import AzureCLI -from ..state.infra_config import InfraConfigStore +from ...config.settings import cfg +from ...services.cloud.azure import AzureCLI +from ...state.infra_config import InfraConfigStore logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ async def _preflight(self, req: web.Request) -> web.Response: jwt_ok, jwt_detail = await self._check_jwt_validation() checks.append({"check": "jwt_validation", "ok": jwt_ok, "detail": jwt_detail}) - from .tunnel_status import resolve_tunnel_info + from ..tunnel_status import resolve_tunnel_info tunnel_info = await resolve_tunnel_info(self._tunnel, self._az) tunnel_ok = tunnel_info["active"] @@ -252,7 +252,7 @@ async def _check_acs_callback_security( }) if voice_configured or voice_routes_active: - from .tunnel_status import resolve_tunnel_info + from ..tunnel_status import resolve_tunnel_info t_info = await resolve_tunnel_info(self._tunnel, self._az) tunnel_active = t_info["active"] diff --git a/app/runtime/server/setup_prerequisites.py b/app/runtime/server/setup/prerequisites.py similarity index 97% rename from app/runtime/server/setup_prerequisites.py rename to app/runtime/server/setup/prerequisites.py index 86b53e2..5b2a1fd 100644 --- a/app/runtime/server/setup_prerequisites.py +++ b/app/runtime/server/setup/prerequisites.py @@ -9,13 +9,13 @@ from aiohttp import web -from ..config.settings import SECRET_ENV_KEYS, cfg -from ..services.azure import AzureCLI -from ..services.keyvault import env_key_to_secret_name, is_kv_ref -from ..services.keyvault import kv as _kv -from ..state.deploy_state import DeployStateStore -from ..state.infra_config import InfraConfigStore -from ..util.async_helpers import run_sync +from ...config.settings import SECRET_ENV_KEYS, cfg +from ...services.cloud.azure import AzureCLI +from ...services.keyvault import env_key_to_secret_name, is_kv_ref +from ...services.keyvault import kv as _kv +from ...state.deploy_state import DeployStateStore +from ...state.infra_config import InfraConfigStore +from ...util.async_helpers import run_sync logger = logging.getLogger(__name__) diff --git a/app/runtime/server/setup/voice.py b/app/runtime/server/setup/voice.py new file mode 100644 index 0000000..8d8be63 --- /dev/null +++ b/app/runtime/server/setup/voice.py @@ -0,0 +1,470 @@ +"""Voice setup routes -- ``/api/setup/voice/*``.""" + +from __future__ import annotations + +import logging + +from aiohttp import web + +from ...config.settings import cfg +from ...services.cloud.azure import AzureCLI +from ...state.infra_config import InfraConfigStore +from ...util.async_helpers import run_sync +from .voice_provision import ( + create_acs, + create_aoai, + ensure_rbac, + ensure_rg, + persist_config, +) +from ._helpers import error_response as _error, ok_response as _ok + +logger = logging.getLogger(__name__) + + +class VoiceSetupRoutes: + """ACS + Azure OpenAI provisioning, phone config, and decommissioning.""" + + def __init__(self, az: AzureCLI, store: InfraConfigStore) -> None: + self._az = az + self._store = store + + def register(self, router: web.UrlDispatcher) -> None: + router.add_get("/api/setup/voice/config", self.get_config) + router.add_post("/api/setup/voice/deploy", self.deploy) + router.add_post("/api/setup/voice/connect", self.connect_existing) + router.add_post("/api/setup/voice/phone", self.save_phone) + router.add_post("/api/setup/voice/decommission", self.decommission) + router.add_get("/api/setup/voice/aoai/list", self.list_aoai) + router.add_get("/api/setup/voice/aoai/deployments", self.list_aoai_deployments) + router.add_post("/api/setup/voice/aoai/validate", self.validate_aoai) + router.add_get("/api/setup/voice/acs/list", self.list_acs) + router.add_get("/api/setup/voice/acs/phones", self.list_acs_phones) + + # ------------------------------------------------------------------ + # Config + # ------------------------------------------------------------------ + + async def get_config(self, _req: web.Request) -> web.Response: + vc = self._store.to_safe_dict().get("channels", {}).get("voice_call", {}) + if vc.get("acs_resource_name"): + rg = vc.get("voice_resource_group") or vc.get("resource_group") + if rg: + account = self._az.account_info() + sub_id = account.get("id", "") if account else "" + if sub_id: + vc["portal_phone_url"] = ( + f"https://portal.azure.com/#@/resource/subscriptions/{sub_id}" + f"/resourceGroups/{rg}" + f"/providers/Microsoft.Communication" + f"/CommunicationServices/{vc['acs_resource_name']}" + f"/phonenumbers" + ) + return web.json_response(vc) + + # ------------------------------------------------------------------ + # Deploy + # ------------------------------------------------------------------ + + async def deploy(self, req: web.Request) -> web.Response: + body = await req.json() + location = body.get("location", "swedencentral").strip() + voice_rg = body.get("voice_resource_group", "").strip() or "polyclaw-voice-rg" + logger.info("Voice deploy started: voice_rg=%s, location=%s", voice_rg, location) + + steps: list[dict] = [] + + if not await ensure_rg(self._az, voice_rg, location, steps): + return _voice_fail(steps) + + acs_name, conn_str = await create_acs(self._az, voice_rg, steps) + if not conn_str: + return _voice_fail(steps) + + aoai_name, aoai_endpoint, aoai_key, deployment_name = await create_aoai( + self._az, voice_rg, location, steps + ) + if not aoai_endpoint: + return _voice_fail(steps) + + if not aoai_key: + await ensure_rbac(self._az, aoai_name, voice_rg, steps) + + persist_config( + self._store, voice_rg, location, acs_name, conn_str, + aoai_name, aoai_endpoint, aoai_key, deployment_name, steps, + ) + logger.info("Voice deploy completed: acs=%s, aoai=%s", acs_name, aoai_name) + + reinit = req.app.get("_reinit_voice") + if reinit: + reinit() + + return web.json_response({ + "status": "ok", + "steps": steps, + "message": ( + "Voice infrastructure deployed." + " Now purchase a phone number in the Azure Portal." + ), + }) + + # ------------------------------------------------------------------ + # Phone + # ------------------------------------------------------------------ + + async def save_phone(self, req: web.Request) -> web.Response: + body = await req.json() + phone = body.get("phone_number", "").strip() + target = body.get("target_number", "").strip() + + updates: dict[str, str] = {} + env_updates: dict[str, str] = {} + + if phone: + if not phone.startswith("+"): + return _error("Source phone number must be in E.164 format (e.g. +14155551234)", 400) + updates["acs_source_number"] = phone + env_updates["ACS_SOURCE_NUMBER"] = phone + + if target: + if not target.startswith("+"): + return _error("Target phone number must be in E.164 format (e.g. +41781234567)", 400) + updates["voice_target_number"] = target + env_updates["VOICE_TARGET_NUMBER"] = target + + if not updates: + return _error("At least one phone number is required", 400) + + self._store.save_voice_call(**updates) + cfg.write_env(**env_updates) + + reinit = req.app.get("_reinit_voice") + if reinit: + reinit() + + return _ok("Phone number(s) saved") + + # ------------------------------------------------------------------ + # Decommission + # ------------------------------------------------------------------ + + async def decommission(self, req: web.Request) -> web.Response: + vc = self._store.channels.voice_call + voice_rg = vc.voice_resource_group or vc.resource_group + steps: list[dict] = [] + + if voice_rg: + rg_exists = await run_sync(self._az.json, "group", "show", "--name", voice_rg) + if rg_exists: + ok, msg = await run_sync( + self._az.ok, "group", "delete", "--name", voice_rg, "--yes", "--no-wait", + ) + steps.append({ + "step": "voice_rg_delete", + "status": "ok" if ok else "failed", + "name": voice_rg, + "detail": f"Deleting {voice_rg}" if ok else msg, + }) + else: + steps.append({"step": "voice_rg_delete", "status": "skip", "detail": "RG not found"}) + else: + rg = vc.resource_group + if vc.acs_resource_name and rg: + ok, _ = await run_sync( + self._az.ok, "communication", "delete", + "--name", vc.acs_resource_name, "--resource-group", rg, "--yes", + ) + steps.append({ + "step": "acs_resource", + "status": "ok" if ok else "failed", + "name": vc.acs_resource_name, + }) + + if vc.azure_openai_resource_name and rg: + ok, _ = await run_sync( + self._az.ok, "cognitiveservices", "account", "delete", + "--name", vc.azure_openai_resource_name, "--resource-group", rg, "--yes", + ) + steps.append({ + "step": "aoai_resource", + "status": "ok" if ok else "failed", + "name": vc.azure_openai_resource_name, + }) + + self._store.clear_voice_call() + cfg.write_env( + ACS_CONNECTION_STRING="", + ACS_SOURCE_NUMBER="", + VOICE_TARGET_NUMBER="", + AZURE_OPENAI_ENDPOINT="", + AZURE_OPENAI_API_KEY="", + AZURE_OPENAI_REALTIME_DEPLOYMENT="", + ACS_CALLBACK_TOKEN="", + ) + + return web.json_response({ + "status": "ok", + "steps": steps, + "message": "Voice infrastructure decommissioned", + }) + + # ------------------------------------------------------------------ + # Discovery: AOAI + # ------------------------------------------------------------------ + + async def list_aoai(self, _req: web.Request) -> web.Response: + resources = await run_sync( + self._az.json, "resource", "list", + "--resource-type", "Microsoft.CognitiveServices/accounts", + ) + if not isinstance(resources, list): + return web.json_response([]) + + return web.json_response([ + { + "name": r.get("name", ""), + "resource_group": r.get("resourceGroup", ""), + "location": r.get("location", ""), + } + for r in resources + if r.get("kind") == "OpenAI" + ]) + + async def list_aoai_deployments(self, req: web.Request) -> web.Response: + name = req.query.get("name", "").strip() + rg = req.query.get("resource_group", "").strip() + if not name or not rg: + return _error("name and resource_group are required", 400) + + deployments = await run_sync( + self._az.json, "cognitiveservices", "account", "deployment", "list", + "--name", name, "--resource-group", rg, + ) + if not isinstance(deployments, list): + return web.json_response([]) + + return web.json_response([ + { + "deployment_name": d.get("name", ""), + "model_name": d.get("properties", {}).get("model", {}).get("name", ""), + "model_version": d.get("properties", {}).get("model", {}).get("version", ""), + "model_format": d.get("properties", {}).get("model", {}).get("format", ""), + } + for d in deployments + ]) + + async def validate_aoai(self, req: web.Request) -> web.Response: + body = await req.json() + name = body.get("name", "").strip() + rg = body.get("resource_group", "").strip() + if not name or not rg: + return _error("name and resource_group are required", 400) + + deployments = await run_sync( + self._az.json, "cognitiveservices", "account", "deployment", "list", + "--name", name, "--resource-group", rg, + ) + if not isinstance(deployments, list): + return web.json_response({ + "valid": False, + "message": f"Cannot list deployments for {name}", + "deployments": [], + }) + + realtime_models = { + "gpt-4o-realtime-preview", + "gpt-realtime-mini", + "gpt-4o-mini-realtime-preview", + } + found = [] + for d in deployments: + model = d.get("properties", {}).get("model", {}) + model_name = model.get("name", "") + found.append({ + "deployment_name": d.get("name", ""), + "model_name": model_name, + "model_version": model.get("version", ""), + "is_realtime": model_name in realtime_models, + }) + + has_realtime = any(f["is_realtime"] for f in found) + return web.json_response({ + "valid": has_realtime, + "message": ( + "Realtime model deployment found" + if has_realtime + else "No realtime model deployment found. Deploy gpt-realtime-mini or gpt-4o-realtime-preview." + ), + "deployments": found, + }) + + # ------------------------------------------------------------------ + # Discovery: ACS + # ------------------------------------------------------------------ + + async def list_acs(self, _req: web.Request) -> web.Response: + resources = await run_sync(self._az.json, "communication", "list") + if not isinstance(resources, list): + return web.json_response([]) + + return web.json_response([ + { + "name": r.get("name", ""), + "resource_group": r.get("resourceGroup", ""), + "location": r.get("location", ""), + } + for r in resources + ]) + + async def list_acs_phones(self, req: web.Request) -> web.Response: + name = req.query.get("name", "").strip() + rg = req.query.get("resource_group", "").strip() + if not name or not rg: + return _error("name and resource_group are required", 400) + + keys = await run_sync( + self._az.json, "communication", "list-key", + "--name", name, "--resource-group", rg, + ) + conn_str = keys.get("primaryConnectionString", "") if isinstance(keys, dict) else "" + if not conn_str: + return web.json_response([]) + + phones = await run_sync( + self._az.json, "communication", "phonenumber", "list", + "--connection-string", conn_str, + ) + if not isinstance(phones, list): + return web.json_response([]) + + return web.json_response([ + {"phone_number": p.get("phoneNumber", "")} + for p in phones + if p.get("phoneNumber") + ]) + + # ------------------------------------------------------------------ + # Connect existing + # ------------------------------------------------------------------ + + async def connect_existing(self, req: web.Request) -> web.Response: + body = await req.json() + steps: list[dict] = [] + + aoai_name = body.get("aoai_name", "").strip() + aoai_rg = body.get("aoai_resource_group", "").strip() + aoai_deployment = body.get("aoai_deployment", "").strip() or "gpt-realtime-mini" + + if not aoai_name or not aoai_rg: + return _error("aoai_name and aoai_resource_group are required", 400) + + aoai_info = await run_sync( + self._az.json, "cognitiveservices", "account", "show", + "--name", aoai_name, "--resource-group", aoai_rg, + ) + if not isinstance(aoai_info, dict): + return _error(f"Azure OpenAI resource '{aoai_name}' not found in RG '{aoai_rg}'", 404) + + aoai_endpoint = aoai_info.get("properties", {}).get("endpoint", "") + steps.append({"step": "aoai_resource", "status": "ok", "name": f"{aoai_name} (existing)"}) + + deployments = await run_sync( + self._az.json, "cognitiveservices", "account", "deployment", "list", + "--name", aoai_name, "--resource-group", aoai_rg, + ) + dep_found = isinstance(deployments, list) and any( + d.get("name") == aoai_deployment for d in deployments + ) + if not dep_found: + steps.append({ + "step": "aoai_deployment", "status": "failed", + "name": aoai_deployment, + "detail": f"Deployment '{aoai_deployment}' not found on {aoai_name}", + }) + return _voice_fail(steps) + + steps.append({"step": "aoai_deployment", "status": "ok", "name": f"{aoai_deployment} (verified)"}) + + aoai_keys = await run_sync( + self._az.json, "cognitiveservices", "account", "keys", "list", + "--name", aoai_name, "--resource-group", aoai_rg, + ) + aoai_key = aoai_keys.get("key1", "") if isinstance(aoai_keys, dict) else "" + if aoai_key: + steps.append({"step": "aoai_keys", "status": "ok"}) + else: + logger.info("AOAI key retrieval skipped (disableLocalAuth likely true)") + steps.append({ + "step": "aoai_keys", "status": "ok", + "detail": "Key-based auth disabled; will use Entra ID (DefaultAzureCredential)", + }) + + acs_name = body.get("acs_name", "").strip() + acs_rg = body.get("acs_resource_group", "").strip() + conn_str = "" + voice_rg = aoai_rg + + if acs_name and acs_rg: + keys = await run_sync( + self._az.json, "communication", "list-key", + "--name", acs_name, "--resource-group", acs_rg, + ) + conn_str = keys.get("primaryConnectionString", "") if isinstance(keys, dict) else "" + if not conn_str: + steps.append({ + "step": "acs_resource", "status": "failed", + "name": acs_name, "detail": "Cannot retrieve connection string", + }) + return _voice_fail(steps) + steps.append({"step": "acs_resource", "status": "ok", "name": f"{acs_name} (existing)"}) + voice_rg = acs_rg + else: + voice_rg = aoai_rg + if not await ensure_rg(self._az, voice_rg, "Global", steps): + return _voice_fail(steps) + acs_name, conn_str = await create_acs(self._az, voice_rg, steps) + if not conn_str: + return _voice_fail(steps) + + location = aoai_info.get("location", "swedencentral") + + if not aoai_key: + await ensure_rbac(self._az, aoai_name, aoai_rg, steps) + + persist_config( + self._store, voice_rg, location, acs_name, conn_str, + aoai_name, aoai_endpoint, aoai_key, aoai_deployment, steps, + ) + + phone = body.get("phone_number", "").strip() + if phone: + self._store.save_voice_call(acs_source_number=phone) + cfg.write_env(ACS_SOURCE_NUMBER=phone) + steps.append({"step": "phone_number", "status": "ok", "name": phone}) + + target = body.get("target_number", "").strip() + if target: + self._store.save_voice_call(voice_target_number=target) + cfg.write_env(VOICE_TARGET_NUMBER=target) + steps.append({"step": "target_number", "status": "ok", "name": target}) + + logger.info("Voice connect completed: acs=%s, aoai=%s", acs_name, aoai_name) + + reinit = req.app.get("_reinit_voice") + if reinit: + reinit() + + return web.json_response({ + "status": "ok", + "steps": steps, + "message": "Connected to existing Azure resources.", + }) + + +def _voice_fail(steps: list[dict]) -> web.Response: + failed = [s for s in steps if s.get("status") == "failed"] + msg = failed[0].get("name", "Unknown step") if failed else "Unknown error" + return web.json_response( + {"status": "error", "steps": steps, "message": f"Voice deploy failed at: {msg}"}, + ) diff --git a/app/runtime/server/setup/voice_provision.py b/app/runtime/server/setup/voice_provision.py new file mode 100644 index 0000000..f5e1a35 --- /dev/null +++ b/app/runtime/server/setup/voice_provision.py @@ -0,0 +1,237 @@ +"""Voice infrastructure provisioning helpers. + +Standalone functions extracted from ``VoiceSetupRoutes`` for ACS + AOAI +resource creation, RBAC assignment, and configuration persistence. +""" + +from __future__ import annotations + +import functools +import logging +import secrets + +from ...config.settings import cfg +from ...services.cloud.azure import AzureCLI +from ...state.infra_config import InfraConfigStore +from ...util.async_helpers import run_sync + +logger = logging.getLogger(__name__) + + +async def ensure_rbac( + az: AzureCLI, aoai_name: str, rg: str, steps: list[dict], +) -> None: + """Assign *Cognitive Services OpenAI User* role to the current principal.""" + account = az.account_info() + if not account: + steps.append({ + "step": "rbac_assign", "status": "skip", + "detail": "Cannot determine current principal (az account show failed)", + }) + return + + principal_id = "" + principal_type = "User" + + user_info = await run_sync( + functools.partial(az.json, "ad", "signed-in-user", "show", quiet=True), + ) + if isinstance(user_info, dict) and user_info.get("id"): + principal_id = user_info["id"] + else: + sp_id = account.get("user", {}).get("name", "") + if sp_id: + sp_info = await run_sync( + functools.partial(az.json, "ad", "sp", "show", "--id", sp_id, quiet=True), + ) + if isinstance(sp_info, dict) and sp_info.get("id"): + principal_id = sp_info["id"] + principal_type = "ServicePrincipal" + + if not principal_id: + steps.append({ + "step": "rbac_assign", "status": "skip", + "detail": "Cannot determine principal ID for RBAC assignment", + }) + return + + aoai_info = await run_sync( + az.json, "cognitiveservices", "account", "show", + "--name", aoai_name, "--resource-group", rg, + ) + scope = aoai_info.get("id", "") if isinstance(aoai_info, dict) else "" + if not scope: + steps.append({ + "step": "rbac_assign", "status": "skip", + "detail": "Cannot resolve resource ID for %s" % aoai_name, + }) + return + + role = "5e0bd9bd-7b93-4f28-af87-19fc36ad61bd" + logger.info("Assigning Cognitive Services OpenAI User role: principal=%s", principal_id) + ok, msg = await run_sync( + az.ok, "role", "assignment", "create", + "--assignee-object-id", principal_id, + "--assignee-principal-type", principal_type, + "--role", role, "--scope", scope, + ) + if ok: + steps.append({"step": "rbac_assign", "status": "ok", + "detail": "Cognitive Services OpenAI User"}) + elif "already exists" in (msg or "").lower() or "conflict" in (msg or "").lower(): + steps.append({"step": "rbac_assign", "status": "ok", "detail": "Already assigned"}) + else: + steps.append({ + "step": "rbac_assign", "status": "warning", + "detail": "Role assignment failed (non-fatal): %s" % msg, + }) + logger.warning("RBAC role assignment failed (non-fatal): %s", msg) + + +async def ensure_rg( + az: AzureCLI, rg: str, location: str, steps: list[dict], +) -> bool: + """Ensure a resource group exists, creating it if necessary.""" + existing = await run_sync(az.json, "group", "show", "--name", rg) + if existing: + steps.append({"step": "resource_group", "status": "ok", + "name": "%s (existing)" % rg}) + return True + + result = await run_sync( + az.json, "group", "create", "--name", rg, "--location", location, + ) + steps.append({"step": "resource_group", + "status": "ok" if result else "failed", "name": rg}) + if not result: + logger.error("Voice deploy FAILED at resource group creation: %s", az.last_stderr) + return bool(result) + + +async def create_acs( + az: AzureCLI, rg: str, steps: list[dict], +) -> tuple[str, str]: + """Create an ACS resource and retrieve its connection string.""" + acs_name = "polyclaw-acs-%s" % secrets.token_hex(4) + acs = await run_sync( + az.json, "communication", "create", + "--name", acs_name, "--location", "Global", + "--data-location", "United States", "--resource-group", rg, + ) + steps.append({"step": "acs_resource", + "status": "ok" if acs else "failed", "name": acs_name}) + if not acs: + logger.error("Voice deploy FAILED at ACS creation: %s", az.last_stderr) + return "", "" + + keys = await run_sync( + az.json, "communication", "list-key", + "--name", acs_name, "--resource-group", rg, + ) + conn_str = keys.get("primaryConnectionString", "") if isinstance(keys, dict) else "" + steps.append({"step": "acs_keys", "status": "ok" if conn_str else "failed"}) + if not conn_str: + logger.error("Voice deploy FAILED retrieving ACS keys: %s", az.last_stderr) + return acs_name, "" + return acs_name, conn_str + + +async def create_aoai( + az: AzureCLI, rg: str, location: str, steps: list[dict], +) -> tuple[str, str, str, str]: + """Create an Azure OpenAI resource with a realtime model deployment. + + Returns ``(name, endpoint, key, deployment_name)``. + """ + aoai_name = "polyclaw-aoai-%s" % secrets.token_hex(4) + deployment_name = "gpt-realtime-mini" + + aoai = await run_sync( + az.json, "cognitiveservices", "account", "create", + "--name", aoai_name, "--resource-group", rg, + "--location", location, "--kind", "OpenAI", + "--sku", "S0", "--custom-domain", aoai_name, + ) + steps.append({"step": "aoai_resource", + "status": "ok" if aoai else "failed", "name": aoai_name}) + if not aoai: + logger.error("Voice deploy FAILED at AOAI creation: %s", az.last_stderr) + return "", "", "", "" + + dep = await run_sync( + az.json, "cognitiveservices", "account", "deployment", "create", + "--name", aoai_name, "--resource-group", rg, + "--deployment-name", deployment_name, + "--model-name", "gpt-realtime-mini", + "--model-version", "2025-10-06", + "--model-format", "OpenAI", + "--sku-capacity", "1", "--sku-name", "GlobalStandard", + ) + steps.append({"step": "aoai_deployment", + "status": "ok" if dep else "failed", "name": deployment_name}) + if not dep: + logger.error("Voice deploy FAILED at model deployment: %s", az.last_stderr) + return aoai_name, "", "", "" + + aoai_info = await run_sync( + az.json, "cognitiveservices", "account", "show", + "--name", aoai_name, "--resource-group", rg, + ) + aoai_endpoint = "" + if isinstance(aoai_info, dict): + aoai_endpoint = aoai_info.get("properties", {}).get("endpoint", "") + + aoai_keys = await run_sync( + az.json, "cognitiveservices", "account", "keys", "list", + "--name", aoai_name, "--resource-group", rg, + ) + aoai_key = aoai_keys.get("key1", "") if isinstance(aoai_keys, dict) else "" + + if not aoai_endpoint: + steps.append({"step": "aoai_keys", "status": "failed"}) + logger.error("Voice deploy FAILED retrieving AOAI endpoint") + return aoai_name, "", "", "" + + if aoai_key: + steps.append({"step": "aoai_keys", "status": "ok"}) + else: + steps.append({"step": "aoai_keys", "status": "ok", + "detail": "Using Entra ID auth"}) + + return aoai_name, aoai_endpoint, aoai_key, deployment_name + + +def persist_config( + store: InfraConfigStore, + voice_rg: str, + location: str, + acs_name: str, + conn_str: str, + aoai_name: str, + aoai_endpoint: str, + aoai_key: str, + deployment_name: str, + steps: list[dict], +) -> None: + """Write voice configuration to the infra config store and ``.env``.""" + store.save_voice_call( + acs_resource_name=acs_name, + acs_connection_string=conn_str, + azure_openai_resource_name=aoai_name, + azure_openai_endpoint=aoai_endpoint, + azure_openai_api_key=aoai_key, + azure_openai_realtime_deployment=deployment_name, + resource_group=voice_rg, + voice_resource_group=voice_rg, + location=location, + ) + callback_token = cfg.acs_callback_token + cfg.write_env( + ACS_CONNECTION_STRING=conn_str, + ACS_SOURCE_NUMBER="", + AZURE_OPENAI_ENDPOINT=aoai_endpoint, + AZURE_OPENAI_API_KEY=aoai_key, + AZURE_OPENAI_REALTIME_DEPLOYMENT=deployment_name, + ACS_CALLBACK_TOKEN=callback_token, + ) + steps.append({"step": "persist_config", "status": "ok"}) diff --git a/app/runtime/server/setup_voice.py b/app/runtime/server/setup_voice.py index ac832a7..f416a7a 100644 --- a/app/runtime/server/setup_voice.py +++ b/app/runtime/server/setup_voice.py @@ -2,16 +2,21 @@ from __future__ import annotations -import functools import logging -import secrets from aiohttp import web from ..config.settings import cfg -from ..services.azure import AzureCLI +from ..services.cloud.azure import AzureCLI from ..state.infra_config import InfraConfigStore from ..util.async_helpers import run_sync +from .voice_provision import ( + create_acs, + create_aoai, + ensure_rbac, + ensure_rg, + persist_config, +) logger = logging.getLogger(__name__) @@ -68,24 +73,24 @@ async def deploy(self, req: web.Request) -> web.Response: steps: list[dict] = [] - if not await self._ensure_rg(voice_rg, location, steps): + if not await ensure_rg(self._az, voice_rg, location, steps): return _voice_fail(steps) - acs_name, conn_str = await self._create_acs(voice_rg, steps) + acs_name, conn_str = await create_acs(self._az, voice_rg, steps) if not conn_str: return _voice_fail(steps) - aoai_name, aoai_endpoint, aoai_key, deployment_name = await self._create_aoai( - voice_rg, location, steps + aoai_name, aoai_endpoint, aoai_key, deployment_name = await create_aoai( + self._az, voice_rg, location, steps ) if not aoai_endpoint: return _voice_fail(steps) if not aoai_key: - await self._ensure_rbac(aoai_name, voice_rg, steps) + await ensure_rbac(self._az, aoai_name, voice_rg, steps) - self._persist_config( - voice_rg, location, acs_name, conn_str, + persist_config( + self._store, voice_rg, location, acs_name, conn_str, aoai_name, aoai_endpoint, aoai_key, deployment_name, steps, ) logger.info("Voice deploy completed: acs=%s, aoai=%s", acs_name, aoai_name) @@ -415,19 +420,19 @@ async def connect_existing(self, req: web.Request) -> web.Response: voice_rg = acs_rg else: voice_rg = aoai_rg - if not await self._ensure_rg(voice_rg, "Global", steps): + if not await ensure_rg(self._az, voice_rg, "Global", steps): return _voice_fail(steps) - acs_name, conn_str = await self._create_acs(voice_rg, steps) + acs_name, conn_str = await create_acs(self._az, voice_rg, steps) if not conn_str: return _voice_fail(steps) location = aoai_info.get("location", "swedencentral") if not aoai_key: - await self._ensure_rbac(aoai_name, aoai_rg, steps) + await ensure_rbac(self._az, aoai_name, aoai_rg, steps) - self._persist_config( - voice_rg, location, acs_name, conn_str, + persist_config( + self._store, voice_rg, location, acs_name, conn_str, aoai_name, aoai_endpoint, aoai_key, aoai_deployment, steps, ) @@ -455,205 +460,6 @@ async def connect_existing(self, req: web.Request) -> web.Response: "message": "Connected to existing Azure resources.", }) - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - async def _ensure_rbac( - self, aoai_name: str, rg: str, steps: list[dict], - ) -> None: - account = self._az.account_info() - if not account: - steps.append({ - "step": "rbac_assign", "status": "skip", - "detail": "Cannot determine current principal (az account show failed)", - }) - return - - principal_id = "" - principal_type = "User" - - user_info = await run_sync( - functools.partial(self._az.json, "ad", "signed-in-user", "show", quiet=True), - ) - if isinstance(user_info, dict) and user_info.get("id"): - principal_id = user_info["id"] - else: - sp_id = account.get("user", {}).get("name", "") - if sp_id: - sp_info = await run_sync( - functools.partial(self._az.json, "ad", "sp", "show", "--id", sp_id, quiet=True), - ) - if isinstance(sp_info, dict) and sp_info.get("id"): - principal_id = sp_info["id"] - principal_type = "ServicePrincipal" - - if not principal_id: - steps.append({ - "step": "rbac_assign", "status": "skip", - "detail": "Cannot determine principal ID for RBAC assignment", - }) - return - - aoai_info = await run_sync( - self._az.json, "cognitiveservices", "account", "show", - "--name", aoai_name, "--resource-group", rg, - ) - scope = aoai_info.get("id", "") if isinstance(aoai_info, dict) else "" - if not scope: - steps.append({ - "step": "rbac_assign", "status": "skip", - "detail": f"Cannot resolve resource ID for {aoai_name}", - }) - return - - role = "5e0bd9bd-7b93-4f28-af87-19fc36ad61bd" - logger.info("Assigning Cognitive Services OpenAI User role: principal=%s", principal_id) - ok, msg = await run_sync( - self._az.ok, "role", "assignment", "create", - "--assignee-object-id", principal_id, - "--assignee-principal-type", principal_type, - "--role", role, "--scope", scope, - ) - if ok: - steps.append({"step": "rbac_assign", "status": "ok", "detail": "Cognitive Services OpenAI User"}) - elif "already exists" in (msg or "").lower() or "conflict" in (msg or "").lower(): - steps.append({"step": "rbac_assign", "status": "ok", "detail": "Already assigned"}) - else: - steps.append({ - "step": "rbac_assign", "status": "warning", - "detail": f"Role assignment failed (non-fatal): {msg}", - }) - logger.warning("RBAC role assignment failed (non-fatal): %s", msg) - - async def _ensure_rg(self, rg: str, location: str, steps: list[dict]) -> bool: - existing = await run_sync(self._az.json, "group", "show", "--name", rg) - if existing: - steps.append({"step": "resource_group", "status": "ok", "name": f"{rg} (existing)"}) - return True - - result = await run_sync( - self._az.json, "group", "create", "--name", rg, "--location", location, - ) - steps.append({"step": "resource_group", "status": "ok" if result else "failed", "name": rg}) - if not result: - logger.error("Voice deploy FAILED at resource group creation: %s", self._az.last_stderr) - return bool(result) - - async def _create_acs(self, rg: str, steps: list[dict]) -> tuple[str, str]: - acs_name = f"polyclaw-acs-{secrets.token_hex(4)}" - acs = await run_sync( - self._az.json, "communication", "create", - "--name", acs_name, "--location", "Global", - "--data-location", "United States", "--resource-group", rg, - ) - steps.append({"step": "acs_resource", "status": "ok" if acs else "failed", "name": acs_name}) - if not acs: - logger.error("Voice deploy FAILED at ACS creation: %s", self._az.last_stderr) - return "", "" - - keys = await run_sync( - self._az.json, "communication", "list-key", - "--name", acs_name, "--resource-group", rg, - ) - conn_str = keys.get("primaryConnectionString", "") if isinstance(keys, dict) else "" - steps.append({"step": "acs_keys", "status": "ok" if conn_str else "failed"}) - if not conn_str: - logger.error("Voice deploy FAILED retrieving ACS keys: %s", self._az.last_stderr) - return acs_name, "" - return acs_name, conn_str - - async def _create_aoai( - self, rg: str, location: str, steps: list[dict], - ) -> tuple[str, str, str, str]: - aoai_name = f"polyclaw-aoai-{secrets.token_hex(4)}" - deployment_name = "gpt-realtime-mini" - - aoai = await run_sync( - self._az.json, "cognitiveservices", "account", "create", - "--name", aoai_name, "--resource-group", rg, - "--location", location, "--kind", "OpenAI", - "--sku", "S0", "--custom-domain", aoai_name, - ) - steps.append({"step": "aoai_resource", "status": "ok" if aoai else "failed", "name": aoai_name}) - if not aoai: - logger.error("Voice deploy FAILED at AOAI creation: %s", self._az.last_stderr) - return "", "", "", "" - - dep = await run_sync( - self._az.json, "cognitiveservices", "account", "deployment", "create", - "--name", aoai_name, "--resource-group", rg, - "--deployment-name", deployment_name, - "--model-name", "gpt-realtime-mini", - "--model-version", "2025-10-06", - "--model-format", "OpenAI", - "--sku-capacity", "1", "--sku-name", "GlobalStandard", - ) - steps.append({"step": "aoai_deployment", "status": "ok" if dep else "failed", "name": deployment_name}) - if not dep: - logger.error("Voice deploy FAILED at model deployment: %s", self._az.last_stderr) - return aoai_name, "", "", "" - - aoai_info = await run_sync( - self._az.json, "cognitiveservices", "account", "show", - "--name", aoai_name, "--resource-group", rg, - ) - aoai_endpoint = "" - if isinstance(aoai_info, dict): - aoai_endpoint = aoai_info.get("properties", {}).get("endpoint", "") - - aoai_keys = await run_sync( - self._az.json, "cognitiveservices", "account", "keys", "list", - "--name", aoai_name, "--resource-group", rg, - ) - aoai_key = aoai_keys.get("key1", "") if isinstance(aoai_keys, dict) else "" - - if not aoai_endpoint: - steps.append({"step": "aoai_keys", "status": "failed"}) - logger.error("Voice deploy FAILED retrieving AOAI endpoint") - return aoai_name, "", "", "" - - if aoai_key: - steps.append({"step": "aoai_keys", "status": "ok"}) - else: - steps.append({"step": "aoai_keys", "status": "ok", "detail": "Using Entra ID auth"}) - - return aoai_name, aoai_endpoint, aoai_key, deployment_name - - def _persist_config( - self, - voice_rg: str, - location: str, - acs_name: str, - conn_str: str, - aoai_name: str, - aoai_endpoint: str, - aoai_key: str, - deployment_name: str, - steps: list[dict], - ) -> None: - self._store.save_voice_call( - acs_resource_name=acs_name, - acs_connection_string=conn_str, - azure_openai_resource_name=aoai_name, - azure_openai_endpoint=aoai_endpoint, - azure_openai_api_key=aoai_key, - azure_openai_realtime_deployment=deployment_name, - resource_group=voice_rg, - voice_resource_group=voice_rg, - location=location, - ) - callback_token = cfg.acs_callback_token - cfg.write_env( - ACS_CONNECTION_STRING=conn_str, - ACS_SOURCE_NUMBER="", - AZURE_OPENAI_ENDPOINT=aoai_endpoint, - AZURE_OPENAI_API_KEY=aoai_key, - AZURE_OPENAI_REALTIME_DEPLOYMENT=deployment_name, - ACS_CALLBACK_TOKEN=callback_token, - ) - steps.append({"step": "persist_config", "status": "ok"}) - def _ok(message: str) -> web.Response: return web.json_response({"status": "ok", "message": message}) diff --git a/app/runtime/server/smoke_test.py b/app/runtime/server/smoke_test.py index d39bea1..92d8fc6 100644 --- a/app/runtime/server/smoke_test.py +++ b/app/runtime/server/smoke_test.py @@ -12,7 +12,7 @@ from ..agent.agent import Agent from ..config.settings import cfg -from ..services.github import GitHubAuth +from ..services.cloud.github import GitHubAuth logger = logging.getLogger(__name__) diff --git a/app/runtime/server/tunnel_status.py b/app/runtime/server/tunnel_status.py index 156370d..35571ac 100644 --- a/app/runtime/server/tunnel_status.py +++ b/app/runtime/server/tunnel_status.py @@ -4,12 +4,14 @@ import logging import time +from dataclasses import dataclass, field from typing import Any import aiohttp -from ..services.azure import AzureCLI +from ..services.cloud.azure import AzureCLI from ..util.async_helpers import run_sync +from ..util.singletons import register_singleton logger = logging.getLogger(__name__) @@ -20,11 +22,28 @@ # Cache the probe result for 15 s (health check is fast but frequent). _PROBE_CACHE_TTL = 15.0 -_cached_endpoint: str | None = None -_cached_endpoint_ts: float = 0.0 -_cached_probe: bool = False -_cached_probe_url: str | None = None -_cached_probe_ts: float = 0.0 + +@dataclass +class _TunnelCache: + """Mutable cache for tunnel endpoint and probe results.""" + + endpoint: str | None = None + endpoint_ts: float = 0.0 + probe: bool = False + probe_url: str | None = None + probe_ts: float = 0.0 + + +_cache = _TunnelCache() + + +def _reset_tunnel_cache() -> None: + """Reset tunnel cache to default values (for test isolation).""" + global _cache # noqa: PLW0603 + _cache = _TunnelCache() + + +register_singleton(_reset_tunnel_cache) async def resolve_tunnel_info( @@ -88,33 +107,29 @@ def _endpoint_to_tunnel_url(endpoint: str) -> str: async def _get_bot_endpoint_cached(az: AzureCLI) -> str | None: """Return the bot messaging endpoint, cached for ``_ENDPOINT_CACHE_TTL`` s.""" - global _cached_endpoint, _cached_endpoint_ts # noqa: PLW0603 - now = time.monotonic() - if _cached_endpoint is not None and (now - _cached_endpoint_ts) < _ENDPOINT_CACHE_TTL: - return _cached_endpoint + if _cache.endpoint is not None and (now - _cache.endpoint_ts) < _ENDPOINT_CACHE_TTL: + return _cache.endpoint endpoint = await run_sync(az.get_bot_endpoint) - _cached_endpoint = endpoint - _cached_endpoint_ts = now + _cache.endpoint = endpoint + _cache.endpoint_ts = now return endpoint async def _probe_tunnel_cached(url: str) -> bool: """Probe the tunnel with a short TTL cache to avoid hammering.""" - global _cached_probe, _cached_probe_url, _cached_probe_ts # noqa: PLW0603 - now = time.monotonic() if ( - _cached_probe_url == url - and (now - _cached_probe_ts) < _PROBE_CACHE_TTL + _cache.probe_url == url + and (now - _cache.probe_ts) < _PROBE_CACHE_TTL ): - return _cached_probe + return _cache.probe active = await _probe_tunnel(url) - _cached_probe = active - _cached_probe_url = url - _cached_probe_ts = now + _cache.probe = active + _cache.probe_url = url + _cache.probe_ts = now return active diff --git a/app/runtime/server/wiring.py b/app/runtime/server/wiring.py new file mode 100644 index 0000000..94d63d3 --- /dev/null +++ b/app/runtime/server/wiring.py @@ -0,0 +1,265 @@ +"""Service wiring -- initialises core components, stores, and external services.""" + +from __future__ import annotations + +import logging +from typing import Any + +from ..config.settings import ServerMode, cfg + +logger = logging.getLogger(__name__) + + +def create_adapter() -> object: + """Create a BotFrameworkAdapter with the current cfg credentials.""" + from botbuilder.core import BotFrameworkAdapter, BotFrameworkAdapterSettings, TurnContext + from botbuilder.schema import Activity, ActivityTypes + + settings = BotFrameworkAdapterSettings( + app_id=cfg.bot_app_id or None, + app_password=cfg.bot_app_password or None, + channel_auth_tenant=cfg.bot_app_tenant_id or None, + ) + adapter = BotFrameworkAdapter(settings) + + async def on_error(context: TurnContext, error: Exception) -> None: + logger.error("Bot turn error: %s", error, exc_info=True) + try: + activity = Activity(type=ActivityTypes.message, text="An error occurred.") + if (context.activity.channel_id or "").lower() == "telegram": + activity.text_format = "plain" + await context.send_activity(activity) + except Exception: + pass + + adapter.on_turn_error = on_error + return adapter + + +def _append_token(url: str, token: str) -> str: + sep = "&" if "?" in url else "?" + return f"{url}{sep}token={token}" + + +def create_voice_handler(agent: object, tunnel: object | None = None) -> object | None: + """Instantiate the ACS + Realtime voice handler, or ``None`` if not configured.""" + cfg.reload() + if not (cfg.acs_connection_string and cfg.acs_source_number and cfg.azure_openai_endpoint): + logger.info("Voice call not configured (ACS/AOAI settings missing)") + return None + + from azure.core.credentials import AzureKeyCredential as _AKC + + from ..realtime import AcsCaller, RealtimeMiddleTier, RealtimeRoutes + + def _resolve_acs_urls() -> tuple[str, str]: + token = cfg.acs_callback_token + cb_path = cfg.acs_callback_path + ws_path = cfg.acs_media_streaming_websocket_path + + logger.debug( + "_resolve_acs_urls: cb_path=%r, ws_path=%r, token=%s", + cb_path, ws_path, "set" if token else "empty", + ) + + # If both paths are already absolute URLs, use them directly + cb_is_absolute = cb_path.startswith("https://") + ws_is_absolute = ws_path.startswith("wss://") + if cb_is_absolute and ws_is_absolute: + resolved = _append_token(cb_path, token), _append_token(ws_path, token) + logger.info("ACS URLs (absolute): callback=%s, ws=%s", resolved[0], resolved[1]) + return resolved + + # Otherwise, resolve relative paths against the tunnel URL + tunnel_url = (getattr(tunnel, "url", None) or "").rstrip("/") + if tunnel_url: + cb = cb_path if cb_is_absolute else f"{tunnel_url}{cb_path or '/api/voice/acs-callback'}" + ws = ws_path if ws_is_absolute else ( + tunnel_url.replace("https://", "wss://").replace("http://", "ws://") + + (ws_path or "/api/voice/media-streaming") + ) + resolved = _append_token(cb, token), _append_token(ws, token) + logger.info("ACS URLs (tunnel): callback=%s, ws=%s", resolved[0], resolved[1]) + return resolved + logger.warning("ACS URLs fallback to localhost -- calls will fail") + return ( + cb_path or f"http://localhost:{cfg.admin_port}/api/voice/acs-callback", + ws_path or f"ws://localhost:{cfg.admin_port}/api/voice/media-streaming", + ) + + caller = AcsCaller( + source_number=cfg.acs_source_number, + acs_connection_string=cfg.acs_connection_string, + resolve_urls=_resolve_acs_urls, + resolve_source_number=lambda: cfg.acs_source_number, + ) + + realtime_credential: _AKC | object + if cfg.azure_openai_api_key: + realtime_credential = _AKC(cfg.azure_openai_api_key) + else: + from azure.identity import DefaultAzureCredential as _DAC + + realtime_credential = _DAC() + + rt_middleware = RealtimeMiddleTier( + endpoint=cfg.azure_openai_endpoint, + deployment=cfg.azure_openai_realtime_deployment, + credential=realtime_credential, + agent=agent, + ) + handler = RealtimeRoutes( + caller, + rt_middleware, + callback_token=cfg.acs_callback_token, + acs_resource_id=cfg.acs_resource_id, + ) + logger.info("Voice call (ACS + Realtime) enabled: source=%s", cfg.acs_source_number) + return handler + + +async def init_core(mode: ServerMode) -> dict[str, Any]: + """Initialise the agent, adapter, bot, and session store. + + Returns a dict of component references keyed by name. + """ + result: dict[str, Any] = { + "agent": None, + "adapter": None, + "conv_store": None, + "session_store": None, + "bot": None, + "bot_ep": None, + } + is_runtime = mode in (ServerMode.runtime, ServerMode.combined) + is_admin = mode in (ServerMode.admin, ServerMode.combined) + + if is_runtime: + from ..agent.agent import Agent + from ..messaging.bot import Bot + from ..messaging.proactive import ConversationReferenceStore + from ..state.session_store import SessionStore + from .bot_endpoint import BotEndpoint + + logger.info("[init_core] creating Agent ...") + agent = Agent() + logger.info("[init_core] starting Agent (Copilot CLI) ...") + await agent.start() + logger.info("[init_core] Agent started successfully") + + adapter = create_adapter() + conv_store = ConversationReferenceStore() + session_store = SessionStore() + + hitl = agent.hitl_interceptor + bot = Bot(agent, conv_store, hitl=hitl) + bot.session_store = session_store + bot.adapter = adapter + bot_ep = BotEndpoint(adapter, bot) + logger.info("[init_core] core initialization complete") + + result.update( + agent=agent, adapter=adapter, conv_store=conv_store, + session_store=session_store, bot=bot, bot_ep=bot_ep, + ) + + if is_admin and not is_runtime: + from ..state.session_store import SessionStore + + result["session_store"] = SessionStore() + logger.info("[init_core] admin-only initialization complete") + + return result + + +def init_services(mode: ServerMode) -> dict[str, Any]: + """Initialise state stores, cloud services, and background processors. + + Returns a dict of service/store references keyed by name. + """ + from ..state.deploy_state import DeployStateStore + from ..state.foundry_iq_config import FoundryIQConfigStore + from ..state.guardrails import GuardrailsConfigStore + from ..state.infra_config import InfraConfigStore + from ..state.mcp_config import McpConfigStore + from ..state.monitoring_config import MonitoringConfigStore + from ..state.sandbox_config import SandboxConfigStore + + is_admin = mode in (ServerMode.admin, ServerMode.combined) + is_runtime = mode in (ServerMode.runtime, ServerMode.combined) + + result: dict[str, Any] = { + "tunnel": None, + "deploy_store": DeployStateStore(), + "infra_store": InfraConfigStore(), + "mcp_store": McpConfigStore(), + "sandbox_store": SandboxConfigStore(), + "foundry_iq_store": FoundryIQConfigStore(), + "guardrails_store": GuardrailsConfigStore(), + "monitoring_store": MonitoringConfigStore(), + "az": None, + "gh": None, + "deployer": None, + "provisioner": None, + "aca_deployer": None, + "scheduler": None, + "proactive_store": None, + "sandbox_executor": None, + } + + if is_runtime: + from ..services.tunnel import CloudflareTunnel + + result["tunnel"] = CloudflareTunnel() + + # Admin-side services + if is_admin: + from ..services.cloud.azure import AzureCLI + from ..services.cloud.github import GitHubAuth + from ..services.deployment.aca_deployer import AcaDeployer + from ..services.deployment.deployer import BotDeployer + from ..services.deployment.provisioner import Provisioner + + az = AzureCLI() + deployer = BotDeployer(az, result["deploy_store"]) + result.update( + az=az, + gh=GitHubAuth(), + deployer=deployer, + provisioner=Provisioner( + az, deployer, + result["infra_store"], result["deploy_store"], + tunnel=result["tunnel"], + ), + aca_deployer=AcaDeployer(az, result["deploy_store"]), + ) + elif is_runtime: + from ..services.cloud.azure import AzureCLI + from ..services.deployment.deployer import BotDeployer + from ..services.deployment.provisioner import Provisioner + + az = AzureCLI() + deployer = BotDeployer(az, result["deploy_store"]) + result.update( + az=az, + deployer=deployer, + provisioner=Provisioner( + az, deployer, + result["infra_store"], result["deploy_store"], + tunnel=result["tunnel"], + ), + ) + + # Runtime-side services + if is_runtime: + from ..sandbox import SandboxExecutor + from ..scheduler import get_scheduler + from ..state.proactive import get_proactive_store + + result.update( + scheduler=get_scheduler(), + proactive_store=get_proactive_store(), + sandbox_executor=SandboxExecutor(result["sandbox_store"]), + ) + + return result diff --git a/app/runtime/services/__init__.py b/app/runtime/services/__init__.py index 6a959ae..6d54069 100644 --- a/app/runtime/services/__init__.py +++ b/app/runtime/services/__init__.py @@ -1,12 +1,38 @@ -"""External service integrations.""" - -__all__ = [ - "AzureCLI", - "BotDeployer", - "CloudflareTunnel", - "GitHubAuth", - "KeyVaultClient", - "MisconfigChecker", - "Provisioner", - "ResourceTracker", -] +"""External service integrations. + +Re-exports are lazy to avoid a circular import: the ``Settings`` singleton +imports ``services.keyvault`` during init, and the cloud / deployment / +security sub-packages import ``cfg`` at module level. Deferring those +imports via ``__getattr__`` means only ``keyvault`` is loaded while +``Settings()`` runs; the heavier sub-packages load on first access, by +which time ``cfg`` is assigned. +""" + +from __future__ import annotations + +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "AzureCLI": (".cloud", "AzureCLI"), + "GitHubAuth": (".cloud", "GitHubAuth"), + "AcaDeployer": (".deployment", "AcaDeployer"), + "BotDeployer": (".deployment", "BotDeployer"), + "Provisioner": (".deployment", "Provisioner"), + "MisconfigChecker": (".security", "MisconfigChecker"), + "PromptShieldService": (".security", "PromptShieldService"), + "SecurityPreflightChecker": (".security", "SecurityPreflightChecker"), + "CloudflareTunnel": (".tunnel", "CloudflareTunnel"), +} + + +def __getattr__(name: str) -> object: + if name in _LAZY_IMPORTS: + import importlib + + subpkg, attr = _LAZY_IMPORTS[name] + mod = importlib.import_module(subpkg, __name__) + val = getattr(mod, attr) + globals()[name] = val # cache for subsequent access + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = list(_LAZY_IMPORTS) diff --git a/app/runtime/services/aca_deployer.py b/app/runtime/services/aca_deployer.py deleted file mode 100644 index 6238775..0000000 --- a/app/runtime/services/aca_deployer.py +++ /dev/null @@ -1,843 +0,0 @@ -"""Azure Container Apps deployer.""" - -from __future__ import annotations - -import logging -import os -import secrets -import subprocess -import time -from dataclasses import dataclass, field -from typing import Any - -from ..config.settings import cfg -from ..state.deploy_state import DeploymentRecord, DeployStateStore -from ..state.sandbox_config import SandboxConfigStore -from .azure import AzureCLI - -logger = logging.getLogger(__name__) - -_IMAGE_NAME = "polyclaw" -_LOCAL_IMAGE = "polyclaw:latest" -_MI_NAME = "polyclaw-runtime-mi" -_ENV_NAME_PREFIX = "polyclaw-env" -_BOT_CONTRIBUTOR_ROLE = "Azure Bot Service Contributor Role" -_RG_READER_ROLE = "Reader" -_SESSION_EXECUTOR_ROLE = "Azure ContainerApps Session Executor" - - -@dataclass -class AcaDeployRequest: - - resource_group: str = "polyclaw-rg" - location: str = "eastus" - bot_display_name: str = "polyclaw" - bot_handle: str = "" - admin_port: int = 9090 - runtime_port: int = 8080 - image_tag: str = "latest" - acr_name: str = "" - env_name: str = "" - - -@dataclass -class AcaDeployResult: - - ok: bool = False - steps: list[dict[str, Any]] = field(default_factory=list) - error: str = "" - runtime_fqdn: str = "" - acr_name: str = "" - deploy_id: str = "" - - -class AcaDeployer: - - def __init__(self, az: AzureCLI, deploy_store: DeployStateStore | None = None) -> None: - self._az = az - self._deploy_store = deploy_store - - def deploy(self, req: AcaDeployRequest) -> AcaDeployResult: - steps: list[dict[str, Any]] = [] - result = AcaDeployResult(steps=steps) - - logger.info("[aca] Starting ACA deployment: rg=%s, location=%s", req.resource_group, req.location) - - rec = DeploymentRecord.new(kind="aca") - result.deploy_id = rec.deploy_id - if self._deploy_store: - self._deploy_store.register(rec) - - try: - self._cleanup_stale_resources(req, steps) - - if not self._ensure_resource_group(req, steps, rec): - result.error = "Resource group creation failed" - return result - - env_vars = self._load_env_vars(steps) - - acr_name = self._ensure_acr(req, steps, rec) - if not acr_name: - result.error = "Container registry creation failed" - return result - result.acr_name = acr_name - - if not self._push_image(acr_name, req.image_tag, steps): - result.error = "Image push failed" - return result - - acr_user, acr_pass = self._get_acr_credentials(acr_name) - if not acr_user: - result.error = "Could not retrieve ACR admin credentials" - return result - - mi_id, mi_client_id = self._ensure_managed_identity(req, steps, rec) - if not mi_id: - result.error = "Managed identity creation failed" - return result - - self._assign_rbac(mi_client_id, req.resource_group, steps) - - env_name, env_id = self._ensure_aca_environment(req, steps, rec) - if not env_name: - result.error = "Container Apps environment creation failed" - return result - - runtime_fqdn = self._ensure_runtime_app( - req, env_id, acr_name, mi_id, mi_client_id, - acr_user, acr_pass, env_vars, steps, rec, - ) - if not runtime_fqdn: - result.error = "Runtime container app creation failed" - return result - result.runtime_fqdn = runtime_fqdn - - ip_steps = self._configure_ip_whitelist(req, steps) - steps.extend(ip_steps) - - runtime_url = f"https://{runtime_fqdn}" - cfg.write_env( - ACA_RUNTIME_FQDN=runtime_fqdn, - ACA_ACR_NAME=acr_name, - ACA_ENV_NAME=env_name, - ACA_MI_RESOURCE_ID=mi_id, - ACA_MI_CLIENT_ID=mi_client_id, - RUNTIME_URL=runtime_url, - ) - os.environ["RUNTIME_URL"] = runtime_url - logger.info("[aca] RUNTIME_URL set to %s", runtime_url) - steps.append({"step": "write_aca_config", "status": "ok"}) - - result.ok = True - logger.info("[aca] Deployment complete: runtime=%s", runtime_fqdn) - - except Exception as exc: - logger.error("[aca] Deployment failed: %s", exc, exc_info=True) - result.error = str(exc) - steps.append({"step": "unexpected_error", "status": "failed", "detail": str(exc)}) - - if self._deploy_store and rec: - if result.ok: - rec.config = { - "runtime_fqdn": result.runtime_fqdn, - "acr_name": result.acr_name, - } - else: - rec.mark_stopped() - self._deploy_store.update(rec) - - return result - - def destroy(self, deploy_id: str | None = None) -> AcaDeployResult: - steps: list[dict[str, Any]] = [] - result = AcaDeployResult(steps=steps) - - rec = None - if deploy_id and self._deploy_store: - rec = self._deploy_store.get(deploy_id) - elif self._deploy_store: - rec = self._deploy_store.current_aca() - - rg = ( - cfg.env.read("BOT_RESOURCE_GROUP") - or (rec.resource_groups[0] if rec and rec.resource_groups else "") - ) - - if rg: - cleaned = self._delete_aca_resources(rg, steps, step_label="destroy") - if cleaned: - logger.info("[aca] Destroyed %d resource(s): %s", - len(cleaned), ", ".join(cleaned)) - else: - logger.info("[aca] No ACA resources found to destroy in %s", rg) - - cfg.write_env( - ACA_RUNTIME_FQDN="", - ACA_ACR_NAME="", ACA_ENV_NAME="", - ACA_MI_RESOURCE_ID="", - ACA_MI_CLIENT_ID="", - RUNTIME_URL="", - ) - steps.append({"step": "clear_aca_config", "status": "ok"}) - - if rec and self._deploy_store: - rec.mark_destroyed() - self._deploy_store.update(rec) - - result.ok = True - return result - - def status(self) -> dict[str, Any]: - runtime_fqdn = cfg.env.read("ACA_RUNTIME_FQDN") - return { - "deployed": bool(runtime_fqdn), - "runtime_fqdn": runtime_fqdn or None, - "acr_name": cfg.env.read("ACA_ACR_NAME") or None, - "env_name": cfg.env.read("ACA_ENV_NAME") or None, - "mi_client_id": cfg.env.read("ACA_MI_CLIENT_ID") or None, - } - - def restart(self) -> dict[str, Any]: - rg = cfg.env.read("BOT_RESOURCE_GROUP") or "polyclaw-rg" - app_name = "polyclaw-runtime" - - revisions = self._az.json( - "containerapp", "revision", "list", - "--name", app_name, - "--resource-group", rg, - quiet=True, - ) - if not revisions or not isinstance(revisions, list): - ok, msg = self._az.ok( - "containerapp", "update", - "--name", app_name, - "--resource-group", rg, - "--set-env-vars", f"RESTART_TS={int(time.time())}", - ) - result_detail = { - "app": app_name, - "status": "ok" if ok else "failed", - "method": "update", - "detail": msg if not ok else "forced new revision", - } - logger.info("[aca.restart] result=%r", result_detail) - return {"ok": ok, "results": [result_detail]} - - active = next( - (r["name"] for r in revisions if r.get("properties", {}).get("active")), - revisions[0].get("name") if revisions else None, - ) - if not active: - result_detail = { - "app": app_name, "status": "failed", - "method": "revision_restart", - "detail": "no active revision found", - } - logger.info("[aca.restart] result=%r", result_detail) - return {"ok": False, "results": [result_detail]} - - ok, msg = self._az.ok( - "containerapp", "revision", "restart", - "--name", app_name, - "--resource-group", rg, - "--revision", active, - ) - result_detail = { - "app": app_name, - "status": "ok" if ok else "failed", - "method": "revision_restart", - "detail": active if ok else msg, - } - logger.info("[aca.restart] result=%r", result_detail) - return {"ok": ok, "results": [result_detail]} - - def _delete_aca_resources( - self, rg: str, steps: list[dict], *, step_label: str = "cleanup", - ) -> list[str]: - rg_exists = self._az.json("group", "show", "--name", rg, quiet=True) - if not isinstance(rg_exists, dict): - logger.info("[aca] Resource group %s does not exist -- nothing to clean", rg) - return [] - - cleaned: list[str] = [] - - apps = self._az.json( - "containerapp", "list", - "--resource-group", rg, quiet=True, - ) - for app in (apps if isinstance(apps, list) else []): - name = app.get("name", "") - if not name: - continue - logger.info("[aca] Deleting container app: %s (waiting)", name) - ok, _ = self._az.ok( - "containerapp", "delete", "--name", name, - "--resource-group", rg, "--yes", - ) - if ok: - cleaned.append(f"containerapp/{name}") - steps.append({"step": f"{step_label}/containerapp/{name}", - "status": "ok" if ok else "failed"}) - - identities = self._az.json( - "identity", "list", - "--resource-group", rg, quiet=True, - ) - for mi in (identities if isinstance(identities, list) else []): - name = mi.get("name", "") - if not name: - continue - logger.info("[aca] Deleting managed identity: %s (waiting)", name) - ok, _ = self._az.ok( - "identity", "delete", "--name", name, - "--resource-group", rg, - ) - if ok: - cleaned.append(f"identity/{name}") - steps.append({"step": f"{step_label}/identity/{name}", - "status": "ok" if ok else "failed"}) - - envs = self._az.json( - "containerapp", "env", "list", - "--resource-group", rg, quiet=True, - ) - for env in (envs if isinstance(envs, list) else []): - name = env.get("name", "") - if not name: - continue - logger.info("[aca] Deleting ACA environment: %s (no-wait)", name) - ok, _ = self._az.ok( - "containerapp", "env", "delete", "--name", name, - "--resource-group", rg, "--yes", "--no-wait", - ) - if ok: - cleaned.append(f"aca-env/{name}") - steps.append({"step": f"{step_label}/aca-env/{name}", - "status": "ok" if ok else "failed"}) - - acrs = self._az.json( - "acr", "list", - "--resource-group", rg, quiet=True, - ) - for acr in (acrs if isinstance(acrs, list) else []): - name = acr.get("name", "") - if not name: - continue - logger.info("[aca] Deleting ACR: %s", name) - ok, _ = self._az.ok( - "acr", "delete", "--name", name, - "--resource-group", rg, "--yes", - ) - if ok: - cleaned.append(f"acr/{name}") - steps.append({"step": f"{step_label}/acr/{name}", - "status": "ok" if ok else "failed"}) - - workspaces = self._az.json( - "monitor", "log-analytics", "workspace", "list", - "--resource-group", rg, quiet=True, - ) - for ws in (workspaces if isinstance(workspaces, list) else []): - name = ws.get("name", "") - if not name: - continue - logger.info("[aca] Deleting Log Analytics workspace: %s", name) - ok, _ = self._az.ok( - "monitor", "log-analytics", "workspace", "delete", - "--workspace-name", name, - "--resource-group", rg, "--yes", "--force", - ) - if ok: - cleaned.append(f"log-analytics/{name}") - steps.append({"step": f"{step_label}/log-analytics/{name}", - "status": "ok" if ok else "failed"}) - - storage_accounts = self._az.json( - "storage", "account", "list", - "--resource-group", rg, quiet=True, - ) - for sa in (storage_accounts if isinstance(storage_accounts, list) else []): - name = sa.get("name", "") - if not name: - continue - tags = sa.get("tags", {}) or {} - kind = sa.get("kind", "") - if "polyclaw_deploy" in tags or kind == "StorageV2": - logger.info("[aca] Deleting storage account: %s", name) - ok, _ = self._az.ok( - "storage", "account", "delete", "--name", name, - "--resource-group", rg, "--yes", - ) - if ok: - cleaned.append(f"storage/{name}") - steps.append({"step": f"{step_label}/storage/{name}", - "status": "ok" if ok else "failed"}) - - return cleaned - - def _cleanup_stale_resources( - self, req: AcaDeployRequest, steps: list[dict], - ) -> None: - logger.info("[aca] Pre-flight: cleaning all ACA resources in %s ...", req.resource_group) - cleaned = self._delete_aca_resources(req.resource_group, steps, step_label="cleanup") - if cleaned: - detail = ", ".join(cleaned) - logger.info("[aca] Cleaned %d resource(s): %s", len(cleaned), detail) - else: - logger.info("[aca] No resources to clean") - steps.append({"step": "cleanup", "status": "ok", "detail": "nothing to clean"}) - - def _ensure_resource_group( - self, req: AcaDeployRequest, steps: list[dict], rec: DeploymentRecord, - ) -> bool: - logger.info("[aca] Step 1/10: Ensuring resource group %s ...", req.resource_group) - tag_args = ["--tags", f"polyclaw_deploy={rec.tag}"] - result = self._az.json( - "group", "create", "--name", req.resource_group, - "--location", req.location, *tag_args, - ) - if result: - steps.append({"step": "resource_group", "status": "ok", "detail": req.resource_group}) - if req.resource_group not in rec.resource_groups: - rec.resource_groups.append(req.resource_group) - return True - steps.append({"step": "resource_group", "status": "failed", "detail": self._az.last_stderr}) - return False - - def _load_env_vars(self, steps: list[dict]) -> dict[str, str]: - from .keyvault import is_kv_ref, kv - - env_map = cfg.env.read_all() - _DEPLOYER_KEYS = frozenset({ - "ACA_RUNTIME_FQDN", "ACA_ACR_NAME", "ACA_ENV_NAME", - "ACA_STORAGE_ACCOUNT", "ACA_MI_RESOURCE_ID", "ACA_MI_CLIENT_ID", - "RUNTIME_URL", - }) - filtered = {k: v for k, v in env_map.items() if k not in _DEPLOYER_KEYS and v} - - resolved_count = 0 - for key, value in list(filtered.items()): - if is_kv_ref(value): - try: - plaintext = kv.resolve_value(value) - if plaintext: - filtered[key] = plaintext - resolved_count += 1 - logger.info("[aca] Resolved @kv: ref for %s", key) - else: - logger.warning( - "[aca] @kv: ref for %s resolved to empty -- removing", key, - ) - del filtered[key] - except Exception: - logger.error( - "[aca] Failed to resolve @kv: ref for %s -- removing", - key, exc_info=True, - ) - del filtered[key] - - count = len(filtered) - logger.info( - "[aca] Step 2/10: Loaded %d env var(s) from local .env " - "(%d @kv: references resolved)", - count, resolved_count, - ) - steps.append({"step": "load_env_vars", "status": "ok", - "detail": f"{count} variable(s), {resolved_count} @kv: resolved"}) - return filtered - - def _ensure_acr( - self, req: AcaDeployRequest, steps: list[dict], rec: DeploymentRecord, - ) -> str: - logger.info("[aca] Step 3/10: Creating container registry ...") - acr_name = "polyclaw" + secrets.token_hex(4) - acr_name = acr_name[:50].replace("-", "") - - result = self._az.json( - "acr", "create", - "--resource-group", req.resource_group, - "--name", acr_name, - "--sku", "Basic", - "--admin-enabled", "true", - "--location", req.location, - ) - if not result: - steps.append({ - "step": "acr_create", "status": "failed", - "detail": self._az.last_stderr, - }) - return "" - steps.append({"step": "acr_create", "status": "ok", "detail": acr_name}) - rec.add_resource("acr", req.resource_group, acr_name, "Container registry") - return acr_name - - def _get_acr_credentials(self, acr_name: str) -> tuple[str, str]: - creds = self._az.json("acr", "credential", "show", "--name", acr_name) - if not isinstance(creds, dict): - return "", "" - username = creds.get("username", "") - passwords = creds.get("passwords", []) - password = passwords[0].get("value", "") if passwords else "" - return username, password - - def _push_image( - self, acr_name: str, tag: str, steps: list[dict], - ) -> bool: - logger.info("[aca] Step 4/10: Pushing pre-built image to ACR ...") - local_image = f"{_IMAGE_NAME}:{tag}" - remote_image = f"{acr_name}.azurecr.io/{_IMAGE_NAME}:{tag}" - - check = subprocess.run( - ["docker", "image", "inspect", local_image], - capture_output=True, text=True, - ) - if check.returncode != 0: - detail = ( - f"Local image '{local_image}' not found. " - "Build it first with: docker build --platform linux/amd64 " - f"-t {local_image} ." - ) - logger.error("[aca] %s", detail) - steps.append({"step": "image_push", "status": "failed", "detail": detail}) - return False - - logger.info("[aca] Logging in to ACR %s ...", acr_name) - ok, msg = self._az.ok("acr", "login", "--name", acr_name) - if not ok: - detail = f"ACR login failed: {msg or self._az.last_stderr}" - logger.error("[aca] %s", detail) - steps.append({"step": "image_push", "status": "failed", "detail": detail}) - return False - - logger.info("[aca] Tagging %s -> %s", local_image, remote_image) - tag_result = subprocess.run( - ["docker", "tag", local_image, remote_image], - capture_output=True, text=True, - ) - if tag_result.returncode != 0: - detail = f"docker tag failed: {tag_result.stderr.strip()}" - logger.error("[aca] %s", detail) - steps.append({"step": "image_push", "status": "failed", "detail": detail}) - return False - - logger.info("[aca] Pushing %s (this may take 1-2 minutes) ...", remote_image) - push_result = subprocess.run( - ["docker", "push", remote_image], - capture_output=True, text=True, timeout=600, - ) - if push_result.returncode != 0: - detail = f"docker push failed: {push_result.stderr.strip()[:500]}" - logger.error("[aca] %s", detail) - steps.append({"step": "image_push", "status": "failed", "detail": detail}) - return False - - logger.info("[aca] Image pushed: %s", remote_image) - steps.append({"step": "image_push", "status": "ok", "detail": remote_image}) - return True - - def _ensure_managed_identity( - self, req: AcaDeployRequest, steps: list[dict], rec: DeploymentRecord, - ) -> tuple[str, str]: - logger.info("[aca] Step 5/10: Creating managed identity ...") - result = self._az.json( - "identity", "create", - "--name", _MI_NAME, - "--resource-group", req.resource_group, - "--location", req.location, - ) - if not isinstance(result, dict): - steps.append({"step": "managed_identity", "status": "failed", - "detail": self._az.last_stderr}) - return "", "" - - mi_id = result.get("id", "") - client_id = result.get("clientId", "") - steps.append({"step": "managed_identity", "status": "ok", "detail": _MI_NAME}) - rec.add_resource("managed_identity", req.resource_group, _MI_NAME, - "Runtime scoped identity") - return mi_id, client_id - - def _assign_rbac( - self, - mi_principal_id: str, - resource_group: str, - steps: list[dict], - ) -> None: - logger.info("[aca] Step 6/10: Assigning RBAC ...") - account = self._az.account_info() - sub_id = account.get("id", "") if account else "" - rg_scope = f"/subscriptions/{sub_id}/resourceGroups/{resource_group}" - - for role in (_BOT_CONTRIBUTOR_ROLE, _RG_READER_ROLE): - label = role.lower().replace(" ", "_") - assigned = False - for attempt in range(4): - if attempt: - delay = 10 * attempt - logger.info( - "[aca] RBAC retry %d/3 for %s in %ds ...", - attempt, label, delay, - ) - time.sleep(delay) - ok, _msg = self._az.ok( - "role", "assignment", "create", - "--assignee", mi_principal_id, - "--role", role, - "--scope", rg_scope, - ) - if ok or "already exists" in (self._az.last_stderr or "").lower(): - assigned = True - break - if assigned: - steps.append({"step": f"rbac_{label}", "status": "ok", - "detail": f"{role} on {resource_group}"}) - else: - steps.append({"step": f"rbac_{label}", "status": "failed", - "detail": self._az.last_stderr}) - - session_scope = self._session_pool_scope(sub_id) - if session_scope: - label = _SESSION_EXECUTOR_ROLE.lower().replace(" ", "_") - assigned = False - for attempt in range(4): - if attempt: - delay = 10 * attempt - logger.info( - "[aca] RBAC retry %d/3 for %s in %ds ...", - attempt, label, delay, - ) - time.sleep(delay) - ok, _msg = self._az.ok( - "role", "assignment", "create", - "--assignee", mi_principal_id, - "--role", _SESSION_EXECUTOR_ROLE, - "--scope", session_scope, - ) - if ok or "already exists" in (self._az.last_stderr or "").lower(): - assigned = True - break - if assigned: - steps.append({"step": f"rbac_{label}", "status": "ok", - "detail": f"{_SESSION_EXECUTOR_ROLE} on session pool"}) - else: - steps.append({"step": f"rbac_{label}", "status": "failed", - "detail": self._az.last_stderr}) - - def _session_pool_scope(self, subscription_id: str) -> str | None: - try: - store = SandboxConfigStore() - pool_id = store.pool_id - if pool_id: - return pool_id - rg = store.resource_group - name = store.pool_name - if rg and name: - return ( - f"/subscriptions/{subscription_id}/resourceGroups/{rg}" - f"/providers/Microsoft.App/sessionPools/{name}" - ) - except Exception as exc: - logger.debug("Could not resolve session pool scope: %s", exc) - return None - - def _ensure_aca_environment( - self, - req: AcaDeployRequest, - steps: list[dict], - rec: DeploymentRecord, - ) -> tuple[str, str]: - logger.info("[aca] Step 7/10: Creating ACA environment ...") - env_name = f"{_ENV_NAME_PREFIX}-{secrets.token_hex(4)}" - - result = self._az.json( - "containerapp", "env", "create", - "--name", env_name, - "--resource-group", req.resource_group, - "--location", req.location, - ) - if not isinstance(result, dict): - steps.append({ - "step": "aca_environment", "status": "failed", - "detail": self._az.last_stderr, - }) - return "", "" - - env_id = result.get("id", "") - steps.append({"step": "aca_environment", "status": "ok", "detail": env_name}) - rec.add_resource("aca_environment", req.resource_group, env_name, - "Container Apps environment") - return env_name, env_id - - def _ensure_runtime_app( - self, - req: AcaDeployRequest, - env_id: str, - acr_name: str, - mi_id: str, - mi_client_id: str, - acr_user: str, - acr_pass: str, - env_vars: dict[str, str], - steps: list[dict], - rec: DeploymentRecord, - ) -> str: - app_name = "polyclaw-runtime" - admin_secret = cfg.admin_secret or secrets.token_urlsafe(24) - image = f"{acr_name}.azurecr.io/{_IMAGE_NAME}:{req.image_tag}" - - logger.info("[aca] Step 8/10: Creating runtime container app ...") - - _SECRET_ENV_KEYS = frozenset({ - "RUNTIME_SP_PASSWORD", "ACS_CALLBACK_TOKEN", - "GITHUB_TOKEN", "BOT_APP_PASSWORD", - "ACS_CONNECTION_STRING", "AZURE_OPENAI_API_KEY", - }) - _SKIP = frozenset({ - "POLYCLAW_MODE", "POLYCLAW_DATA_DIR", "ADMIN_PORT", - "ADMIN_SECRET", "POLYCLAW_CONTAINER", "POLYCLAW_USE_MI", - "AZURE_CLIENT_ID", - }) | _SECRET_ENV_KEYS - aca_secrets: dict[str, str] = { - "admin-secret": admin_secret, - } - for env_key in _SECRET_ENV_KEYS: - secret_name = env_key.lower().replace("_", "-") - value = env_vars.get(env_key, "") - if value: - aca_secrets[secret_name] = value - - env_pairs = [ - "POLYCLAW_MODE=runtime", - f"ADMIN_PORT={req.runtime_port}", - "ADMIN_SECRET=secretref:admin-secret", - "POLYCLAW_CONTAINER=1", - "POLYCLAW_USE_MI=1", - f"AZURE_CLIENT_ID={mi_client_id}", - ] - for env_key in sorted(_SECRET_ENV_KEYS): - secret_name = env_key.lower().replace("_", "-") - if secret_name in aca_secrets: - env_pairs.append(f"{env_key}=secretref:{secret_name}") - - for key, value in sorted(env_vars.items()): - if key not in _SKIP and value: - env_pairs.append(f"{key}={value}") - - logger.info("[aca] Container env vars: %d total (%d via ACA secrets)", - len(env_pairs), len(aca_secrets)) - - secret_pairs = [f"{name}={value}" for name, value in sorted(aca_secrets.items())] - - create_args: list[str] = [ - "containerapp", "create", - "--name", app_name, - "--resource-group", req.resource_group, - "--environment", env_id, - "--image", image, - "--cpu", "2", "--memory", "4Gi", - "--min-replicas", "1", "--max-replicas", "1", - "--ingress", "external", - "--target-port", str(req.runtime_port), - "--registry-server", f"{acr_name}.azurecr.io", - "--registry-username", acr_user, - "--registry-password", acr_pass, - "--secrets", *secret_pairs, - "--env-vars", *env_pairs, - ] - - result = self._az.json(*create_args) - if not isinstance(result, dict): - detail = self._az.last_stderr - logger.error("[aca] containerapp create failed: %s", detail[:1000]) - steps.append({ - "step": "runtime_container_app", "status": "failed", - "detail": detail[:500], - }) - return "" - - logger.info("[aca] Assigning managed identity to container app ...") - id_ok, id_msg = self._az.ok( - "containerapp", "identity", "assign", - "--name", app_name, - "--resource-group", req.resource_group, - "--user-assigned", mi_id, - ) - if not id_ok: - logger.warning("[aca] MI assignment failed (non-fatal): %s", id_msg) - - fqdn = result.get("properties", {}).get("configuration", {}).get( - "ingress", {} - ).get("fqdn", "") - - if fqdn: - bot_endpoint = f"https://{fqdn}/api/messages" - self._az.ok( - "containerapp", "update", - "--name", app_name, - "--resource-group", req.resource_group, - "--set-env-vars", f"BOT_ENDPOINT={bot_endpoint}", - ) - - steps.append({"step": "runtime_container_app", "status": "ok", "detail": fqdn}) - rec.add_resource("container_app", req.resource_group, app_name, - "Runtime data plane (MI-scoped)") - return fqdn - - def _configure_ip_whitelist( - self, - req: AcaDeployRequest, - steps: list[dict], - ) -> list[dict[str, Any]]: - ip_steps: list[dict[str, Any]] = [] - - public_ip = self._detect_public_ip() - if not public_ip: - ip_steps.append({ - "step": "ip_whitelist", - "status": "skipped", - "detail": "Could not detect public IP -- runtime ingress unrestricted", - }) - return ip_steps - - ok, msg = self._az.ok( - "containerapp", "ingress", "access-restriction", "set", - "--name", "polyclaw-runtime", - "--resource-group", req.resource_group, - "--rule-name", "allow-deployer", - "--ip-address", f"{public_ip}/32", - "--action", "Allow", - "--description", "Allow deployer IP", - ) - if ok: - ip_steps.append({ - "step": "ip_whitelist", - "status": "ok", - "detail": f"Runtime restricted to {public_ip}/32", - }) - else: - ip_steps.append({ - "step": "ip_whitelist", - "status": "warning", - "detail": f"Could not set IP restriction: {msg}", - }) - - return ip_steps - - @staticmethod - def _detect_public_ip() -> str: - import urllib.request - - for url in ( - "https://api.ipify.org", - "https://ifconfig.me/ip", - "https://checkip.amazonaws.com", - ): - try: - with urllib.request.urlopen(url, timeout=10) as resp: - ip = resp.read().decode().strip() - if ip and "." in ip: - return ip - except Exception: - continue - return "" diff --git a/app/runtime/services/cloud/__init__.py b/app/runtime/services/cloud/__init__.py new file mode 100644 index 0000000..ae853a4 --- /dev/null +++ b/app/runtime/services/cloud/__init__.py @@ -0,0 +1,9 @@ +"""Cloud identity and CLI integrations.""" + +from __future__ import annotations + +from .azure import AzureCLI +from .github import GitHubAuth +from .runtime_identity import RuntimeIdentityProvisioner + +__all__ = ["AzureCLI", "GitHubAuth", "RuntimeIdentityProvisioner"] diff --git a/app/runtime/services/cloud/_azure_rbac.py b/app/runtime/services/cloud/_azure_rbac.py new file mode 100644 index 0000000..b89cb06 --- /dev/null +++ b/app/runtime/services/cloud/_azure_rbac.py @@ -0,0 +1,43 @@ +"""Shared Azure RBAC constants and helpers used across deployment and identity modules.""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# Common identity / image names. +MI_NAME = "polyclaw-runtime-mi" +IMAGE_NAME = "polyclaw" + +# RBAC role names used for runtime identity scoping. +BOT_CONTRIBUTOR_ROLE = "Azure Bot Service Contributor Role" +RG_READER_ROLE = "Reader" +KV_SECRETS_ROLE = "Key Vault Secrets Officer" +SESSION_EXECUTOR_ROLE = "Azure ContainerApps Session Executor" + + +def session_pool_scope(subscription_id: str) -> str | None: + """Return the ARM resource scope for the ACA session pool, or ``None``. + + The session pool id is stored in ``sandbox.json`` after provisioning. + Shared between ``runtime_identity`` and ``aca_provision``. + """ + from ...state.sandbox_config import SandboxConfigStore + + try: + store = SandboxConfigStore() + pool_id = store.pool_id + if pool_id: + return pool_id + rg = store.resource_group + name = store.pool_name + if rg and name: + return ( + f"/subscriptions/{subscription_id}/resourceGroups/{rg}" + f"/providers/Microsoft.App/sessionPools/{name}" + ) + except Exception as exc: + logger.debug("Could not resolve session pool scope: %s", exc) + return None diff --git a/app/runtime/services/azure.py b/app/runtime/services/cloud/azure.py similarity index 98% rename from app/runtime/services/azure.py rename to app/runtime/services/cloud/azure.py index 2e0946e..6fde317 100644 --- a/app/runtime/services/azure.py +++ b/app/runtime/services/cloud/azure.py @@ -13,8 +13,7 @@ from time import time as _time from typing import Any -from ..config.settings import cfg -from ..util.result import Result +from ...util.result import Result logger = logging.getLogger(__name__) @@ -164,6 +163,8 @@ def login_device_code(self) -> dict[str, Any]: def get_bot_endpoint(self) -> str | None: """Read the messaging endpoint URL from the deployed Azure Bot Service.""" + from ...config.settings import cfg + rg = cfg.env.read("BOT_RESOURCE_GROUP") name = cfg.env.read("BOT_NAME") if not (rg and name): @@ -175,6 +176,8 @@ def get_bot_endpoint(self) -> str | None: return endpoint or None def update_endpoint(self, endpoint: str) -> Result: + from ...config.settings import cfg + rg = cfg.env.read("BOT_RESOURCE_GROUP") name = cfg.env.read("BOT_NAME") if not (rg and name): @@ -198,6 +201,8 @@ def update_endpoint(self, endpoint: str) -> Result: return Result.ok("Endpoint updated") def get_channels(self) -> dict[str, bool]: + from ...config.settings import cfg + rg = cfg.env.read("BOT_RESOURCE_GROUP") name = cfg.env.read("BOT_NAME") if not (rg and name): @@ -277,6 +282,8 @@ def configure_telegram(self, token: str, *, validated_name: str = "") -> Result: return Result.fail(f"Invalid Telegram token: {tok_result.message}") display = tok_result.message logger.info("Telegram token validated: %s", display) + from ...config.settings import cfg + rg = cfg.env.read("BOT_RESOURCE_GROUP") name = cfg.env.read("BOT_NAME") if not (rg and name): @@ -289,6 +296,8 @@ def configure_telegram(self, token: str, *, validated_name: str = "") -> Result: return Result.ok(f"Telegram configured ({display})") if result else result def remove_channel(self, channel: str) -> Result: + from ...config.settings import cfg + rg = cfg.env.read("BOT_RESOURCE_GROUP") name = cfg.env.read("BOT_NAME") if not (rg and name): diff --git a/app/runtime/services/github.py b/app/runtime/services/cloud/github.py similarity index 97% rename from app/runtime/services/github.py rename to app/runtime/services/cloud/github.py index 76e9d49..cf55537 100644 --- a/app/runtime/services/github.py +++ b/app/runtime/services/cloud/github.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging import os import re import selectors @@ -9,7 +10,7 @@ from time import time as _time from typing import Any -from ..config.settings import cfg +logger = logging.getLogger(__name__) class GitHubAuth: @@ -19,6 +20,8 @@ def __init__(self) -> None: self._login_proc: subprocess.Popen | None = None def status(self) -> dict[str, Any]: + from ...config.settings import cfg + if cfg.github_token: return {"authenticated": True, "details": "Using GITHUB_TOKEN from environment"} try: diff --git a/app/runtime/services/runtime_identity.py b/app/runtime/services/cloud/runtime_identity.py similarity index 90% rename from app/runtime/services/runtime_identity.py rename to app/runtime/services/cloud/runtime_identity.py index 183e2ee..62d6493 100644 --- a/app/runtime/services/runtime_identity.py +++ b/app/runtime/services/cloud/runtime_identity.py @@ -22,28 +22,20 @@ import logging from typing import Any -from ..config.settings import cfg -from ..state.sandbox_config import SandboxConfigStore +from ...state.sandbox_config import SandboxConfigStore +from ._azure_rbac import ( + BOT_CONTRIBUTOR_ROLE as _BOT_CONTRIBUTOR_ROLE, + KV_SECRETS_ROLE as _KV_SECRETS_ROLE, + MI_NAME as _MI_NAME, + RG_READER_ROLE as _RG_READER_ROLE, + SESSION_EXECUTOR_ROLE as _SESSION_EXECUTOR_ROLE, + session_pool_scope as _session_pool_scope_fn, +) from .azure import AzureCLI logger = logging.getLogger(__name__) _SP_DISPLAY_NAME = "polyclaw-runtime" -_MI_NAME = "polyclaw-runtime-mi" -_BOT_CONTRIBUTOR_ROLE = "Azure Bot Service Contributor Role" - -# Additional role so the SP can create / update resource groups tagged for -# the runtime (the RG itself must already exist or the admin must create it -# before handing off). -_RG_READER_ROLE = "Reader" - -# The runtime needs to read and write Key Vault secrets (e.g. MicrosoftAppId, -# MicrosoftAppPassword) that are required for bot channel registration. -_KV_SECRETS_ROLE = "Key Vault Secrets Officer" - -# Required for interacting with ACA Dynamic Sessions (code interpreter / -# shell sandbox). Must be scoped to the session pool resource. -_SESSION_EXECUTOR_ROLE = "Azure ContainerApps Session Executor" class RuntimeIdentityProvisioner: @@ -147,6 +139,8 @@ def provision(self, resource_group: str) -> dict[str, Any]: self._assign_role(app_id, _SESSION_EXECUTOR_ROLE, session_scope, steps) # 7. Write the SP credentials to the shared .env + from ...config.settings import cfg + cfg.write_env( RUNTIME_SP_APP_ID=app_id, RUNTIME_SP_PASSWORD=password, @@ -170,6 +164,8 @@ def provision(self, resource_group: str) -> dict[str, Any]: def revoke(self) -> dict[str, Any]: """Delete the runtime SP and clear env vars.""" + from ...config.settings import cfg + steps: list[dict[str, str]] = [] app_id = cfg.env.read("RUNTIME_SP_APP_ID") @@ -253,6 +249,8 @@ def provision_managed_identity( self._assign_role(principal_id, _SESSION_EXECUTOR_ROLE, session_scope, steps) # Write MI config to .env so the ACA deployer can reference it + from ...config.settings import cfg + cfg.write_env( ACA_MI_RESOURCE_ID=mi_id, ACA_MI_CLIENT_ID=client_id, @@ -274,6 +272,8 @@ def provision_managed_identity( def revoke_managed_identity(self, resource_group: str) -> dict[str, Any]: """Delete the managed identity.""" + from ...config.settings import cfg + steps: list[dict[str, str]] = [] mi_id = cfg.env.read("ACA_MI_RESOURCE_ID") if not mi_id: @@ -288,6 +288,8 @@ def revoke_managed_identity(self, resource_group: str) -> dict[str, Any]: def status(self) -> dict[str, Any]: """Return current runtime identity state.""" + from ...config.settings import cfg + app_id = cfg.env.read("RUNTIME_SP_APP_ID") mi_client_id = cfg.env.read("ACA_MI_CLIENT_ID") return { @@ -300,26 +302,8 @@ def status(self) -> dict[str, Any]: } def _session_pool_scope(self, subscription_id: str) -> str | None: - """Return the ARM resource scope for the ACA session pool, or ``None``. - - The session pool id is stored in ``sandbox.json`` after provisioning. - """ - try: - store = SandboxConfigStore() - pool_id = store.pool_id - if pool_id: - return pool_id - # Fall back to constructing the scope from individual fields - rg = store.resource_group - name = store.pool_name - if rg and name: - return ( - f"/subscriptions/{subscription_id}/resourceGroups/{rg}" - f"/providers/Microsoft.App/sessionPools/{name}" - ) - except Exception as exc: - logger.debug("Could not resolve session pool scope: %s", exc) - return None + """Return the ARM resource scope for the ACA session pool, or ``None``.""" + return _session_pool_scope_fn(subscription_id) def _keyvault_scope(self, subscription_id: str) -> str | None: """Return the ARM resource scope for the Key Vault, or ``None``. @@ -329,6 +313,8 @@ def _keyvault_scope(self, subscription_id: str) -> str | None: scope ensures the KV Secrets Officer role grants access regardless of which RG the vault is in. """ + from ...config.settings import cfg + kv_name = cfg.env.read("KEY_VAULT_NAME") or "" kv_rg = cfg.env.read("KEY_VAULT_RG") or "" if kv_name and kv_rg: diff --git a/app/runtime/services/deployment/__init__.py b/app/runtime/services/deployment/__init__.py new file mode 100644 index 0000000..c5a29a1 --- /dev/null +++ b/app/runtime/services/deployment/__init__.py @@ -0,0 +1,9 @@ +"""Deployment and infrastructure provisioning.""" + +from __future__ import annotations + +from .aca_deployer import AcaDeployer +from .deployer import BotDeployer +from .provisioner import Provisioner + +__all__ = ["AcaDeployer", "BotDeployer", "Provisioner"] diff --git a/app/runtime/services/deployment/aca_deployer.py b/app/runtime/services/deployment/aca_deployer.py new file mode 100644 index 0000000..6cd44e3 --- /dev/null +++ b/app/runtime/services/deployment/aca_deployer.py @@ -0,0 +1,455 @@ +"""Azure Container Apps deployer.""" + +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Any + +from ...config.settings import cfg +from ...state.deploy_state import DeploymentRecord, DeployStateStore +from ..cloud.azure import AzureCLI +from ..cloud._azure_rbac import IMAGE_NAME as _IMAGE_NAME +from .aca_provision import ( + assign_rbac, + configure_ip_whitelist, + ensure_acr, + ensure_aca_environment, + ensure_managed_identity, + ensure_runtime_app, + get_acr_credentials, + push_image, +) + +logger = logging.getLogger(__name__) + +_LOCAL_IMAGE = "polyclaw:latest" + + +@dataclass +class AcaDeployRequest: + + resource_group: str = "polyclaw-rg" + location: str = "eastus" + bot_display_name: str = "polyclaw" + bot_handle: str = "" + admin_port: int = 9090 + runtime_port: int = 8080 + image_tag: str = "latest" + acr_name: str = "" + env_name: str = "" + + +@dataclass +class AcaDeployResult: + + ok: bool = False + steps: list[dict[str, Any]] = field(default_factory=list) + error: str = "" + runtime_fqdn: str = "" + acr_name: str = "" + deploy_id: str = "" + + +class AcaDeployer: + + def __init__(self, az: AzureCLI, deploy_store: DeployStateStore | None = None) -> None: + self._az = az + self._deploy_store = deploy_store + + def deploy(self, req: AcaDeployRequest) -> AcaDeployResult: + steps: list[dict[str, Any]] = [] + result = AcaDeployResult(steps=steps) + + logger.info("[aca] Starting ACA deployment: rg=%s, location=%s", req.resource_group, req.location) + + rec = DeploymentRecord.new(kind="aca") + result.deploy_id = rec.deploy_id + if self._deploy_store: + self._deploy_store.register(rec) + + try: + self._cleanup_stale_resources(req, steps) + + if not self._ensure_resource_group(req, steps, rec): + result.error = "Resource group creation failed" + return result + + env_vars = self._load_env_vars(steps) + + acr_name = ensure_acr(self._az, req.resource_group, req.location, steps, rec) + if not acr_name: + result.error = "Container registry creation failed" + return result + result.acr_name = acr_name + + if not push_image(self._az, acr_name, req.image_tag, steps): + result.error = "Image push failed" + return result + + acr_user, acr_pass = get_acr_credentials(self._az, acr_name) + if not acr_user: + result.error = "Could not retrieve ACR admin credentials" + return result + + mi_id, mi_client_id = ensure_managed_identity( + self._az, req.resource_group, req.location, steps, rec, + ) + if not mi_id: + result.error = "Managed identity creation failed" + return result + + assign_rbac(self._az, mi_client_id, req.resource_group, steps) + + env_name, env_id = ensure_aca_environment( + self._az, req.resource_group, req.location, steps, rec, + ) + if not env_name: + result.error = "Container Apps environment creation failed" + return result + + runtime_fqdn = ensure_runtime_app( + self._az, req.resource_group, env_id, acr_name, + mi_id, mi_client_id, acr_user, acr_pass, + env_vars, req.image_tag, req.runtime_port, steps, rec, + ) + if not runtime_fqdn: + result.error = "Runtime container app creation failed" + return result + result.runtime_fqdn = runtime_fqdn + + ip_steps = configure_ip_whitelist(self._az, req.resource_group) + steps.extend(ip_steps) + + runtime_url = f"https://{runtime_fqdn}" + cfg.write_env( + ACA_RUNTIME_FQDN=runtime_fqdn, + ACA_ACR_NAME=acr_name, + ACA_ENV_NAME=env_name, + ACA_MI_RESOURCE_ID=mi_id, + ACA_MI_CLIENT_ID=mi_client_id, + RUNTIME_URL=runtime_url, + ) + os.environ["RUNTIME_URL"] = runtime_url + logger.info("[aca] RUNTIME_URL set to %s", runtime_url) + steps.append({"step": "write_aca_config", "status": "ok"}) + + result.ok = True + logger.info("[aca] Deployment complete: runtime=%s", runtime_fqdn) + + except Exception as exc: + logger.error("[aca] Deployment failed: %s", exc, exc_info=True) + result.error = str(exc) + steps.append({"step": "unexpected_error", "status": "failed", "detail": str(exc)}) + + if self._deploy_store and rec: + if result.ok: + rec.config = { + "runtime_fqdn": result.runtime_fqdn, + "acr_name": result.acr_name, + } + else: + rec.mark_stopped() + self._deploy_store.update(rec) + + return result + + def destroy(self, deploy_id: str | None = None) -> AcaDeployResult: + steps: list[dict[str, Any]] = [] + result = AcaDeployResult(steps=steps) + + rec = None + if deploy_id and self._deploy_store: + rec = self._deploy_store.get(deploy_id) + elif self._deploy_store: + rec = self._deploy_store.current_aca() + + rg = ( + cfg.env.read("BOT_RESOURCE_GROUP") + or (rec.resource_groups[0] if rec and rec.resource_groups else "") + ) + + if rg: + cleaned = self._delete_aca_resources(rg, steps, step_label="destroy") + if cleaned: + logger.info("[aca] Destroyed %d resource(s): %s", + len(cleaned), ", ".join(cleaned)) + else: + logger.info("[aca] No ACA resources found to destroy in %s", rg) + + cfg.write_env( + ACA_RUNTIME_FQDN="", + ACA_ACR_NAME="", ACA_ENV_NAME="", + ACA_MI_RESOURCE_ID="", + ACA_MI_CLIENT_ID="", + RUNTIME_URL="", + ) + steps.append({"step": "clear_aca_config", "status": "ok"}) + + if rec and self._deploy_store: + rec.mark_destroyed() + self._deploy_store.update(rec) + + result.ok = True + return result + + def status(self) -> dict[str, Any]: + runtime_fqdn = cfg.env.read("ACA_RUNTIME_FQDN") + return { + "deployed": bool(runtime_fqdn), + "runtime_fqdn": runtime_fqdn or None, + "acr_name": cfg.env.read("ACA_ACR_NAME") or None, + "env_name": cfg.env.read("ACA_ENV_NAME") or None, + "mi_client_id": cfg.env.read("ACA_MI_CLIENT_ID") or None, + } + + def restart(self) -> dict[str, Any]: + rg = cfg.env.read("BOT_RESOURCE_GROUP") or "polyclaw-rg" + app_name = "polyclaw-runtime" + + revisions = self._az.json( + "containerapp", "revision", "list", + "--name", app_name, + "--resource-group", rg, + quiet=True, + ) + if not revisions or not isinstance(revisions, list): + ok, msg = self._az.ok( + "containerapp", "update", + "--name", app_name, + "--resource-group", rg, + "--set-env-vars", f"RESTART_TS={int(time.time())}", + ) + result_detail = { + "app": app_name, + "status": "ok" if ok else "failed", + "method": "update", + "detail": msg if not ok else "forced new revision", + } + logger.info("[aca.restart] result=%r", result_detail) + return {"ok": ok, "results": [result_detail]} + + active = next( + (r["name"] for r in revisions if r.get("properties", {}).get("active")), + revisions[0].get("name") if revisions else None, + ) + if not active: + result_detail = { + "app": app_name, "status": "failed", + "method": "revision_restart", + "detail": "no active revision found", + } + logger.info("[aca.restart] result=%r", result_detail) + return {"ok": False, "results": [result_detail]} + + ok, msg = self._az.ok( + "containerapp", "revision", "restart", + "--name", app_name, + "--resource-group", rg, + "--revision", active, + ) + result_detail = { + "app": app_name, + "status": "ok" if ok else "failed", + "method": "revision_restart", + "detail": active if ok else msg, + } + logger.info("[aca.restart] result=%r", result_detail) + return {"ok": ok, "results": [result_detail]} + + def _delete_aca_resources( + self, rg: str, steps: list[dict], *, step_label: str = "cleanup", + ) -> list[str]: + rg_exists = self._az.json("group", "show", "--name", rg, quiet=True) + if not isinstance(rg_exists, dict): + logger.info("[aca] Resource group %s does not exist -- nothing to clean", rg) + return [] + + cleaned: list[str] = [] + + apps = self._az.json( + "containerapp", "list", + "--resource-group", rg, quiet=True, + ) + for app in (apps if isinstance(apps, list) else []): + name = app.get("name", "") + if not name: + continue + logger.info("[aca] Deleting container app: %s (waiting)", name) + ok, _ = self._az.ok( + "containerapp", "delete", "--name", name, + "--resource-group", rg, "--yes", + ) + if ok: + cleaned.append(f"containerapp/{name}") + steps.append({"step": f"{step_label}/containerapp/{name}", + "status": "ok" if ok else "failed"}) + + identities = self._az.json( + "identity", "list", + "--resource-group", rg, quiet=True, + ) + for mi in (identities if isinstance(identities, list) else []): + name = mi.get("name", "") + if not name: + continue + logger.info("[aca] Deleting managed identity: %s (waiting)", name) + ok, _ = self._az.ok( + "identity", "delete", "--name", name, + "--resource-group", rg, + ) + if ok: + cleaned.append(f"identity/{name}") + steps.append({"step": f"{step_label}/identity/{name}", + "status": "ok" if ok else "failed"}) + + envs = self._az.json( + "containerapp", "env", "list", + "--resource-group", rg, quiet=True, + ) + for env in (envs if isinstance(envs, list) else []): + name = env.get("name", "") + if not name: + continue + logger.info("[aca] Deleting ACA environment: %s (no-wait)", name) + ok, _ = self._az.ok( + "containerapp", "env", "delete", "--name", name, + "--resource-group", rg, "--yes", "--no-wait", + ) + if ok: + cleaned.append(f"aca-env/{name}") + steps.append({"step": f"{step_label}/aca-env/{name}", + "status": "ok" if ok else "failed"}) + + acrs = self._az.json( + "acr", "list", + "--resource-group", rg, quiet=True, + ) + for acr in (acrs if isinstance(acrs, list) else []): + name = acr.get("name", "") + if not name: + continue + logger.info("[aca] Deleting ACR: %s", name) + ok, _ = self._az.ok( + "acr", "delete", "--name", name, + "--resource-group", rg, "--yes", + ) + if ok: + cleaned.append(f"acr/{name}") + steps.append({"step": f"{step_label}/acr/{name}", + "status": "ok" if ok else "failed"}) + + workspaces = self._az.json( + "monitor", "log-analytics", "workspace", "list", + "--resource-group", rg, quiet=True, + ) + for ws in (workspaces if isinstance(workspaces, list) else []): + name = ws.get("name", "") + if not name: + continue + logger.info("[aca] Deleting Log Analytics workspace: %s", name) + ok, _ = self._az.ok( + "monitor", "log-analytics", "workspace", "delete", + "--workspace-name", name, + "--resource-group", rg, "--yes", "--force", + ) + if ok: + cleaned.append(f"log-analytics/{name}") + steps.append({"step": f"{step_label}/log-analytics/{name}", + "status": "ok" if ok else "failed"}) + + storage_accounts = self._az.json( + "storage", "account", "list", + "--resource-group", rg, quiet=True, + ) + for sa in (storage_accounts if isinstance(storage_accounts, list) else []): + name = sa.get("name", "") + if not name: + continue + tags = sa.get("tags", {}) or {} + kind = sa.get("kind", "") + if "polyclaw_deploy" in tags or kind == "StorageV2": + logger.info("[aca] Deleting storage account: %s", name) + ok, _ = self._az.ok( + "storage", "account", "delete", "--name", name, + "--resource-group", rg, "--yes", + ) + if ok: + cleaned.append(f"storage/{name}") + steps.append({"step": f"{step_label}/storage/{name}", + "status": "ok" if ok else "failed"}) + + return cleaned + + def _cleanup_stale_resources( + self, req: AcaDeployRequest, steps: list[dict], + ) -> None: + logger.info("[aca] Pre-flight: cleaning all ACA resources in %s ...", req.resource_group) + cleaned = self._delete_aca_resources(req.resource_group, steps, step_label="cleanup") + if cleaned: + detail = ", ".join(cleaned) + logger.info("[aca] Cleaned %d resource(s): %s", len(cleaned), detail) + else: + logger.info("[aca] No resources to clean") + steps.append({"step": "cleanup", "status": "ok", "detail": "nothing to clean"}) + + def _ensure_resource_group( + self, req: AcaDeployRequest, steps: list[dict], rec: DeploymentRecord, + ) -> bool: + logger.info("[aca] Step 1/10: Ensuring resource group %s ...", req.resource_group) + tag_args = ["--tags", f"polyclaw_deploy={rec.tag}"] + result = self._az.json( + "group", "create", "--name", req.resource_group, + "--location", req.location, *tag_args, + ) + if result: + steps.append({"step": "resource_group", "status": "ok", "detail": req.resource_group}) + if req.resource_group not in rec.resource_groups: + rec.resource_groups.append(req.resource_group) + return True + steps.append({"step": "resource_group", "status": "failed", "detail": self._az.last_stderr}) + return False + + def _load_env_vars(self, steps: list[dict]) -> dict[str, str]: + from ..keyvault import is_kv_ref, kv + + env_map = cfg.env.read_all() + _DEPLOYER_KEYS = frozenset({ + "ACA_RUNTIME_FQDN", "ACA_ACR_NAME", "ACA_ENV_NAME", + "ACA_STORAGE_ACCOUNT", "ACA_MI_RESOURCE_ID", "ACA_MI_CLIENT_ID", + "RUNTIME_URL", + }) + filtered = {k: v for k, v in env_map.items() if k not in _DEPLOYER_KEYS and v} + + resolved_count = 0 + for key, value in list(filtered.items()): + if is_kv_ref(value): + try: + plaintext = kv.resolve_value(value) + if plaintext: + filtered[key] = plaintext + resolved_count += 1 + logger.info("[aca] Resolved @kv: ref for %s", key) + else: + logger.warning( + "[aca] @kv: ref for %s resolved to empty -- removing", key, + ) + del filtered[key] + except Exception: + logger.error( + "[aca] Failed to resolve @kv: ref for %s -- removing", + key, exc_info=True, + ) + del filtered[key] + + count = len(filtered) + logger.info( + "[aca] Step 2/10: Loaded %d env var(s) from local .env " + "(%d @kv: references resolved)", + count, resolved_count, + ) + steps.append({"step": "load_env_vars", "status": "ok", + "detail": f"{count} variable(s), {resolved_count} @kv: resolved"}) + return filtered diff --git a/app/runtime/services/deployment/aca_provision.py b/app/runtime/services/deployment/aca_provision.py new file mode 100644 index 0000000..00c0ed5 --- /dev/null +++ b/app/runtime/services/deployment/aca_provision.py @@ -0,0 +1,433 @@ +"""ACA resource provisioning helpers for the deployer.""" + +from __future__ import annotations + +import logging +import secrets +import subprocess +import time +from typing import Any + +from ...config.settings import cfg +from ...state.deploy_state import DeploymentRecord +from ..cloud.azure import AzureCLI +from ..cloud._azure_rbac import ( + BOT_CONTRIBUTOR_ROLE as _BOT_CONTRIBUTOR_ROLE, + IMAGE_NAME as _IMAGE_NAME, + MI_NAME as _MI_NAME, + RG_READER_ROLE as _RG_READER_ROLE, + SESSION_EXECUTOR_ROLE as _SESSION_EXECUTOR_ROLE, + session_pool_scope as _session_pool_scope, +) + +logger = logging.getLogger(__name__) + +_ENV_NAME_PREFIX = "polyclaw-env" + + +def ensure_acr( + az: AzureCLI, + resource_group: str, + location: str, + steps: list[dict], + rec: DeploymentRecord, +) -> str: + """Create a container registry. Returns the ACR name, or ``""`` on failure.""" + logger.info("[aca] Step 3/10: Creating container registry ...") + acr_name = "polyclaw" + secrets.token_hex(4) + acr_name = acr_name[:50].replace("-", "") + + result = az.json( + "acr", "create", + "--resource-group", resource_group, + "--name", acr_name, + "--sku", "Basic", + "--admin-enabled", "true", + "--location", location, + ) + if not result: + steps.append({ + "step": "acr_create", "status": "failed", + "detail": az.last_stderr, + }) + return "" + steps.append({"step": "acr_create", "status": "ok", "detail": acr_name}) + rec.add_resource("acr", resource_group, acr_name, "Container registry") + return acr_name + + +def get_acr_credentials(az: AzureCLI, acr_name: str) -> tuple[str, str]: + """Return ``(username, password)`` for the ACR admin account.""" + creds = az.json("acr", "credential", "show", "--name", acr_name) + if not isinstance(creds, dict): + return "", "" + username = creds.get("username", "") + passwords = creds.get("passwords", []) + password = passwords[0].get("value", "") if passwords else "" + return username, password + + +def push_image( + az: AzureCLI, + acr_name: str, + tag: str, + steps: list[dict], +) -> bool: + """Build, tag, and push the local Docker image to ACR.""" + logger.info("[aca] Step 4/10: Pushing pre-built image to ACR ...") + local_image = f"{_IMAGE_NAME}:{tag}" + remote_image = f"{acr_name}.azurecr.io/{_IMAGE_NAME}:{tag}" + + check = subprocess.run( + ["docker", "image", "inspect", local_image], + capture_output=True, text=True, + ) + if check.returncode != 0: + detail = ( + f"Local image '{local_image}' not found. " + "Build it first with: docker build --platform linux/amd64 " + f"-t {local_image} ." + ) + logger.error("[aca] %s", detail) + steps.append({"step": "image_push", "status": "failed", "detail": detail}) + return False + + logger.info("[aca] Logging in to ACR %s ...", acr_name) + ok, msg = az.ok("acr", "login", "--name", acr_name) + if not ok: + detail = f"ACR login failed: {msg or az.last_stderr}" + logger.error("[aca] %s", detail) + steps.append({"step": "image_push", "status": "failed", "detail": detail}) + return False + + logger.info("[aca] Tagging %s -> %s", local_image, remote_image) + tag_result = subprocess.run( + ["docker", "tag", local_image, remote_image], + capture_output=True, text=True, + ) + if tag_result.returncode != 0: + detail = f"docker tag failed: {tag_result.stderr.strip()}" + logger.error("[aca] %s", detail) + steps.append({"step": "image_push", "status": "failed", "detail": detail}) + return False + + logger.info("[aca] Pushing %s (this may take 1-2 minutes) ...", remote_image) + push_result = subprocess.run( + ["docker", "push", remote_image], + capture_output=True, text=True, timeout=600, + ) + if push_result.returncode != 0: + detail = f"docker push failed: {push_result.stderr.strip()[:500]}" + logger.error("[aca] %s", detail) + steps.append({"step": "image_push", "status": "failed", "detail": detail}) + return False + + logger.info("[aca] Image pushed: %s", remote_image) + steps.append({"step": "image_push", "status": "ok", "detail": remote_image}) + return True + + +def ensure_managed_identity( + az: AzureCLI, + resource_group: str, + location: str, + steps: list[dict], + rec: DeploymentRecord, +) -> tuple[str, str]: + """Create a user-assigned managed identity. Returns ``(id, client_id)``.""" + logger.info("[aca] Step 5/10: Creating managed identity ...") + result = az.json( + "identity", "create", + "--name", _MI_NAME, + "--resource-group", resource_group, + "--location", location, + ) + if not isinstance(result, dict): + steps.append({"step": "managed_identity", "status": "failed", + "detail": az.last_stderr}) + return "", "" + + mi_id = result.get("id", "") + client_id = result.get("clientId", "") + steps.append({"step": "managed_identity", "status": "ok", "detail": _MI_NAME}) + rec.add_resource("managed_identity", resource_group, _MI_NAME, + "Runtime scoped identity") + return mi_id, client_id + + +def assign_rbac( + az: AzureCLI, + mi_principal_id: str, + resource_group: str, + steps: list[dict], +) -> None: + """Assign RBAC roles to the managed identity.""" + logger.info("[aca] Step 6/10: Assigning RBAC ...") + account = az.account_info() + sub_id = account.get("id", "") if account else "" + rg_scope = f"/subscriptions/{sub_id}/resourceGroups/{resource_group}" + + for role in (_BOT_CONTRIBUTOR_ROLE, _RG_READER_ROLE): + label = role.lower().replace(" ", "_") + assigned = False + for attempt in range(4): + if attempt: + delay = 10 * attempt + logger.info( + "[aca] RBAC retry %d/3 for %s in %ds ...", + attempt, label, delay, + ) + time.sleep(delay) + ok, _msg = az.ok( + "role", "assignment", "create", + "--assignee", mi_principal_id, + "--role", role, + "--scope", rg_scope, + ) + if ok or "already exists" in (az.last_stderr or "").lower(): + assigned = True + break + if assigned: + steps.append({"step": f"rbac_{label}", "status": "ok", + "detail": f"{role} on {resource_group}"}) + else: + steps.append({"step": f"rbac_{label}", "status": "failed", + "detail": az.last_stderr}) + + session_scope = _session_pool_scope(sub_id) + if session_scope: + label = _SESSION_EXECUTOR_ROLE.lower().replace(" ", "_") + assigned = False + for attempt in range(4): + if attempt: + delay = 10 * attempt + logger.info( + "[aca] RBAC retry %d/3 for %s in %ds ...", + attempt, label, delay, + ) + time.sleep(delay) + ok, _msg = az.ok( + "role", "assignment", "create", + "--assignee", mi_principal_id, + "--role", _SESSION_EXECUTOR_ROLE, + "--scope", session_scope, + ) + if ok or "already exists" in (az.last_stderr or "").lower(): + assigned = True + break + if assigned: + steps.append({"step": f"rbac_{label}", "status": "ok", + "detail": f"{_SESSION_EXECUTOR_ROLE} on session pool"}) + else: + steps.append({"step": f"rbac_{label}", "status": "failed", + "detail": az.last_stderr}) + + +def ensure_aca_environment( + az: AzureCLI, + resource_group: str, + location: str, + steps: list[dict], + rec: DeploymentRecord, +) -> tuple[str, str]: + """Create an ACA environment. Returns ``(env_name, env_id)``.""" + logger.info("[aca] Step 7/10: Creating ACA environment ...") + env_name = f"{_ENV_NAME_PREFIX}-{secrets.token_hex(4)}" + + result = az.json( + "containerapp", "env", "create", + "--name", env_name, + "--resource-group", resource_group, + "--location", location, + ) + if not isinstance(result, dict): + steps.append({ + "step": "aca_environment", "status": "failed", + "detail": az.last_stderr, + }) + return "", "" + + env_id = result.get("id", "") + steps.append({"step": "aca_environment", "status": "ok", "detail": env_name}) + rec.add_resource("aca_environment", resource_group, env_name, + "Container Apps environment") + return env_name, env_id + + +def ensure_runtime_app( + az: AzureCLI, + resource_group: str, + env_id: str, + acr_name: str, + mi_id: str, + mi_client_id: str, + acr_user: str, + acr_pass: str, + env_vars: dict[str, str], + image_tag: str, + runtime_port: int, + steps: list[dict], + rec: DeploymentRecord, +) -> str: + """Create the runtime container app. Returns the FQDN, or ``""`` on failure.""" + app_name = "polyclaw-runtime" + admin_secret = cfg.admin_secret or secrets.token_urlsafe(24) + image = f"{acr_name}.azurecr.io/{_IMAGE_NAME}:{image_tag}" + + logger.info("[aca] Step 8/10: Creating runtime container app ...") + + _SECRET_ENV_KEYS = frozenset({ + "RUNTIME_SP_PASSWORD", "ACS_CALLBACK_TOKEN", + "GITHUB_TOKEN", "BOT_APP_PASSWORD", + "ACS_CONNECTION_STRING", "AZURE_OPENAI_API_KEY", + }) + _SKIP = frozenset({ + "POLYCLAW_MODE", "POLYCLAW_DATA_DIR", "ADMIN_PORT", + "ADMIN_SECRET", "POLYCLAW_CONTAINER", "POLYCLAW_USE_MI", + "AZURE_CLIENT_ID", + }) | _SECRET_ENV_KEYS + aca_secrets: dict[str, str] = { + "admin-secret": admin_secret, + } + for env_key in _SECRET_ENV_KEYS: + secret_name = env_key.lower().replace("_", "-") + value = env_vars.get(env_key, "") + if value: + aca_secrets[secret_name] = value + + env_pairs = [ + "POLYCLAW_MODE=runtime", + f"ADMIN_PORT={runtime_port}", + "ADMIN_SECRET=secretref:admin-secret", + "POLYCLAW_CONTAINER=1", + "POLYCLAW_USE_MI=1", + f"AZURE_CLIENT_ID={mi_client_id}", + ] + for env_key in sorted(_SECRET_ENV_KEYS): + secret_name = env_key.lower().replace("_", "-") + if secret_name in aca_secrets: + env_pairs.append(f"{env_key}=secretref:{secret_name}") + + for key, value in sorted(env_vars.items()): + if key not in _SKIP and value: + env_pairs.append(f"{key}={value}") + + logger.info("[aca] Container env vars: %d total (%d via ACA secrets)", + len(env_pairs), len(aca_secrets)) + + secret_pairs = [f"{name}={value}" for name, value in sorted(aca_secrets.items())] + + create_args: list[str] = [ + "containerapp", "create", + "--name", app_name, + "--resource-group", resource_group, + "--environment", env_id, + "--image", image, + "--cpu", "2", "--memory", "4Gi", + "--min-replicas", "1", "--max-replicas", "1", + "--ingress", "external", + "--target-port", str(runtime_port), + "--registry-server", f"{acr_name}.azurecr.io", + "--registry-username", acr_user, + "--registry-password", acr_pass, + "--secrets", *secret_pairs, + "--env-vars", *env_pairs, + ] + + result = az.json(*create_args) + if not isinstance(result, dict): + detail = az.last_stderr + logger.error("[aca] containerapp create failed: %s", detail[:1000]) + steps.append({ + "step": "runtime_container_app", "status": "failed", + "detail": detail[:500], + }) + return "" + + logger.info("[aca] Assigning managed identity to container app ...") + id_ok, id_msg = az.ok( + "containerapp", "identity", "assign", + "--name", app_name, + "--resource-group", resource_group, + "--user-assigned", mi_id, + ) + if not id_ok: + logger.warning("[aca] MI assignment failed (non-fatal): %s", id_msg) + + fqdn = result.get("properties", {}).get("configuration", {}).get( + "ingress", {} + ).get("fqdn", "") + + if fqdn: + bot_endpoint = f"https://{fqdn}/api/messages" + az.ok( + "containerapp", "update", + "--name", app_name, + "--resource-group", resource_group, + "--set-env-vars", f"BOT_ENDPOINT={bot_endpoint}", + ) + + steps.append({"step": "runtime_container_app", "status": "ok", "detail": fqdn}) + rec.add_resource("container_app", resource_group, app_name, + "Runtime data plane (MI-scoped)") + return fqdn + + +def configure_ip_whitelist( + az: AzureCLI, + resource_group: str, +) -> list[dict[str, Any]]: + """Restrict the runtime container's ingress to the deployer's IP.""" + ip_steps: list[dict[str, Any]] = [] + + public_ip = detect_public_ip() + if not public_ip: + ip_steps.append({ + "step": "ip_whitelist", + "status": "skipped", + "detail": "Could not detect public IP -- runtime ingress unrestricted", + }) + return ip_steps + + ok, msg = az.ok( + "containerapp", "ingress", "access-restriction", "set", + "--name", "polyclaw-runtime", + "--resource-group", resource_group, + "--rule-name", "allow-deployer", + "--ip-address", f"{public_ip}/32", + "--action", "Allow", + "--description", "Allow deployer IP", + ) + if ok: + ip_steps.append({ + "step": "ip_whitelist", + "status": "ok", + "detail": f"Runtime restricted to {public_ip}/32", + }) + else: + ip_steps.append({ + "step": "ip_whitelist", + "status": "warning", + "detail": f"Could not set IP restriction: {msg}", + }) + + return ip_steps + + +def detect_public_ip() -> str: + """Return the deployer's public IP address, or ``""`` if unavailable.""" + import urllib.request + + for url in ( + "https://api.ipify.org", + "https://ifconfig.me/ip", + "https://checkip.amazonaws.com", + ): + try: + with urllib.request.urlopen(url, timeout=10) as resp: + ip = resp.read().decode().strip() + if ip and "." in ip: + return ip + except Exception: + continue + return "" diff --git a/app/runtime/services/deployer.py b/app/runtime/services/deployment/deployer.py similarity index 99% rename from app/runtime/services/deployer.py rename to app/runtime/services/deployment/deployer.py index a83518e..035f576 100644 --- a/app/runtime/services/deployer.py +++ b/app/runtime/services/deployment/deployer.py @@ -9,9 +9,9 @@ from dataclasses import dataclass from typing import Any -from ..config.settings import cfg -from ..state.deploy_state import DeployStateStore -from .azure import AzureCLI +from ...config.settings import cfg +from ...state.deploy_state import DeployStateStore +from ..cloud.azure import AzureCLI logger = logging.getLogger(__name__) diff --git a/app/runtime/services/provisioner.py b/app/runtime/services/deployment/provisioner.py similarity index 97% rename from app/runtime/services/provisioner.py rename to app/runtime/services/deployment/provisioner.py index 02c2153..d83fd30 100644 --- a/app/runtime/services/provisioner.py +++ b/app/runtime/services/deployment/provisioner.py @@ -5,12 +5,12 @@ import logging from typing import Any -from ..config.settings import cfg -from ..state.deploy_state import DeploymentRecord, DeployStateStore -from ..state.infra_config import InfraConfigStore -from .azure import AzureCLI +from ...config.settings import cfg +from ...state.deploy_state import DeploymentRecord, DeployStateStore +from ...state.infra_config import InfraConfigStore +from ..cloud.azure import AzureCLI from .deployer import BotDeployer, DeployRequest -from .runtime_identity import RuntimeIdentityProvisioner +from ..cloud.runtime_identity import RuntimeIdentityProvisioner logger = logging.getLogger(__name__) diff --git a/app/runtime/services/keyvault.py b/app/runtime/services/keyvault.py index a88d714..a64fab1 100644 --- a/app/runtime/services/keyvault.py +++ b/app/runtime/services/keyvault.py @@ -205,6 +205,16 @@ def _allow_current_ip(self) -> bool: kv = KeyVaultClient() +def _reset_kv() -> None: + """Reset the module-level Key Vault singleton (for test isolation).""" + kv.reinit() + + +from ..util.singletons import register_singleton # noqa: E402 + +register_singleton(_reset_kv) + + def resolve_if_kv_ref(value: str) -> str: """Resolve a ``@kv:secret-name`` reference, returning the original value if not a ref. diff --git a/app/runtime/services/otel.py b/app/runtime/services/otel.py index eaf9260..4997579 100644 --- a/app/runtime/services/otel.py +++ b/app/runtime/services/otel.py @@ -31,6 +31,11 @@ def _reset_otel_state() -> None: _otel_active = False +from ..util.singletons import register_singleton # noqa: E402 + +register_singleton(_reset_otel_state) + + def configure_otel( connection_string: str, *, diff --git a/app/runtime/services/resource_tracker.py b/app/runtime/services/resource_tracker.py index 50cda5d..aab7819 100644 --- a/app/runtime/services/resource_tracker.py +++ b/app/runtime/services/resource_tracker.py @@ -7,7 +7,7 @@ from typing import Any from ..state.deploy_state import TAG_PREFIX, DeployStateStore -from .azure import AzureCLI +from .cloud.azure import AzureCLI logger = logging.getLogger(__name__) diff --git a/app/runtime/services/security/__init__.py b/app/runtime/services/security/__init__.py new file mode 100644 index 0000000..b43820d --- /dev/null +++ b/app/runtime/services/security/__init__.py @@ -0,0 +1,9 @@ +"""Security, content safety, and misconfiguration checking.""" + +from __future__ import annotations + +from .misconfig_checker import MisconfigChecker +from .prompt_shield import PromptShieldService +from .security_preflight import SecurityPreflightChecker + +__all__ = ["MisconfigChecker", "PromptShieldService", "SecurityPreflightChecker"] diff --git a/app/runtime/services/misconfig_checker.py b/app/runtime/services/security/misconfig_checker.py similarity index 99% rename from app/runtime/services/misconfig_checker.py rename to app/runtime/services/security/misconfig_checker.py index 660e62e..9e7cdb7 100644 --- a/app/runtime/services/misconfig_checker.py +++ b/app/runtime/services/security/misconfig_checker.py @@ -6,7 +6,7 @@ from dataclasses import asdict, dataclass, field from typing import Any, Literal -from .azure import AzureCLI +from ..cloud.azure import AzureCLI logger = logging.getLogger(__name__) diff --git a/app/runtime/services/security/preflight_identity.py b/app/runtime/services/security/preflight_identity.py new file mode 100644 index 0000000..4cb8c5a --- /dev/null +++ b/app/runtime/services/security/preflight_identity.py @@ -0,0 +1,240 @@ +"""Identity preflight checks (login gate, identity config, validity, credential expiry).""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from ...config.settings import cfg +from ..cloud.azure import AzureCLI +from .security_preflight import ( + IdentityInfo, + PreflightCheck, + PreflightResult, + add_check as _add, +) + + +# -- Azure login gate --------------------------------------------------- + +def check_azure_logged_in(az: AzureCLI, result: PreflightResult) -> bool: + cmd = "az account show" + account = az.json("account", "show", quiet=True) + if isinstance(account, dict) and account.get("id"): + sub = account.get("name", account.get("id", "?")) + _add( + result, id="azure_logged_in", category="identity", + name="Azure CLI Authenticated", + status="pass", + detail=f"Logged in to subscription: {sub}", + evidence=f"subscription={sub}\ntenantId={account.get('tenantId', '?')}", + command=cmd, + ) + return True + _add( + result, id="azure_logged_in", category="identity", + name="Azure CLI Authenticated", + status="fail", + detail="Not logged in -- RBAC and identity checks require Azure CLI auth", + evidence=az.last_stderr or "No response", + command=cmd, + ) + return False + + +def skip_azure_checks(result: PreflightResult) -> None: + for check_id, name, cat in [ + ("identity_configured", "Runtime Identity Configured", "identity"), + ("identity_valid", "Identity Exists in Azure AD", "identity"), + ("identity_credential_expiry", "Credential Expiry", "identity"), + ("rbac_assignments_list", "RBAC Assignments", "rbac"), + ("rbac_bot_contributor", "Azure Bot Service Contributor Role", "rbac"), + ("rbac_reader", "Reader Role", "rbac"), + ("rbac_kv_access", "Key Vault Access Role", "rbac"), + ("rbac_session_pool", "Session Pool Executor", "rbac"), + ("rbac_no_elevated", "No Elevated Roles", "rbac"), + ("rbac_scope_contained", "Scope Limited to Resource Group", "rbac"), + ]: + _add( + result, id=check_id, category=cat, name=name, + status="skip", + detail="Skipped -- Azure CLI not authenticated", + command="", + ) + + +# -- Identity checks ---------------------------------------------------- + +def check_identity_configured( + az: AzureCLI, result: PreflightResult, +) -> IdentityInfo | None: + sp_app_id = cfg.env.read("RUNTIME_SP_APP_ID") + mi_client_id = cfg.env.read("ACA_MI_CLIENT_ID") + mi_resource_id = cfg.env.read("ACA_MI_RESOURCE_ID") + + if mi_client_id: + _add( + result, id="identity_configured", category="identity", + name="Runtime Identity Configured", + status="pass", + detail=f"User-assigned managed identity: client_id={mi_client_id}", + evidence=( + f"ACA_MI_CLIENT_ID={mi_client_id}\n" + f"ACA_MI_RESOURCE_ID={mi_resource_id}" + ), + command="env: ACA_MI_CLIENT_ID, ACA_MI_RESOURCE_ID", + ) + return { + "strategy": "managed_identity", + "client_id": mi_client_id, + "resource_id": mi_resource_id, + "assignee": mi_client_id, + } + + if sp_app_id: + sp_tenant = cfg.env.read("RUNTIME_SP_TENANT") + has_pw = bool(cfg.env.read("RUNTIME_SP_PASSWORD")) + _add( + result, id="identity_configured", category="identity", + name="Runtime Identity Configured", + status="pass", + detail=f"Scoped service principal: app_id={sp_app_id}", + evidence=( + f"RUNTIME_SP_APP_ID={sp_app_id}\n" + f"RUNTIME_SP_TENANT={sp_tenant}\n" + f"RUNTIME_SP_PASSWORD={'***' if has_pw else 'MISSING'}" + ), + command="env: RUNTIME_SP_APP_ID, RUNTIME_SP_TENANT, RUNTIME_SP_PASSWORD", + ) + return { + "strategy": "sp", + "app_id": sp_app_id, + "tenant": sp_tenant, + "assignee": sp_app_id, + } + + _add( + result, id="identity_configured", category="identity", + name="Runtime Identity Configured", + status="skip", + detail="No runtime identity configured (RUNTIME_SP_* and ACA_MI_* absent)", + evidence="RUNTIME_SP_APP_ID=(empty)\nACA_MI_CLIENT_ID=(empty)", + command="env: RUNTIME_SP_APP_ID, ACA_MI_CLIENT_ID", + ) + return None + + +def check_identity_valid( + az: AzureCLI, result: PreflightResult, info: IdentityInfo, +) -> None: + if info["strategy"] == "sp": + app_id = info["app_id"] + cmd = f"az ad sp show --id {app_id}" + sp = az.json("ad", "sp", "show", "--id", app_id) + if isinstance(sp, dict) and sp.get("appId"): + display = sp.get("displayName", "?") + _add( + result, id="identity_valid", category="identity", + name="Service Principal Exists in Azure AD", + status="pass", + detail=f"{display} ({app_id})", + evidence=( + f"displayName={display}\n" + f"appId={app_id}\n" + f"objectId={sp.get('id', '?')}" + ), + command=cmd, + ) + else: + _add( + result, id="identity_valid", category="identity", + name="Service Principal Exists in Azure AD", + status="fail", + detail=f"SP not found: {app_id}", + evidence=az.last_stderr or "No response", + command=cmd, + ) + else: + resource_id = info.get("resource_id", "") + if not resource_id: + _add( + result, id="identity_valid", category="identity", + name="Managed Identity Exists", + status="skip", detail="No MI resource ID configured", + command="", + ) + return + cmd = f"az identity show --ids {resource_id}" + mi = az.json("identity", "show", "--ids", resource_id) + if isinstance(mi, dict) and mi.get("clientId"): + _add( + result, id="identity_valid", category="identity", + name="Managed Identity Exists", + status="pass", + detail=f"{mi.get('name', '?')} (client={mi.get('clientId', '?')})", + evidence=( + f"name={mi.get('name', '?')}\n" + f"clientId={mi.get('clientId', '?')}\n" + f"principalId={mi.get('principalId', '?')}" + ), + command=cmd, + ) + else: + _add( + result, id="identity_valid", category="identity", + name="Managed Identity Exists", + status="fail", + detail=f"MI not found: {resource_id}", + evidence=az.last_stderr or "No response", + command=cmd, + ) + + +def check_credential_expiry( + az: AzureCLI, result: PreflightResult, info: IdentityInfo, +) -> None: + if info["strategy"] != "sp": + _add( + result, id="identity_credential_expiry", category="identity", + name="Credential Expiry", + status="pass", + detail="Managed identities do not have expiring credentials", + command="(not applicable for MI)", + ) + return + + app_id = info["app_id"] + cmd = f"az ad app credential list --id {app_id}" + creds = az.json("ad", "app", "credential", "list", "--id", app_id) + if not isinstance(creds, list) or not creds: + _add( + result, id="identity_credential_expiry", category="identity", + name="Credential Expiry", + status="warn", + detail="Could not retrieve credential list", + evidence=az.last_stderr or "Empty response", + command=cmd, + ) + return + + latest = max(creds, key=lambda c: c.get("endDateTime", "")) + end = latest.get("endDateTime", "") + now = datetime.now(timezone.utc).isoformat() + + if end and end > now: + _add( + result, id="identity_credential_expiry", category="identity", + name="Credential Expiry", + status="pass", + detail=f"Valid until {end}", + evidence=f"endDateTime={end}\nnow={now}\ncredentials_count={len(creds)}", + command=cmd, + ) + else: + _add( + result, id="identity_credential_expiry", category="identity", + name="Credential Expiry", + status="fail", + detail=f"Credential EXPIRED: {end}", + evidence=f"endDateTime={end}\nnow={now}", + command=cmd, + ) diff --git a/app/runtime/services/security/preflight_rbac.py b/app/runtime/services/security/preflight_rbac.py new file mode 100644 index 0000000..e1f610a --- /dev/null +++ b/app/runtime/services/security/preflight_rbac.py @@ -0,0 +1,276 @@ +"""RBAC preflight checks.""" + +from __future__ import annotations + +from typing import Any + +from ..cloud.azure import AzureCLI +from .security_preflight import ( + IdentityInfo, + PreflightCheck, + PreflightResult, + _ELEVATED_ROLES, + add_check as _add, +) + + +def check_rbac_list( + az: AzureCLI, result: PreflightResult, info: IdentityInfo, +) -> list[dict[str, Any]] | None: + assignee = info.get("assignee", "") + if not assignee: + return None + + cmd = f"az role assignment list --assignee {assignee} --all" + assignments = az.json( + "role", "assignment", "list", "--assignee", assignee, "--all", + ) + if not isinstance(assignments, list): + _add( + result, id="rbac_assignments_list", category="rbac", + name="RBAC Assignments Retrieved", + status="fail", + detail="Could not list RBAC assignments", + evidence=az.last_stderr or "No response", + command=cmd, + ) + return None + + summary = ", ".join( + f"{a.get('roleDefinitionName', '?')} @ " + f"{a.get('scope', '?').rsplit('/', 1)[-1]}" + for a in assignments + ) + _add( + result, id="rbac_assignments_list", category="rbac", + name="RBAC Assignments Retrieved", + status="pass", + detail=f"{len(assignments)} assignment(s): {summary}", + evidence="\n".join( + f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" + for a in assignments + ), + command=cmd, + ) + return assignments + + +def check_rbac_has_role( + result: PreflightResult, + assignments: list[dict[str, Any]], + role_name: str, + check_id: str, + check_name: str, + bot_rg: str, + *, + missing_severity: str = "fail", + missing_detail: str = "", +) -> None: + matching = [ + a for a in assignments + if a.get("roleDefinitionName") == role_name + ] + if matching: + scopes = [a.get("scope", "") for a in matching] + _add( + result, id=check_id, category="rbac", + name=check_name, + status="pass", + detail=f"{role_name} assigned ({len(matching)} assignment(s))", + evidence="\n".join(f"scope={s}" for s in scopes), + command="Filtered from role assignment list", + ) + else: + detail = missing_detail or f"{role_name} NOT found in assignments" + _add( + result, id=check_id, category="rbac", + name=check_name, + status=missing_severity, + detail=detail, + evidence=( + f"Expected '{role_name}' but not present " + f"in {len(assignments)} assignment(s)" + ), + command="Filtered from role assignment list", + ) + + +def check_rbac_kv_access( + result: PreflightResult, + assignments: list[dict[str, Any]], + info: IdentityInfo, +) -> None: + kv_roles = [ + a for a in assignments + if "key vault" in (a.get("roleDefinitionName") or "").lower() + ] + + if not kv_roles: + _add( + result, id="rbac_kv_access", category="rbac", + name="Key Vault Access Role", + status="warn", + detail="No Key Vault role assignment found", + evidence=f"Checked {len(assignments)} assignments for 'Key Vault' roles", + command="Filtered from role assignment list", + ) + return + + role_names = [a.get("roleDefinitionName", "?") for a in kv_roles] + has_officer = "Key Vault Secrets Officer" in role_names + has_user = "Key Vault Secrets User" in role_names + + if info["strategy"] == "managed_identity": + if has_user and not has_officer: + status = "pass" + detail = "Key Vault Secrets User (read-only) -- correct for MI" + elif has_officer: + status = "warn" + detail = ( + "Key Vault Secrets Officer (read+write) -- " + "consider restricting to Secrets User for runtime" + ) + else: + status = "pass" + detail = f"Key Vault role: {', '.join(role_names)}" + else: + status = "pass" + detail = f"Key Vault role: {', '.join(role_names)}" + + _add( + result, id="rbac_kv_access", category="rbac", + name="Key Vault Access Role", + status=status, + detail=detail, + evidence="\n".join( + f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" + for a in kv_roles + ), + command="Filtered from role assignment list", + ) + + +def check_rbac_session_pool( + result: PreflightResult, assignments: list[dict[str, Any]], +) -> None: + from ...state.sandbox_config import SandboxConfigStore + + try: + sandbox_store = SandboxConfigStore() + sandbox_enabled = sandbox_store.enabled + sandbox_configured = sandbox_store.is_provisioned + except Exception: + sandbox_enabled = False + sandbox_configured = False + + matching = [ + a for a in assignments + if "session" in (a.get("roleDefinitionName") or "").lower() + ] + if matching: + names = [a.get("roleDefinitionName", "?") for a in matching] + _add( + result, id="rbac_session_pool", category="rbac", + name="Session Pool Executor", + status="pass", + detail=f"Session role: {', '.join(names)}", + evidence="\n".join( + f"scope={a.get('scope', '?')}" for a in matching + ), + command="Filtered from role assignment list", + ) + elif sandbox_enabled or sandbox_configured: + _add( + result, id="rbac_session_pool", category="rbac", + name="Session Pool Executor", + status="fail", + detail=( + "Azure ContainerApps Session Executor NOT found -- " + "required for sandbox (HTTP 403 on file upload/execute)" + ), + evidence=f"Not present in {len(assignments)} assignment(s)", + command="Filtered from role assignment list", + ) + else: + _add( + result, id="rbac_session_pool", category="rbac", + name="Session Pool Executor", + status="warn", + detail="ContainerApps Session Executor NOT found (needed if sandbox is enabled)", + evidence=f"Not present in {len(assignments)} assignment(s)", + command="Filtered from role assignment list", + ) + + +def check_rbac_no_elevated( + result: PreflightResult, assignments: list[dict[str, Any]], +) -> None: + elevated = [ + a for a in assignments + if a.get("roleDefinitionName") in _ELEVATED_ROLES + ] + if not elevated: + _add( + result, id="rbac_no_elevated", category="rbac", + name="No Elevated Roles", + status="pass", + detail="No Owner, Contributor, or User Access Administrator roles", + evidence=( + f"Checked {len(assignments)} assignment(s) against: " + f"{', '.join(sorted(_ELEVATED_ROLES))}" + ), + command="Filtered from role assignment list", + ) + else: + _add( + result, id="rbac_no_elevated", category="rbac", + name="No Elevated Roles", + status="fail", + detail=( + f"ELEVATED roles found: " + f"{', '.join(a.get('roleDefinitionName', '?') for a in elevated)}" + ), + evidence="\n".join( + f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" + for a in elevated + ), + command="Filtered from role assignment list", + ) + + +def check_rbac_scope_contained( + result: PreflightResult, assignments: list[dict[str, Any]], +) -> None: + out_of_scope = [ + a for a in assignments + if "/resourcegroups/" not in (a.get("scope") or "").lower() + ] + if not out_of_scope: + _add( + result, id="rbac_scope_contained", category="rbac", + name="Scope Limited to Resource Group", + status="pass", + detail=( + f"All {len(assignments)} assignment(s) scoped to " + f"resource group level or below" + ), + evidence="\n".join( + f"- {a.get('scope', '?')}" for a in assignments + ) if assignments else "No assignments", + command="Scope analysis from role assignment list", + ) + else: + _add( + result, id="rbac_scope_contained", category="rbac", + name="Scope Limited to Resource Group", + status="fail", + detail=( + f"{len(out_of_scope)} assignment(s) at subscription or management " + f"group level" + ), + evidence="\n".join( + f"- {a.get('roleDefinitionName', '?')} at {a.get('scope', '?')}" + for a in out_of_scope + ), + command="Scope analysis from role assignment list", + ) diff --git a/app/runtime/services/security/preflight_secrets.py b/app/runtime/services/security/preflight_secrets.py new file mode 100644 index 0000000..152be4f --- /dev/null +++ b/app/runtime/services/security/preflight_secrets.py @@ -0,0 +1,301 @@ +"""Secret-isolation preflight checks.""" + +from __future__ import annotations + +import os +from pathlib import Path + +from ...config.settings import cfg +from .security_preflight import PreflightCheck, PreflightResult, add_check as _add + + +def run_secret_checks(result: PreflightResult) -> None: + """Execute all secret-isolation checks.""" + check_admin_cli_isolated(result) + check_no_github_in_runtime(result) + check_bot_credentials(result) + check_admin_secret(result) + check_kv_reachable(result) + check_acs_credential(result) + check_aoai_credential(result) + check_sp_creds_written(result) + + +def check_admin_cli_isolated(result: PreflightResult) -> None: + admin_home = os.environ.get("POLYCLAW_ADMIN_HOME", "/admin-home") + azure_dir = Path(admin_home) / ".azure" + mode = cfg.server_mode.value + + if mode == "admin": + exists = azure_dir.exists() + _add( + result, id="secret_admin_cli_isolated", category="secrets", + name="Admin CLI Session Isolated", + status="pass" if exists else "warn", + detail=( + f"Azure CLI config at {azure_dir}: " + f"{'present' if exists else 'not found'}" + ), + evidence=( + f"HOME={os.environ.get('HOME', '?')}\n" + f"AZURE_CONFIG_DIR={os.environ.get('AZURE_CONFIG_DIR', '?')}\n" + f"exists={exists}" + ), + command=f"os.path.exists({azure_dir})", + ) + elif mode == "runtime": + exists = azure_dir.exists() + _add( + result, id="secret_admin_cli_isolated", category="secrets", + name="Admin CLI Session Isolated", + status="pass" if not exists else "fail", + detail=( + "Admin CLI config not accessible from runtime" + if not exists + else f"RISK: Admin CLI config accessible at {azure_dir}" + ), + evidence=( + f"HOME={os.environ.get('HOME', '?')}\n" + f"{azure_dir} exists={exists}" + ), + command=f"os.path.exists({azure_dir})", + ) + else: + _add( + result, id="secret_admin_cli_isolated", category="secrets", + name="Admin CLI Session Isolated", + status="warn", + detail=( + "Combined mode -- admin and runtime share the same " + "container (no credential isolation)" + ), + evidence=f"POLYCLAW_SERVER_MODE={mode}", + command="cfg.server_mode", + ) + + +def check_no_github_in_runtime(result: PreflightResult) -> None: + env_data = cfg.env.read_all() + gh_token = env_data.get("GITHUB_TOKEN", "") + gh2 = env_data.get("GH_TOKEN", "") + mode = cfg.server_mode.value + + if mode == "runtime": + has = bool(gh_token or gh2) + _add( + result, id="secret_no_github_runtime", category="secrets", + name="No GitHub Token in Runtime", + status="fail" if has else "pass", + detail=( + "GitHub token NOT present in runtime environment" + if not has + else "RISK: GitHub token accessible in runtime env" + ), + evidence=( + f"GITHUB_TOKEN={'set (' + str(len(gh_token)) + ' chars)' if gh_token else 'empty'}\n" + f"GH_TOKEN={'set' if gh2 else 'empty'}" + ), + command="env: GITHUB_TOKEN, GH_TOKEN", + ) + elif mode == "admin": + has = bool(gh_token or gh2) + _add( + result, id="secret_no_github_runtime", category="secrets", + name="GitHub Token (Admin Only)", + status="pass", + detail=f"GitHub token on admin: {'present' if has else 'not configured'}", + evidence=( + f"GITHUB_TOKEN={'set' if gh_token else 'empty'}\n" + f"GH_TOKEN={'set' if gh2 else 'empty'}" + ), + command="env: GITHUB_TOKEN, GH_TOKEN", + ) + else: + _add( + result, id="secret_no_github_runtime", category="secrets", + name="GitHub Token Isolation", + status="warn", + detail="Combined mode -- GitHub token shared with agent runtime", + evidence=f"POLYCLAW_SERVER_MODE={mode}", + command="cfg.server_mode + env", + ) + + +def check_bot_credentials(result: PreflightResult) -> None: + env_data = cfg.env.read_all() + app_id = env_data.get("BOT_APP_ID", "") + app_pw = env_data.get("BOT_APP_PASSWORD", "") + both = bool(app_id and app_pw) + + _add( + result, id="secret_bot_creds", category="secrets", + name="Bot Credentials Present", + status="pass" if both else ("warn" if app_id else "skip"), + detail=( + f"BOT_APP_ID={'set' if app_id else 'missing'}, " + f"BOT_APP_PASSWORD={'set' if app_pw else 'missing'}" + ), + evidence=( + f"BOT_APP_ID={app_id[:12] + '...' if app_id else '(empty)'}\n" + f"BOT_APP_PASSWORD={'***' if app_pw else '(empty)'}" + ), + command="env: BOT_APP_ID, BOT_APP_PASSWORD", + ) + + +def check_admin_secret(result: PreflightResult) -> None: + secret = cfg.admin_secret + _add( + result, id="secret_admin_secret", category="secrets", + name="Admin Secret Configured", + status="pass" if secret else "fail", + detail=( + f"ADMIN_SECRET set ({len(secret)} chars)" + if secret + else "ADMIN_SECRET MISSING" + ), + evidence=f"ADMIN_SECRET={'***' if secret else '(empty)'}\nlength={len(secret) if secret else 0}", + command="env: ADMIN_SECRET", + ) + + +def check_kv_reachable(result: PreflightResult) -> None: + from ..keyvault import kv as _kv + + if not _kv.enabled: + _add( + result, id="secret_kv_reachable", category="secrets", + name="Key Vault Reachable", + status="skip", + detail="Key Vault not configured", + evidence=f"KEY_VAULT_URL={cfg.env.read('KEY_VAULT_URL') or '(empty)'}", + command="keyvault.enabled", + ) + return + + try: + secrets_list = _kv.list_secrets() + _add( + result, id="secret_kv_reachable", category="secrets", + name="Key Vault Reachable", + status="pass", + detail=f"Key Vault accessible, {len(secrets_list)} secret(s) readable", + evidence=f"url={_kv.url}\nsecrets_count={len(secrets_list)}", + command="keyvault.list_secrets()", + ) + except Exception as exc: + _add( + result, id="secret_kv_reachable", category="secrets", + name="Key Vault Reachable", + status="fail", + detail=f"Key Vault NOT reachable: {exc}", + evidence=f"url={_kv.url}\nerror={exc}", + command="keyvault.list_secrets()", + ) + + +def check_acs_credential(result: PreflightResult) -> None: + conn = cfg.acs_connection_string + if conn: + parts = { + k.strip().lower(): v.strip() + for k, _, v in (seg.partition("=") for seg in conn.split(";") if "=" in seg) + } + has_ep = bool(parts.get("endpoint")) + _add( + result, id="secret_acs_present", category="secrets", + name="ACS Connection String", + status="pass" if has_ep else "warn", + detail=( + f"ACS connection string " + f"{'well-formed' if has_ep else 'malformed (missing endpoint)'}" + ), + evidence=f"ACS_CONNECTION_STRING=***({len(conn)} chars)\nhas_endpoint={has_ep}", + command="env: ACS_CONNECTION_STRING", + ) + else: + _add( + result, id="secret_acs_present", category="secrets", + name="ACS Connection String", + status="skip", + detail="ACS not configured", + evidence="ACS_CONNECTION_STRING=(empty)", + command="env: ACS_CONNECTION_STRING", + ) + + +def check_aoai_credential(result: PreflightResult) -> None: + endpoint = cfg.azure_openai_endpoint + key = cfg.azure_openai_api_key + + if endpoint: + _add( + result, id="secret_aoai_present", category="secrets", + name="Azure OpenAI Configuration", + status="pass", + detail=f"Endpoint configured, {'API key' if key else 'identity auth'} mode", + evidence=( + f"AZURE_OPENAI_ENDPOINT={endpoint}\n" + f"AZURE_OPENAI_API_KEY={'***' if key else '(identity-auth)'}" + ), + command="env: AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY", + ) + else: + _add( + result, id="secret_aoai_present", category="secrets", + name="Azure OpenAI Configuration", + status="skip", + detail="Azure OpenAI not configured", + evidence="AZURE_OPENAI_ENDPOINT=(empty)", + command="env: AZURE_OPENAI_ENDPOINT", + ) + + +def check_sp_creds_written(result: PreflightResult) -> None: + env_data = cfg.env.read_all() + sp_id = env_data.get("RUNTIME_SP_APP_ID", "") + sp_pw = env_data.get("RUNTIME_SP_PASSWORD", "") + sp_tenant = env_data.get("RUNTIME_SP_TENANT", "") + + if not sp_id: + mi_id = env_data.get("ACA_MI_CLIENT_ID", "") + if mi_id: + _add( + result, id="secret_identity_creds", category="secrets", + name="Runtime Identity Credentials in .env", + status="pass", + detail="Managed identity credentials written to .env", + evidence=( + f"ACA_MI_CLIENT_ID={mi_id}\n" + f"ACA_MI_RESOURCE_ID={env_data.get('ACA_MI_RESOURCE_ID', '?')}" + ), + command="env: ACA_MI_CLIENT_ID, ACA_MI_RESOURCE_ID", + ) + else: + _add( + result, id="secret_identity_creds", category="secrets", + name="Runtime Identity Credentials in .env", + status="skip", + detail="No runtime identity credentials in .env", + evidence="RUNTIME_SP_APP_ID=(empty)\nACA_MI_CLIENT_ID=(empty)", + command="env: RUNTIME_SP_APP_ID, ACA_MI_CLIENT_ID", + ) + return + + all_set = bool(sp_id and sp_pw and sp_tenant) + _add( + result, id="secret_identity_creds", category="secrets", + name="SP Credentials in .env", + status="pass" if all_set else "fail", + detail=( + f"app_id={'set' if sp_id else 'MISSING'}, " + f"password={'set' if sp_pw else 'MISSING'}, " + f"tenant={'set' if sp_tenant else 'MISSING'}" + ), + evidence=( + f"RUNTIME_SP_APP_ID={sp_id[:12] + '...' if sp_id else '(empty)'}\n" + f"RUNTIME_SP_PASSWORD={'***' if sp_pw else '(empty)'}\n" + f"RUNTIME_SP_TENANT={sp_tenant or '(empty)'}" + ), + command="env: RUNTIME_SP_APP_ID, RUNTIME_SP_PASSWORD, RUNTIME_SP_TENANT", + ) diff --git a/app/runtime/services/prompt_shield.py b/app/runtime/services/security/prompt_shield.py similarity index 99% rename from app/runtime/services/prompt_shield.py rename to app/runtime/services/security/prompt_shield.py index 8bbc1bd..b5494e7 100644 --- a/app/runtime/services/prompt_shield.py +++ b/app/runtime/services/security/prompt_shield.py @@ -234,8 +234,6 @@ def __init__(self) -> None: def get_token(self) -> str: """Return a valid bearer token, refreshing if necessary.""" - import time - # Return cached token if still valid (with 5-min buffer) if self._cached_token and time.time() < self._expires_on - 300: return self._cached_token diff --git a/app/runtime/services/security/security_preflight.py b/app/runtime/services/security/security_preflight.py new file mode 100644 index 0000000..be6e82b --- /dev/null +++ b/app/runtime/services/security/security_preflight.py @@ -0,0 +1,154 @@ +"""Security preflight checker -- verifiable runtime identity and secret isolation checks. + +Every check runs a real command or environment inspection and reports evidence. +No static claims -- every assertion is verified at runtime. + +Identity and RBAC checks live in ``preflight_identity``. +Secret-isolation checks live in ``preflight_secrets``. +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from typing import Any + +from ...config.settings import cfg +from ..cloud.azure import AzureCLI + +logger = logging.getLogger(__name__) + +# Elevated RBAC roles that the runtime identity should never hold. +_ELEVATED_ROLES = frozenset({ + "Owner", + "Contributor", + "User Access Administrator", + "Role Based Access Control Administrator", +}) + +# Type alias for the identity dict passed between preflight check modules. +IdentityInfo = dict[str, Any] + + +@dataclass +class PreflightCheck: + """Result of a single security preflight check.""" + + id: str + category: str + name: str + status: str = "pending" # pending | pass | fail | warn | skip + detail: str = "" + evidence: str = "" + command: str = "" + + +@dataclass +class PreflightResult: + """Aggregated result of all preflight checks.""" + + checks: list[PreflightCheck] = field(default_factory=list) + run_at: str = "" + passed: int = 0 + failed: int = 0 + warnings: int = 0 + skipped: int = 0 + + +def add_check(result: PreflightResult, **kwargs: Any) -> PreflightCheck: + """Create a :class:`PreflightCheck`, append it to *result*, and return it.""" + check = PreflightCheck(**kwargs) + result.checks.append(check) + return check + + +class SecurityPreflightChecker: + """Run verifiable security checks against the runtime identity and secrets.""" + + def __init__(self, az: AzureCLI) -> None: + self._az = az + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def run_all(self) -> PreflightResult: + """Execute all security preflight checks and return evidence.""" + from . import preflight_identity as _id + from . import preflight_rbac as _rbac + from . import preflight_secrets as _sec + + result = PreflightResult(run_at=datetime.now(timezone.utc).isoformat()) + + # Gate: is Azure CLI logged in? + if not _id.check_azure_logged_in(self._az, result): + _id.skip_azure_checks(result) + _sec.run_secret_checks(result) + self._tally(result) + return result + + # Identity verification + identity = _id.check_identity_configured(self._az, result) + if identity: + _id.check_identity_valid(self._az, result, identity) + _id.check_credential_expiry(self._az, result, identity) + + # RBAC verification + assignments = _rbac.check_rbac_list(self._az, result, identity) + if assignments is not None: + bot_rg = cfg.env.read("BOT_RESOURCE_GROUP") or "" + _rbac.check_rbac_has_role( + result, assignments, "Azure Bot Service Contributor Role", + "rbac_bot_contributor", "Azure Bot Service Contributor Role", bot_rg, + ) + _rbac.check_rbac_has_role( + result, assignments, "Reader", + "rbac_reader", "Reader Role", bot_rg, + ) + _rbac.check_rbac_kv_access(result, assignments, identity) + if identity.get("strategy") == "managed_identity": + _rbac.check_rbac_has_role( + result, assignments, "Cognitive Services OpenAI User", + "rbac_aoai_user", "Azure OpenAI Access", + "", + missing_severity="warn", + missing_detail="Needed for identity-auth voice", + ) + _rbac.check_rbac_session_pool(result, assignments) + _rbac.check_rbac_no_elevated(result, assignments) + _rbac.check_rbac_scope_contained(result, assignments) + + # Secret isolation + _sec.run_secret_checks(result) + + self._tally(result) + return result + + @staticmethod + def to_dict(result: PreflightResult) -> dict[str, Any]: + """Serialize a *PreflightResult* to a JSON-safe dict.""" + return { + "checks": [asdict(c) for c in result.checks], + "run_at": result.run_at, + "passed": result.passed, + "failed": result.failed, + "warnings": result.warnings, + "skipped": result.skipped, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _tally(result: PreflightResult) -> None: + for c in result.checks: + if c.status == "pass": + result.passed += 1 + elif c.status == "fail": + result.failed += 1 + elif c.status == "warn": + result.warnings += 1 + elif c.status == "skip": + result.skipped += 1 diff --git a/app/runtime/services/security_preflight.py b/app/runtime/services/security_preflight.py deleted file mode 100644 index 643d902..0000000 --- a/app/runtime/services/security_preflight.py +++ /dev/null @@ -1,918 +0,0 @@ -"""Security preflight checker -- verifiable runtime identity and secret isolation checks. - -Every check runs a real command or environment inspection and reports evidence. -No static claims -- every assertion is verified at runtime. -""" - -from __future__ import annotations - -import logging -import os -from dataclasses import asdict, dataclass, field -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - -from ..config.settings import cfg -from .azure import AzureCLI - -logger = logging.getLogger(__name__) - -# Elevated RBAC roles that the runtime identity should never hold. -_ELEVATED_ROLES = frozenset({ - "Owner", - "Contributor", - "User Access Administrator", - "Role Based Access Control Administrator", -}) - - -@dataclass -class PreflightCheck: - """Result of a single security preflight check.""" - - id: str - category: str - name: str - status: str = "pending" # pending | pass | fail | warn | skip - detail: str = "" - evidence: str = "" - command: str = "" - - -@dataclass -class PreflightResult: - """Aggregated result of all preflight checks.""" - - checks: list[PreflightCheck] = field(default_factory=list) - run_at: str = "" - passed: int = 0 - failed: int = 0 - warnings: int = 0 - skipped: int = 0 - - -class SecurityPreflightChecker: - """Run verifiable security checks against the runtime identity and secrets.""" - - def __init__(self, az: AzureCLI) -> None: - self._az = az - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - def run_all(self) -> PreflightResult: - """Execute all security preflight checks and return evidence.""" - result = PreflightResult(run_at=datetime.now(timezone.utc).isoformat()) - - # Gate: is Azure CLI logged in? - if not self._check_azure_logged_in(result): - self._skip_azure_checks(result) - self._run_secret_checks(result) - self._tally(result) - return result - - # Identity verification - identity = self._check_identity_configured(result) - if identity: - self._check_identity_valid(result, identity) - self._check_credential_expiry(result, identity) - - # RBAC verification - assignments = self._check_rbac_list(result, identity) - if assignments is not None: - bot_rg = cfg.env.read("BOT_RESOURCE_GROUP") or "" - self._check_rbac_has_role( - result, assignments, "Azure Bot Service Contributor Role", - "rbac_bot_contributor", "Azure Bot Service Contributor Role", bot_rg, - ) - self._check_rbac_has_role( - result, assignments, "Reader", - "rbac_reader", "Reader Role", bot_rg, - ) - self._check_rbac_kv_access(result, assignments, identity) - if identity.get("strategy") == "managed_identity": - self._check_rbac_has_role( - result, assignments, "Cognitive Services OpenAI User", - "rbac_aoai_user", "Azure OpenAI Access", - "", - missing_severity="warn", - missing_detail="Needed for identity-auth voice", - ) - self._check_rbac_session_pool(result, assignments) - self._check_rbac_no_elevated(result, assignments) - self._check_rbac_scope_contained(result, assignments) - - # Secret isolation - self._run_secret_checks(result) - - self._tally(result) - return result - - @staticmethod - def to_dict(result: PreflightResult) -> dict[str, Any]: - """Serialize a *PreflightResult* to a JSON-safe dict.""" - return { - "checks": [asdict(c) for c in result.checks], - "run_at": result.run_at, - "passed": result.passed, - "failed": result.failed, - "warnings": result.warnings, - "skipped": result.skipped, - } - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - @staticmethod - def _tally(result: PreflightResult) -> None: - for c in result.checks: - if c.status == "pass": - result.passed += 1 - elif c.status == "fail": - result.failed += 1 - elif c.status == "warn": - result.warnings += 1 - elif c.status == "skip": - result.skipped += 1 - - def _add(self, result: PreflightResult, **kwargs: Any) -> PreflightCheck: - check = PreflightCheck(**kwargs) - result.checks.append(check) - return check - - # ------------------------------------------------------------------ - # Azure login gate - # ------------------------------------------------------------------ - - def _check_azure_logged_in(self, result: PreflightResult) -> bool: - cmd = "az account show" - account = self._az.json("account", "show", quiet=True) - if isinstance(account, dict) and account.get("id"): - sub = account.get("name", account.get("id", "?")) - self._add( - result, id="azure_logged_in", category="identity", - name="Azure CLI Authenticated", - status="pass", - detail=f"Logged in to subscription: {sub}", - evidence=f"subscription={sub}\ntenantId={account.get('tenantId', '?')}", - command=cmd, - ) - return True - self._add( - result, id="azure_logged_in", category="identity", - name="Azure CLI Authenticated", - status="fail", - detail="Not logged in -- RBAC and identity checks require Azure CLI auth", - evidence=self._az.last_stderr or "No response", - command=cmd, - ) - return False - - def _skip_azure_checks(self, result: PreflightResult) -> None: - for check_id, name, cat in [ - ("identity_configured", "Runtime Identity Configured", "identity"), - ("identity_valid", "Identity Exists in Azure AD", "identity"), - ("identity_credential_expiry", "Credential Expiry", "identity"), - ("rbac_assignments_list", "RBAC Assignments", "rbac"), - ("rbac_bot_contributor", "Azure Bot Service Contributor Role", "rbac"), - ("rbac_reader", "Reader Role", "rbac"), - ("rbac_kv_access", "Key Vault Access Role", "rbac"), - ("rbac_session_pool", "Session Pool Executor", "rbac"), - ("rbac_no_elevated", "No Elevated Roles", "rbac"), - ("rbac_scope_contained", "Scope Limited to Resource Group", "rbac"), - ]: - self._add( - result, id=check_id, category=cat, name=name, - status="skip", - detail="Skipped -- Azure CLI not authenticated", - command="", - ) - - # ------------------------------------------------------------------ - # Identity checks - # ------------------------------------------------------------------ - - def _check_identity_configured( - self, result: PreflightResult, - ) -> dict[str, Any] | None: - sp_app_id = cfg.env.read("RUNTIME_SP_APP_ID") - mi_client_id = cfg.env.read("ACA_MI_CLIENT_ID") - mi_resource_id = cfg.env.read("ACA_MI_RESOURCE_ID") - - if mi_client_id: - self._add( - result, id="identity_configured", category="identity", - name="Runtime Identity Configured", - status="pass", - detail=f"User-assigned managed identity: client_id={mi_client_id}", - evidence=( - f"ACA_MI_CLIENT_ID={mi_client_id}\n" - f"ACA_MI_RESOURCE_ID={mi_resource_id}" - ), - command="env: ACA_MI_CLIENT_ID, ACA_MI_RESOURCE_ID", - ) - return { - "strategy": "managed_identity", - "client_id": mi_client_id, - "resource_id": mi_resource_id, - "assignee": mi_client_id, - } - - if sp_app_id: - sp_tenant = cfg.env.read("RUNTIME_SP_TENANT") - has_pw = bool(cfg.env.read("RUNTIME_SP_PASSWORD")) - self._add( - result, id="identity_configured", category="identity", - name="Runtime Identity Configured", - status="pass", - detail=f"Scoped service principal: app_id={sp_app_id}", - evidence=( - f"RUNTIME_SP_APP_ID={sp_app_id}\n" - f"RUNTIME_SP_TENANT={sp_tenant}\n" - f"RUNTIME_SP_PASSWORD={'***' if has_pw else 'MISSING'}" - ), - command="env: RUNTIME_SP_APP_ID, RUNTIME_SP_TENANT, RUNTIME_SP_PASSWORD", - ) - return { - "strategy": "sp", - "app_id": sp_app_id, - "tenant": sp_tenant, - "assignee": sp_app_id, - } - - self._add( - result, id="identity_configured", category="identity", - name="Runtime Identity Configured", - status="skip", - detail="No runtime identity configured (RUNTIME_SP_* and ACA_MI_* absent)", - evidence="RUNTIME_SP_APP_ID=(empty)\nACA_MI_CLIENT_ID=(empty)", - command="env: RUNTIME_SP_APP_ID, ACA_MI_CLIENT_ID", - ) - return None - - def _check_identity_valid( - self, result: PreflightResult, info: dict[str, Any], - ) -> None: - if info["strategy"] == "sp": - app_id = info["app_id"] - cmd = f"az ad sp show --id {app_id}" - sp = self._az.json("ad", "sp", "show", "--id", app_id) - if isinstance(sp, dict) and sp.get("appId"): - display = sp.get("displayName", "?") - self._add( - result, id="identity_valid", category="identity", - name="Service Principal Exists in Azure AD", - status="pass", - detail=f"{display} ({app_id})", - evidence=( - f"displayName={display}\n" - f"appId={app_id}\n" - f"objectId={sp.get('id', '?')}" - ), - command=cmd, - ) - else: - self._add( - result, id="identity_valid", category="identity", - name="Service Principal Exists in Azure AD", - status="fail", - detail=f"SP not found: {app_id}", - evidence=self._az.last_stderr or "No response", - command=cmd, - ) - else: - resource_id = info.get("resource_id", "") - if not resource_id: - self._add( - result, id="identity_valid", category="identity", - name="Managed Identity Exists", - status="skip", detail="No MI resource ID configured", - command="", - ) - return - cmd = f"az identity show --ids {resource_id}" - mi = self._az.json("identity", "show", "--ids", resource_id) - if isinstance(mi, dict) and mi.get("clientId"): - self._add( - result, id="identity_valid", category="identity", - name="Managed Identity Exists", - status="pass", - detail=f"{mi.get('name', '?')} (client={mi.get('clientId', '?')})", - evidence=( - f"name={mi.get('name', '?')}\n" - f"clientId={mi.get('clientId', '?')}\n" - f"principalId={mi.get('principalId', '?')}" - ), - command=cmd, - ) - else: - self._add( - result, id="identity_valid", category="identity", - name="Managed Identity Exists", - status="fail", - detail=f"MI not found: {resource_id}", - evidence=self._az.last_stderr or "No response", - command=cmd, - ) - - def _check_credential_expiry( - self, result: PreflightResult, info: dict[str, Any], - ) -> None: - if info["strategy"] != "sp": - self._add( - result, id="identity_credential_expiry", category="identity", - name="Credential Expiry", - status="pass", - detail="Managed identities do not have expiring credentials", - command="(not applicable for MI)", - ) - return - - app_id = info["app_id"] - cmd = f"az ad app credential list --id {app_id}" - creds = self._az.json("ad", "app", "credential", "list", "--id", app_id) - if not isinstance(creds, list) or not creds: - self._add( - result, id="identity_credential_expiry", category="identity", - name="Credential Expiry", - status="warn", - detail="Could not retrieve credential list", - evidence=self._az.last_stderr or "Empty response", - command=cmd, - ) - return - - latest = max(creds, key=lambda c: c.get("endDateTime", "")) - end = latest.get("endDateTime", "") - now = datetime.now(timezone.utc).isoformat() - - if end and end > now: - self._add( - result, id="identity_credential_expiry", category="identity", - name="Credential Expiry", - status="pass", - detail=f"Valid until {end}", - evidence=f"endDateTime={end}\nnow={now}\ncredentials_count={len(creds)}", - command=cmd, - ) - else: - self._add( - result, id="identity_credential_expiry", category="identity", - name="Credential Expiry", - status="fail", - detail=f"Credential EXPIRED: {end}", - evidence=f"endDateTime={end}\nnow={now}", - command=cmd, - ) - - # ------------------------------------------------------------------ - # RBAC checks - # ------------------------------------------------------------------ - - def _check_rbac_list( - self, result: PreflightResult, info: dict[str, Any], - ) -> list[dict[str, Any]] | None: - assignee = info.get("assignee", "") - if not assignee: - return None - - cmd = f"az role assignment list --assignee {assignee} --all" - assignments = self._az.json( - "role", "assignment", "list", "--assignee", assignee, "--all", - ) - if not isinstance(assignments, list): - self._add( - result, id="rbac_assignments_list", category="rbac", - name="RBAC Assignments Retrieved", - status="fail", - detail="Could not list RBAC assignments", - evidence=self._az.last_stderr or "No response", - command=cmd, - ) - return None - - summary = ", ".join( - f"{a.get('roleDefinitionName', '?')} @ " - f"{a.get('scope', '?').rsplit('/', 1)[-1]}" - for a in assignments - ) - self._add( - result, id="rbac_assignments_list", category="rbac", - name="RBAC Assignments Retrieved", - status="pass", - detail=f"{len(assignments)} assignment(s): {summary}", - evidence="\n".join( - f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" - for a in assignments - ), - command=cmd, - ) - return assignments - - def _check_rbac_has_role( - self, - result: PreflightResult, - assignments: list[dict[str, Any]], - role_name: str, - check_id: str, - check_name: str, - bot_rg: str, - *, - missing_severity: str = "fail", - missing_detail: str = "", - ) -> None: - matching = [ - a for a in assignments - if a.get("roleDefinitionName") == role_name - ] - if matching: - scopes = [a.get("scope", "") for a in matching] - self._add( - result, id=check_id, category="rbac", - name=check_name, - status="pass", - detail=f"{role_name} assigned ({len(matching)} assignment(s))", - evidence="\n".join(f"scope={s}" for s in scopes), - command=f"Filtered from role assignment list", - ) - else: - detail = missing_detail or f"{role_name} NOT found in assignments" - self._add( - result, id=check_id, category="rbac", - name=check_name, - status=missing_severity, - detail=detail, - evidence=( - f"Expected '{role_name}' but not present " - f"in {len(assignments)} assignment(s)" - ), - command=f"Filtered from role assignment list", - ) - - def _check_rbac_kv_access( - self, - result: PreflightResult, - assignments: list[dict[str, Any]], - info: dict[str, Any], - ) -> None: - kv_roles = [ - a for a in assignments - if "key vault" in (a.get("roleDefinitionName") or "").lower() - ] - - if not kv_roles: - self._add( - result, id="rbac_kv_access", category="rbac", - name="Key Vault Access Role", - status="warn", - detail="No Key Vault role assignment found", - evidence=f"Checked {len(assignments)} assignments for 'Key Vault' roles", - command="Filtered from role assignment list", - ) - return - - role_names = [a.get("roleDefinitionName", "?") for a in kv_roles] - has_officer = "Key Vault Secrets Officer" in role_names - has_user = "Key Vault Secrets User" in role_names - - if info["strategy"] == "managed_identity": - if has_user and not has_officer: - status = "pass" - detail = "Key Vault Secrets User (read-only) -- correct for MI" - elif has_officer: - status = "warn" - detail = ( - "Key Vault Secrets Officer (read+write) -- " - "consider restricting to Secrets User for runtime" - ) - else: - status = "pass" - detail = f"Key Vault role: {', '.join(role_names)}" - else: - # SP may legitimately have Officer (used during provisioning) - status = "pass" - detail = f"Key Vault role: {', '.join(role_names)}" - - self._add( - result, id="rbac_kv_access", category="rbac", - name="Key Vault Access Role", - status=status, - detail=detail, - evidence="\n".join( - f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" - for a in kv_roles - ), - command="Filtered from role assignment list", - ) - - def _check_rbac_session_pool( - self, result: PreflightResult, assignments: list[dict[str, Any]], - ) -> None: - from ..state.sandbox_config import SandboxConfigStore - - try: - sandbox_store = SandboxConfigStore() - sandbox_enabled = sandbox_store.enabled - sandbox_configured = sandbox_store.is_provisioned - except Exception: - sandbox_enabled = False - sandbox_configured = False - - matching = [ - a for a in assignments - if "session" in (a.get("roleDefinitionName") or "").lower() - ] - if matching: - names = [a.get("roleDefinitionName", "?") for a in matching] - self._add( - result, id="rbac_session_pool", category="rbac", - name="Session Pool Executor", - status="pass", - detail=f"Session role: {', '.join(names)}", - evidence="\n".join( - f"scope={a.get('scope', '?')}" for a in matching - ), - command="Filtered from role assignment list", - ) - elif sandbox_enabled or sandbox_configured: - self._add( - result, id="rbac_session_pool", category="rbac", - name="Session Pool Executor", - status="fail", - detail=( - "Azure ContainerApps Session Executor NOT found -- " - "required for sandbox (HTTP 403 on file upload/execute)" - ), - evidence=f"Not present in {len(assignments)} assignment(s)", - command="Filtered from role assignment list", - ) - else: - self._add( - result, id="rbac_session_pool", category="rbac", - name="Session Pool Executor", - status="warn", - detail="ContainerApps Session Executor NOT found (needed if sandbox is enabled)", - evidence=f"Not present in {len(assignments)} assignment(s)", - command="Filtered from role assignment list", - ) - - def _check_rbac_no_elevated( - self, result: PreflightResult, assignments: list[dict[str, Any]], - ) -> None: - elevated = [ - a for a in assignments - if a.get("roleDefinitionName") in _ELEVATED_ROLES - ] - if not elevated: - self._add( - result, id="rbac_no_elevated", category="rbac", - name="No Elevated Roles", - status="pass", - detail="No Owner, Contributor, or User Access Administrator roles", - evidence=( - f"Checked {len(assignments)} assignment(s) against: " - f"{', '.join(sorted(_ELEVATED_ROLES))}" - ), - command="Filtered from role assignment list", - ) - else: - self._add( - result, id="rbac_no_elevated", category="rbac", - name="No Elevated Roles", - status="fail", - detail=( - f"ELEVATED roles found: " - f"{', '.join(a.get('roleDefinitionName', '?') for a in elevated)}" - ), - evidence="\n".join( - f"- {a.get('roleDefinitionName', '?')} on {a.get('scope', '?')}" - for a in elevated - ), - command="Filtered from role assignment list", - ) - - def _check_rbac_scope_contained( - self, result: PreflightResult, assignments: list[dict[str, Any]], - ) -> None: - out_of_scope = [ - a for a in assignments - if "/resourcegroups/" not in (a.get("scope") or "").lower() - ] - if not out_of_scope: - self._add( - result, id="rbac_scope_contained", category="rbac", - name="Scope Limited to Resource Group", - status="pass", - detail=( - f"All {len(assignments)} assignment(s) scoped to " - f"resource group level or below" - ), - evidence="\n".join( - f"- {a.get('scope', '?')}" for a in assignments - ) if assignments else "No assignments", - command="Scope analysis from role assignment list", - ) - else: - self._add( - result, id="rbac_scope_contained", category="rbac", - name="Scope Limited to Resource Group", - status="fail", - detail=( - f"{len(out_of_scope)} assignment(s) at subscription or management " - f"group level" - ), - evidence="\n".join( - f"- {a.get('roleDefinitionName', '?')} at {a.get('scope', '?')}" - for a in out_of_scope - ), - command="Scope analysis from role assignment list", - ) - - # ------------------------------------------------------------------ - # Secret isolation checks - # ------------------------------------------------------------------ - - def _run_secret_checks(self, result: PreflightResult) -> None: - self._check_admin_cli_isolated(result) - self._check_no_github_in_runtime(result) - self._check_bot_credentials(result) - self._check_admin_secret(result) - self._check_kv_reachable(result) - self._check_acs_credential(result) - self._check_aoai_credential(result) - self._check_sp_creds_written(result) - - def _check_admin_cli_isolated(self, result: PreflightResult) -> None: - admin_home = os.environ.get("POLYCLAW_ADMIN_HOME", "/admin-home") - azure_dir = Path(admin_home) / ".azure" - mode = cfg.server_mode.value - - if mode == "admin": - exists = azure_dir.exists() - self._add( - result, id="secret_admin_cli_isolated", category="secrets", - name="Admin CLI Session Isolated", - status="pass" if exists else "warn", - detail=( - f"Azure CLI config at {azure_dir}: " - f"{'present' if exists else 'not found'}" - ), - evidence=( - f"HOME={os.environ.get('HOME', '?')}\n" - f"AZURE_CONFIG_DIR={os.environ.get('AZURE_CONFIG_DIR', '?')}\n" - f"exists={exists}" - ), - command=f"os.path.exists({azure_dir})", - ) - elif mode == "runtime": - exists = azure_dir.exists() - self._add( - result, id="secret_admin_cli_isolated", category="secrets", - name="Admin CLI Session Isolated", - status="pass" if not exists else "fail", - detail=( - "Admin CLI config not accessible from runtime" - if not exists - else f"RISK: Admin CLI config accessible at {azure_dir}" - ), - evidence=( - f"HOME={os.environ.get('HOME', '?')}\n" - f"{azure_dir} exists={exists}" - ), - command=f"os.path.exists({azure_dir})", - ) - else: - self._add( - result, id="secret_admin_cli_isolated", category="secrets", - name="Admin CLI Session Isolated", - status="warn", - detail=( - "Combined mode -- admin and runtime share the same " - "container (no credential isolation)" - ), - evidence=f"POLYCLAW_SERVER_MODE={mode}", - command="cfg.server_mode", - ) - - def _check_no_github_in_runtime(self, result: PreflightResult) -> None: - env_data = cfg.env.read_all() - gh_token = env_data.get("GITHUB_TOKEN", "") - gh2 = env_data.get("GH_TOKEN", "") - mode = cfg.server_mode.value - - if mode == "runtime": - has = bool(gh_token or gh2) - self._add( - result, id="secret_no_github_runtime", category="secrets", - name="No GitHub Token in Runtime", - status="fail" if has else "pass", - detail=( - "GitHub token NOT present in runtime environment" - if not has - else "RISK: GitHub token accessible in runtime env" - ), - evidence=( - f"GITHUB_TOKEN={'set (' + str(len(gh_token)) + ' chars)' if gh_token else 'empty'}\n" - f"GH_TOKEN={'set' if gh2 else 'empty'}" - ), - command="env: GITHUB_TOKEN, GH_TOKEN", - ) - elif mode == "admin": - has = bool(gh_token or gh2) - self._add( - result, id="secret_no_github_runtime", category="secrets", - name="GitHub Token (Admin Only)", - status="pass", - detail=f"GitHub token on admin: {'present' if has else 'not configured'}", - evidence=( - f"GITHUB_TOKEN={'set' if gh_token else 'empty'}\n" - f"GH_TOKEN={'set' if gh2 else 'empty'}" - ), - command="env: GITHUB_TOKEN, GH_TOKEN", - ) - else: - self._add( - result, id="secret_no_github_runtime", category="secrets", - name="GitHub Token Isolation", - status="warn", - detail="Combined mode -- GitHub token shared with agent runtime", - evidence=f"POLYCLAW_SERVER_MODE={mode}", - command="cfg.server_mode + env", - ) - - def _check_bot_credentials(self, result: PreflightResult) -> None: - env_data = cfg.env.read_all() - app_id = env_data.get("BOT_APP_ID", "") - app_pw = env_data.get("BOT_APP_PASSWORD", "") - both = bool(app_id and app_pw) - - self._add( - result, id="secret_bot_creds", category="secrets", - name="Bot Credentials Present", - status="pass" if both else ("warn" if app_id else "skip"), - detail=( - f"BOT_APP_ID={'set' if app_id else 'missing'}, " - f"BOT_APP_PASSWORD={'set' if app_pw else 'missing'}" - ), - evidence=( - f"BOT_APP_ID={app_id[:12] + '...' if app_id else '(empty)'}\n" - f"BOT_APP_PASSWORD={'***' if app_pw else '(empty)'}" - ), - command="env: BOT_APP_ID, BOT_APP_PASSWORD", - ) - - def _check_admin_secret(self, result: PreflightResult) -> None: - secret = cfg.admin_secret - self._add( - result, id="secret_admin_secret", category="secrets", - name="Admin Secret Configured", - status="pass" if secret else "fail", - detail=( - f"ADMIN_SECRET set ({len(secret)} chars)" - if secret - else "ADMIN_SECRET MISSING" - ), - evidence=f"ADMIN_SECRET={'***' if secret else '(empty)'}\nlength={len(secret) if secret else 0}", - command="env: ADMIN_SECRET", - ) - - def _check_kv_reachable(self, result: PreflightResult) -> None: - from ..services.keyvault import kv as _kv - - if not _kv.enabled: - self._add( - result, id="secret_kv_reachable", category="secrets", - name="Key Vault Reachable", - status="skip", - detail="Key Vault not configured", - evidence=f"KEY_VAULT_URL={cfg.env.read('KEY_VAULT_URL') or '(empty)'}", - command="keyvault.enabled", - ) - return - - try: - secrets_list = _kv.list_secrets() - self._add( - result, id="secret_kv_reachable", category="secrets", - name="Key Vault Reachable", - status="pass", - detail=f"Key Vault accessible, {len(secrets_list)} secret(s) readable", - evidence=f"url={_kv.url}\nsecrets_count={len(secrets_list)}", - command="keyvault.list_secrets()", - ) - except Exception as exc: - self._add( - result, id="secret_kv_reachable", category="secrets", - name="Key Vault Reachable", - status="fail", - detail=f"Key Vault NOT reachable: {exc}", - evidence=f"url={_kv.url}\nerror={exc}", - command="keyvault.list_secrets()", - ) - - def _check_acs_credential(self, result: PreflightResult) -> None: - conn = cfg.acs_connection_string - if conn: - parts = { - k.strip().lower(): v.strip() - for k, _, v in (seg.partition("=") for seg in conn.split(";") if "=" in seg) - } - has_ep = bool(parts.get("endpoint")) - self._add( - result, id="secret_acs_present", category="secrets", - name="ACS Connection String", - status="pass" if has_ep else "warn", - detail=( - f"ACS connection string " - f"{'well-formed' if has_ep else 'malformed (missing endpoint)'}" - ), - evidence=f"ACS_CONNECTION_STRING=***({len(conn)} chars)\nhas_endpoint={has_ep}", - command="env: ACS_CONNECTION_STRING", - ) - else: - self._add( - result, id="secret_acs_present", category="secrets", - name="ACS Connection String", - status="skip", - detail="ACS not configured", - evidence="ACS_CONNECTION_STRING=(empty)", - command="env: ACS_CONNECTION_STRING", - ) - - def _check_aoai_credential(self, result: PreflightResult) -> None: - endpoint = cfg.azure_openai_endpoint - key = cfg.azure_openai_api_key - - if endpoint: - self._add( - result, id="secret_aoai_present", category="secrets", - name="Azure OpenAI Configuration", - status="pass", - detail=f"Endpoint configured, {'API key' if key else 'identity auth'} mode", - evidence=( - f"AZURE_OPENAI_ENDPOINT={endpoint}\n" - f"AZURE_OPENAI_API_KEY={'***' if key else '(identity-auth)'}" - ), - command="env: AZURE_OPENAI_ENDPOINT, AZURE_OPENAI_API_KEY", - ) - else: - self._add( - result, id="secret_aoai_present", category="secrets", - name="Azure OpenAI Configuration", - status="skip", - detail="Azure OpenAI not configured", - evidence="AZURE_OPENAI_ENDPOINT=(empty)", - command="env: AZURE_OPENAI_ENDPOINT", - ) - - def _check_sp_creds_written(self, result: PreflightResult) -> None: - env_data = cfg.env.read_all() - sp_id = env_data.get("RUNTIME_SP_APP_ID", "") - sp_pw = env_data.get("RUNTIME_SP_PASSWORD", "") - sp_tenant = env_data.get("RUNTIME_SP_TENANT", "") - - if not sp_id: - mi_id = env_data.get("ACA_MI_CLIENT_ID", "") - if mi_id: - self._add( - result, id="secret_identity_creds", category="secrets", - name="Runtime Identity Credentials in .env", - status="pass", - detail="Managed identity credentials written to .env", - evidence=( - f"ACA_MI_CLIENT_ID={mi_id}\n" - f"ACA_MI_RESOURCE_ID={env_data.get('ACA_MI_RESOURCE_ID', '?')}" - ), - command="env: ACA_MI_CLIENT_ID, ACA_MI_RESOURCE_ID", - ) - else: - self._add( - result, id="secret_identity_creds", category="secrets", - name="Runtime Identity Credentials in .env", - status="skip", - detail="No runtime identity credentials in .env", - evidence="RUNTIME_SP_APP_ID=(empty)\nACA_MI_CLIENT_ID=(empty)", - command="env: RUNTIME_SP_APP_ID, ACA_MI_CLIENT_ID", - ) - return - - all_set = bool(sp_id and sp_pw and sp_tenant) - self._add( - result, id="secret_identity_creds", category="secrets", - name="SP Credentials in .env", - status="pass" if all_set else "fail", - detail=( - f"app_id={'set' if sp_id else 'MISSING'}, " - f"password={'set' if sp_pw else 'MISSING'}, " - f"tenant={'set' if sp_tenant else 'MISSING'}" - ), - evidence=( - f"RUNTIME_SP_APP_ID={sp_id[:12] + '...' if sp_id else '(empty)'}\n" - f"RUNTIME_SP_PASSWORD={'***' if sp_pw else '(empty)'}\n" - f"RUNTIME_SP_TENANT={sp_tenant or '(empty)'}" - ), - command="env: RUNTIME_SP_APP_ID, RUNTIME_SP_PASSWORD, RUNTIME_SP_TENANT", - ) diff --git a/app/runtime/state/__init__.py b/app/runtime/state/__init__.py index 65e4418..08e6ce6 100644 --- a/app/runtime/state/__init__.py +++ b/app/runtime/state/__init__.py @@ -1,5 +1,8 @@ """Persistent state stores backed by JSON files.""" +from __future__ import annotations + +from ._base import BaseConfigStore from .deploy_state import DeployStateStore, DeploymentRecord from .foundry_iq_config import FoundryIQConfigStore from .infra_config import InfraConfigStore @@ -13,6 +16,7 @@ from .tool_activity_store import ToolActivityStore, get_tool_activity_store __all__ = [ + "BaseConfigStore", "DeployStateStore", "DeploymentRecord", "FoundryIQConfigStore", diff --git a/app/runtime/state/_base.py b/app/runtime/state/_base.py new file mode 100644 index 0000000..21df52b --- /dev/null +++ b/app/runtime/state/_base.py @@ -0,0 +1,127 @@ +"""Base class for dataclass-backed JSON config stores.""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict +from pathlib import Path +from typing import Any, Generic, TypeVar + +from ..config.settings import cfg + +logger = logging.getLogger(__name__) + +C = TypeVar("C") + + +class BaseConfigStore(Generic[C]): + """JSON-file-backed config store using a dataclass for schema. + + Subclasses must set class variables: + + - ``_config_type``: the dataclass class used for config schema + - ``_default_filename``: default JSON filename inside ``cfg.data_dir`` + + Optional class variables: + + - ``_log_label``: human label used in warning messages (defaults to filename) + + Override ``_apply_raw`` to customise how JSON fields are mapped onto the + config dataclass (e.g. secret resolution). Override ``_save_data`` to + customise the dict that is serialised to disk (e.g. secret storage). + """ + + _config_type: type[C] + _default_filename: str + _log_label: str = "" + _SECRET_FIELDS: frozenset[str] = frozenset() + _secret_prefix: str = "" + + def __init__(self, path: Path | None = None) -> None: + self._path = path or (cfg.data_dir / self._default_filename) + self._config: C = self._config_type() + self._load() + + @property + def path(self) -> Path: + return self._path + + @property + def config(self) -> C: + return self._config + + def to_dict(self) -> dict[str, Any]: + """Return the config as a plain dict.""" + return asdict(self._config) + + # -- persistence ------------------------------------------------------- + + def _load(self) -> None: + if not self._path.exists(): + return + try: + raw = json.loads(self._path.read_text()) + self._apply_raw(raw) + except Exception as exc: + label = self._log_label or self._default_filename + logger.warning( + "Failed to load %s from %s: %s", label, self._path, exc, exc_info=True, + ) + + def _apply_raw(self, raw: dict[str, Any]) -> None: + """Populate config fields from a raw JSON dict. + + Default implementation sets every dataclass field found in *raw*. + Override for custom deserialisation (e.g. secret resolution). + """ + for field_name in self._config_type.__dataclass_fields__: + if field_name in raw: + setattr(self._config, field_name, raw[field_name]) + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text(json.dumps(self._save_data(), indent=2) + "\n") + + def _save_data(self) -> dict[str, Any]: + """Return the data dict to serialise. + + Default implementation returns ``dataclasses.asdict(self._config)``. + Override for custom serialisation (e.g. secret storage). + """ + return asdict(self._config) + + # -- secret helpers ---------------------------------------------------- + + def _store_secrets(self, data: dict[str, Any]) -> dict[str, Any]: + """Replace secret fields with Key Vault references before persisting. + + Only operates when ``_SECRET_FIELDS`` is non-empty and Key Vault is + enabled. Uses ``_secret_prefix`` to namespace the secret names. + """ + from ..services.keyvault import env_key_to_secret_name, is_kv_ref, kv + + result = dict(data) + if not kv.enabled or not self._SECRET_FIELDS: + return result + prefix = self._secret_prefix + for k in self._SECRET_FIELDS: + val = result.get(k, "") + if val and not is_kv_ref(val): + try: + ref = kv.store(env_key_to_secret_name(f"{prefix}{k}"), val) + result[k] = ref + except Exception as exc: + logger.warning( + "Failed to store secret %s in KV: %s", k, exc, exc_info=True, + ) + return result + + @staticmethod + def _resolve_secret(value: Any) -> Any: + """Resolve a possible Key Vault reference back to its plaintext.""" + if not isinstance(value, str): + return value + from ..services.keyvault import resolve_if_kv_ref + + return resolve_if_kv_ref(value) diff --git a/app/runtime/state/_json_store.py b/app/runtime/state/_json_store.py index be448af..575c79b 100644 --- a/app/runtime/state/_json_store.py +++ b/app/runtime/state/_json_store.py @@ -29,7 +29,7 @@ def load(self) -> Any: try: return json.loads(self._path.read_text()) except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to load %s: %s", self._path, exc) + logger.warning("Failed to load %s: %s", self._path, exc, exc_info=True) return self._default_copy() def save(self, data: Any) -> None: diff --git a/app/runtime/state/deploy_state.py b/app/runtime/state/deploy_state.py index 992df9b..2e53003 100644 --- a/app/runtime/state/deploy_state.py +++ b/app/runtime/state/deploy_state.py @@ -129,17 +129,26 @@ def by_kind(self, kind: str) -> list[DeploymentRecord]: return [d for d in self._deployments.values() if d.kind == kind] def current_local(self) -> DeploymentRecord | None: - local = [d for d in self._deployments.values() if d.kind == "local" and d.status == "active"] + local = [ + d for d in self._deployments.values() + if d.kind == "local" and d.status == "active" + ] return max(local, key=lambda d: d.updated_at) if local else None def current_aca(self) -> DeploymentRecord | None: - aca = [d for d in self._deployments.values() if d.kind == "aca" and d.status == "active"] + aca = [ + d for d in self._deployments.values() + if d.kind == "aca" and d.status == "active" + ] return max(aca, key=lambda d: d.updated_at) if aca else None def register(self, record: DeploymentRecord) -> None: self._deployments[record.deploy_id] = record self._save() - logger.info("Registered deployment %s (kind=%s, tag=%s)", record.deploy_id, record.kind, record.tag) + logger.info( + "Registered deployment %s (kind=%s, tag=%s)", + record.deploy_id, record.kind, record.tag, + ) def update(self, record: DeploymentRecord) -> None: record.touch() @@ -193,7 +202,9 @@ def _load(self) -> None: rec.resources = resources self._deployments[did] = rec except Exception as exc: - logger.warning("Failed to load deploy state from %s: %s", self._path, exc) + logger.warning( + "Failed to load deploy state from %s: %s", self._path, exc, exc_info=True, + ) def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) diff --git a/app/runtime/state/foundry_iq_config.py b/app/runtime/state/foundry_iq_config.py index 53a4244..d3df172 100644 --- a/app/runtime/state/foundry_iq_config.py +++ b/app/runtime/state/foundry_iq_config.py @@ -2,13 +2,11 @@ from __future__ import annotations -import json import logging from dataclasses import asdict, dataclass -from pathlib import Path from typing import Any -from ..config.settings import cfg +from ._base import BaseConfigStore logger = logging.getLogger(__name__) @@ -33,23 +31,14 @@ class FoundryIQConfig: provisioned: bool = False -class FoundryIQConfigStore: +class FoundryIQConfigStore(BaseConfigStore[FoundryIQConfig]): """JSON-file-backed Foundry IQ configuration.""" + _config_type = FoundryIQConfig + _default_filename = "foundry_iq.json" + _log_label = "Foundry IQ config" _SECRET_FIELDS = frozenset({"search_api_key", "embedding_api_key"}) - - def __init__(self, path: Path | None = None) -> None: - self._path = path or (cfg.data_dir / "foundry_iq.json") - self._config = FoundryIQConfig() - self._load() - - @property - def path(self) -> Path: - return self._path - - @property - def config(self) -> FoundryIQConfig: - return self._config + _secret_prefix = "foundryiq-" @property def enabled(self) -> bool: @@ -105,46 +94,17 @@ def clear_provisioning(self) -> None: self._config.enabled = False self._save() - def _load(self) -> None: - if not self._path.exists(): - return - try: - raw = json.loads(self._path.read_text()) - for k in FoundryIQConfig.__dataclass_fields__: - if k in raw: - value = raw[k] - if k in self._SECRET_FIELDS and isinstance(value, str): - value = self._resolve_secret(value) - setattr(self._config, k, value) - except Exception as exc: - logger.warning("Failed to load Foundry IQ config from %s: %s", self._path, exc) - - def _save(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) + def _apply_raw(self, raw: dict[str, Any]) -> None: + for k in FoundryIQConfig.__dataclass_fields__: + if k in raw: + value = raw[k] + if k in self._SECRET_FIELDS and isinstance(value, str): + value = self._resolve_secret(value) + setattr(self._config, k, value) + + def _save_data(self) -> dict[str, Any]: data = asdict(self._config) - data = self._store_secrets(data) - self._path.write_text(json.dumps(data, indent=2) + "\n") - - def _store_secrets(self, d: dict[str, Any]) -> dict[str, Any]: - from ..services.keyvault import kv, env_key_to_secret_name, is_kv_ref - - result = dict(d) - if not kv.enabled: - return result - for k in self._SECRET_FIELDS: - val = result.get(k, "") - if val and not is_kv_ref(val): - try: - ref = kv.store(env_key_to_secret_name(f"foundryiq-{k}"), val) - result[k] = ref - except Exception as exc: - logger.warning("Failed to store secret %s in KV: %s", k, exc) - return result - - @staticmethod - def _resolve_secret(value: str) -> str: - from ..services.keyvault import resolve_if_kv_ref - return resolve_if_kv_ref(value) + return self._store_secrets(data) # -- singleton ------------------------------------------------------------- diff --git a/app/runtime/state/guardrails/__init__.py b/app/runtime/state/guardrails/__init__.py new file mode 100644 index 0000000..0b8d481 --- /dev/null +++ b/app/runtime/state/guardrails/__init__.py @@ -0,0 +1,49 @@ +"""Guardrails -- policy engine, presets, risk tiers, and bulk operations.""" + +from __future__ import annotations + +from .config import ( + GuardrailsConfigStore, + get_guardrails_config, +) +from .models import ( + GuardrailRule, + GuardrailsConfig, + _VALID_STRATEGIES, +) +from .presets import ( + PRESET_BALANCED, + PRESET_PERMISSIVE, + PRESET_RESTRICTIVE, + _ALL_PRESET_TOOL_IDS, + _build_preset_policies, + list_background_agents, + list_presets, +) +from .risk import ( + _MODEL_TIERS, + _risk_of, + get_model_tier, + get_preset_for_model, + list_model_tiers, +) + +__all__ = [ + "GuardrailRule", + "GuardrailsConfig", + "GuardrailsConfigStore", + "PRESET_BALANCED", + "PRESET_PERMISSIVE", + "PRESET_RESTRICTIVE", + "_ALL_PRESET_TOOL_IDS", + "_MODEL_TIERS", + "_VALID_STRATEGIES", + "_build_preset_policies", + "_risk_of", + "get_guardrails_config", + "get_model_tier", + "get_preset_for_model", + "list_background_agents", + "list_model_tiers", + "list_presets", +] diff --git a/app/runtime/state/guardrails/bulk.py b/app/runtime/state/guardrails/bulk.py new file mode 100644 index 0000000..257b5f5 --- /dev/null +++ b/app/runtime/state/guardrails/bulk.py @@ -0,0 +1,123 @@ +"""Bulk guardrails operations -- presets, strategies, and model defaults.""" + +from __future__ import annotations + +from .models import GuardrailsConfig, _VALID_STRATEGIES +from .presets import ( + PRESET_BALANCED, + PRESET_PERMISSIVE, + PRESET_RESTRICTIVE, + _ALL_PRESET_TOOL_IDS, + _EFFECTIVE_MODEL_PRESET, + _PRESET_MATRIX, + _PRESET_OVERRIDES, + _build_preset_policies, + list_presets, +) +from .risk import ( + _MODEL_TIERS, + _risk_of, + get_model_tier, + get_preset_for_model, +) + + +def apply_preset_to_config( + config: GuardrailsConfig, preset: str, *, auto_models: bool = True, +) -> None: + """Apply a named preset to *config* in place. + + Overwrites ``context_defaults`` and ``tool_policies``. When + *auto_models* is ``True``, recommended models are added as model + columns with tier-appropriate policies and all existing model + columns are refreshed. + """ + valid = {PRESET_RESTRICTIVE, PRESET_BALANCED, PRESET_PERMISSIVE} + if preset not in valid: + raise ValueError("preset must be one of: %s" % ", ".join(sorted(valid))) + policies = _build_preset_policies(preset) + config.context_defaults = policies["context_defaults"] + config.tool_policies = policies["tool_policies"] + config.hitl_enabled = True + if auto_models: + preset_meta = next((p for p in list_presets() if p["id"] == preset), None) + if preset_meta: + new_models = [ + m for m in preset_meta["recommended_for"] + if m not in config.model_columns + ] + if new_models: + apply_model_defaults_to_config(config, new_models, preset=preset) + if config.model_columns: + apply_model_defaults_to_config(config, preset=preset) + + +def set_all_strategies_on_config(config: GuardrailsConfig, strategy: str) -> None: + """Set every tool policy and context default on *config* to *strategy*. + + All tools in ``_ALL_PRESET_TOOL_IDS`` across interactive and background + contexts are set to the given strategy. All known models from + ``_MODEL_TIERS`` are added as model columns with the same strategy. + """ + if strategy not in _VALID_STRATEGIES: + raise ValueError( + "strategy must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES)) + ) + policies: dict[str, dict[str, str]] = {"interactive": {}, "background": {}} + for tool_id in _ALL_PRESET_TOOL_IDS: + for ctx in ("interactive", "background"): + policies[ctx][tool_id] = strategy + config.context_defaults = { + "interactive": strategy, + "background": strategy, + } + config.tool_policies = policies + config.model_columns = sorted(_MODEL_TIERS.keys()) + model_policies: dict[str, dict[str, dict[str, str]]] = {} + for model in config.model_columns: + per_ctx: dict[str, dict[str, str]] = {} + for ctx in ("interactive", "background"): + per_ctx[ctx] = {tool_id: strategy for tool_id in _ALL_PRESET_TOOL_IDS} + model_policies[model] = per_ctx + config.model_policies = model_policies + config.hitl_enabled = True + + +def apply_model_defaults_to_config( + config: GuardrailsConfig, + models: list[str] | None = None, + *, + preset: str | None = None, +) -> None: + """Auto-populate model columns on *config* with tier-appropriate policies. + + For each model, determines the effective preset via the + ``_EFFECTIVE_MODEL_PRESET`` cross-reference of *preset* (the + user-selected risk posture) and the model's inherent tier. + + If *models* is ``None``, uses the existing ``model_columns``. + If *preset* is ``None``, falls back to the model's own tier preset. + """ + target_models = models if models is not None else list(config.model_columns) + for model in target_models: + if model not in config.model_columns: + config.model_columns.append(model) + tier = get_model_tier(model) + if preset: + effective = _EFFECTIVE_MODEL_PRESET.get( + (preset, tier), + get_preset_for_model(model), + ) + else: + effective = get_preset_for_model(model) + matrix = _PRESET_MATRIX.get(effective, _PRESET_MATRIX[PRESET_RESTRICTIVE]) + overrides = _PRESET_OVERRIDES.get(effective, {}) + per_ctx: dict[str, dict[str, str]] = {} + for ctx in ("interactive", "background"): + ctx_overrides = overrides.get(ctx, {}) + ctx_policies: dict[str, str] = {} + for tool_id in _ALL_PRESET_TOOL_IDS: + risk = _risk_of(tool_id) + ctx_policies[tool_id] = ctx_overrides.get(tool_id, matrix[ctx][risk]) + per_ctx[ctx] = ctx_policies + config.model_policies[model] = per_ctx diff --git a/app/runtime/state/guardrails/config.py b/app/runtime/state/guardrails/config.py new file mode 100644 index 0000000..69b1bc0 --- /dev/null +++ b/app/runtime/state/guardrails/config.py @@ -0,0 +1,505 @@ +"""Guardrails configuration -- HITL approval rules for tools and MCP servers.""" + +from __future__ import annotations + +import json +import logging +from dataclasses import asdict +from pathlib import Path +from typing import Any + +from ...agent.policy_bridge import ( + build_engine, + config_to_yaml, + make_eval_context, + validate_yaml, + yaml_to_config, +) +from ...config.settings import cfg + +from .bulk import ( + apply_model_defaults_to_config, + apply_preset_to_config, + set_all_strategies_on_config, +) + +# Re-export public symbols so existing imports keep working. +from .models import GuardrailRule, GuardrailsConfig, _VALID_STRATEGIES # noqa: F401 +from .presets import ( # noqa: F401 + PRESET_BALANCED, + PRESET_PERMISSIVE, + PRESET_RESTRICTIVE, + _ALL_PRESET_TOOL_IDS, + _build_preset_policies, + list_background_agents, + list_presets, +) +from .risk import ( # noqa: F401 + _MODEL_TIERS, + _risk_of, + get_model_tier, + get_preset_for_model, + list_model_tiers, +) + +logger = logging.getLogger(__name__) + +_instance: GuardrailsConfigStore | None = None + + +class GuardrailsConfigStore: + """JSON-file-backed guardrails configuration. + + The store maintains both a JSON file (UI state, phone numbers, AITL + config, etc.) and a YAML policy file consumed by the agent-policy-guard + ``PolicyEngine``. Every mutation regenerates the YAML and rebuilds + the engine so that ``resolve_action()`` always reflects the latest + configuration. + """ + + def __init__(self, path: Path | None = None) -> None: + self._path = path or (cfg.data_dir / "guardrails.json") + self._policy_path = self._path.with_name("policy.yaml") + self._config = GuardrailsConfig() + self._engine = build_engine(self._generate_yaml()) + self._load() + + @property + def path(self) -> Path: + return self._path + + @property + def config(self) -> GuardrailsConfig: + return self._config + + @property + def hitl_enabled(self) -> bool: + return self._config.hitl_enabled + + @property + def default_action(self) -> str: + return self._config.default_action + + @property + def rules(self) -> list[GuardrailRule]: + return list(self._config.rules) + + def set_hitl_enabled(self, enabled: bool) -> None: + self._config.hitl_enabled = enabled + self._save() + + def set_default_action(self, action: str) -> None: + if action not in _VALID_STRATEGIES: + raise ValueError("action must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES))) + self._config.default_action = action + self._save() + + @property + def default_channel(self) -> str: + return self._config.default_channel + + @property + def phone_number(self) -> str: + return self._config.phone_number + + def set_default_channel(self, channel: str) -> None: + if channel not in ("chat", "phone"): + raise ValueError("channel must be 'chat' or 'phone'") + self._config.default_channel = channel + self._save() + + def set_phone_number(self, number: str) -> None: + self._config.phone_number = number + self._save() + + def set_aitl_model(self, model: str) -> None: + self._config.aitl_model = model + self._save() + + def set_aitl_spotlighting(self, enabled: bool) -> None: + self._config.aitl_spotlighting = enabled + self._save() + + def set_filter_mode(self, mode: str) -> None: + if mode != "prompt_shields": + raise ValueError("filter_mode must be 'prompt_shields'") + self._config.filter_mode = mode + self._save() + + def set_content_safety_endpoint(self, endpoint: str) -> None: + self._config.content_safety_endpoint = endpoint + self._save() + + def set_content_safety_key(self, key: str) -> None: + self._config.content_safety_key = key + self._save() + + def set_context_default(self, context: str, strategy: str) -> None: + if strategy not in _VALID_STRATEGIES: + raise ValueError("strategy must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES))) + self._config.context_defaults[context] = strategy + self._save() + + def remove_context_default(self, context: str) -> bool: + """Remove a context-level default, reverting to fallback resolution.""" + if context in self._config.context_defaults: + del self._config.context_defaults[context] + self._save() + return True + return False + + def set_tool_policy( + self, context: str, tool_id: str, strategy: str, + ) -> None: + if strategy not in _VALID_STRATEGIES: + raise ValueError("strategy must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES))) + if context not in self._config.tool_policies: + self._config.tool_policies[context] = {} + self._config.tool_policies[context][tool_id] = strategy + self._save() + + def remove_tool_policy(self, context: str, tool_id: str) -> bool: + policies = self._config.tool_policies.get(context, {}) + if tool_id in policies: + del policies[tool_id] + self._save() + return True + return False + + def add_model_column(self, model: str) -> None: + if model not in self._config.model_columns: + self._config.model_columns.append(model) + self._save() + + def remove_model_column(self, model: str) -> bool: + if model in self._config.model_columns: + self._config.model_columns.remove(model) + self._config.model_policies.pop(model, None) + self._save() + return True + return False + + def set_model_policy( + self, model: str, tool_id: str, strategy: str, context: str = "interactive", + ) -> None: + if strategy not in _VALID_STRATEGIES: + raise ValueError("strategy must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES))) + if model not in self._config.model_policies: + self._config.model_policies[model] = {} + if context not in self._config.model_policies[model]: + self._config.model_policies[model][context] = {} + self._config.model_policies[model][context][tool_id] = strategy + self._save() + + def remove_model_policy( + self, model: str, tool_id: str, context: str = "interactive", + ) -> bool: + ctx_policies = self._config.model_policies.get(model, {}).get(context, {}) + if tool_id in ctx_policies: + del ctx_policies[tool_id] + self._save() + return True + return False + + def apply_preset(self, preset: str, *, auto_models: bool = True) -> None: + """Apply a named preset to context_defaults and tool_policies.""" + apply_preset_to_config(self._config, preset, auto_models=auto_models) + self._save() + + def set_all_strategies(self, strategy: str) -> None: + """Set every tool policy and context default to *strategy*.""" + set_all_strategies_on_config(self._config, strategy) + self._save() + + def apply_model_defaults( + self, + models: list[str] | None = None, + *, + preset: str | None = None, + ) -> None: + """Auto-populate model columns with tier-appropriate policies.""" + apply_model_defaults_to_config(self._config, models, preset=preset) + self._save() + + def add_rule( + self, + *, + name: str, + pattern: str, + scope: str = "tool", + action: str = "ask", + enabled: bool = True, + description: str = "", + contexts: list[str] | None = None, + models: list[str] | None = None, + hitl_channel: str = "chat", + ) -> GuardrailRule: + if scope not in ("tool", "mcp"): + raise ValueError("scope must be 'tool' or 'mcp'") + if action not in _VALID_STRATEGIES: + raise ValueError("action must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES))) + if hitl_channel not in ("chat", "phone"): + raise ValueError("hitl_channel must be 'chat' or 'phone'") + rule = GuardrailRule( + name=name, + pattern=pattern, + scope=scope, + action=action, + enabled=enabled, + description=description, + contexts=contexts or [], + models=models or [], + hitl_channel=hitl_channel, + ) + self._config.rules.append(rule) + self._save() + return rule + + def update_rule(self, rule_id: str, **kwargs: Any) -> GuardrailRule | None: + for rule in self._config.rules: + if rule.id == rule_id: + for k, v in kwargs.items(): + if k == "id": + continue + if hasattr(rule, k): + setattr(rule, k, v) + self._save() + return rule + return None + + def remove_rule(self, rule_id: str) -> bool: + before = len(self._config.rules) + self._config.rules = [r for r in self._config.rules if r.id != rule_id] + if len(self._config.rules) < before: + self._save() + return True + return False + + def get_rule(self, rule_id: str) -> GuardrailRule | None: + for rule in self._config.rules: + if rule.id == rule_id: + return rule + return None + + def resolve_action( + self, + tool_name: str, + mcp_server: str | None = None, + execution_context: str = "", + model: str = "", + ) -> str: + """Determine the strategy for a given tool invocation. + + Delegates to the agent-policy-guard ``PolicyEngine`` which evaluates + the generated YAML policy set. The YAML encodes all context defaults, + tool policies, model policies, legacy rules, and background-agent + fallbacks. + + When ``hitl_enabled`` is ``False`` the engine already has ``allow`` + as its default effect and no policies are generated, so it returns + ``"allow"`` for every call. + """ + ctx = make_eval_context( + tool_name=tool_name, + mcp_server=mcp_server, + execution_context=execution_context, + model=model, + ) + result = self._engine.resolve(ctx) + logger.debug( + "[guardrails.resolve] engine result: tool=%s ctx=%s model=%s -> %s", + tool_name, execution_context, model, result, + ) + return result + + def resolve_channel( + self, + tool_name: str, + mcp_server: str | None = None, + execution_context: str = "", + model: str = "", + ) -> str: + """Determine the HITL channel for a tool invocation. + + Returns the ``hitl_channel`` of the first matching rule, or the + store-level ``default_channel``. + """ + if not self._config.hitl_enabled: + return "chat" + + for rule in self._config.rules: + if not rule.enabled: + continue + if rule.contexts and execution_context and execution_context not in rule.contexts: + continue + if rule.models and model: + if not any(self._matches(m, model) for m in rule.models): + continue + if rule.scope == "tool" and self._matches(rule.pattern, tool_name): + return rule.hitl_channel + if rule.scope == "mcp" and mcp_server and self._matches(rule.pattern, mcp_server): + return rule.hitl_channel + + return self._config.default_channel + + def to_dict(self) -> dict[str, Any]: + return { + # Frontend-canonical fields + "enabled": self._config.hitl_enabled, + "default_strategy": self._config.default_action, + "hitl_channel": self._config.default_channel, + "context_defaults": dict(self._config.context_defaults), + "tool_policies": { + ctx: dict(policies) + for ctx, policies in self._config.tool_policies.items() + }, + "model_columns": list(self._config.model_columns), + "model_policies": { + model: { + ctx: dict(tool_map) + for ctx, tool_map in ctx_policies.items() + } + for model, ctx_policies in self._config.model_policies.items() + }, + # Backend / legacy fields + "hitl_enabled": self._config.hitl_enabled, + "default_action": self._config.default_action, + "default_channel": self._config.default_channel, + "phone_number": self._config.phone_number, + "aitl_model": self._config.aitl_model, + "aitl_spotlighting": self._config.aitl_spotlighting, + "filter_mode": self._config.filter_mode, + "content_safety_endpoint": self._config.content_safety_endpoint, + "rules": [asdict(r) for r in self._config.rules], + } + + @staticmethod + def _matches(pattern: str, name: str) -> bool: + """Simple glob-style matching: '*' matches everything, prefix* matches prefix.""" + if pattern == "*": + return True + if pattern.endswith("*"): + return name.startswith(pattern[:-1]) + return pattern == name + + def _load(self) -> None: + if not self._path.exists(): + self._rebuild_engine() + return + try: + raw = json.loads(self._path.read_text()) + self._config = GuardrailsConfig( + hitl_enabled=raw.get("enabled", raw.get("hitl_enabled", False)), + default_action=raw.get("default_strategy", raw.get("default_action", "allow")), + default_channel=raw.get("hitl_channel", raw.get("default_channel", "chat")), + phone_number=raw.get("phone_number", ""), + aitl_model=raw.get("aitl_model", "gpt-4.1"), + aitl_spotlighting=raw.get("aitl_spotlighting", True), + filter_mode=raw.get("filter_mode", "prompt_shields"), + content_safety_endpoint=raw.get("content_safety_endpoint", ""), + content_safety_key=raw.get("content_safety_key", ""), + rules=[ + GuardrailRule(**{ + k: v for k, v in r.items() + if k in GuardrailRule.__dataclass_fields__ + }) + for r in raw.get("rules", []) + ], + context_defaults=raw.get("context_defaults", {}), + tool_policies=raw.get("tool_policies", {}), + model_columns=raw.get("model_columns", []), + model_policies=raw.get("model_policies", {}), + ) + self._rebuild_engine() + except Exception as exc: + logger.warning( + "Failed to load guardrails config from %s: %s", + self._path, exc, exc_info=True, + ) + + @property + def policy_path(self) -> Path: + """Path to the generated policy YAML file.""" + return self._policy_path + + def get_policy_yaml(self) -> str: + """Return the current policy as a YAML string.""" + return self._generate_yaml() + + def set_policy_yaml(self, yaml_text: str) -> str | None: + """Apply a raw YAML policy, updating the config to match. + + Returns ``None`` on success or an error message string. + """ + error = validate_yaml(yaml_text) + if error: + return error + try: + parsed = yaml_to_config(yaml_text) + self._config.default_action = parsed["default_action"] + self._config.default_channel = parsed["default_channel"] + self._config.context_defaults = parsed["context_defaults"] + self._config.tool_policies = parsed["tool_policies"] + self._config.model_columns = parsed["model_columns"] + self._config.model_policies = parsed["model_policies"] + if parsed.get("rules"): + self._config.rules = [ + GuardrailRule(**{ + k: v for k, v in r.items() + if k in GuardrailRule.__dataclass_fields__ + }) + for r in parsed["rules"] + ] + self._save() + return None + except Exception as exc: + logger.warning("[guardrails] failed to apply YAML: %s", exc, exc_info=True) + return str(exc) + + def _generate_yaml(self) -> str: + """Generate a policy YAML string from the current config.""" + return config_to_yaml( + hitl_enabled=self._config.hitl_enabled, + default_action=self._config.default_action, + default_channel=self._config.default_channel, + context_defaults=self._config.context_defaults, + tool_policies=self._config.tool_policies, + model_columns=self._config.model_columns, + model_policies=self._config.model_policies, + rules=[asdict(r) for r in self._config.rules], + ) + + def _rebuild_engine(self) -> None: + """Rebuild the PolicyEngine from the current config.""" + yaml_text = self._generate_yaml() + self._engine = build_engine(yaml_text) + # Write the YAML file alongside the JSON for reference / expert mode + try: + self._policy_path.write_text(yaml_text) + except Exception as exc: + logger.warning( + "[guardrails] failed to write policy.yaml: %s", exc, exc_info=True, + ) + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text(json.dumps(self.to_dict(), indent=2) + "\n") + self._rebuild_engine() + + +def get_guardrails_config(path: Path | None = None) -> GuardrailsConfigStore: + """Module-level singleton accessor.""" + global _instance + if _instance is None: + _instance = GuardrailsConfigStore(path) + return _instance + + +def _reset_guardrails_config() -> None: + global _instance + _instance = None + + +from ...util.singletons import register_singleton # noqa: E402 + +register_singleton(_reset_guardrails_config) diff --git a/app/runtime/state/guardrails/models.py b/app/runtime/state/guardrails/models.py new file mode 100644 index 0000000..f0a0a55 --- /dev/null +++ b/app/runtime/state/guardrails/models.py @@ -0,0 +1,52 @@ +"""Guardrails data models -- dataclasses and shared constants.""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field + +_VALID_STRATEGIES = frozenset({"allow", "deny", "hitl", "pitl", "aitl", "filter", "ask"}) + + +@dataclass +class GuardrailRule: + """A single approval rule for a tool or MCP server.""" + + id: str = "" + name: str = "" + pattern: str = "" + scope: str = "tool" # "tool" | "mcp" + action: str = "allow" # "allow" | "deny" | "ask" + enabled: bool = True + description: str = "" + # Context-aware policy fields + contexts: list[str] = field(default_factory=list) # [] = all contexts + models: list[str] = field(default_factory=list) # [] = all models + hitl_channel: str = "chat" # "chat" | "phone" + + def __post_init__(self) -> None: + if not self.id: + self.id = str(uuid.uuid4())[:8] + + +@dataclass +class GuardrailsConfig: + """Top-level guardrails configuration.""" + + hitl_enabled: bool = False + default_action: str = "allow" # "allow" | "deny" | "hitl" | "pitl" | "aitl" | "filter" + default_channel: str = "chat" # "chat" | "phone" + phone_number: str = "" # E.164 number for phone verification + aitl_model: str = "gpt-4.1" # Model used by the AITL reviewer agent + aitl_spotlighting: bool = True # Spotlight untrusted content in AITL prompts + filter_mode: str = "prompt_shields" # always "prompt_shields" + content_safety_endpoint: str = "" # Azure Content Safety endpoint URL + content_safety_key: str = "" # Azure Content Safety API key + rules: list[GuardrailRule] = field(default_factory=list) + # Policy matrix fields (frontend-driven) + context_defaults: dict[str, str] = field(default_factory=dict) + tool_policies: dict[str, dict[str, str]] = field(default_factory=dict) + # Model-specific columns: user-defined model identifiers + model_columns: list[str] = field(default_factory=list) + # Model-scoped policies: model -> context -> tool -> strategy + model_policies: dict[str, dict[str, dict[str, str]]] = field(default_factory=dict) diff --git a/app/runtime/state/guardrails/presets.py b/app/runtime/state/guardrails/presets.py new file mode 100644 index 0000000..248a217 --- /dev/null +++ b/app/runtime/state/guardrails/presets.py @@ -0,0 +1,269 @@ +"""Preset definitions and background agent metadata for guardrails.""" + +from __future__ import annotations + +from typing import Any + +from .risk import _risk_of + +# ── Background agent metadata ─────────────────────────────────────────── + +BACKGROUND_AGENTS: tuple[dict[str, Any], ...] = ( + { + "id": "scheduler", + "name": "Scheduler", + "description": ( + "Runs scheduled tasks on a cron schedule. Has full tool access " + "including file operations, terminal, and MCP servers." + ), + "has_tools": True, + "default_policy": "background", + "risk_note": ( + "Changing the policy for the scheduler may cause scheduled tasks " + "to hang waiting for approval or fail silently." + ), + }, + { + "id": "bot_processor", + "name": "Bot Message Processor", + "description": ( + "Processes messages from Teams, Telegram, and other bot channels. " + "Shares the full tool set with the interactive agent." + ), + "has_tools": True, + "default_policy": "background", + "risk_note": ( + "Changing the policy for the bot processor may cause channel " + "messages to hang or tools to be blocked for bot users." + ), + }, + { + "id": "proactive_loop", + "name": "Proactive Loop", + "description": ( + "Generates proactive messages and notifications. Text-only -- " + "has no tool access." + ), + "has_tools": False, + "default_policy": "allow", + "risk_note": ( + "This agent has no tool access. Guardrail changes have no effect." + ), + }, + { + "id": "memory_formation", + "name": "Memory Formation", + "description": ( + "Post-processes conversations to extract and store memories. " + "Text-only -- has no tool access." + ), + "has_tools": False, + "default_policy": "allow", + "risk_note": ( + "This agent has no tool access. Guardrail changes have no effect." + ), + }, + { + "id": "aitl_reviewer", + "name": "AITL Reviewer", + "description": ( + "AI reviewer that evaluates tool calls for safety. Uses one " + "internal decision tool (submit_decision)." + ), + "has_tools": True, + "default_policy": "allow", + "risk_note": ( + "The AITL reviewer IS the guardrail. Restricting it will " + "prevent it from functioning and break AITL-based approvals." + ), + }, + { + "id": "realtime", + "name": "Realtime Voice Agent", + "description": ( + "Bridges the Realtime voice model to the Copilot SDK agent. " + "Spawns one-shot sessions to execute tool-based tasks requested " + "via voice calls." + ), + "has_tools": True, + "default_policy": "background", + "risk_note": ( + "Changing the policy for the realtime agent may cause voice " + "call tool invocations to hang or be blocked." + ), + }, +) + +_BACKGROUND_AGENT_IDS: frozenset[str] = frozenset( + a["id"] for a in BACKGROUND_AGENTS +) + + +def list_background_agents() -> list[dict[str, Any]]: + """Return metadata for all background agents.""" + return list(BACKGROUND_AGENTS) + + +# ── Preset constants ──────────────────────────────────────────────────── + +PRESET_RESTRICTIVE = "restrictive" +PRESET_BALANCED = "balanced" +PRESET_PERMISSIVE = "permissive" + +_TIER_TO_PRESET: dict[int, str] = { + 1: PRESET_PERMISSIVE, + 2: PRESET_BALANCED, + 3: PRESET_RESTRICTIVE, +} + +# Cross-reference: (selected_preset, model_tier) -> effective preset for model-column +# policies. This ensures that switching presets actually changes per-model rules +# while still respecting each model's inherent safety tier. +_EFFECTIVE_MODEL_PRESET: dict[tuple[str, int], str] = { + # Permissive preset: strong/standard models get permissive, cautious gets balanced + (PRESET_PERMISSIVE, 1): PRESET_PERMISSIVE, + (PRESET_PERMISSIVE, 2): PRESET_PERMISSIVE, + (PRESET_PERMISSIVE, 3): PRESET_BALANCED, + # Balanced preset: strong gets permissive, standard balanced, cautious balanced + (PRESET_BALANCED, 1): PRESET_PERMISSIVE, + (PRESET_BALANCED, 2): PRESET_BALANCED, + (PRESET_BALANCED, 3): PRESET_BALANCED, + # Restrictive preset: strong gets balanced, standard/cautious get restrictive + (PRESET_RESTRICTIVE, 1): PRESET_BALANCED, + (PRESET_RESTRICTIVE, 2): PRESET_RESTRICTIVE, + (PRESET_RESTRICTIVE, 3): PRESET_RESTRICTIVE, +} + +# Strategy lookup: (preset, context, risk) -> strategy +_PRESET_MATRIX: dict[str, dict[str, dict[str, str]]] = { + PRESET_PERMISSIVE: { + "interactive": {"low": "filter", "medium": "filter", "high": "filter"}, + "background": {"low": "filter", "medium": "filter", "high": "hitl"}, + }, + PRESET_BALANCED: { + "interactive": {"low": "filter", "medium": "filter", "high": "hitl"}, + "background": {"low": "filter", "medium": "hitl", "high": "deny"}, + }, + PRESET_RESTRICTIVE: { + "interactive": {"low": "filter", "medium": "hitl", "high": "hitl"}, + "background": {"low": "filter", "medium": "deny", "high": "deny"}, + }, +} + +# Per-preset tool overrides applied *after* the risk matrix. +_PRESET_OVERRIDES: dict[str, dict[str, dict[str, str]]] = { + PRESET_BALANCED: { + "background": { + "create": "aitl", + "edit": "aitl", + "run": "aitl", + "bash": "aitl", + "make_voice_call": "aitl", + }, + }, +} + +# Every tool/MCP/skill that presets should populate explicitly. +_ALL_PRESET_TOOL_IDS: list[str] = [ + # SDK + "create", "edit", "view", "grep", "glob", "run", "bash", + # Custom agent tools + "schedule_task", "cancel_task", "list_scheduled_tasks", "make_voice_call", + "send_adaptive_card", "send_hero_card", "send_thumbnail_card", "send_card_carousel", + "search_memories_tool", + # MCP + "mcp:microsoft-learn", "mcp:playwright", "mcp:github-mcp-server", "mcp:azure-mcp-server", + # Skills (builtin) + "skill:web-search", "skill:summarize-url", "skill:note-taking", "skill:daily-briefing", +] + +# Restrictiveness ranking for merging model policies across contexts. +_STRATEGY_RANK: dict[str, int] = { + "allow": 0, + "filter": 1, + "aitl": 2, + "hitl": 3, + "pitl": 4, + "ask": 4, + "deny": 5, +} + + +def _strategy_rank(strategy: str) -> int: + return _STRATEGY_RANK.get(strategy, 3) + + +def _build_preset_policies(preset: str) -> dict[str, Any]: + """Return context_defaults and tool_policies for a given preset name. + + Uses the ``_PRESET_MATRIX`` to map (preset, context, risk) -> strategy + for every known tool/MCP/skill. + """ + matrix = _PRESET_MATRIX.get(preset, _PRESET_MATRIX[PRESET_RESTRICTIVE]) + overrides = _PRESET_OVERRIDES.get(preset, {}) + policies: dict[str, dict[str, str]] = {"interactive": {}, "background": {}} + for tool_id in _ALL_PRESET_TOOL_IDS: + risk = _risk_of(tool_id) + for ctx in ("interactive", "background"): + policies[ctx][tool_id] = matrix[ctx][risk] + # Apply per-tool overrides after the matrix + for ctx, tool_map in overrides.items(): + for tool_id, strategy in tool_map.items(): + policies[ctx][tool_id] = strategy + # Context-level defaults (for tools not explicitly listed) + ctx_defaults = { + ctx: matrix[ctx]["medium"] for ctx in ("interactive", "background") + } + return { + "context_defaults": ctx_defaults, + "tool_policies": policies, + } + + +def list_presets() -> list[dict[str, Any]]: + """Return metadata for all available presets.""" + from .risk import _MODEL_TIERS + + return [ + { + "id": PRESET_RESTRICTIVE, + "name": "Restrictive", + "description": ( + "For smaller or older models. Read-only tools allowed; " + "file edits and browser require HITL in interactive; " + "terminal, GitHub, Azure, and all MCP denied in background." + ), + "tier": 3, + "recommended_for": sorted( + m for m, t in _MODEL_TIERS.items() if t == 3 + ), + }, + { + "id": PRESET_BALANCED, + "name": "Balanced", + "description": ( + "For standard models. Low-risk tools allowed everywhere; " + "terminal and GitHub/Azure require HITL in interactive; " + "file operations, terminal, and voice calls use AITL in " + "background; high-risk MCP denied in background. " + "MS Learn allowed." + ), + "tier": 2, + "recommended_for": sorted( + m for m, t in _MODEL_TIERS.items() if t == 2 + ), + }, + { + "id": PRESET_PERMISSIVE, + "name": "Permissive", + "description": ( + "For strong frontier models. All tools allowed in interactive. " + "Terminal, GitHub, Azure still require HITL in background. " + "MS Learn, file operations, and browser allowed everywhere." + ), + "tier": 1, + "recommended_for": sorted( + m for m, t in _MODEL_TIERS.items() if t == 1 + ), + }, + ] diff --git a/app/runtime/state/guardrails/risk.py b/app/runtime/state/guardrails/risk.py new file mode 100644 index 0000000..3485696 --- /dev/null +++ b/app/runtime/state/guardrails/risk.py @@ -0,0 +1,118 @@ +"""Risk classification and model tier definitions for guardrails.""" + +from __future__ import annotations + +from typing import Any + +# ── Model tiers ────────────────────────────────────────────────────────── +# Tier 1 (cautious): large frontier models -- most access, highest risk posture +# Tier 2 (standard): capable mid-range models +# Tier 3 (safe): smaller / older models -- least access, lowest risk posture + +_MODEL_TIERS: dict[str, int] = { + # Tier 1 -- cautious (most permissive, highest risk) + "gpt-5.3-codex": 1, + "claude-opus-4.6": 1, + "claude-opus-4.6-fast": 1, + # Tier 2 -- standard + "claude-sonnet-4.6": 2, + "gpt-5.2": 2, + "gemini-3-pro-preview": 2, + # Tier 3 -- safe (most restrictive, lowest risk) + "gpt-5-mini": 3, + "gpt-4.1": 3, +} + +_DEFAULT_TIER = 3 # Unknown models get the most restrictive tier + +# ── MCP / tool risk classification ────────────────────────────────────── +# Risk levels: low (read-only / public), medium (browser / scheduling), +# high (code repos, infra, phone calls). + +_MCP_RISK: dict[str, str] = { + "mcp:microsoft-learn": "low", # read-only public docs + "mcp:playwright": "medium", # browser automation, can navigate sites + "mcp:github-mcp-server": "high", # create repos, PRs, push code + "mcp:azure-mcp-server": "high", # create/delete Azure resources +} + +_SKILL_RISK: dict[str, str] = { + "skill:daily-briefing": "low", # read-only from local memory + "skill:wiki-search": "low", # read-only, public API + "skill:wiki-summary": "low", + "skill:wiki-deep-dive": "low", + "skill:gh-status-check": "low", # read-only, public API + "skill:gh-incidents": "low", + "skill:gh-maintenance": "low", + "skill:web-search": "medium", # browser-based + "skill:summarize-url": "medium", # browser-based + "skill:note-taking": "medium", # filesystem writes + "skill:daily-rollover": "medium", # M365 reads + file writes + "skill:end-day": "medium", + "skill:weekly-review": "medium", + "skill:monthly-review": "medium", + "skill:setup-foundry": "high", # provisions Azure infra + "skill:foundry-agent-chat": "high", # creates cloud agents + "skill:foundry-code-interpreter": "high", + "skill:setup-workiq": "medium", + "skill:setup-wikipedia": "low", +} + +_CUSTOM_TOOL_RISK: dict[str, str] = { + "schedule_task": "medium", + "cancel_task": "medium", + "list_scheduled_tasks": "low", + "make_voice_call": "high", + "search_memories_tool": "low", + "send_adaptive_card": "low", + "send_hero_card": "low", + "send_thumbnail_card": "low", + "send_card_carousel": "low", +} + + +def _risk_of(tool_id: str) -> str: + """Return the risk level for any tool/MCP/skill id.""" + if tool_id in _MCP_RISK: + return _MCP_RISK[tool_id] + if tool_id in _SKILL_RISK: + return _SKILL_RISK[tool_id] + if tool_id in _CUSTOM_TOOL_RISK: + return _CUSTOM_TOOL_RISK[tool_id] + # SDK tools + if tool_id in ("view", "grep", "glob"): + return "low" + if tool_id in ("create", "edit"): + return "medium" + if tool_id in ("run", "bash"): + return "high" + # Unknown MCP or skill -- default to high for safety + if tool_id.startswith("mcp:") or tool_id.startswith("skill:"): + return "high" + return "medium" + + +def get_model_tier(model: str) -> int: + """Return the security tier for a model (1=cautious, 2=standard, 3=safe).""" + return _MODEL_TIERS.get(model, _DEFAULT_TIER) + + +def get_preset_for_model(model: str) -> str: + """Return the recommended preset name for a model.""" + from .presets import _TIER_TO_PRESET, PRESET_RESTRICTIVE + + return _TIER_TO_PRESET.get(get_model_tier(model), PRESET_RESTRICTIVE) + + +def list_model_tiers() -> list[dict[str, Any]]: + """Return all known models with their tier and recommended preset.""" + result: list[dict[str, Any]] = [] + _TIER_LABELS = {1: "Strong", 2: "Standard", 3: "Cautious"} + for model, tier in sorted(_MODEL_TIERS.items(), key=lambda x: (x[1], x[0])): + result.append({ + "model": model, + "tier": tier, + "tier_label": _TIER_LABELS.get(tier, "Unknown"), + "preset": get_preset_for_model(model), + }) + return result diff --git a/app/runtime/state/guardrails_config.py b/app/runtime/state/guardrails_config.py index a47d07b..f47356f 100644 --- a/app/runtime/state/guardrails_config.py +++ b/app/runtime/state/guardrails_config.py @@ -4,8 +4,7 @@ import json import logging -import uuid -from dataclasses import asdict, dataclass, field +from dataclasses import asdict from pathlib import Path from typing import Any @@ -18,439 +17,34 @@ ) from ..config.settings import cfg -logger = logging.getLogger(__name__) - -_instance: GuardrailsConfigStore | None = None - - -@dataclass -class GuardrailRule: - """A single approval rule for a tool or MCP server.""" - - id: str = "" - name: str = "" - pattern: str = "" - scope: str = "tool" # "tool" | "mcp" - action: str = "allow" # "allow" | "deny" | "ask" - enabled: bool = True - description: str = "" - # Context-aware policy fields - contexts: list[str] = field(default_factory=list) # [] = all contexts - models: list[str] = field(default_factory=list) # [] = all models - hitl_channel: str = "chat" # "chat" | "phone" - - def __post_init__(self) -> None: - if not self.id: - self.id = str(uuid.uuid4())[:8] - - -_VALID_STRATEGIES = frozenset({"allow", "deny", "hitl", "pitl", "aitl", "filter", "ask"}) - -# ── Background agent metadata ─────────────────────────────────────────── -# Each background agent gets its own execution context so policy can be -# set per-agent. ``resolve_action`` falls back from the agent-specific -# context to ``"background"`` when no override exists. - -BACKGROUND_AGENTS: tuple[dict[str, Any], ...] = ( - { - "id": "scheduler", - "name": "Scheduler", - "description": ( - "Runs scheduled tasks on a cron schedule. Has full tool access " - "including file operations, terminal, and MCP servers." - ), - "has_tools": True, - "default_policy": "background", - "risk_note": ( - "Changing the policy for the scheduler may cause scheduled tasks " - "to hang waiting for approval or fail silently." - ), - }, - { - "id": "bot_processor", - "name": "Bot Message Processor", - "description": ( - "Processes messages from Teams, Telegram, and other bot channels. " - "Shares the full tool set with the interactive agent." - ), - "has_tools": True, - "default_policy": "background", - "risk_note": ( - "Changing the policy for the bot processor may cause channel " - "messages to hang or tools to be blocked for bot users." - ), - }, - { - "id": "proactive_loop", - "name": "Proactive Loop", - "description": ( - "Generates proactive messages and notifications. Text-only -- " - "has no tool access." - ), - "has_tools": False, - "default_policy": "allow", - "risk_note": ( - "This agent has no tool access. Guardrail changes have no effect." - ), - }, - { - "id": "memory_formation", - "name": "Memory Formation", - "description": ( - "Post-processes conversations to extract and store memories. " - "Text-only -- has no tool access." - ), - "has_tools": False, - "default_policy": "allow", - "risk_note": ( - "This agent has no tool access. Guardrail changes have no effect." - ), - }, - { - "id": "aitl_reviewer", - "name": "AITL Reviewer", - "description": ( - "AI reviewer that evaluates tool calls for safety. Uses one " - "internal decision tool (submit_decision)." - ), - "has_tools": True, - "default_policy": "allow", - "risk_note": ( - "The AITL reviewer IS the guardrail. Restricting it will " - "prevent it from functioning and break AITL-based approvals." - ), - }, - { - "id": "realtime", - "name": "Realtime Voice Agent", - "description": ( - "Bridges the Realtime voice model to the Copilot SDK agent. " - "Spawns one-shot sessions to execute tool-based tasks requested " - "via voice calls." - ), - "has_tools": True, - "default_policy": "background", - "risk_note": ( - "Changing the policy for the realtime agent may cause voice " - "call tool invocations to hang or be blocked." - ), - }, +from .guardrails_bulk import ( + apply_model_defaults_to_config, + apply_preset_to_config, + set_all_strategies_on_config, ) -# Set of agent context IDs for fast lookup in resolve_action fallback. -_BACKGROUND_AGENT_IDS: frozenset[str] = frozenset( - a["id"] for a in BACKGROUND_AGENTS +# Re-export public symbols so existing imports keep working. +from .guardrails_models import GuardrailRule, GuardrailsConfig, _VALID_STRATEGIES +from .guardrails_presets import ( + PRESET_BALANCED, + PRESET_PERMISSIVE, + PRESET_RESTRICTIVE, + _ALL_PRESET_TOOL_IDS, + _build_preset_policies, + list_background_agents, + list_presets, +) +from .guardrails_risk import ( + _MODEL_TIERS, + _risk_of, + get_model_tier, + get_preset_for_model, + list_model_tiers, ) +logger = logging.getLogger(__name__) -def list_background_agents() -> list[dict[str, Any]]: - """Return metadata for all background agents.""" - return list(BACKGROUND_AGENTS) - -# ── Model tiers ────────────────────────────────────────────────────────── -# Tier 1 (cautious): large frontier models -- most access, highest risk posture -# Tier 2 (standard): capable mid-range models -# Tier 3 (safe): smaller / older models -- least access, lowest risk posture - -_MODEL_TIERS: dict[str, int] = { - # Tier 1 -- cautious (most permissive, highest risk) - "gpt-5.3-codex": 1, - "claude-opus-4.6": 1, - "claude-opus-4.6-fast": 1, - # Tier 2 -- standard - "claude-sonnet-4.6": 2, - "gpt-5.2": 2, - "gemini-3-pro-preview": 2, - # Tier 3 -- safe (most restrictive, lowest risk) - "gpt-5-mini": 3, - "gpt-4.1": 3, -} - -_DEFAULT_TIER = 3 # Unknown models get the most restrictive tier - -# SDK tool categories used by presets -_FILE_TOOLS = frozenset({"create", "edit", "view", "grep", "glob"}) -_TERMINAL_TOOLS = frozenset({"run", "bash"}) - -# ── MCP / tool risk classification ────────────────────────────────────── -# Risk levels: low (read-only / public), medium (browser / scheduling), -# high (code repos, infra, phone calls). - -_MCP_RISK: dict[str, str] = { - "mcp:microsoft-learn": "low", # read-only public docs - "mcp:playwright": "medium", # browser automation, can navigate sites - "mcp:github-mcp-server": "high", # create repos, PRs, push code - "mcp:azure-mcp-server": "high", # create/delete Azure resources -} - -_SKILL_RISK: dict[str, str] = { - "skill:daily-briefing": "low", # read-only from local memory - "skill:wiki-search": "low", # read-only, public API - "skill:wiki-summary": "low", - "skill:wiki-deep-dive": "low", - "skill:gh-status-check": "low", # read-only, public API - "skill:gh-incidents": "low", - "skill:gh-maintenance": "low", - "skill:web-search": "medium", # browser-based - "skill:summarize-url": "medium", # browser-based - "skill:note-taking": "medium", # filesystem writes - "skill:daily-rollover": "medium", # M365 reads + file writes - "skill:end-day": "medium", - "skill:weekly-review": "medium", - "skill:monthly-review": "medium", - "skill:setup-foundry": "high", # provisions Azure infra - "skill:foundry-agent-chat": "high", # creates cloud agents - "skill:foundry-code-interpreter": "high", - "skill:setup-workiq": "medium", - "skill:setup-wikipedia": "low", -} - -_CUSTOM_TOOL_RISK: dict[str, str] = { - "schedule_task": "medium", - "cancel_task": "medium", - "list_scheduled_tasks": "low", - "make_voice_call": "high", - "search_memories_tool": "low", - "send_adaptive_card": "low", - "send_hero_card": "low", - "send_thumbnail_card": "low", - "send_card_carousel": "low", -} - - -def _risk_of(tool_id: str) -> str: - """Return the risk level for any tool/MCP/skill id.""" - if tool_id in _MCP_RISK: - return _MCP_RISK[tool_id] - if tool_id in _SKILL_RISK: - return _SKILL_RISK[tool_id] - if tool_id in _CUSTOM_TOOL_RISK: - return _CUSTOM_TOOL_RISK[tool_id] - # SDK tools - if tool_id in ("view", "grep", "glob"): - return "low" - if tool_id in ("create", "edit"): - return "medium" - if tool_id in ("run", "bash"): - return "high" - # Unknown MCP or skill -- default to high for safety - if tool_id.startswith("mcp:") or tool_id.startswith("skill:"): - return "high" - return "medium" - - -# ── Preset definitions ────────────────────────────────────────────────── - -PRESET_MINIMAL = "minimal" -PRESET_SUPERVISED = "supervised" -PRESET_RESTRICTIVE = "restrictive" -PRESET_BALANCED = "balanced" -PRESET_PERMISSIVE = "permissive" - -_TIER_TO_PRESET: dict[int, str] = { - 1: PRESET_PERMISSIVE, - 2: PRESET_BALANCED, - 3: PRESET_RESTRICTIVE, -} - -# Cross-reference: (selected_preset, model_tier) -> effective preset for model-column -# policies. This ensures that switching presets actually changes per-model rules -# while still respecting each model's inherent safety tier. -_EFFECTIVE_MODEL_PRESET: dict[tuple[str, int], str] = { - # Permissive preset: strong/standard models get permissive, cautious gets balanced - (PRESET_PERMISSIVE, 1): PRESET_PERMISSIVE, - (PRESET_PERMISSIVE, 2): PRESET_PERMISSIVE, - (PRESET_PERMISSIVE, 3): PRESET_BALANCED, - # Balanced preset: strong gets permissive, standard balanced, cautious restrictive - (PRESET_BALANCED, 1): PRESET_PERMISSIVE, - (PRESET_BALANCED, 2): PRESET_BALANCED, - (PRESET_BALANCED, 3): PRESET_RESTRICTIVE, - # Restrictive preset: strong gets balanced, standard/cautious get restrictive - (PRESET_RESTRICTIVE, 1): PRESET_BALANCED, - (PRESET_RESTRICTIVE, 2): PRESET_RESTRICTIVE, - (PRESET_RESTRICTIVE, 3): PRESET_RESTRICTIVE, -} - -# Strategy lookup: (preset, context, risk) -> strategy -# Rows: risk low / medium / high -# Columns: interactive / background - -_PRESET_MATRIX: dict[str, dict[str, dict[str, str]]] = { - PRESET_PERMISSIVE: { - "interactive": {"low": "filter", "medium": "filter", "high": "filter"}, - "background": {"low": "filter", "medium": "filter", "high": "hitl"}, - }, - PRESET_BALANCED: { - "interactive": {"low": "filter", "medium": "filter", "high": "hitl"}, - "background": {"low": "filter", "medium": "hitl", "high": "deny"}, - }, - PRESET_RESTRICTIVE: { - "interactive": {"low": "filter", "medium": "hitl", "high": "hitl"}, - "background": {"low": "filter", "medium": "deny", "high": "deny"}, - }, -} - -# Per-preset tool overrides applied *after* the risk matrix. -# Format: {preset: {context: {tool_id: strategy}}} -_PRESET_OVERRIDES: dict[str, dict[str, dict[str, str]]] = { - PRESET_BALANCED: { - "background": { - "create": "filter", - "edit": "filter", - }, - }, -} - -# Every tool/MCP/skill that presets should populate explicitly. -_ALL_PRESET_TOOL_IDS: list[str] = [ - # SDK - "create", "edit", "view", "grep", "glob", "run", "bash", - # Custom agent tools - "schedule_task", "cancel_task", "list_scheduled_tasks", "make_voice_call", - "send_adaptive_card", "send_hero_card", "send_thumbnail_card", "send_card_carousel", - "search_memories_tool", - # MCP - "mcp:microsoft-learn", "mcp:playwright", "mcp:github-mcp-server", "mcp:azure-mcp-server", - # Skills (builtin) - "skill:web-search", "skill:summarize-url", "skill:note-taking", "skill:daily-briefing", -] - - -def _build_preset_policies(preset: str) -> dict[str, Any]: - """Return context_defaults and tool_policies for a given preset name. - - Uses the ``_PRESET_MATRIX`` to map (preset, context, risk) -> strategy - for every known tool/MCP/skill. - """ - matrix = _PRESET_MATRIX.get(preset, _PRESET_MATRIX[PRESET_RESTRICTIVE]) - overrides = _PRESET_OVERRIDES.get(preset, {}) - policies: dict[str, dict[str, str]] = {"interactive": {}, "background": {}} - for tool_id in _ALL_PRESET_TOOL_IDS: - risk = _risk_of(tool_id) - for ctx in ("interactive", "background"): - policies[ctx][tool_id] = matrix[ctx][risk] - # Apply per-tool overrides after the matrix - for ctx, tool_map in overrides.items(): - for tool_id, strategy in tool_map.items(): - policies[ctx][tool_id] = strategy - # Context-level defaults (for tools not explicitly listed) - ctx_defaults = { - ctx: matrix[ctx]["medium"] for ctx in ("interactive", "background") - } - return { - "context_defaults": ctx_defaults, - "tool_policies": policies, - } - - -def get_model_tier(model: str) -> int: - """Return the security tier for a model (1=cautious, 2=standard, 3=safe).""" - return _MODEL_TIERS.get(model, _DEFAULT_TIER) - - -def get_preset_for_model(model: str) -> str: - """Return the recommended preset name for a model.""" - return _TIER_TO_PRESET.get(get_model_tier(model), PRESET_RESTRICTIVE) - - -def list_model_tiers() -> list[dict[str, Any]]: - """Return all known models with their tier and recommended preset.""" - result: list[dict[str, Any]] = [] - _TIER_LABELS = {1: "Strong", 2: "Standard", 3: "Cautious"} - for model, tier in sorted(_MODEL_TIERS.items(), key=lambda x: (x[1], x[0])): - result.append({ - "model": model, - "tier": tier, - "tier_label": _TIER_LABELS.get(tier, "Unknown"), - "preset": get_preset_for_model(model), - }) - return result - - -def list_presets() -> list[dict[str, Any]]: - """Return metadata for all available presets.""" - return [ - { - "id": PRESET_RESTRICTIVE, - "name": "Restrictive", - "description": ( - "For smaller or older models. Read-only tools allowed; " - "file edits and browser require HITL in interactive; " - "terminal, GitHub, Azure, and all MCP denied in background." - ), - "tier": 3, - "recommended_for": sorted( - m for m, t in _MODEL_TIERS.items() if t == 3 - ), - }, - { - "id": PRESET_BALANCED, - "name": "Balanced", - "description": ( - "For standard models. Low-risk tools allowed everywhere; " - "terminal and GitHub/Azure require HITL in interactive; " - "browser and schedules HITL in background; high-risk denied " - "in background. MS Learn allowed." - ), - "tier": 2, - "recommended_for": sorted( - m for m, t in _MODEL_TIERS.items() if t == 2 - ), - }, - { - "id": PRESET_PERMISSIVE, - "name": "Permissive", - "description": ( - "For strong frontier models. All tools allowed in interactive. " - "Terminal, GitHub, Azure still require HITL in background. " - "MS Learn, file operations, and browser allowed everywhere." - ), - "tier": 1, - "recommended_for": sorted( - m for m, t in _MODEL_TIERS.items() if t == 1 - ), - }, - ] - - -# Restrictiveness ranking for merging model policies across contexts. -# Higher rank = more restrictive. -_STRATEGY_RANK: dict[str, int] = { - "allow": 0, - "filter": 1, - "aitl": 2, - "hitl": 3, - "pitl": 4, - "ask": 4, - "deny": 5, -} - - -def _strategy_rank(strategy: str) -> int: - return _STRATEGY_RANK.get(strategy, 3) - - -@dataclass -class GuardrailsConfig: - """Top-level guardrails configuration.""" - - hitl_enabled: bool = False - default_action: str = "allow" # "allow" | "deny" | "hitl" | "pitl" | "aitl" | "filter" - default_channel: str = "chat" # "chat" | "phone" - phone_number: str = "" # E.164 number for phone verification - aitl_model: str = "gpt-4.1" # Model used by the AITL reviewer agent - aitl_spotlighting: bool = True # Spotlight untrusted content in AITL prompts - filter_mode: str = "prompt_shields" # always "prompt_shields" - content_safety_endpoint: str = "" # Azure Content Safety endpoint URL - content_safety_key: str = "" # Azure Content Safety API key - rules: list[GuardrailRule] = field(default_factory=list) - # Policy matrix fields (frontend-driven) - context_defaults: dict[str, str] = field(default_factory=dict) - tool_policies: dict[str, dict[str, str]] = field(default_factory=dict) - # Model-specific columns: user-defined model identifiers - model_columns: list[str] = field(default_factory=list) - # Model-scoped policies: model -> context -> tool -> strategy - model_policies: dict[str, dict[str, dict[str, str]]] = field(default_factory=dict) +_instance: GuardrailsConfigStore | None = None class GuardrailsConfigStore: @@ -572,10 +166,6 @@ def remove_tool_policy(self, context: str, tool_id: str) -> bool: return True return False - # ------------------------------------------------------------------ - # Model columns - # ------------------------------------------------------------------ - def add_model_column(self, model: str) -> None: if model not in self._config.model_columns: self._config.model_columns.append(model) @@ -611,74 +201,14 @@ def remove_model_policy( return True return False - # ------------------------------------------------------------------ - # Presets - # ------------------------------------------------------------------ - def apply_preset(self, preset: str, *, auto_models: bool = True) -> None: - """Apply a named preset to context_defaults and tool_policies. - - This overwrites the existing context_defaults and tool_policies. - When *auto_models* is ``True`` (default), the preset's recommended - models are added as model columns with tier-appropriate policies. - All existing model columns are also refreshed to reflect the new - preset's risk posture. - """ - valid = {PRESET_RESTRICTIVE, PRESET_BALANCED, PRESET_PERMISSIVE} - if preset not in valid: - raise ValueError("preset must be one of: %s" % ", ".join(sorted(valid))) - policies = _build_preset_policies(preset) - self._config.context_defaults = policies["context_defaults"] - self._config.tool_policies = policies["tool_policies"] - self._config.hitl_enabled = True - if auto_models: - # Add recommended models for this preset tier as model columns - preset_meta = next((p for p in list_presets() if p["id"] == preset), None) - if preset_meta: - new_models = [ - m for m in preset_meta["recommended_for"] - if m not in self._config.model_columns - ] - if new_models: - self.apply_model_defaults(new_models, preset=preset) - # Refresh ALL existing model columns with the new preset's posture - if self._config.model_columns: - self.apply_model_defaults(preset=preset) + """Apply a named preset to context_defaults and tool_policies.""" + apply_preset_to_config(self._config, preset, auto_models=auto_models) self._save() def set_all_strategies(self, strategy: str) -> None: - """Set every tool policy and context default to *strategy*. - - This is a bulk operation: all tools in ``_ALL_PRESET_TOOL_IDS`` - across interactive and background contexts are set to the given - strategy, and both context defaults are also set. All known - models from ``_MODEL_TIERS`` are added as model columns with - the same strategy applied to every tool across both contexts. - Guardrails are enabled. - """ - if strategy not in _VALID_STRATEGIES: - raise ValueError( - "strategy must be one of: %s" % ", ".join(sorted(_VALID_STRATEGIES)) - ) - policies: dict[str, dict[str, str]] = {"interactive": {}, "background": {}} - for tool_id in _ALL_PRESET_TOOL_IDS: - for ctx in ("interactive", "background"): - policies[ctx][tool_id] = strategy - self._config.context_defaults = { - "interactive": strategy, - "background": strategy, - } - self._config.tool_policies = policies - # Populate all known models with the same strategy - self._config.model_columns = sorted(_MODEL_TIERS.keys()) - model_policies: dict[str, dict[str, dict[str, str]]] = {} - for model in self._config.model_columns: - per_ctx: dict[str, dict[str, str]] = {} - for ctx in ("interactive", "background"): - per_ctx[ctx] = {tool_id: strategy for tool_id in _ALL_PRESET_TOOL_IDS} - model_policies[model] = per_ctx - self._config.model_policies = model_policies - self._config.hitl_enabled = True + """Set every tool policy and context default to *strategy*.""" + set_all_strategies_on_config(self._config, strategy) self._save() def apply_model_defaults( @@ -687,41 +217,8 @@ def apply_model_defaults( *, preset: str | None = None, ) -> None: - """Auto-populate model columns with tier-appropriate policies. - - For each model, determines the effective preset via the - ``_EFFECTIVE_MODEL_PRESET`` cross-reference of *preset* (the - user-selected risk posture) and the model's inherent tier. - Each model gets separate per-context policies (interactive - and background). - - If *models* is ``None``, uses the existing ``model_columns``. - If *preset* is ``None``, falls back to the model's own tier - preset. - """ - target_models = models if models is not None else list(self._config.model_columns) - for model in target_models: - if model not in self._config.model_columns: - self._config.model_columns.append(model) - tier = get_model_tier(model) - if preset: - effective = _EFFECTIVE_MODEL_PRESET.get( - (preset, tier), - get_preset_for_model(model), - ) - else: - effective = get_preset_for_model(model) - matrix = _PRESET_MATRIX.get(effective, _PRESET_MATRIX[PRESET_RESTRICTIVE]) - overrides = _PRESET_OVERRIDES.get(effective, {}) - per_ctx: dict[str, dict[str, str]] = {} - for ctx in ("interactive", "background"): - ctx_overrides = overrides.get(ctx, {}) - ctx_policies: dict[str, str] = {} - for tool_id in _ALL_PRESET_TOOL_IDS: - risk = _risk_of(tool_id) - ctx_policies[tool_id] = ctx_overrides.get(tool_id, matrix[ctx][risk]) - per_ctx[ctx] = ctx_policies - self._config.model_policies[model] = per_ctx + """Auto-populate model columns with tier-appropriate policies.""" + apply_model_defaults_to_config(self._config, models, preset=preset) self._save() def add_rule( @@ -917,10 +414,6 @@ def _load(self) -> None: except Exception as exc: logger.warning("Failed to load guardrails config from %s: %s", self._path, exc) - # ------------------------------------------------------------------ - # Policy YAML management - # ------------------------------------------------------------------ - @property def policy_path(self) -> Path: """Path to the generated policy YAML file.""" diff --git a/app/runtime/state/infra_config.py b/app/runtime/state/infra_config.py index 49d3735..7ee2a17 100644 --- a/app/runtime/state/infra_config.py +++ b/app/runtime/state/infra_config.py @@ -2,13 +2,11 @@ from __future__ import annotations -import json import logging from dataclasses import asdict, dataclass, field -from pathlib import Path from typing import Any -from ..config.settings import cfg +from ._base import BaseConfigStore logger = logging.getLogger(__name__) @@ -48,16 +46,30 @@ class ChannelsConfig: voice_call: VoiceCallConfig = field(default_factory=VoiceCallConfig) -class InfraConfigStore: +@dataclass +class InfraConfig: + """Top-level config dataclass wrapping bot and channel configs.""" + + bot: BotInfraConfig = field(default_factory=BotInfraConfig) + channels: ChannelsConfig = field(default_factory=ChannelsConfig) + + +class InfraConfigStore(BaseConfigStore[InfraConfig]): """Persists infrastructure configuration to ``infra.json``.""" - _SECRET_FIELDS = {"token", "acs_connection_string", "azure_openai_api_key"} + _config_type = InfraConfig + _default_filename = "infra.json" + _log_label = "infra config" + _SECRET_FIELDS = frozenset({"token", "acs_connection_string", "azure_openai_api_key"}) + _secret_prefix = "infra-" + + @property + def bot(self) -> BotInfraConfig: + return self._config.bot - def __init__(self, path: Path | None = None) -> None: - self._path = path or (cfg.data_dir / "infra.json") - self.bot = BotInfraConfig() - self.channels = ChannelsConfig() - self._load() + @property + def channels(self) -> ChannelsConfig: + return self._config.channels @property def bot_configured(self) -> bool: @@ -71,108 +83,75 @@ def telegram_configured(self) -> bool: def voice_call_configured(self) -> bool: return bool(self.channels.voice_call.acs_connection_string) - def _load(self) -> None: - if not self._path.exists(): - return - try: - data = json.loads(self._path.read_text()) - except (json.JSONDecodeError, OSError): - return - bot_data = data.get("bot", {}) + def _apply_raw(self, raw: dict[str, Any]) -> None: + bot_data = raw.get("bot", {}) for k, v in bot_data.items(): - if hasattr(self.bot, k): + if hasattr(self._config.bot, k): try: - setattr(self.bot, k, self._resolve_secret(v)) + setattr(self._config.bot, k, self._resolve_secret(v)) except Exception: logger.warning("Failed to resolve bot.%s -- skipping", k, exc_info=True) - tg_data = data.get("channels", {}).get("telegram", {}) + tg_data = raw.get("channels", {}).get("telegram", {}) for k, v in tg_data.items(): - if hasattr(self.channels.telegram, k): + if hasattr(self._config.channels.telegram, k): try: - setattr(self.channels.telegram, k, self._resolve_secret(v)) + setattr(self._config.channels.telegram, k, self._resolve_secret(v)) except Exception: logger.warning("Failed to resolve telegram.%s -- skipping", k, exc_info=True) - vc_data = data.get("channels", {}).get("voice_call", {}) + vc_data = raw.get("channels", {}).get("voice_call", {}) for k, v in vc_data.items(): - if hasattr(self.channels.voice_call, k): + if hasattr(self._config.channels.voice_call, k): try: - setattr(self.channels.voice_call, k, self._resolve_secret(v)) + setattr(self._config.channels.voice_call, k, self._resolve_secret(v)) except Exception: logger.warning("Failed to resolve voice_call.%s -- skipping", k, exc_info=True) - def _save(self) -> None: - data = { - "bot": asdict(self.bot), + def _save_data(self) -> dict[str, Any]: + return { + "bot": asdict(self._config.bot), "channels": { - "telegram": self._store_secrets(asdict(self.channels.telegram)), - "voice_call": self._store_secrets(asdict(self.channels.voice_call)), + "telegram": self._store_secrets(asdict(self._config.channels.telegram)), + "voice_call": self._store_secrets(asdict(self._config.channels.voice_call)), }, } - self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text(json.dumps(data, indent=2) + "\n") def save_bot(self, **kwargs: str) -> None: for k, v in kwargs.items(): - if hasattr(self.bot, k): - setattr(self.bot, k, v) + if hasattr(self._config.bot, k): + setattr(self._config.bot, k, v) self._save() def save_telegram(self, **kwargs: str) -> None: for k, v in kwargs.items(): - if hasattr(self.channels.telegram, k): - setattr(self.channels.telegram, k, v) + if hasattr(self._config.channels.telegram, k): + setattr(self._config.channels.telegram, k, v) self._save() def clear_telegram(self) -> None: - self.channels.telegram = TelegramChannelConfig() + self._config.channels.telegram = TelegramChannelConfig() self._save() def save_voice_call(self, **kwargs: str) -> None: for k, v in kwargs.items(): - if hasattr(self.channels.voice_call, k): - setattr(self.channels.voice_call, k, v) + if hasattr(self._config.channels.voice_call, k): + setattr(self._config.channels.voice_call, k, v) self._save() def clear_voice_call(self) -> None: - self.channels.voice_call = VoiceCallConfig() + self._config.channels.voice_call = VoiceCallConfig() self._save() def to_safe_dict(self) -> dict[str, Any]: - data = { - "bot": asdict(self.bot), + return { + "bot": asdict(self._config.bot), "channels": { - "telegram": self._mask_secrets(asdict(self.channels.telegram)), - "voice_call": self._mask_secrets(asdict(self.channels.voice_call)), + "telegram": self._mask_secrets(asdict(self._config.channels.telegram)), + "voice_call": self._mask_secrets(asdict(self._config.channels.voice_call)), }, } - return data def _mask_secrets(self, d: dict[str, Any]) -> dict[str, Any]: return { k: ("****" if k in self._SECRET_FIELDS and v else v) for k, v in d.items() } - - def _store_secrets(self, d: dict[str, Any]) -> dict[str, Any]: - from ..services.keyvault import kv, env_key_to_secret_name, is_kv_ref - - result = dict(d) - if not kv.enabled: - return result - for k in self._SECRET_FIELDS: - val = result.get(k, "") - if val and not is_kv_ref(val): - try: - ref = kv.store(env_key_to_secret_name(f"infra-{k}"), val) - result[k] = ref - except Exception as exc: - logger.warning("Failed to store secret %s in KV: %s", k, exc) - return result - - @staticmethod - def _resolve_secret(value: Any) -> Any: - if not isinstance(value, str): - return value - from ..services.keyvault import resolve_if_kv_ref - - return resolve_if_kv_ref(value) diff --git a/app/runtime/state/mcp_config.py b/app/runtime/state/mcp_config.py index d053be1..717ed77 100644 --- a/app/runtime/state/mcp_config.py +++ b/app/runtime/state/mcp_config.py @@ -195,7 +195,9 @@ def _load(self) -> None: if dirty: self._save() except Exception as exc: - logger.warning("Failed to load MCP config from %s: %s", self._path, exc) + logger.warning( + "Failed to load MCP config from %s: %s", self._path, exc, exc_info=True, + ) def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) diff --git a/app/runtime/state/memory.py b/app/runtime/state/memory.py index cd90e71..f74d05d 100644 --- a/app/runtime/state/memory.py +++ b/app/runtime/state/memory.py @@ -152,7 +152,7 @@ def _format_transcript(entries: list[_ChatEntry]) -> str: @staticmethod def _build_system_message() -> str: - from .profile import _profile_path, _usage_path as _skill_usage_path + from .profile import profile_path, _usage_path as _skill_usage_path template = (_TEMPLATES_DIR / "memory_prompt.md").read_text() @@ -163,7 +163,7 @@ def _build_system_message() -> str: return template.format( memory_daily_dir=cfg.memory_daily_dir, memory_topics_dir=cfg.memory_topics_dir, - profile_path=_profile_path(), + profile_path=profile_path(), skill_usage_path=_skill_usage_path(), suggestions_path=cfg.data_dir / "suggestions.txt", data_dir=cfg.data_dir, @@ -173,8 +173,6 @@ def _build_system_message() -> str: @staticmethod def _build_proactive_section() -> str: - from .session_store import SessionStore - proactive_template = (_TEMPLATES_DIR / "proactive_prompt_section.md").read_text() store = get_proactive_store() @@ -321,14 +319,17 @@ async def _process_proactive_followup(self) -> None: hours_since = store.hours_since_last_sent() if hours_since is not None and hours_since < prefs.min_gap_hours: - logger.info("Proactive gap too short (%.1fh < %dh), skipping.", hours_since, prefs.min_gap_hours) + logger.info( + "Proactive gap too short (%.1fh < %dh), skipping.", + hours_since, prefs.min_gap_hours, + ) return store.schedule_followup(message=message, deliver_at=deliver_at, context=context) self._last_proactive_scheduled = True logger.info("Proactive follow-up scheduled: %s at %s", message[:50], deliver_at) except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to process proactive follow-up: %s", exc) + logger.warning("Failed to process proactive follow-up: %s", exc, exc_info=True) try: followup_path.unlink(missing_ok=True) except OSError: @@ -361,7 +362,7 @@ def _process_proactive_reaction() -> None: store.update_preferences(avoided_topics=list(prefs.avoided_topics) + [detail]) logger.info("Added avoided topic from negative reaction: %s", detail) except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to process proactive reaction: %s", exc) + logger.warning("Failed to process proactive reaction: %s", exc, exc_info=True) try: reaction_path.unlink(missing_ok=True) except OSError: diff --git a/app/runtime/state/monitoring_config.py b/app/runtime/state/monitoring_config.py index 8d0fb57..627e59a 100644 --- a/app/runtime/state/monitoring_config.py +++ b/app/runtime/state/monitoring_config.py @@ -2,15 +2,10 @@ from __future__ import annotations -import json -import logging from dataclasses import asdict, dataclass, field -from pathlib import Path from typing import Any -from ..config.settings import cfg - -logger = logging.getLogger(__name__) +from ._base import BaseConfigStore def _url_encode_slashes(path: str) -> str: @@ -34,21 +29,12 @@ class MonitoringConfig: subscription_id: str = "" -class MonitoringConfigStore: +class MonitoringConfigStore(BaseConfigStore[MonitoringConfig]): """JSON-file-backed monitoring / OTel configuration.""" - def __init__(self, path: Path | None = None) -> None: - self._path = path or (cfg.data_dir / "monitoring.json") - self._config = MonitoringConfig() - self._load() - - @property - def path(self) -> Path: - return self._path - - @property - def config(self) -> MonitoringConfig: - return self._config + _config_type = MonitoringConfig + _default_filename = "monitoring.json" + _log_label = "monitoring config" @property def enabled(self) -> bool: @@ -157,27 +143,4 @@ def to_dict_full(self) -> dict[str, Any]: """Return the full config including secrets -- internal use only.""" return asdict(self._config) - def _load(self) -> None: - if not self._path.exists(): - return - try: - raw = json.loads(self._path.read_text()) - self._config = MonitoringConfig( - enabled=raw.get("enabled", False), - connection_string=raw.get("connection_string", ""), - sampling_ratio=raw.get("sampling_ratio", 1.0), - enable_live_metrics=raw.get("enable_live_metrics", False), - instrumentation_options=raw.get("instrumentation_options", {}), - provisioned=raw.get("provisioned", False), - app_insights_name=raw.get("app_insights_name", ""), - workspace_name=raw.get("workspace_name", ""), - resource_group=raw.get("resource_group", ""), - location=raw.get("location", ""), - subscription_id=raw.get("subscription_id", ""), - ) - except Exception as exc: - logger.warning("Failed to load monitoring config from %s: %s", self._path, exc) - def _save(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text(json.dumps(asdict(self._config), indent=2) + "\n") diff --git a/app/runtime/state/plugin_config.py b/app/runtime/state/plugin_config.py index 4ebcbba..0daf743 100644 --- a/app/runtime/state/plugin_config.py +++ b/app/runtime/state/plugin_config.py @@ -15,6 +15,12 @@ logger = logging.getLogger(__name__) +_DEFAULT_STATE: dict[str, Any] = { + "enabled": False, + "setup_completed": False, + "installed_at": None, +} + class PluginConfigStore: """JSON-file-backed plugin state store.""" @@ -29,22 +35,14 @@ def path(self) -> Path: return self._path def get_state(self, plugin_id: str) -> dict[str, Any]: - return self._plugins.get(plugin_id, { - "enabled": False, - "setup_completed": False, - "installed_at": None, - }) + return self._plugins.get(plugin_id, dict(_DEFAULT_STATE)) def list_states(self) -> dict[str, dict[str, Any]]: return dict(self._plugins) def set_enabled(self, plugin_id: str, enabled: bool) -> None: if plugin_id not in self._plugins: - self._plugins[plugin_id] = { - "enabled": False, - "setup_completed": False, - "installed_at": None, - } + self._plugins[plugin_id] = dict(_DEFAULT_STATE) self._plugins[plugin_id]["enabled"] = enabled if enabled and not self._plugins[plugin_id].get("installed_at"): self._plugins[plugin_id]["installed_at"] = datetime.now(UTC).isoformat() @@ -52,20 +50,12 @@ def set_enabled(self, plugin_id: str, enabled: bool) -> None: def mark_setup_completed(self, plugin_id: str) -> None: if plugin_id not in self._plugins: - self._plugins[plugin_id] = { - "enabled": False, - "setup_completed": False, - "installed_at": None, - } + self._plugins[plugin_id] = dict(_DEFAULT_STATE) self._plugins[plugin_id]["setup_completed"] = True self._save() def reset(self, plugin_id: str) -> None: - self._plugins[plugin_id] = { - "enabled": False, - "setup_completed": False, - "installed_at": None, - } + self._plugins[plugin_id] = dict(_DEFAULT_STATE) self._save() def _load(self) -> None: @@ -75,7 +65,9 @@ def _load(self) -> None: raw = json.loads(self._path.read_text()) self._plugins = raw.get("plugins", {}) except Exception as exc: - logger.warning("Failed to load plugin config from %s: %s", self._path, exc) + logger.warning( + "Failed to load plugin config from %s: %s", self._path, exc, exc_info=True, + ) def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) diff --git a/app/runtime/state/proactive.py b/app/runtime/state/proactive.py index a41185d..185dfcc 100644 --- a/app/runtime/state/proactive.py +++ b/app/runtime/state/proactive.py @@ -63,7 +63,7 @@ def _load(self) -> None: try: self._data = json.loads(self._path.read_text()) except (json.JSONDecodeError, OSError) as exc: - logger.warning("Failed to load proactive state: %s", exc) + logger.warning("Failed to load proactive state: %s", exc, exc_info=True) self._data = {} env_default = cfg.proactive_enabled if not self._path.exists() else False self._data.setdefault("enabled", env_default) diff --git a/app/runtime/state/profile.py b/app/runtime/state/profile.py index 53f2c76..355b85c 100644 --- a/app/runtime/state/profile.py +++ b/app/runtime/state/profile.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import logging import time from pathlib import Path from typing import Any @@ -11,8 +10,6 @@ from ..config.settings import cfg from ..util.singletons import register_singleton -logger = logging.getLogger(__name__) - _DEFAULT_PROFILE: dict[str, Any] = { "name": "polyclaw", "emoji": "", @@ -26,6 +23,11 @@ def _profile_path() -> Path: return cfg.data_dir / "agent_profile.json" +def profile_path() -> Path: + """Return the path to the agent profile JSON file.""" + return _profile_path() + + def _usage_path() -> Path: return cfg.data_dir / "skill_usage.json" diff --git a/app/runtime/state/sandbox_config.py b/app/runtime/state/sandbox_config.py index 0b4f5b2..4113bae 100644 --- a/app/runtime/state/sandbox_config.py +++ b/app/runtime/state/sandbox_config.py @@ -2,15 +2,10 @@ from __future__ import annotations -import json -import logging -from dataclasses import asdict, dataclass, field -from pathlib import Path +from dataclasses import dataclass, field from typing import Any -from ..config.settings import cfg - -logger = logging.getLogger(__name__) +from ._base import BaseConfigStore DEFAULT_WHITELIST: list[str] = [ "media", "memory", "notes", "sessions", "skills", @@ -38,21 +33,12 @@ class SandboxConfig: pool_id: str = "" -class SandboxConfigStore: +class SandboxConfigStore(BaseConfigStore[SandboxConfig]): """JSON-file-backed sandbox configuration.""" - def __init__(self, path: Path | None = None) -> None: - self._path = path or (cfg.data_dir / "sandbox.json") - self._config = SandboxConfig() - self._load() - - @property - def path(self) -> Path: - return self._path - - @property - def config(self) -> SandboxConfig: - return self._config + _config_type = SandboxConfig + _default_filename = "sandbox.json" + _log_label = "sandbox config" @property def enabled(self) -> bool: @@ -155,27 +141,4 @@ def update(self, **kwargs: Any) -> None: setattr(self._config, k, v) self._save() - def to_dict(self) -> dict[str, Any]: - return asdict(self._config) - - def _load(self) -> None: - if not self._path.exists(): - return - try: - raw = json.loads(self._path.read_text()) - self._config = SandboxConfig( - enabled=raw.get("enabled", False), - sync_data=raw.get("sync_data", True), - session_pool_endpoint=raw.get("session_pool_endpoint", ""), - whitelist=raw.get("whitelist", list(DEFAULT_WHITELIST)), - resource_group=raw.get("resource_group", ""), - location=raw.get("location", ""), - pool_name=raw.get("pool_name", ""), - pool_id=raw.get("pool_id", ""), - ) - except Exception as exc: - logger.warning("Failed to load sandbox config from %s: %s", self._path, exc) - - def _save(self) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text(json.dumps(asdict(self._config), indent=2) + "\n") + diff --git a/app/runtime/state/tool_activity_models.py b/app/runtime/state/tool_activity_models.py new file mode 100644 index 0000000..373f5e1 --- /dev/null +++ b/app/runtime/state/tool_activity_models.py @@ -0,0 +1,91 @@ +"""Data models and risk-scoring helpers for tool activity tracking.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass +class ToolActivityEntry: + """A single recorded tool invocation.""" + + id: str = "" + session_id: str = "" + tool: str = "" + call_id: str = "" + category: str = "" # sdk | custom | mcp | skill + arguments: str = "" + result: str = "" + status: str = "" # started | completed | denied | error + timestamp: float = 0.0 + duration_ms: float | None = None + flagged: bool = False + flag_reason: str = "" + risk_score: int = 0 # 0-100 computed risk score + risk_factors: list[str] = field(default_factory=list) + model: str = "" # which LLM model initiated this tool call + interaction_type: str = "" # "" | hitl | aitl | pitl | filter | deny + shield_result: str = "" # "" | clean | attack | error | not_configured + shield_detail: str = "" # human-readable detail from Content Safety API + shield_elapsed_ms: float | None = None # round-trip time for the shield call + + +_SUSPICIOUS_PATTERNS: list[tuple[str, int, str]] = [ + # (pattern, severity 1-100, description) + ("rm -rf", 90, "Recursive forced deletion"), + ("rm -r /", 100, "Root filesystem deletion"), + ("DROP TABLE", 85, "SQL table drop"), + ("DELETE FROM", 60, "SQL mass deletion"), + ("curl.*|.*sh", 80, "Remote code execution via curl"), + ("wget.*|.*sh", 80, "Remote code execution via wget"), + ("eval(", 75, "Dynamic code evaluation"), + ("exec(", 75, "Dynamic code execution"), + ("os.system", 70, "Shell command execution"), + ("subprocess", 50, "Subprocess invocation"), + ("chmod 777", 65, "World-writable permissions"), + ("passwd", 55, "Password file access"), + ("/etc/shadow", 90, "Shadow password file access"), + ("env | grep", 45, "Environment variable enumeration"), + ("printenv", 45, "Environment variable dump"), + ("base64 -d", 60, "Base64 decode (potential obfuscation)"), + (".ssh/", 70, "SSH directory access"), + ("id_rsa", 85, "SSH private key access"), + ("PRIVATE KEY", 95, "Private key exposure"), + ("API_KEY", 50, "API key in arguments"), + ("SECRET", 55, "Secret value in arguments"), + ("TOKEN", 45, "Token value in arguments"), + ("password", 50, "Password in arguments"), + ("credentials", 55, "Credentials reference"), + ("sudo ", 60, "Privilege escalation"), + ("nc -l", 70, "Netcat listener (reverse shell)"), + (">&/dev/tcp", 90, "Bash reverse shell"), + ("/dev/tcp", 85, "Network device access"), + ("mkfifo", 65, "Named pipe creation"), + ("nmap", 55, "Network scanning"), + ("sqlmap", 80, "SQL injection tool"), + (".env", 40, "Environment file access"), + ("aws configure", 50, "Cloud credential configuration"), + ("gcloud auth", 50, "Cloud credential configuration"), + ("az login", 40, "Azure CLI login"), + ("docker run", 45, "Container execution"), + ("kubectl exec", 55, "Kubernetes pod execution"), +] + + +def check_suspicious(arguments: str, result: str) -> tuple[bool, str, int, list[str]]: + """Check if a tool call looks suspicious based on arguments/result. + + Returns (flagged, primary_reason, risk_score, risk_factors). + """ + text = f"{arguments} {result}".lower() + factors: list[str] = [] + max_severity = 0 + primary_reason = "" + for pattern, severity, description in _SUSPICIOUS_PATTERNS: + if pattern.lower() in text: + factors.append(description) + if severity > max_severity: + max_severity = severity + primary_reason = f"Suspicious pattern: {pattern}" + flagged = max_severity >= 40 + return flagged, primary_reason, max_severity, factors diff --git a/app/runtime/state/tool_activity_store.py b/app/runtime/state/tool_activity_store.py index 23134b9..93f309d 100644 --- a/app/runtime/state/tool_activity_store.py +++ b/app/runtime/state/tool_activity_store.py @@ -8,102 +8,17 @@ import logging import threading import time -from dataclasses import asdict, dataclass, field +from dataclasses import asdict from pathlib import Path from typing import Any from ..config.settings import cfg from ..util.singletons import register_singleton +from .tool_activity_models import ToolActivityEntry, check_suspicious logger = logging.getLogger(__name__) -@dataclass -class ToolActivityEntry: - """A single recorded tool invocation.""" - - id: str = "" - session_id: str = "" - tool: str = "" - call_id: str = "" - category: str = "" # sdk | custom | mcp | skill - arguments: str = "" - result: str = "" - status: str = "" # started | completed | denied | error - timestamp: float = 0.0 - duration_ms: float | None = None - flagged: bool = False - flag_reason: str = "" - risk_score: int = 0 # 0-100 computed risk score - risk_factors: list[str] = field(default_factory=list) - model: str = "" # which LLM model initiated this tool call - interaction_type: str = "" # "" | hitl | aitl | pitl | filter | deny - shield_result: str = "" # "" | clean | attack | error | not_configured - shield_detail: str = "" # human-readable detail from Content Safety API - shield_elapsed_ms: float | None = None # round-trip time for the shield call - - -_SUSPICIOUS_PATTERNS: list[tuple[str, int, str]] = [ - # (pattern, severity 1-100, description) - ("rm -rf", 90, "Recursive forced deletion"), - ("rm -r /", 100, "Root filesystem deletion"), - ("DROP TABLE", 85, "SQL table drop"), - ("DELETE FROM", 60, "SQL mass deletion"), - ("curl.*|.*sh", 80, "Remote code execution via curl"), - ("wget.*|.*sh", 80, "Remote code execution via wget"), - ("eval(", 75, "Dynamic code evaluation"), - ("exec(", 75, "Dynamic code execution"), - ("os.system", 70, "Shell command execution"), - ("subprocess", 50, "Subprocess invocation"), - ("chmod 777", 65, "World-writable permissions"), - ("passwd", 55, "Password file access"), - ("/etc/shadow", 90, "Shadow password file access"), - ("env | grep", 45, "Environment variable enumeration"), - ("printenv", 45, "Environment variable dump"), - ("base64 -d", 60, "Base64 decode (potential obfuscation)"), - (".ssh/", 70, "SSH directory access"), - ("id_rsa", 85, "SSH private key access"), - ("PRIVATE KEY", 95, "Private key exposure"), - ("API_KEY", 50, "API key in arguments"), - ("SECRET", 55, "Secret value in arguments"), - ("TOKEN", 45, "Token value in arguments"), - ("password", 50, "Password in arguments"), - ("credentials", 55, "Credentials reference"), - ("sudo ", 60, "Privilege escalation"), - ("nc -l", 70, "Netcat listener (reverse shell)"), - (">&/dev/tcp", 90, "Bash reverse shell"), - ("/dev/tcp", 85, "Network device access"), - ("mkfifo", 65, "Named pipe creation"), - ("nmap", 55, "Network scanning"), - ("sqlmap", 80, "SQL injection tool"), - (".env", 40, "Environment file access"), - ("aws configure", 50, "Cloud credential configuration"), - ("gcloud auth", 50, "Cloud credential configuration"), - ("az login", 40, "Azure CLI login"), - ("docker run", 45, "Container execution"), - ("kubectl exec", 55, "Kubernetes pod execution"), -] - - -def _check_suspicious(arguments: str, result: str) -> tuple[bool, str, int, list[str]]: - """Check if a tool call looks suspicious based on arguments/result. - - Returns (flagged, primary_reason, risk_score, risk_factors). - """ - text = f"{arguments} {result}".lower() - factors: list[str] = [] - max_severity = 0 - primary_reason = "" - for pattern, severity, description in _SUSPICIOUS_PATTERNS: - if pattern.lower() in text: - factors.append(description) - if severity > max_severity: - max_severity = severity - primary_reason = f"Suspicious pattern: {pattern}" - flagged = max_severity >= 40 - return flagged, primary_reason, max_severity, factors - - class ToolActivityStore: """Append-only log of tool invocations for audit and review. @@ -120,6 +35,13 @@ def __init__(self, path: Path | None = None) -> None: self._counter = 0 self._load() + def _deduplicated(self) -> list[ToolActivityEntry]: + """Return entries deduplicated by id, keeping the latest version.""" + by_id: dict[str, ToolActivityEntry] = {} + for e in self._entries: + by_id[e.id] = e + return list(by_id.values()) + def _load(self) -> None: if not self._path.exists(): return @@ -139,7 +61,7 @@ def _load(self) -> None: self._counter = max(self._counter, int(entry.id.split("-")[-1] or "0")) self._entries = list(by_id.values()) except (json.JSONDecodeError, OSError) as exc: - logger.warning("[tool_activity] failed to load: %s", exc) + logger.warning("[tool_activity] failed to load: %s", exc, exc_info=True) def _next_id(self) -> str: self._counter += 1 @@ -174,7 +96,7 @@ def record_start( model=model, interaction_type=interaction_type, ) - flagged, reason, risk, factors = _check_suspicious(arguments, "") + flagged, reason, risk, factors = check_suspicious(arguments, "") entry.flagged = flagged entry.flag_reason = reason entry.risk_score = risk @@ -209,7 +131,7 @@ def record_complete( pending.result = result[:2000] if result else "" pending.status = status pending.duration_ms = (time.time() - pending.timestamp) * 1000 - flagged, reason, risk, factors = _check_suspicious(pending.arguments, result) + flagged, reason, risk, factors = check_suspicious(pending.arguments, result) if flagged and not pending.flagged: pending.flagged = True pending.flag_reason = reason @@ -264,11 +186,9 @@ def query( ) -> dict[str, Any]: """Query tool activity with filters.""" with self._lock: - # Deduplicate: keep the latest version of each entry id - by_id: dict[str, ToolActivityEntry] = {} - for e in self._entries: - by_id[e.id] = e - entries = sorted(by_id.values(), key=lambda e: e.timestamp, reverse=True) + entries = sorted( + self._deduplicated(), key=lambda e: e.timestamp, reverse=True, + ) # Apply filters if session_id: @@ -301,10 +221,7 @@ def query( def get_summary(self) -> dict[str, Any]: """Get aggregate statistics about tool activity.""" with self._lock: - by_id: dict[str, ToolActivityEntry] = {} - for e in self._entries: - by_id[e.id] = e - entries = list(by_id.values()) + entries = self._deduplicated() total = len(entries) flagged = sum(1 for e in entries if e.flagged) @@ -406,10 +323,7 @@ def get_timeline( ) -> list[dict[str, Any]]: """Return tool call counts bucketed by time interval.""" with self._lock: - by_id: dict[str, ToolActivityEntry] = {} - for e in self._entries: - by_id[e.id] = e - entries = list(by_id.values()) + entries = self._deduplicated() if not entries: return [] @@ -427,7 +341,10 @@ def get_timeline( continue bucket_ts = int(e.timestamp // bucket_secs) * bucket_secs if bucket_ts not in buckets: - buckets[bucket_ts] = {"total": 0, "flagged": 0, "sdk": 0, "mcp": 0, "custom": 0, "skill": 0} + buckets[bucket_ts] = { + "total": 0, "flagged": 0, + "sdk": 0, "mcp": 0, "custom": 0, "skill": 0, + } buckets[bucket_ts]["total"] += 1 if e.flagged: buckets[bucket_ts]["flagged"] += 1 @@ -442,10 +359,7 @@ def get_timeline( def get_session_breakdown(self) -> list[dict[str, Any]]: """Return per-session aggregation for the session-level audit view.""" with self._lock: - by_id: dict[str, ToolActivityEntry] = {} - for e in self._entries: - by_id[e.id] = e - entries = list(by_id.values()) + entries = self._deduplicated() sessions: dict[str, dict[str, Any]] = {} for e in entries: @@ -570,7 +484,7 @@ def import_from_sessions(self, session_store: object) -> int: status="completed", timestamp=msg.get("timestamp", 0), ) - flagged, reason, risk, factors = _check_suspicious(entry.arguments, entry.result) + flagged, reason, risk, factors = check_suspicious(entry.arguments, entry.result) entry.flagged = flagged entry.flag_reason = reason entry.risk_score = risk diff --git a/app/runtime/tests/test_agent_tools.py b/app/runtime/tests/test_agent_tools.py index 7eb2ee0..9f84994 100644 --- a/app/runtime/tests/test_agent_tools.py +++ b/app/runtime/tests/test_agent_tools.py @@ -150,7 +150,7 @@ def test_list_with_tasks(self): class TestMakeVoiceCallTool: - @patch("app.runtime.agent.tools.cfg") + @patch("app.runtime.agent.tools.voice.cfg") def test_no_target_number(self, mock_cfg): from app.runtime.agent.tools import make_voice_call @@ -158,8 +158,8 @@ def test_no_target_number(self, mock_cfg): result = _call_tool(make_voice_call, {"prompt": "hi"}) assert result["status"] == "error" - @patch("app.runtime.agent.tools.threading.Thread") - @patch("app.runtime.agent.tools.cfg") + @patch("app.runtime.agent.tools.voice.threading.Thread") + @patch("app.runtime.agent.tools.voice.cfg") def test_with_target_number(self, mock_cfg, mock_thread): from app.runtime.agent.tools import make_voice_call diff --git a/app/runtime/tests/test_azure_cli.py b/app/runtime/tests/test_azure_cli.py index b071401..88201dc 100644 --- a/app/runtime/tests/test_azure_cli.py +++ b/app/runtime/tests/test_azure_cli.py @@ -10,7 +10,7 @@ import pytest -from app.runtime.services.azure import AzureCLI +from app.runtime.services.cloud.azure import AzureCLI from app.runtime.util.result import Result @@ -156,7 +156,7 @@ def test_network_error(self, mock_urlopen) -> None: assert result.success is False assert "Cannot reach" in result.message - @patch("app.runtime.services.azure.sleep", return_value=None) + @patch("app.runtime.services.cloud.azure.sleep", return_value=None) @patch("urllib.request.urlopen") def test_404_not_retried(self, mock_urlopen, _mock_sleep) -> None: """A 404 means the bot doesn't exist -- it should NOT be retried.""" @@ -170,7 +170,7 @@ def test_404_not_retried(self, mock_urlopen, _mock_sleep) -> None: # Only one attempt -- no retries on 404. assert mock_urlopen.call_count == 1 - @patch("app.runtime.services.azure.sleep", return_value=None) + @patch("app.runtime.services.cloud.azure.sleep", return_value=None) @patch("urllib.request.urlopen") def test_retries_on_transient_502(self, mock_urlopen, _mock_sleep) -> None: """A transient 502 should be retried and succeed on the next attempt.""" @@ -202,7 +202,7 @@ class TestAzureCLIGetChannels: @patch.object(AzureCLI, "json") def test_no_config(self, mock_json) -> None: az = AzureCLI() - with patch("app.runtime.services.azure.cfg") as mock_cfg: + with patch("app.runtime.config.settings.cfg") as mock_cfg: mock_cfg.env = MagicMock() mock_cfg.env.read.return_value = "" result = az.get_channels() @@ -214,7 +214,7 @@ def test_with_telegram(self, mock_json) -> None: "properties": {"configuredChannels": ["webchat", "telegram"]} } az = AzureCLI() - with patch("app.runtime.services.azure.cfg") as mock_cfg: + with patch("app.runtime.config.settings.cfg") as mock_cfg: mock_cfg.env = MagicMock() mock_cfg.env.read.side_effect = lambda k: "rg" if k == "BOT_RESOURCE_GROUP" else "bot" result = az.get_channels() @@ -225,7 +225,7 @@ class TestAzureCLIUpdateEndpoint: @patch.object(AzureCLI, "json") def test_not_configured(self, mock_json) -> None: az = AzureCLI() - with patch("app.runtime.services.azure.cfg") as mock_cfg: + with patch("app.runtime.config.settings.cfg") as mock_cfg: mock_cfg.env = MagicMock() mock_cfg.env.read.return_value = "" result = az.update_endpoint("https://example.com/api/messages") diff --git a/app/runtime/tests/test_content_safety_routes.py b/app/runtime/tests/test_content_safety_routes.py index 6156854..eea9fff 100644 --- a/app/runtime/tests/test_content_safety_routes.py +++ b/app/runtime/tests/test_content_safety_routes.py @@ -9,8 +9,8 @@ from aiohttp.test_utils import TestClient, TestServer from app.runtime.server.routes.content_safety_routes import ContentSafetyRoutes -from app.runtime.services.prompt_shield import ShieldResult -from app.runtime.state.guardrails_config import GuardrailsConfigStore +from app.runtime.services.security.prompt_shield import ShieldResult +from app.runtime.state.guardrails import GuardrailsConfigStore def _build_app(routes: ContentSafetyRoutes) -> web.Application: diff --git a/app/runtime/tests/test_extract_media.py b/app/runtime/tests/test_extract_media.py index 0fb0fa3..b4b0a4f 100644 --- a/app/runtime/tests/test_extract_media.py +++ b/app/runtime/tests/test_extract_media.py @@ -1,4 +1,4 @@ -"""Tests for incoming media (extract_outgoing_attachments, download_attachment).""" +"""Tests for media extraction (extract_outgoing_attachments, download_attachment).""" from __future__ import annotations @@ -8,7 +8,8 @@ import pytest -from app.runtime.media.incoming import build_media_prompt, extract_outgoing_attachments +from app.runtime.media.incoming import build_media_prompt +from app.runtime.media.outgoing import extract_outgoing_attachments class TestExtractOutgoingAttachments: diff --git a/app/runtime/tests/test_guardrails_policy_validation.py b/app/runtime/tests/test_guardrails_policy_validation.py index b3bac83..c0ed8a3 100644 --- a/app/runtime/tests/test_guardrails_policy_validation.py +++ b/app/runtime/tests/test_guardrails_policy_validation.py @@ -9,7 +9,7 @@ import pytest -from app.runtime.state.guardrails_config import ( +from app.runtime.state.guardrails import ( GuardrailsConfigStore, PRESET_BALANCED, PRESET_PERMISSIVE, @@ -151,16 +151,19 @@ def setup_store(self, tmp_path) -> None: self.s = _store(tmp_path) self.s.apply_preset(PRESET_BALANCED, auto_models=False) - def test_file_ops_filtered_everywhere(self) -> None: - for ctx in ("interactive", "background"): - assert self.s.resolve_action("create", execution_context=ctx) == "filter" - assert self.s.resolve_action("edit", execution_context=ctx) == "filter" + def test_file_ops_filtered_interactive(self) -> None: + assert self.s.resolve_action("create", execution_context="interactive") == "filter" + assert self.s.resolve_action("edit", execution_context="interactive") == "filter" + + def test_file_ops_aitl_background(self) -> None: + assert self.s.resolve_action("create", execution_context="background") == "aitl" + assert self.s.resolve_action("edit", execution_context="background") == "aitl" def test_terminal_hitl_interactive(self) -> None: assert self.s.resolve_action("run", execution_context="interactive") == "hitl" - def test_terminal_denied_background(self) -> None: - assert self.s.resolve_action("run", execution_context="background") == "deny" + def test_terminal_aitl_background(self) -> None: + assert self.s.resolve_action("run", execution_context="background") == "aitl" def test_playwright_hitl_background(self) -> None: assert self.s.resolve_action( @@ -469,7 +472,7 @@ def setup_store(self, tmp_path) -> None: ) def test_voice_call_denied_by_rule(self) -> None: - # make_voice_call is in preset tool_policies (high risk -> hitl interactive / deny bg) + # make_voice_call is in preset tool_policies (high risk -> hitl interactive / aitl bg) # Since it's in tool_policies, the rule doesn't override it for preset contexts. # But the tool_policies entry comes first. # Interactive: make_voice_call = hitl (preset balanced: high risk interactive) @@ -477,10 +480,10 @@ def test_voice_call_denied_by_rule(self) -> None: "make_voice_call", execution_context="interactive", ) == "hitl" - def test_voice_call_denied_background(self) -> None: + def test_voice_call_aitl_background(self) -> None: assert self.s.resolve_action( "make_voice_call", execution_context="background", - ) == "deny" + ) == "aitl" def test_strong_model_create_files_filtered(self) -> None: assert self.s.resolve_action("create", model="gpt-5.3-codex") == "filter" diff --git a/app/runtime/tests/test_guardrails_presets.py b/app/runtime/tests/test_guardrails_presets.py index dc68f36..96e6b6e 100644 --- a/app/runtime/tests/test_guardrails_presets.py +++ b/app/runtime/tests/test_guardrails_presets.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.runtime.state.guardrails_config import ( +from app.runtime.state.guardrails import ( _ALL_PRESET_TOOL_IDS, _MODEL_TIERS, PRESET_BALANCED, @@ -136,10 +136,11 @@ def test_balanced_hitl_terminal_in_interactive(self) -> None: assert p["tool_policies"]["interactive"]["run"] == "hitl" assert p["tool_policies"]["interactive"]["bash"] == "hitl" - def test_balanced_denies_high_risk_in_background(self) -> None: + def test_balanced_aitl_terminal_voice_in_background(self) -> None: p = _build_preset_policies(PRESET_BALANCED) - assert p["tool_policies"]["background"]["run"] == "deny" - assert p["tool_policies"]["background"]["bash"] == "deny" + assert p["tool_policies"]["background"]["run"] == "aitl" + assert p["tool_policies"]["background"]["bash"] == "aitl" + assert p["tool_policies"]["background"]["make_voice_call"] == "aitl" assert p["tool_policies"]["background"]["mcp:github-mcp-server"] == "deny" assert p["tool_policies"]["background"]["mcp:azure-mcp-server"] == "deny" @@ -152,10 +153,10 @@ def test_balanced_filters_file_ops_in_interactive(self) -> None: assert p["tool_policies"]["interactive"]["create"] == "filter" assert p["tool_policies"]["interactive"]["edit"] == "filter" - def test_balanced_filters_file_ops_in_background(self) -> None: + def test_balanced_aitl_file_ops_in_background(self) -> None: p = _build_preset_policies(PRESET_BALANCED) - assert p["tool_policies"]["background"]["create"] == "filter" - assert p["tool_policies"]["background"]["edit"] == "filter" + assert p["tool_policies"]["background"]["create"] == "aitl" + assert p["tool_policies"]["background"]["edit"] == "aitl" # ── Permissive ── def test_permissive_filters_most_in_interactive(self) -> None: @@ -230,10 +231,10 @@ def test_apply_model_defaults_differentiates_tiers(self, tmp_path) -> None: assert strong["background"]["run"] == "hitl" assert strong["interactive"]["mcp:microsoft-learn"] == "filter" - # Standard (balanced): view filtered, run hitl interactive / deny bg + # Standard (balanced): view filtered, run hitl interactive / aitl bg (override) assert standard["interactive"]["view"] == "filter" assert standard["interactive"]["run"] == "hitl" - assert standard["background"]["run"] == "deny" + assert standard["background"]["run"] == "aitl" assert standard["interactive"]["mcp:microsoft-learn"] == "filter" # Cautious (restrictive): run hitl interactive / deny bg, github deny bg @@ -267,10 +268,10 @@ def test_cautious_model_mcp_risk_differentiation(self, tmp_path) -> None: # MS Learn (low risk) -> filter everywhere assert cautious["interactive"]["mcp:microsoft-learn"] == "filter" assert cautious["background"]["mcp:microsoft-learn"] == "filter" - # Playwright (medium risk) -> hitl interactive / deny background + # Playwright (medium risk) -> hitl interactive / deny background (own-tier restrictive) assert cautious["interactive"]["mcp:playwright"] == "hitl" assert cautious["background"]["mcp:playwright"] == "deny" - # GitHub/Azure (high risk) -> hitl interactive / deny background + # GitHub/Azure (high risk) -> hitl interactive / deny background (own-tier restrictive) assert cautious["interactive"]["mcp:github-mcp-server"] == "hitl" assert cautious["background"]["mcp:github-mcp-server"] == "deny" assert cautious["interactive"]["mcp:azure-mcp-server"] == "hitl" @@ -365,7 +366,7 @@ class TestBackgroundAgents: """Background agent metadata and resolve_action fallback.""" def test_list_background_agents(self) -> None: - from app.runtime.state.guardrails_config import list_background_agents + from app.runtime.state.guardrails import list_background_agents agents = list_background_agents() ids = [a["id"] for a in agents] diff --git a/app/runtime/tests/test_hitl.py b/app/runtime/tests/test_hitl.py index 34ce194..8b3762a 100644 --- a/app/runtime/tests/test_hitl.py +++ b/app/runtime/tests/test_hitl.py @@ -8,7 +8,7 @@ import pytest from app.runtime.agent.hitl import HitlInterceptor -from app.runtime.state.guardrails_config import GuardrailsConfigStore +from app.runtime.state.guardrails import GuardrailsConfigStore @pytest.fixture() @@ -22,7 +22,7 @@ def guardrails(tmp_path) -> GuardrailsConfigStore: store._path = tmp_path / "guardrails.json" store._policy_path = tmp_path / "policy.yaml" store._lock = __import__("threading").Lock() - from app.runtime.state.guardrails_config import GuardrailsConfig + from app.runtime.state.guardrails import GuardrailsConfig store._config = GuardrailsConfig(hitl_enabled=True, default_action="ask") store._rebuild_engine() @@ -45,8 +45,10 @@ class TestWebChatApproval: async def test_ask_chat_emits_approval_requested(self, hitl): events: list[tuple[str, dict]] = [] - hitl.set_emit(lambda t, d: events.append((t, d))) - hitl.set_execution_context("interactive") + hitl.bind_turn( + emit=lambda t, d: events.append((t, d)), + execution_context="interactive", + ) async def approve_later(): await asyncio.sleep(0.05) @@ -63,8 +65,7 @@ async def approve_later(): assert "approval_request" in event_types async def test_ask_chat_deny(self, hitl): - hitl.set_emit(lambda t, d: None) - hitl.set_execution_context("interactive") + hitl.bind_turn(emit=lambda t, d: None, execution_context="interactive") async def deny_later(): await asyncio.sleep(0.05) @@ -84,8 +85,7 @@ class TestBotChannelApproval: async def test_ask_bot_sends_confirmation_text(self, hitl): bot_reply = AsyncMock() - hitl.set_bot_reply_fn(bot_reply) - hitl.set_execution_context("background") + hitl.bind_turn(bot_reply_fn=bot_reply, execution_context="background") async def approve_later(): await asyncio.sleep(0.05) @@ -106,8 +106,7 @@ async def approve_later(): async def test_ask_bot_deny_with_no(self, hitl): bot_reply = AsyncMock() - hitl.set_bot_reply_fn(bot_reply) - hitl.set_execution_context("background") + hitl.bind_turn(bot_reply_fn=bot_reply, execution_context="background") async def deny_later(): await asyncio.sleep(0.05) @@ -123,8 +122,7 @@ async def deny_later(): async def test_ask_bot_deny_with_arbitrary_text(self, hitl): bot_reply = AsyncMock() - hitl.set_bot_reply_fn(bot_reply) - hitl.set_execution_context("background") + hitl.bind_turn(bot_reply_fn=bot_reply, execution_context="background") async def reply_later(): await asyncio.sleep(0.05) @@ -140,8 +138,7 @@ async def reply_later(): async def test_ask_bot_approve_yes_case_insensitive(self, hitl): bot_reply = AsyncMock() - hitl.set_bot_reply_fn(bot_reply) - hitl.set_execution_context("background") + hitl.bind_turn(bot_reply_fn=bot_reply, execution_context="background") async def approve_later(): await asyncio.sleep(0.05) @@ -195,7 +192,7 @@ async def test_deny_blocks(self, hitl, guardrails): guardrails._config.default_action = "deny" guardrails._rebuild_engine() events: list[tuple[str, dict]] = [] - hitl.set_emit(lambda t, d: events.append((t, d))) + hitl.bind_turn(emit=lambda t, d: events.append((t, d))) result = await hitl.on_pre_tool_use( {"toolCallId": "c2", "toolName": "run", "input": "rm -rf /"}, @@ -207,16 +204,16 @@ async def test_deny_blocks(self, hitl, guardrails): class TestClearCallbacks: - """Tests for clearing channel callbacks.""" + """Tests for bind_turn / unbind_turn lifecycle.""" def test_clear_emit(self, hitl): - hitl.set_emit(lambda t, d: None) - hitl.clear_emit() + hitl.bind_turn(emit=lambda t, d: None) + hitl.unbind_turn() assert hitl._emit is None def test_clear_bot_reply_fn(self, hitl): - hitl.set_bot_reply_fn(AsyncMock()) - hitl.clear_bot_reply_fn() + hitl.bind_turn(bot_reply_fn=AsyncMock()) + hitl.unbind_turn() assert hitl._bot_reply_fn is None @@ -225,10 +222,8 @@ class TestNoApprovalChannel: async def test_deny_when_no_channel_available(self, hitl): """HITL strategy with no approval channel must deny immediately.""" - # Ensure no callbacks are set - hitl.clear_emit() - hitl.clear_bot_reply_fn() - hitl.set_execution_context("background") + # Bind turn with no emit or bot_reply_fn + hitl.bind_turn(execution_context="background") result = await hitl.on_pre_tool_use( {"toolCallId": "no-ch-1", "toolName": "bash", "input": "date"}, @@ -241,9 +236,7 @@ async def test_deny_when_no_channel_available(self, hitl): async def test_deny_when_no_channel_does_not_block(self, hitl): """Ensure denial returns in <1s, not the 300s timeout.""" - hitl.clear_emit() - hitl.clear_bot_reply_fn() - hitl.set_execution_context("interactive") + hitl.bind_turn(execution_context="interactive") import time t0 = time.monotonic() @@ -258,7 +251,7 @@ async def test_deny_when_no_channel_does_not_block(self, hitl): async def test_ask_chat_denies_without_emitter(self, hitl): """_ask_chat must deny immediately if called with no emitter.""" - hitl.clear_emit() + hitl.unbind_turn() result = await hitl._ask_chat("orphan-1", "bash", "echo hello") assert result["permissionDecision"] == "deny" @@ -318,7 +311,7 @@ async def test_precheck_skipped_when_endpoint_not_configured(self, hitl, guardra shield.configured = False # No endpoint set shield.check = MagicMock() hitl.set_prompt_shield(shield) - hitl.set_execution_context("interactive") + hitl.bind_turn(execution_context="interactive") # AITL reviewer is not set, so it falls through to interactive # (which denies without an emitter). The point is that the @@ -343,7 +336,7 @@ async def test_precheck_runs_when_endpoint_configured(self, hitl, guardrails): shield_result.detail = "Attack found" shield.check = MagicMock(return_value=shield_result) hitl.set_prompt_shield(shield) - hitl.set_execution_context("interactive") + hitl.bind_turn(execution_context="interactive") result = await hitl.on_pre_tool_use( {"toolCallId": "f5", "toolName": "bash", "input": "ignore all"}, @@ -366,25 +359,22 @@ async def test_concurrent_messages_dont_lose_callback(self, hitl): bot_reply_1 = AsyncMock() bot_reply_2 = AsyncMock() - # Simulate Task0 setting its callback - hitl.set_bot_reply_fn(bot_reply_1) + # Simulate Task0 binding its turn + hitl.bind_turn(bot_reply_fn=bot_reply_1) assert hitl._bot_reply_fn is bot_reply_1 - # Task1 overwrites before Task0 clears (the race window) - hitl.set_bot_reply_fn(bot_reply_2) + # Task1 overwrites before Task0 unbinds (the race window) + hitl.bind_turn(bot_reply_fn=bot_reply_2) assert hitl._bot_reply_fn is bot_reply_2 - # Task0 clears -- this WAS the bug: it cleared Task1's callback - hitl.clear_bot_reply_fn() - # After the fix, this is protected by the lock in message_processor. - # The interceptor itself doesn't enforce ordering, but the processor does. + # Task0 unbinds -- protected by the lock in message_processor. + hitl.unbind_turn() assert hitl._bot_reply_fn is None async def test_bot_reply_set_before_tool_use(self, hitl): """bot_reply_fn must be set when on_pre_tool_use is called.""" bot_reply = AsyncMock() - hitl.set_bot_reply_fn(bot_reply) - hitl.set_execution_context("background") + hitl.bind_turn(bot_reply_fn=bot_reply, execution_context="background") async def approve_later(): await asyncio.sleep(0.05) diff --git a/app/runtime/tests/test_identity_routes.py b/app/runtime/tests/test_identity_routes.py index dc4c9f4..e003dbf 100644 --- a/app/runtime/tests/test_identity_routes.py +++ b/app/runtime/tests/test_identity_routes.py @@ -9,7 +9,7 @@ from aiohttp.test_utils import TestClient, TestServer from app.runtime.server.routes.identity_routes import IdentityRoutes -from app.runtime.state.guardrails_config import GuardrailsConfigStore +from app.runtime.state.guardrails import GuardrailsConfigStore def _build_app(routes: IdentityRoutes) -> web.Application: @@ -142,6 +142,108 @@ async def test_roles_with_assignments(self, mock_cfg) -> None: assert checks["Key Vault Secrets Officer"] is False assert checks["Azure ContainerApps Session Executor"] is False + # Session executor check should include detail about missing role + se_check = next( + c for c in data["checks"] + if c["role"] == "Azure ContainerApps Session Executor" + ) + assert se_check["detail"] == "Role not assigned to this identity" + + @pytest.mark.asyncio + @patch("app.runtime.server.routes.identity_routes.cfg") + async def test_roles_session_executor_wrong_scope(self, mock_cfg, tmp_path) -> None: + """Session Executor on wrong scope should report as not present.""" + mock_cfg.runtime_sp_app_id = "app-id" + mock_cfg.aca_mi_client_id = "" + mock_cfg.runtime_sp_tenant = "" + + az = MagicMock() + az.json.side_effect = [ + {"id": "obj-id-resolved"}, # _sp_show + [ + { + "roleDefinitionName": "Azure ContainerApps Session Executor", + "scope": "/subscriptions/sub1/resourceGroups/wrong-rg" + "/providers/Microsoft.App/sessionPools/wrong-pool", + "condition": "", + }, + ], + ] + + from app.runtime.state.sandbox_config import SandboxConfigStore + + sandbox_store = SandboxConfigStore(tmp_path / "sandbox.json") + sandbox_store.set_pool_metadata( + resource_group="polyclaw-sandbox-rg", + location="eastus", + pool_name="my-pool", + pool_id="/subscriptions/sub1/resourceGroups/polyclaw-sandbox-rg" + "/providers/Microsoft.App/sessionPools/my-pool", + endpoint="https://eastus.dynamicsessions.io", + ) + + routes = IdentityRoutes(az=az, sandbox_store=sandbox_store) + app = _build_app(routes) + async with TestClient(TestServer(app)) as client: + resp = await client.get("/api/identity/roles") + assert resp.status == 200 + data = await resp.json() + se_check = next( + c for c in data["checks"] + if c["role"] == "Azure ContainerApps Session Executor" + ) + assert se_check["present"] is False + assert "wrong scope" in se_check["detail"].lower() + assert "expected_scope" in se_check + + @pytest.mark.asyncio + @patch("app.runtime.server.routes.identity_routes.cfg") + async def test_roles_session_executor_correct_scope(self, mock_cfg, tmp_path) -> None: + """Session Executor on correct scope should report as present.""" + mock_cfg.runtime_sp_app_id = "app-id" + mock_cfg.aca_mi_client_id = "" + mock_cfg.runtime_sp_tenant = "" + + pool_scope = ( + "/subscriptions/sub1/resourceGroups/polyclaw-sandbox-rg" + "/providers/Microsoft.App/sessionPools/my-pool" + ) + + az = MagicMock() + az.json.side_effect = [ + {"id": "obj-id-resolved"}, # _sp_show + [ + { + "roleDefinitionName": "Azure ContainerApps Session Executor", + "scope": pool_scope, + "condition": "", + }, + ], + ] + + from app.runtime.state.sandbox_config import SandboxConfigStore + + sandbox_store = SandboxConfigStore(tmp_path / "sandbox.json") + sandbox_store.set_pool_metadata( + resource_group="polyclaw-sandbox-rg", + location="eastus", + pool_name="my-pool", + pool_id=pool_scope, + endpoint="https://eastus.dynamicsessions.io", + ) + + routes = IdentityRoutes(az=az, sandbox_store=sandbox_store) + app = _build_app(routes) + async with TestClient(TestServer(app)) as client: + resp = await client.get("/api/identity/roles") + assert resp.status == 200 + data = await resp.json() + se_check = next( + c for c in data["checks"] + if c["role"] == "Azure ContainerApps Session Executor" + ) + assert se_check["present"] is True + @pytest.mark.asyncio @patch("app.runtime.server.routes.identity_routes.cfg") async def test_roles_sp_show_fails_uses_app_id(self, mock_cfg) -> None: diff --git a/app/runtime/tests/test_incoming_media.py b/app/runtime/tests/test_incoming_media.py index ce70a8f..d8b888b 100644 --- a/app/runtime/tests/test_incoming_media.py +++ b/app/runtime/tests/test_incoming_media.py @@ -5,10 +5,8 @@ import re from pathlib import Path -from app.runtime.media.incoming import ( - _FILE_PATH_RE, - build_media_prompt, -) +from app.runtime.media.incoming import build_media_prompt +from app.runtime.media.outgoing import _FILE_PATH_RE class TestBuildMediaPrompt: diff --git a/app/runtime/tests/test_misconfig_checker.py b/app/runtime/tests/test_misconfig_checker.py index 49633a3..f9a9120 100644 --- a/app/runtime/tests/test_misconfig_checker.py +++ b/app/runtime/tests/test_misconfig_checker.py @@ -2,7 +2,7 @@ from __future__ import annotations -from app.runtime.services.misconfig_checker import CheckResult, Finding, MisconfigChecker +from app.runtime.services.security.misconfig_checker import CheckResult, Finding, MisconfigChecker class _FakeAzureCLI: diff --git a/app/runtime/tests/test_prerequisites.py b/app/runtime/tests/test_prerequisites.py index ed831c4..a80f3c7 100644 --- a/app/runtime/tests/test_prerequisites.py +++ b/app/runtime/tests/test_prerequisites.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock from app.runtime.config.settings import cfg -from app.runtime.server.setup_prerequisites import PrerequisitesRoutes +from app.runtime.server.setup.prerequisites import PrerequisitesRoutes from app.runtime.state.deploy_state import DeployStateStore, DeploymentRecord from app.runtime.state.infra_config import InfraConfigStore diff --git a/app/runtime/tests/test_prompt_shield.py b/app/runtime/tests/test_prompt_shield.py index b5f3eeb..7310ffc 100644 --- a/app/runtime/tests/test_prompt_shield.py +++ b/app/runtime/tests/test_prompt_shield.py @@ -8,7 +8,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch -from app.runtime.services.prompt_shield import ( +from app.runtime.services.security.prompt_shield import ( PromptShieldService, _BearerTokenProvider, ) diff --git a/app/runtime/tests/test_provisioner.py b/app/runtime/tests/test_provisioner.py index 59a3b11..68bbe5e 100644 --- a/app/runtime/tests/test_provisioner.py +++ b/app/runtime/tests/test_provisioner.py @@ -7,7 +7,7 @@ import pytest from app.runtime.config.settings import cfg -from app.runtime.services.provisioner import Provisioner +from app.runtime.services.deployment.provisioner import Provisioner from app.runtime.state.deploy_state import DeployStateStore from app.runtime.state.infra_config import InfraConfigStore from app.runtime.util.result import Result @@ -86,11 +86,7 @@ class TestProvision: """Full provision flow -- registers Entra app + runtime identity, no bot service.""" def test_skips_when_not_configured(self, provisioner, store): - store.bot = MagicMock() - store.bot.resource_group = "" - store.bot.location = "" - - # bot_configured returns False when rg/location are empty + # bot_configured returns False when rg/location are empty (default) with patch.object(type(store), "bot_configured", new_callable=lambda: property(lambda self: False)): steps = provisioner.provision() assert any(s["step"] == "bot_config" and s["status"] == "skip" for s in steps) diff --git a/app/runtime/tests/test_realtime_tools.py b/app/runtime/tests/test_realtime_tools.py index a18db4d..2343c17 100644 --- a/app/runtime/tests/test_realtime_tools.py +++ b/app/runtime/tests/test_realtime_tools.py @@ -128,7 +128,7 @@ async def test_not_found(self) -> None: class TestMakeRealtimeHook: """Verify that _make_realtime_hook creates a properly configured interceptor.""" - @patch("app.runtime.state.guardrails_config.get_guardrails_config") + @patch("app.runtime.state.guardrails.config.get_guardrails_config") def test_hook_sets_execution_context(self, mock_get_cfg: MagicMock) -> None: mock_store = MagicMock() mock_store.hitl_enabled = True @@ -140,7 +140,7 @@ def test_hook_sets_execution_context(self, mock_get_cfg: MagicMock) -> None: hook = _make_realtime_hook(agent) assert callable(hook) - @patch("app.runtime.state.guardrails_config.get_guardrails_config") + @patch("app.runtime.state.guardrails.config.get_guardrails_config") def test_hook_forwards_aitl_from_shared(self, mock_get_cfg: MagicMock) -> None: mock_store = MagicMock() mock_store.hitl_enabled = True @@ -157,7 +157,7 @@ def test_hook_forwards_aitl_from_shared(self, mock_get_cfg: MagicMock) -> None: hook = _make_realtime_hook(agent) assert callable(hook) - @patch("app.runtime.state.guardrails_config.get_guardrails_config") + @patch("app.runtime.state.guardrails.config.get_guardrails_config") def test_hook_works_without_shared_hitl(self, mock_get_cfg: MagicMock) -> None: mock_store = MagicMock() mock_store.hitl_enabled = True diff --git a/app/runtime/tests/test_sandbox_executor.py b/app/runtime/tests/test_sandbox_executor.py index f3ca1c2..c66f401 100644 --- a/app/runtime/tests/test_sandbox_executor.py +++ b/app/runtime/tests/test_sandbox_executor.py @@ -156,7 +156,7 @@ def test_create_data_zip_empty(self, tmp_path: Path) -> None: store = MagicMock() store.whitelist = ["nonexistent"] executor = SandboxExecutor(config_store=store) - with patch("app.runtime.sandbox.cfg") as mock_cfg: + with patch("app.runtime.sandbox.executor.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path result = executor._create_data_zip() assert result is None @@ -166,7 +166,7 @@ def test_create_data_zip_with_file(self, tmp_path: Path) -> None: store = MagicMock() store.whitelist = ["test.json"] executor = SandboxExecutor(config_store=store) - with patch("app.runtime.sandbox.cfg") as mock_cfg: + with patch("app.runtime.sandbox.executor.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path result = executor._create_data_zip() assert result is not None @@ -180,7 +180,7 @@ def test_create_data_zip_with_dir(self, tmp_path: Path) -> None: store = MagicMock() store.whitelist = ["subdir"] executor = SandboxExecutor(config_store=store) - with patch("app.runtime.sandbox.cfg") as mock_cfg: + with patch("app.runtime.sandbox.executor.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path result = executor._create_data_zip() assert result is not None @@ -195,7 +195,7 @@ def test_merge_result_zip(self, tmp_path: Path) -> None: store = MagicMock() store.whitelist = ["allowed"] executor = SandboxExecutor(config_store=store) - with patch("app.runtime.sandbox.cfg") as mock_cfg: + with patch("app.runtime.sandbox.executor.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path count = executor._merge_result_zip(buf.getvalue()) assert count == 1 @@ -210,7 +210,7 @@ def test_merge_result_zip_blocks_path_traversal(self, tmp_path: Path) -> None: store = MagicMock() store.whitelist = [] executor = SandboxExecutor(config_store=store) - with patch("app.runtime.sandbox.cfg") as mock_cfg: + with patch("app.runtime.sandbox.executor.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path count = executor._merge_result_zip(buf.getvalue()) assert count == 0 @@ -344,7 +344,7 @@ class TestUploadBytesRetry: """Tests for _upload_bytes retry logic with exponential backoff.""" @pytest.mark.asyncio - @patch("app.runtime.sandbox._UPLOAD_BACKOFF_BASE", 0.0) + @patch("app.runtime.sandbox.executor._UPLOAD_BACKOFF_BASE", 0.0) async def test_upload_succeeds_on_first_attempt(self) -> None: store = MagicMock() executor = SandboxExecutor(config_store=store) @@ -359,11 +359,11 @@ async def test_upload_succeeds_on_first_attempt(self) -> None: result = await executor._upload_bytes( http, "https://endpoint", "sess-1", "file.zip", b"data", {}, ) - assert result is True + assert result == "" assert http.post.call_count == 1 @pytest.mark.asyncio - @patch("app.runtime.sandbox._UPLOAD_BACKOFF_BASE", 0.0) + @patch("app.runtime.sandbox.executor._UPLOAD_BACKOFF_BASE", 0.0) async def test_upload_retries_on_http_error_then_succeeds(self) -> None: store = MagicMock() executor = SandboxExecutor(config_store=store) @@ -385,11 +385,11 @@ async def test_upload_retries_on_http_error_then_succeeds(self) -> None: result = await executor._upload_bytes( http, "https://endpoint", "sess-1", "file.zip", b"data", {}, ) - assert result is True + assert result == "" assert http.post.call_count == 2 @pytest.mark.asyncio - @patch("app.runtime.sandbox._UPLOAD_BACKOFF_BASE", 0.0) + @patch("app.runtime.sandbox.executor._UPLOAD_BACKOFF_BASE", 0.0) async def test_upload_retries_on_exception_then_succeeds(self) -> None: store = MagicMock() executor = SandboxExecutor(config_store=store) @@ -407,11 +407,11 @@ async def test_upload_retries_on_exception_then_succeeds(self) -> None: result = await executor._upload_bytes( http, "https://endpoint", "sess-1", "data.zip", b"data", {}, ) - assert result is True + assert result == "" assert http.post.call_count == 2 @pytest.mark.asyncio - @patch("app.runtime.sandbox._UPLOAD_BACKOFF_BASE", 0.0) + @patch("app.runtime.sandbox.executor._UPLOAD_BACKOFF_BASE", 0.0) async def test_upload_fails_after_all_retries(self) -> None: store = MagicMock() executor = SandboxExecutor(config_store=store) @@ -428,5 +428,6 @@ async def test_upload_fails_after_all_retries(self) -> None: result = await executor._upload_bytes( http, "https://endpoint", "sess-1", "file.zip", b"data", {}, ) - assert result is False + assert result != "" + assert "503" in result assert http.post.call_count == 3 diff --git a/app/runtime/tests/test_spotlight.py b/app/runtime/tests/test_spotlight.py index 4584104..d519f71 100644 --- a/app/runtime/tests/test_spotlight.py +++ b/app/runtime/tests/test_spotlight.py @@ -6,7 +6,7 @@ from unittest.mock import patch from app.runtime.util.spotlight import datamark, delimit, spotlight -from app.runtime.state.guardrails_config import GuardrailsConfigStore +from app.runtime.state.guardrails import GuardrailsConfigStore import pytest @@ -95,13 +95,13 @@ class TestGuardrailsSpotlightingConfig: """Guardrails store persists and exposes the aitl_spotlighting toggle.""" def test_default_enabled(self, tmp_path: Path) -> None: - with patch("app.runtime.state.guardrails_config.cfg") as mock_cfg: + with patch("app.runtime.state.guardrails.config.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path store = GuardrailsConfigStore(tmp_path / "guardrails.json") assert store.config.aitl_spotlighting is True def test_toggle_off_and_persist(self, tmp_path: Path) -> None: - with patch("app.runtime.state.guardrails_config.cfg") as mock_cfg: + with patch("app.runtime.state.guardrails.config.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path store = GuardrailsConfigStore(tmp_path / "guardrails.json") store.set_aitl_spotlighting(False) @@ -112,7 +112,7 @@ def test_toggle_off_and_persist(self, tmp_path: Path) -> None: assert store2.config.aitl_spotlighting is False def test_toggle_on_and_persist(self, tmp_path: Path) -> None: - with patch("app.runtime.state.guardrails_config.cfg") as mock_cfg: + with patch("app.runtime.state.guardrails.config.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path store = GuardrailsConfigStore(tmp_path / "guardrails.json") store.set_aitl_spotlighting(False) @@ -120,7 +120,7 @@ def test_toggle_on_and_persist(self, tmp_path: Path) -> None: assert store.config.aitl_spotlighting is True def test_to_dict_includes_spotlighting(self, tmp_path: Path) -> None: - with patch("app.runtime.state.guardrails_config.cfg") as mock_cfg: + with patch("app.runtime.state.guardrails.config.cfg") as mock_cfg: mock_cfg.data_dir = tmp_path store = GuardrailsConfigStore(tmp_path / "guardrails.json") d = store.to_dict() diff --git a/app/runtime/tests/test_tunnel_restriction.py b/app/runtime/tests/test_tunnel_restriction.py index 25486c8..b12ba12 100644 --- a/app/runtime/tests/test_tunnel_restriction.py +++ b/app/runtime/tests/test_tunnel_restriction.py @@ -18,7 +18,7 @@ from aiohttp.test_utils import TestClient, TestServer from app.runtime.config import cfg -from app.runtime.server.app import ( +from app.runtime.server.middleware import ( _CF_HEADERS, _PUBLIC_EXACT, _PUBLIC_PREFIXES, diff --git a/docs/content/architecture/state.md b/docs/content/architecture/state.md index f3d91fc..8fb528b 100644 --- a/docs/content/architecture/state.md +++ b/docs/content/architecture/state.md @@ -20,7 +20,7 @@ Each chat session is persisted as a separate JSON file containing message histor | One file per session | Easy inspection and backup | | Archival policies | `24h`, `7d`, `30d`, `never` | | Session resume | Last 20 messages loaded as context | -| Metadata | Model used, message count, timestamps | +| Metadata | Model, title, message count, created/updated timestamps | ### Memory Store @@ -35,7 +35,7 @@ Memory formation is triggered after `MEMORY_IDLE_MINUTES` (default: 5) of inacti ### Profile Store -**File**: `profile.json` +**File**: `agent_profile.json` Tracks the agent's identity and behavioral state: @@ -46,9 +46,15 @@ Tracks the agent's identity and behavioral state: | `location` | Timezone context | | `emotional_state` | Current mood (affects responses) | | `preferences` | Communication style preferences | -| `skill_usage` | Usage counts per skill | -| `interaction_log` | Recent interaction timestamps | -| `contribution_heatmap` | Activity by hour/day | + +Related data is stored in separate files: + +| File | Description | +|---|---| +| `skill_usage.json` | Usage counts per skill | +| `interactions.json` | Recent interaction log (last 1 000 entries) | + +The `get_full_profile()` helper merges the profile with skill usage, per-day contribution counts, and activity statistics. ### MCP Config @@ -73,9 +79,28 @@ Manages autonomous proactive messaging: |---|---| | `enabled` | Whether proactive messaging is active | | `pending` | Single pending message awaiting delivery | -| `sent` | Last 100 sent messages | +| `history` | Last 100 delivered messages with reactions | | `preferences` | Timing, frequency, and topic constraints | -| `daily_count` | Messages sent today | + +The `messages_sent_today()` and `hours_since_last_sent()` methods compute daily counts and gap tracking from the history rather than persisting them as separate fields. + +### Guardrails Config + +**Files**: `guardrails.json`, `policy.yaml` + +Stores human-in-the-loop (HITL) approval rules, tool-level and context-level policies, model-specific overrides, and Content Safety settings. A YAML policy file is generated alongside the JSON and consumed by the `PolicyEngine` at runtime. + +### Monitoring Config + +**File**: `monitoring.json` + +Stores OpenTelemetry and Application Insights configuration including connection strings, sampling ratio, live metrics toggle, and provisioning metadata. + +### Tool Activity Store + +**File**: `tool_activity.jsonl` + +Append-only JSON-lines log of every tool invocation. Each entry records tool name, arguments, result, duration, risk score, and Content Safety shield results. Supports query, timeline, CSV export, and session-level breakdowns for audit. ### Other State Files @@ -83,16 +108,16 @@ Manages autonomous proactive messaging: |---|---| | `SOUL.md` | Agent personality definition | | `scheduler.json` | Scheduled task definitions | -| `deploy_state.json` | Deployment records | -| `infra_config.json` | Infrastructure configuration | -| `plugin_config.json` | Plugin enabled/disabled state | -| `sandbox_config.json` | Sandbox configuration | -| `foundry_iq_config.json` | Azure AI Foundry IQ settings | -| `conversation_references.json` | Bot Framework conversation references | +| `deployments.json` | Deployment records | +| `infra.json` | Infrastructure configuration (bot, channels, voice) | +| `plugins.json` | Plugin enabled/disabled state | +| `sandbox.json` | Sandbox configuration and session pool metadata | +| `foundry_iq.json` | Azure AI Foundry IQ / Search settings | +| `conversation_refs.json` | Bot Framework conversation references | ## Design Principles - **No database required** -- everything is flat files for simplicity and portability -- **Human-readable** -- JSON and Markdown files can be inspected and edited manually +- **Human-readable** -- JSON, JSONL, and Markdown files can be inspected and edited manually - **Docker-friendly** -- mount `~/.polyclaw` as a volume for persistence -- **Atomic writes** -- state modules use write-then-rename for crash safety +- **Thread-safe I/O** -- shared stores use `threading.Lock` for concurrent access diff --git a/docs/content/configuration/_index.md b/docs/content/configuration/_index.md index 938b4fc..8c0dc62 100644 --- a/docs/content/configuration/_index.md +++ b/docs/content/configuration/_index.md @@ -10,12 +10,13 @@ Polyclaw is configured through environment variables loaded from a `.env` file o | Variable | Default | Description | |---|---|---| | `GITHUB_TOKEN` | -- | GitHub PAT with Copilot access. Supports `@kv:` prefix. | -| `COPILOT_MODEL` | `claude-sonnet-4-20250514` | Default LLM model for conversations | +| `COPILOT_MODEL` | `claude-sonnet-4.6` | Default LLM model for conversations | | `COPILOT_AGENT` | -- | Optional Copilot agent name | -| `ADMIN_PORT` | `8000` | Admin server listen port | +| `ADMIN_PORT` | `9090` | Admin server listen port | | `ADMIN_SECRET` | -- | Bearer token for API authentication. Supports `@kv:` prefix. | | `POLYCLAW_DATA_DIR` | `~/.polyclaw` | Root directory for all persistent data | | `DOTENV_PATH` | -- | Custom path to `.env` file | +| `POLYCLAW_SERVER_MODE` | `combined` | Server mode: `combined`, `admin`, or `runtime` | ## Bot Framework @@ -42,7 +43,7 @@ Polyclaw is configured through environment variables loaded from a `.env` file o | Variable | Default | Description | |---|---|---| -| `MEMORY_MODEL` | `claude-sonnet-4-20250514` | Model used for memory consolidation | +| `MEMORY_MODEL` | `claude-sonnet-4.6` | Model used for memory consolidation | | `MEMORY_IDLE_MINUTES` | `5` | Minutes of inactivity before memory formation triggers | ## Proactive Messaging diff --git a/docs/content/features/monitoring.md b/docs/content/features/monitoring.md index 0d9f4bc..4be0079 100644 --- a/docs/content/features/monitoring.md +++ b/docs/content/features/monitoring.md @@ -84,11 +84,24 @@ When enabled, the Live Metrics stream provides a real-time view of request rate, --- +## Agent Framework Dashboard in Application Insights + +Polyclaw emits telemetry in the format expected by the Microsoft Agent Framework. This means the built-in **Agent Framework dashboard** available in Application Insights works out of the box -- no extra configuration or custom queries required. + +The dashboard surfaces agent-level metrics such as session counts, tool invocation success rates, model call latency, and failure breakdowns, all derived from the spans and semantic attributes that Polyclaw already produces. + +![Agent Framework dashboard in Application Insights](/screenshots/mafpreview-appinsights.png) + +Because the telemetry follows the Agent Framework conventions (`gen_ai.system`, `gen_ai.request.model`, operation names, dependency types), the dashboard can correlate traces across the full request lifecycle -- from incoming user messages through model calls and tool executions to the final response. + +--- + ## Dashboard Links After provisioning, the monitoring configuration provides direct links to: - **Azure Portal** -- Application Insights overview, transaction search, failures, and performance views +- **Agent Framework Dashboard** -- the built-in agent monitoring view in Application Insights, powered by the telemetry format described above - **Grafana Agent Dashboard** -- pre-built dashboard URL for Azure Managed Grafana (if configured) --- diff --git a/docs/themes/polyclaw/static/screenshots/mafpreview-appinsights.png b/docs/themes/polyclaw/static/screenshots/mafpreview-appinsights.png new file mode 100644 index 0000000..5aa43d5 Binary files /dev/null and b/docs/themes/polyclaw/static/screenshots/mafpreview-appinsights.png differ