diff --git a/evolve_server/engines/workflow.py b/evolve_server/engines/workflow.py index 601728d..86a6026 100644 --- a/evolve_server/engines/workflow.py +++ b/evolve_server/engines/workflow.py @@ -984,8 +984,34 @@ async def run_once(self) -> dict: queued_candidates, elapsed, ) + if uploaded_skills > 0: + await self._notify_proxy_reload() return summary + async def _notify_proxy_reload(self) -> None: + mode = str(getattr(self.config, "skill_reload_mode", "") or "poll").strip().lower() + url = str(getattr(self.config, "proxy_reload_url", "") or "").strip().rstrip("/") + if mode != "callback" or not url: + return + headers: dict[str, str] = {} + api_key = str(getattr(self.config, "proxy_reload_api_key", "") or "") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + import httpx + + for attempt in range(3): + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.post(f"{url}/internal/reload-skills", headers=headers) + resp.raise_for_status() + logger.info("[EvolveServer] notified proxy to reload skills: %s", url) + return + except Exception as exc: + if attempt < 2: + await asyncio.sleep(1.0 * (attempt + 1)) + else: + logger.warning("[EvolveServer] proxy reload notify failed after 3 attempts: %s", exc) + async def run_periodic(self) -> None: self._running = True logger.info("[EvolveServer] periodic mode: interval=%ds", self.config.interval_seconds) diff --git a/skillclaw/api_server.py b/skillclaw/api_server.py index 283724e..2057167 100644 --- a/skillclaw/api_server.py +++ b/skillclaw/api_server.py @@ -42,7 +42,14 @@ _CYAN = "\033[36m" _RESET = "\033[0m" -_NON_STANDARD_BODY_KEYS = {"session_id", "session_done", "turn_type"} +_NON_STANDARD_BODY_KEYS = { + "session_id", + "session_done", + "turn_type", + "_skillclaw_protocol", +} +_PROTOCOL_ANTHROPIC_MESSAGES = "anthropic_messages" +_PROTOCOL_RESPONSES_COMPAT = "responses_compat" # ------------------------------------------------------------------ # @@ -205,6 +212,35 @@ def _resolve_session_done( return str(candidate).strip().lower() in _TRUE_STRINGS +def _looks_like_session_title_response(content: str) -> bool: + """Return True for Claude Code's internal generate_session_title response.""" + text = str(content or "").strip() + if not text or len(text) > 500: + return False + try: + parsed = json.loads(text) + except Exception: + return False + if not isinstance(parsed, dict) or set(parsed.keys()) - {"title"}: + return False + title = parsed.get("title") + return isinstance(title, str) and bool(title.strip()) + + +def _classify_raw_turn_kind(protocol: str, content: str, tool_calls: list[dict]) -> str: + """Classify recorded raw/main turns for user-turn cadence decisions.""" + if tool_calls: + return "tool_use" + if protocol == _PROTOCOL_ANTHROPIC_MESSAGES and _looks_like_session_title_response(content): + return "session_title" + return "final" + + +def _is_user_turn_boundary(raw_turn_kind: str) -> bool: + """Only final assistant responses advance the user-visible turn counter.""" + return raw_turn_kind == "final" + + def _normalize_tool_name(raw_name: str, args_raw: str) -> str: """ Normalize tool names from model output. @@ -1418,6 +1454,7 @@ def __init__( # State machines self._turn_counts: dict[str, int] = {} + self._user_turn_counts: dict[str, int] = {} self._pending_turn_data: dict[str, dict[int, dict]] = {} # session → {turn → data} self._prm_tasks: dict[str, dict[int, asyncio.Task]] = {} # session → {turn → task} self._pending_records: dict[str, dict] = {} # for record logging @@ -1428,6 +1465,7 @@ def __init__( self._background_tasks: set[asyncio.Task] = set() # transient async tasks (upload, submit) self._responses_store: dict[str, dict[str, Any]] = {} # response_id -> stored response/history self._session_sweeper_task: Optional[asyncio.Task] = None + self._skill_reload_task: Optional[asyncio.Task] = None self._session_idle_close_seconds = max( 0, int(getattr(config, "session_idle_close_seconds", _SESSION_IDLE_CLOSE_SECONDS)), @@ -1440,6 +1478,10 @@ def __init__( 1, int(getattr(config, "shutdown_drain_timeout_seconds", _SHUTDOWN_DRAIN_TIMEOUT_SECONDS)), ) + self._skill_reload_interval_seconds = max( + 5, + int(getattr(config, "sharing_skill_reload_interval_seconds", 30) or 30), + ) # Session boundary detection for non-OpenClaw agents (QwenPaw, IronClaw, etc.) # Maps pseudo-session key (e.g. "tui-model") to tracking metadata. @@ -1478,6 +1520,7 @@ def _build_app(self) -> FastAPI: async def lifespan(_app: FastAPI): owner._ready_event.set() owner._start_session_idle_sweeper() + owner._start_skill_reload_polling() try: yield finally: @@ -1491,6 +1534,17 @@ async def lifespan(_app: FastAPI): async def healthz(): return {"ok": True} + @app.post("/internal/reload-skills") + async def reload_skills( + request: Request, + authorization: Optional[str] = Header(default=None), + ): + owner: SkillClawAPIServer = request.app.state.owner + await owner._check_auth(authorization) + await owner._pull_skills_from_cloud() + skill_count = len(owner.skill_manager.get_all_skills()) if owner.skill_manager else 0 + return {"ok": True, "skills": skill_count} + @app.get("/v1/models") async def list_models( request: Request, @@ -1576,19 +1630,42 @@ async def responses( body = await request.json() if owner._responses_native_enabled(): + record_body = copy.deepcopy(body) turn_type = _resolve_turn_type(x_turn_type, body.get("turn_type"), default="main") - body = owner._prepare_native_responses_body(body, turn_type=turn_type) + injected_skills = owner._prepare_native_responses_body_inplace(body, turn_type=turn_type) + _raw_sid = x_session_id or codex_session_id or body.get("session_id") or "" + session_id = _raw_sid or await owner._resolve_tui_session( + body.get("model", owner._served_model), + len(body.get("input", []) if isinstance(body.get("input"), list) else []), + ) + session_done = _resolve_session_done(x_session_done, body.get("session_done")) if bool(body.get("stream", False)): return StreamingResponse( - owner._stream_llm_responses(body), + owner._stream_and_track_responses( + body, + record_body=record_body, + session_id=session_id, + turn_type=turn_type, + injected_skills=injected_skills, + session_done=session_done, + ), media_type="text/event-stream", ) response_payload = await owner._forward_to_llm_responses(body) + owner._record_responses_turn( + session_id, + record_body, + response_payload, + turn_type=turn_type, + injected_skills=injected_skills, + session_done=session_done, + ) return JSONResponse(content=response_payload) previous_response_id = str(body.get("previous_response_id") or "").strip() store_response = bool(body.get("store", True)) openai_body = _responses_to_openai_body(body, owner._served_model) + openai_body["_skillclaw_protocol"] = _PROTOCOL_RESPONSES_COMPAT if previous_response_id: stored = owner._responses_store.get(previous_response_id) if stored is None: @@ -1710,6 +1787,7 @@ async def anthropic_messages( stream = bool(raw_body.get("stream", False)) tool_names = _anthropic_request_tool_names(raw_body) openai_body = _anthropic_to_openai_body(raw_body) + openai_body["_skillclaw_protocol"] = _PROTOCOL_ANTHROPIC_MESSAGES model = raw_body.get("model") or owner._served_model incoming_messages = openai_body.get("messages", []) @@ -1909,6 +1987,42 @@ async def _await_background_tasks(self, timeout_seconds: float) -> None: else: logger.info("[OpenClaw] background drain complete (%d task(s))", len(done)) + def _start_skill_reload_polling(self) -> None: + if not self.config.sharing_enabled: + return + mode = str(getattr(self.config, "sharing_skill_reload_mode", "") or "poll").strip().lower() + if mode != "poll": + return + if self._skill_reload_task is not None and not self._skill_reload_task.done(): + return + self._skill_reload_task = asyncio.create_task(self._skill_reload_poll_loop()) + self._skill_reload_task.add_done_callback(self._task_done_cb) + logger.info( + "[SkillHub] skill reload polling enabled interval=%ds", + self._skill_reload_interval_seconds, + ) + + async def _skill_reload_poll_loop(self) -> None: + consecutive_failures = 0 + try: + while True: + jitter = random.uniform(0, self._skill_reload_interval_seconds * 0.1) + backoff = min(consecutive_failures * 5.0, 60.0) + await asyncio.sleep(self._skill_reload_interval_seconds + jitter + backoff) + try: + await self._pull_skills_from_cloud() + consecutive_failures = 0 + except Exception as exc: + consecutive_failures += 1 + logger.warning( + "[SkillHub] skill reload poll failed (streak=%d): %s", + consecutive_failures, + exc, + ) + except asyncio.CancelledError: + logger.info("[SkillHub] skill reload polling stopped") + raise + async def _drain_active_sessions(self, reason: str) -> None: active_ids = self._collect_active_session_ids() if not active_ids: @@ -1918,6 +2032,10 @@ async def _drain_active_sessions(self, reason: str) -> None: await self._close_session(sid, reason=reason) async def _shutdown_cleanup(self) -> None: + if self._skill_reload_task is not None: + self._skill_reload_task.cancel() + await asyncio.gather(self._skill_reload_task, return_exceptions=True) + self._skill_reload_task = None if self._session_sweeper_task is not None: self._session_sweeper_task.cancel() await asyncio.gather(self._session_sweeper_task, return_exceptions=True) @@ -1980,6 +2098,7 @@ async def _close_session(self, session_id: str, reason: str = "explicit") -> Non ) eff = self._session_scored_turns.pop(session_id, 0) self._turn_counts.pop(session_id, None) + self._user_turn_counts.pop(session_id, None) self._pending_turn_data.pop(session_id, None) prm_tasks = self._prm_tasks.pop(session_id, {}) for task in prm_tasks.values(): @@ -2192,6 +2311,7 @@ async def _handle_request( turn_type: str, session_done: bool, ) -> dict[str, Any]: + protocol = str(body.get("_skillclaw_protocol") or "").strip() messages = body.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="messages must be a non-empty list") @@ -2371,32 +2491,38 @@ def _prompt_len(msgs): ) response_text = content or (json.dumps(tool_calls, ensure_ascii=False) if tool_calls else "") self._buffer_record(session_id, turn_num, messages, prompt_text, response_text, tool_calls) - self._session_turns.setdefault(session_id, []).append( - { - "turn_num": turn_num, - "prompt_text": user_instruction, - "response_text": response_text, - "reasoning_content": reasoning or None, - "tool_calls": tool_calls, - "read_skills": read_skills, - "modified_skills": modified_skills, - "tool_results": tool_summaries, - "tool_results_raw": [], - "tool_observations": [], - "tool_errors": [], - "injected_skills": injected_skills, - "prm_score": None, - } - ) - self._maybe_upload_session_snapshot(session_id, turn_num) + raw_turn_kind = _classify_raw_turn_kind(protocol, content, tool_calls) + turn_record = { + "turn_num": turn_num, + "raw_turn_kind": raw_turn_kind, + "prompt_text": user_instruction, + "response_text": response_text, + "reasoning_content": reasoning or None, + "tool_calls": tool_calls, + "read_skills": read_skills, + "modified_skills": modified_skills, + "tool_results": tool_summaries, + "tool_results_raw": [], + "tool_observations": [], + "tool_errors": [], + "injected_skills": injected_skills, + "prm_score": None, + } + self._session_turns.setdefault(session_id, []).append(turn_record) + if _is_user_turn_boundary(raw_turn_kind): + user_turn_num = self._next_user_turn_num(session_id) + turn_record["user_turn_num"] = user_turn_num + self._maybe_upload_session_snapshot(session_id, user_turn_num) self._pending_turn_data.setdefault(session_id, {})[turn_num] = { "prompt_text": prompt_text, "response_text": response_text, } logger.info( - "[OpenClaw] MAIN session=%s turn=%d prompt_est_tokens=%d response_chars=%d", + "[OpenClaw] MAIN session=%s turn=%d user_turn=%s kind=%s prompt_est_tokens=%d response_chars=%d", session_id, turn_num, + turn_record.get("user_turn_num", "-"), + raw_turn_kind, _estimate_openai_body_input_tokens({"messages": messages, "tools": tools}), len(response_text), ) @@ -2460,8 +2586,13 @@ def _prepare_responses_forward( def _prepare_native_responses_body(self, body: dict[str, Any], *, turn_type: str) -> dict[str, Any]: """Apply non-destructive SkillClaw hooks before native Responses forwarding.""" prepared = dict(body) + self._prepare_native_responses_body_inplace(prepared, turn_type=turn_type) + return prepared + + def _prepare_native_responses_body_inplace(self, body: dict[str, Any], *, turn_type: str) -> list[str]: + """Inject skills into a Responses body in-place. Returns injected skill names.""" if not self.skill_manager or turn_type != "main": - return prepared + return [] try: self.skill_manager.refresh_if_changed() @@ -2472,7 +2603,7 @@ def _prepare_native_responses_body(self, body: dict[str, Any], *, turn_type: str max_chars=getattr(self.config, "max_skills_prompt_chars", 30_000), ) if not skill_text: - return prepared + return [] all_skills = self.skill_manager.get_all_skills() skill_names = [s.get("name", "unknown_skill") for s in all_skills if isinstance(s, dict)] @@ -2483,9 +2614,76 @@ def _prepare_native_responses_body(self, body: dict[str, Any], *, turn_type: str ) self.skill_manager.record_injection(skill_names) - existing = _normalize_responses_content(prepared.get("instructions", "")) - prepared["instructions"] = (existing + "\n\n" + skill_text).strip() if existing else skill_text - return prepared + existing = _normalize_responses_content(body.get("instructions", "")) + body["instructions"] = (existing + "\n\n" + skill_text).strip() if existing else skill_text + return skill_names + + def _record_responses_turn( + self, + session_id: str, + request_body: dict[str, Any], + response_payload: dict[str, Any], + *, + turn_type: str, + injected_skills: list[str], + session_done: bool, + ) -> None: + """Record a Responses API turn into the session tracking system.""" + if not session_id: + return + self._touch_session(session_id) + prompt_text = _normalize_responses_content(request_body.get("instructions", "")) + inp = request_body.get("input") + if isinstance(inp, str): + prompt_text = (prompt_text + "\n" + inp).strip() if prompt_text else inp + elif isinstance(inp, list): + user_parts = [] + for item in inp: + if isinstance(item, dict) and item.get("role") == "user": + user_parts.append(_normalize_responses_content(item.get("content", ""))) + if user_parts: + joined = " ".join(user_parts) + prompt_text = (prompt_text + "\n" + joined).strip() if prompt_text else joined + response_parts = [] + for item in response_payload.get("output", []): + if not isinstance(item, dict): + continue + if item.get("type") == "message": + for part in item.get("content", []): + if isinstance(part, dict) and part.get("type") == "output_text": + response_parts.append(part.get("text", "")) + elif item.get("type") == "function_call": + name = item.get("name", "") + args = str(item.get("arguments", ""))[:500] + response_parts.append(f"[tool:{name}] {args}") + response_text = "\n".join(response_parts) + turns = self._session_turns.setdefault(session_id, []) + turn_num = len(turns) + 1 + turn_record = { + "turn_num": turn_num, + "raw_turn_kind": "final" if turn_type == "main" else "side", + "prompt_text": prompt_text[:2000], + "response_text": response_text[:2000], + "injected_skills": injected_skills, + "prm_score": None, + } + turns.append(turn_record) + if turn_type == "main": + user_turn_num = self._next_user_turn_num(session_id) + turn_record["user_turn_num"] = user_turn_num + self._maybe_upload_session_snapshot(session_id, user_turn_num) + logger.info( + "[Codex] %s session=%s turn=%d user_turn=%s prompt=%d chars response=%d chars skills=%s", + turn_type, + session_id, + turn_num, + turn_record.get("user_turn_num", "-"), + len(prompt_text), + len(response_text), + ",".join(injected_skills) if injected_skills else "(none)", + ) + if session_done: + self._safe_create_task(self._close_session(session_id, reason="codex_session_done")) async def _forward_to_llm_responses(self, body: dict[str, Any]) -> dict[str, Any]: """Forward a Codex Responses payload to an upstream Responses API.""" @@ -2535,6 +2733,118 @@ async def _forward_to_llm_responses(self, body: dict[str, Any]) -> dict[str, Any logger.error("[OpenClaw] Responses forward failed: %s", e, exc_info=True) raise HTTPException(status_code=502, detail=f"Responses forward error: {e}") from e + async def _stream_and_track_responses( + self, + body: dict[str, Any], + *, + record_body: dict[str, Any] | None = None, + session_id: str, + turn_type: str, + injected_skills: list[str], + session_done: bool, + ): + """Wrap _stream_llm_responses: passthrough SSE + parse response.completed inline.""" + tracked = False + buf = "" + output_items: dict[int, dict[str, Any]] = {} + output_text_parts: dict[tuple[int, int], str] = {} + + def ensure_message_item(output_index: int) -> dict[str, Any]: + item = output_items.setdefault( + output_index, + { + "type": "message", + "role": "assistant", + "status": "completed", + "content": [], + }, + ) + content = item.setdefault("content", []) + if not isinstance(content, list): + item["content"] = [] + return item + + def apply_output_text(output_index: int, content_index: int, text: str) -> None: + item = ensure_message_item(output_index) + content = item.setdefault("content", []) + while len(content) <= content_index: + content.append({"type": "output_text", "text": "", "annotations": []}) + part = content[content_index] + if isinstance(part, dict): + part["type"] = part.get("type") or "output_text" + part["text"] = text + part.setdefault("annotations", []) + + def parse_responses_stream_event(data: dict[str, Any]) -> dict[str, Any] | None: + event_type = data.get("type") + output_index = int(data.get("output_index", 0) or 0) + content_index = int(data.get("content_index", 0) or 0) + + if event_type == "response.output_item.added": + item = data.get("item") + if isinstance(item, dict): + output_items[output_index] = item + elif event_type == "response.output_item.done": + item = data.get("item") + if isinstance(item, dict): + output_items[output_index] = item + elif event_type == "response.output_text.delta": + key = (output_index, content_index) + output_text_parts[key] = output_text_parts.get(key, "") + str(data.get("delta") or "") + apply_output_text(output_index, content_index, output_text_parts[key]) + elif event_type == "response.output_text.done": + text = str(data.get("text") or output_text_parts.get((output_index, content_index), "")) + output_text_parts[(output_index, content_index)] = text + apply_output_text(output_index, content_index, text) + elif event_type == "response.content_part.done": + part = data.get("part") + if isinstance(part, dict) and part.get("type") == "output_text": + text = str(part.get("text") or output_text_parts.get((output_index, content_index), "")) + output_text_parts[(output_index, content_index)] = text + apply_output_text(output_index, content_index, text) + elif event_type == "response.completed": + response_payload = data.get("response") if isinstance(data.get("response"), dict) else dict(data) + if output_items and not response_payload.get("output"): + response_payload = { + **response_payload, + "output": [item for _, item in sorted(output_items.items())], + } + return response_payload + return None + + async for chunk in self._stream_llm_responses(body): + if not tracked: + try: + text = chunk.decode("utf-8", errors="ignore") if isinstance(chunk, bytes) else chunk + buf += text + while "\n" in buf: + line, buf = buf.split("\n", 1) + stripped = line.strip() + if not stripped.startswith("data: "): + continue + raw = stripped[6:] + if raw == "[DONE]": + continue + try: + data = json.loads(raw) + except Exception: + continue + response_payload = parse_responses_stream_event(data) if isinstance(data, dict) else None + if response_payload is not None: + self._record_responses_turn( + session_id, + record_body or body, + response_payload, + turn_type=turn_type, + injected_skills=injected_skills, + session_done=session_done, + ) + tracked = True + break + except Exception: + pass + yield chunk + async def _stream_llm_responses(self, body: dict[str, Any]): """Passthrough upstream Responses SSE without aggregating or rewriting events.""" import httpx @@ -2544,7 +2854,7 @@ async def _stream_llm_responses(self, body: dict[str, Any]): async with httpx.AsyncClient(timeout=_llm_request_timeout_seconds()) as client: async with client.stream("POST", url, json=send_body, headers=headers) as resp: resp.raise_for_status() - async for chunk in resp.aiter_raw(): + async for chunk in resp.aiter_bytes(): if chunk: yield chunk except httpx.HTTPStatusError as e: @@ -2778,11 +3088,21 @@ async def _upload_session_data( logger.warning("[SkillHub] session upload failed: %s", e) return False - def _maybe_upload_session_snapshot(self, session_id: str, turn_num: int) -> None: + def _next_user_turn_num(self, session_id: str) -> int: + self._user_turn_counts[session_id] = self._user_turn_counts.get(session_id, 0) + 1 + return self._user_turn_counts[session_id] + + def _advance_user_turn_and_maybe_upload(self, session_id: str) -> int: + user_turn_num = self._next_user_turn_num(session_id) + self._maybe_upload_session_snapshot(session_id, user_turn_num) + return user_turn_num + + def _maybe_upload_session_snapshot(self, session_id: str, user_turn_num: int) -> None: + """Queue a session snapshot when the user-visible turn cadence is reached.""" interval = max(0, int(getattr(self.config, "sharing_session_upload_interval", 0) or 0)) if not self.config.sharing_enabled or interval <= 0: return - if turn_num <= 0 or turn_num % interval != 0: + if user_turn_num <= 0 or user_turn_num % interval != 0: return turns = copy.deepcopy(self._session_turns.get(session_id, [])) if not turns: @@ -2798,14 +3118,23 @@ async def _trigger_evolve(self) -> None: url = str(getattr(self.config, "evolve_server_url", "") or "").strip().rstrip("/") if not url: return - try: - import httpx + import httpx - async with httpx.AsyncClient(timeout=5.0) as client: - await client.post(f"{url}/trigger") - logger.info("[SkillHub] triggered evolve server: %s", url) - except Exception as e: - logger.warning("[SkillHub] evolve trigger failed: %s", e) + for attempt in range(3): + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post(f"{url}/trigger") + resp.raise_for_status() + result = resp.json() + logger.info("[SkillHub] triggered evolve server: %s", url) + if isinstance(result, dict) and int(result.get("uploaded_skills") or 0) > 0: + await self._pull_skills_from_cloud() + return + except Exception as e: + if attempt < 2: + await asyncio.sleep(1.0 * (attempt + 1)) + else: + logger.warning("[SkillHub] evolve trigger failed after 3 attempts: %s", e) # ------------------------------------------------------------------ # # Skill pull (cloud -> local) # diff --git a/skillclaw/claw_adapter.py b/skillclaw/claw_adapter.py index b0636a1..b054c94 100644 --- a/skillclaw/claw_adapter.py +++ b/skillclaw/claw_adapter.py @@ -45,6 +45,7 @@ _HERMES_BACKUP_DIR = Path.home() / ".skillclaw" / "backups" / "hermes" _CODEX_HOME = Path.home() / ".codex" _CODEX_CONFIG_PATH = _CODEX_HOME / "config.toml" +_CODEX_PROFILE_CONFIG_PATH = _CODEX_HOME / "skillclaw.config.toml" _CODEX_SKILLS_DIR = _CODEX_HOME / "skills" _CODEX_BACKUP_DIR = Path.home() / ".skillclaw" / "backups" / "codex" _CLAUDE_HOME = Path.home() / ".claude" @@ -656,7 +657,6 @@ def _build_codex_provider_block(base_url: str, api_key: str) -> str: def _build_codex_profile_block(model_id: str) -> str: lines = [ - "[profiles.skillclaw]", f"model = {_format_toml_value(model_id)}", 'model_provider = "skillclaw"', ] @@ -673,6 +673,7 @@ def _configure_codex(cfg: "SkillClawConfig") -> None: api_key = cfg.proxy_api_key or "skillclaw" base_url = f"http://127.0.0.1:{cfg.proxy_port}/v1" config_path = _CODEX_CONFIG_PATH + profile_config_path = _CODEX_PROFILE_CONFIG_PATH _prepare_external_skills_dir(_CODEX_SKILLS_DIR, "Codex") existing_text = "" @@ -687,16 +688,17 @@ def _configure_codex(cfg: "SkillClawConfig") -> None: updated = _remove_top_level_toml_keys(updated, {"model", "model_provider"}) updated = _remove_toml_table(updated, "model_providers.skillclaw").rstrip() + "\n\n" updated = _remove_toml_table(updated, "profiles.skillclaw").rstrip() + "\n\n" - updated += _build_codex_provider_block(base_url, api_key) - updated += "\n" + _build_codex_profile_block(model_id) + profile_text = _build_codex_profile_block(model_id) + "\n" + _build_codex_provider_block(base_url, api_key) _backup_codex_config_if_changed(config_path, updated) _write_text_atomic(config_path, updated, "Codex config") + _write_text_atomic(profile_config_path, profile_text, "Codex SkillClaw profile config") def inspect_codex_config(cfg: "SkillClawConfig") -> dict[str, object]: """Return a diagnostic snapshot of the local Codex integration state.""" config_path = _CODEX_CONFIG_PATH + profile_config_path = _CODEX_PROFILE_CONFIG_PATH expected_model = cfg.served_model_name or cfg.llm_model_id or "skillclaw-model" expected_base_url = f"http://127.0.0.1:{cfg.proxy_port}/v1" expected_api_key = cfg.proxy_api_key or "skillclaw" @@ -711,16 +713,21 @@ def inspect_codex_config(cfg: "SkillClawConfig") -> dict[str, object]: text = config_path.read_text(encoding="utf-8") except Exception as e: logger.warning("[ClawAdapter] Failed to read Codex config %s: %s", config_path, e) + profile_text = "" + if profile_config_path.exists(): + try: + profile_text = profile_config_path.read_text(encoding="utf-8") + except Exception as e: + logger.warning("[ClawAdapter] Failed to read Codex profile config %s: %s", profile_config_path, e) configured_model = str(_extract_top_level_toml_value(text, "model") or "") configured_provider = str(_extract_top_level_toml_value(text, "model_provider") or "") - provider_cfg = _extract_toml_table(text, "model_providers.skillclaw") + provider_cfg = _extract_toml_table(profile_text, "model_providers.skillclaw") configured_base_url = str(provider_cfg.get("base_url") or "") configured_wire_api = str(provider_cfg.get("wire_api") or "") configured_token = str(provider_cfg.get("experimental_bearer_token") or "") - profile_cfg = _extract_toml_table(text, "profiles.skillclaw") - configured_profile_model = str(profile_cfg.get("model") or "") - configured_profile_provider = str(profile_cfg.get("model_provider") or "") + configured_profile_model = str(_extract_top_level_toml_value(profile_text, "model") or "") + configured_profile_provider = str(_extract_top_level_toml_value(profile_text, "model_provider") or "") proxy_match = ( configured_profile_model == expected_model @@ -743,9 +750,13 @@ def inspect_codex_config(cfg: "SkillClawConfig") -> dict[str, object]: if not config_path.exists(): issues.append("Codex config is missing: ~/.codex/config.toml") + if not profile_config_path.exists(): + issues.append("Codex SkillClaw profile config is missing: ~/.codex/skillclaw.config.toml") if not proxy_match: issues.append("Codex SkillClaw profile is missing or not pointing at the local SkillClaw proxy.") - next_steps.append("Start SkillClaw once with `claw_type=codex` so it can register ~/.codex/config.toml.") + next_steps.append( + "Start SkillClaw once with `claw_type=codex` so it can register ~/.codex/skillclaw.config.toml." + ) if configured_provider == "skillclaw": issues.append("Codex global model_provider still points at SkillClaw; normal Codex runs may be intercepted.") next_steps.append('Remove top-level `model_provider = "skillclaw"` or run `skillclaw restore codex`.') @@ -797,7 +808,12 @@ def restore_codex_config(backup_path: Path | None = None) -> dict[str, str]: text = source.read_text(encoding="utf-8") target = _CODEX_CONFIG_PATH _write_text_atomic(target, text, "Codex config restore") - return {"source": str(source), "target": str(target)} + profile_target = _CODEX_PROFILE_CONFIG_PATH + removed_profile = False + if profile_target.exists(): + profile_target.unlink() + removed_profile = True + return {"source": str(source), "target": str(target), "removed_profile": str(removed_profile)} # ------------------------------------------------------------------ # diff --git a/skillclaw/cli.py b/skillclaw/cli.py index 89c99a0..529c702 100644 --- a/skillclaw/cli.py +++ b/skillclaw/cli.py @@ -502,6 +502,8 @@ def restore_codex(backup_path: str | None): raise click.ClickException(str(exc)) from None click.echo(f"Restored Codex config: {result['target']} <- {result['source']}") + if result.get("removed_profile") == "True": + click.echo("Removed Codex SkillClaw profile config: ~/.codex/skillclaw.config.toml") @restore.command(name="claude") diff --git a/tests/test_codex_profile_integration.py b/tests/test_codex_profile_integration.py index 832c365..6fcbcbb 100644 --- a/tests/test_codex_profile_integration.py +++ b/tests/test_codex_profile_integration.py @@ -27,12 +27,14 @@ def record_injection(self, names: list[str]) -> None: def test_configure_codex_registers_profile_without_replacing_global_defaults(monkeypatch, tmp_path: Path) -> None: config_path = tmp_path / ".codex" / "config.toml" + profile_config_path = tmp_path / ".codex" / "skillclaw.config.toml" config_path.parent.mkdir(parents=True) config_path.write_text( 'model = "gpt-5.5"\nmodel_provider = "openai"\n\n[profiles.default]\nmodel = "gpt-5.5"\n', encoding="utf-8", ) monkeypatch.setattr(claw_adapter, "_CODEX_CONFIG_PATH", config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_PROFILE_CONFIG_PATH", profile_config_path) monkeypatch.setattr(claw_adapter, "_CODEX_SKILLS_DIR", tmp_path / ".codex" / "skills") monkeypatch.setattr(claw_adapter, "_CODEX_BACKUP_DIR", tmp_path / "backups") @@ -47,24 +49,38 @@ def test_configure_codex_registers_profile_without_replacing_global_defaults(mon text = config_path.read_text(encoding="utf-8") assert 'model = "gpt-5.5"' in text assert 'model_provider = "openai"' in text - assert "[model_providers.skillclaw]" in text - assert 'base_url = "http://127.0.0.1:31000/v1"' in text - assert 'wire_api = "responses"' in text - assert 'experimental_bearer_token = "skillclaw-key"' in text - assert "[profiles.skillclaw]" in text - assert 'model = "skillclaw-model"' in text - assert 'model_provider = "skillclaw"' in text + assert "[profiles.skillclaw]" not in text + assert "[model_providers.skillclaw]" not in text + profile_text = profile_config_path.read_text(encoding="utf-8") + assert 'model = "skillclaw-model"' in profile_text + assert 'model_provider = "skillclaw"' in profile_text + assert "[model_providers.skillclaw]" in profile_text + assert 'base_url = "http://127.0.0.1:31000/v1"' in profile_text + assert 'wire_api = "responses"' in profile_text + assert 'experimental_bearer_token = "skillclaw-key"' in profile_text assert (tmp_path / ".codex" / "skills").is_dir() def test_configure_codex_removes_legacy_global_skillclaw_defaults(monkeypatch, tmp_path: Path) -> None: config_path = tmp_path / ".codex" / "config.toml" + profile_config_path = tmp_path / ".codex" / "skillclaw.config.toml" config_path.parent.mkdir(parents=True) config_path.write_text( - 'model = "skillclaw-model"\nmodel_provider = "skillclaw"\n\n[profiles.default]\nmodel = "gpt-5.5"\n', + ( + 'model = "skillclaw-model"\n' + 'model_provider = "skillclaw"\n\n' + "[model_providers.skillclaw]\n" + 'base_url = "http://127.0.0.1:30000/v1"\n\n' + "[profiles.skillclaw]\n" + 'model = "skillclaw-model"\n' + 'model_provider = "skillclaw"\n\n' + "[profiles.default]\n" + 'model = "gpt-5.5"\n' + ), encoding="utf-8", ) monkeypatch.setattr(claw_adapter, "_CODEX_CONFIG_PATH", config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_PROFILE_CONFIG_PATH", profile_config_path) monkeypatch.setattr(claw_adapter, "_CODEX_SKILLS_DIR", tmp_path / ".codex" / "skills") monkeypatch.setattr(claw_adapter, "_CODEX_BACKUP_DIR", tmp_path / "backups") @@ -73,7 +89,70 @@ def test_configure_codex_removes_legacy_global_skillclaw_defaults(monkeypatch, t top_level = config_path.read_text(encoding="utf-8").split("[", 1)[0] assert "model_provider" not in top_level assert "model =" not in top_level - assert "[profiles.skillclaw]" in config_path.read_text(encoding="utf-8") + text = config_path.read_text(encoding="utf-8") + assert "[profiles.skillclaw]" not in text + assert "[model_providers.skillclaw]" not in text + assert "[profiles.default]" in text + assert "[model_providers.skillclaw]" in profile_config_path.read_text(encoding="utf-8") + + +def test_inspect_codex_config_reads_split_profile_config(monkeypatch, tmp_path: Path) -> None: + config_path = tmp_path / ".codex" / "config.toml" + profile_config_path = tmp_path / ".codex" / "skillclaw.config.toml" + skills_dir = tmp_path / ".codex" / "skills" + config_path.parent.mkdir(parents=True) + skills_dir.mkdir() + config_path.write_text('model = "gpt-5.5"\nmodel_provider = "openai"\n', encoding="utf-8") + profile_config_path.write_text( + ( + 'model = "skillclaw-model"\n' + 'model_provider = "skillclaw"\n\n' + "[model_providers.skillclaw]\n" + 'name = "SkillClaw"\n' + 'base_url = "http://127.0.0.1:31000/v1"\n' + 'wire_api = "responses"\n' + 'experimental_bearer_token = "skillclaw-key"\n' + ), + encoding="utf-8", + ) + monkeypatch.setattr(claw_adapter, "_CODEX_CONFIG_PATH", config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_PROFILE_CONFIG_PATH", profile_config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_SKILLS_DIR", skills_dir) + monkeypatch.setattr(claw_adapter, "_CODEX_BACKUP_DIR", tmp_path / "backups") + + report = claw_adapter.inspect_codex_config( + SkillClawConfig( + served_model_name="skillclaw-model", + proxy_api_key="skillclaw-key", + proxy_port=31000, + skills_dir=str(skills_dir), + ) + ) + + assert report["status"] == "ok" + assert report["proxy_match"] is True + assert report["configured_profile_model"] == "skillclaw-model" + assert report["configured_base_url"] == "http://127.0.0.1:31000/v1" + + +def test_restore_codex_config_removes_split_profile_config(monkeypatch, tmp_path: Path) -> None: + config_path = tmp_path / ".codex" / "config.toml" + profile_config_path = tmp_path / ".codex" / "skillclaw.config.toml" + backup_path = tmp_path / "backups" / "config.latest.toml" + config_path.parent.mkdir(parents=True) + backup_path.parent.mkdir(parents=True) + config_path.write_text('model_provider = "skillclaw"\n', encoding="utf-8") + profile_config_path.write_text('model_provider = "skillclaw"\n', encoding="utf-8") + backup_path.write_text('model_provider = "openai"\n', encoding="utf-8") + monkeypatch.setattr(claw_adapter, "_CODEX_CONFIG_PATH", config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_PROFILE_CONFIG_PATH", profile_config_path) + monkeypatch.setattr(claw_adapter, "_CODEX_BACKUP_DIR", backup_path.parent) + + result = claw_adapter.restore_codex_config() + + assert config_path.read_text(encoding="utf-8") == 'model_provider = "openai"\n' + assert not profile_config_path.exists() + assert result["removed_profile"] == "True" def test_codex_config_defaults_to_responses_mode_and_codex_skills(tmp_path: Path) -> None: diff --git a/tests/test_evolve_proxy_reload.py b/tests/test_evolve_proxy_reload.py new file mode 100644 index 0000000..ff1bcfb --- /dev/null +++ b/tests/test_evolve_proxy_reload.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import types + +import pytest + +from evolve_server.core.config import EvolveServerConfig +from evolve_server.engines.workflow import EvolveServer + + +@pytest.mark.anyio +async def test_notify_proxy_reload_posts_callback_with_auth(monkeypatch) -> None: + server = EvolveServer.__new__(EvolveServer) + server.config = EvolveServerConfig( + skill_reload_mode="callback", + proxy_reload_url="http://proxy.test/", + proxy_reload_api_key="secret", + ) + captured = {} + + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + captured["timeout"] = kwargs.get("timeout") + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, headers): + captured["url"] = url + captured["headers"] = headers + return types.SimpleNamespace(raise_for_status=lambda: None) + + fake_httpx = types.SimpleNamespace(AsyncClient=FakeAsyncClient) + monkeypatch.setitem(__import__("sys").modules, "httpx", fake_httpx) + + await server._notify_proxy_reload() + + assert captured == { + "timeout": 5.0, + "url": "http://proxy.test/internal/reload-skills", + "headers": {"Authorization": "Bearer secret"}, + } + + +@pytest.mark.anyio +async def test_notify_proxy_reload_retries_on_http_error(monkeypatch) -> None: + server = EvolveServer.__new__(EvolveServer) + server.config = EvolveServerConfig( + skill_reload_mode="callback", + proxy_reload_url="http://proxy.test", + proxy_reload_api_key="secret", + ) + attempts = {"count": 0} + + class FakeResponse: + def raise_for_status(self): + raise RuntimeError("401 Unauthorized") + + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, headers): + attempts["count"] += 1 + return FakeResponse() + + async def fake_sleep(_delay): + return None + + fake_httpx = types.SimpleNamespace(AsyncClient=FakeAsyncClient) + monkeypatch.setitem(__import__("sys").modules, "httpx", fake_httpx) + monkeypatch.setattr("evolve_server.engines.workflow.asyncio.sleep", fake_sleep) + + await server._notify_proxy_reload() + + assert attempts == {"count": 3} + + +@pytest.mark.anyio +async def test_notify_proxy_reload_skips_non_callback_modes(monkeypatch) -> None: + server = EvolveServer.__new__(EvolveServer) + server.config = EvolveServerConfig( + skill_reload_mode="poll", + proxy_reload_url="http://proxy.test", + proxy_reload_api_key="secret", + ) + called = {"http": False} + + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + called["http"] = True + + fake_httpx = types.SimpleNamespace(AsyncClient=FakeAsyncClient) + monkeypatch.setitem(__import__("sys").modules, "httpx", fake_httpx) + + await server._notify_proxy_reload() + + assert called == {"http": False} diff --git a/tests/test_responses_native.py b/tests/test_responses_native.py index b22140f..5cf52b3 100644 --- a/tests/test_responses_native.py +++ b/tests/test_responses_native.py @@ -1,3 +1,6 @@ +import gzip +import json + import httpx import pytest @@ -122,6 +125,124 @@ async def fake_forward(body): assert seen["body"]["tools"] == [{"type": "custom", "name": "js_repl"}] +@pytest.mark.asyncio +async def test_native_responses_records_original_prompt_before_skill_injection(): + class FakeSkillManager: + def refresh_if_changed(self): + return None + + def build_injection_prompt(self, *, max_chars): + return "\n" + ("catalog filler " * 300) + "\n" + + def get_all_skills(self): + return [{"name": "demo-skill"}] + + def record_injection(self, names): + self.names = names + + server = SkillClawAPIServer( + SkillClawConfig( + llm_api_mode="responses", + llm_api_base="http://upstream.test/v1", + llm_model_id="upstream-model", + proxy_api_key="skillclaw", + record_enabled=False, + ), + skill_manager=FakeSkillManager(), + ) + seen = {} + + async def fake_forward(body): + seen["instructions"] = body.get("instructions", "") + return { + "id": "resp_native", + "object": "response", + "created_at": 0, + "status": "completed", + "model": "upstream-model", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "native ok", "annotations": []}], + } + ], + } + + server._forward_to_llm_responses = fake_forward + client = httpx.AsyncClient(transport=httpx.ASGITransport(app=server.app), base_url="http://test") + try: + response = await client.post( + "/v1/responses", + headers={"Authorization": "Bearer skillclaw", "Session_id": "codex-session-1"}, + json={ + "model": "skillclaw-model", + "instructions": "original instructions", + "input": "actual user task", + "stream": False, + }, + ) + finally: + await client.aclose() + + assert response.status_code == 200 + assert "" in seen["instructions"] + turn = server._session_turns["codex-session-1"][0] + assert turn["prompt_text"] == "original instructions\nactual user task" + assert "" not in turn["prompt_text"] + assert turn["injected_skills"] == ["demo-skill"] + assert turn["raw_turn_kind"] == "final" + assert turn["user_turn_num"] == 1 + + +def test_native_responses_upload_cadence_uses_user_turn_counter() -> None: + server = object.__new__(SkillClawAPIServer) + server.config = SkillClawConfig(sharing_enabled=True, sharing_session_upload_interval=2) + server._session_turns = {} + server._user_turn_counts = {} + server._session_last_active = {} + queued = [] + + def fake_create_task(coro): + queued.append(coro) + return None + + server._safe_create_task = fake_create_task + + response_payload = { + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "ok"}], + } + ] + } + server._record_responses_turn( + "codex-session-1", + {"input": "first"}, + response_payload, + turn_type="main", + injected_skills=[], + session_done=False, + ) + server._record_responses_turn( + "codex-session-1", + {"input": "second"}, + response_payload, + turn_type="main", + injected_skills=[], + session_done=False, + ) + + assert [turn["turn_num"] for turn in server._session_turns["codex-session-1"]] == [1, 2] + assert [turn["user_turn_num"] for turn in server._session_turns["codex-session-1"]] == [1, 2] + assert len(queued) == 1 + + queued[0].close() + + @pytest.mark.asyncio async def test_forward_to_llm_responses_stream_preserves_upstream_sse(monkeypatch): captured = {} @@ -130,7 +251,7 @@ class FakeStreamResponse: def raise_for_status(self): return None - async def aiter_raw(self): + async def aiter_bytes(self): yield b'data: {"type":"response.created"}\n\n' yield b'data: {"type":"response.completed"}\n\n' yield b"data: [DONE]\n\n" @@ -191,6 +312,202 @@ def stream(self, method, url, json, headers): assert captured["json"]["tools"] == body["tools"] +@pytest.mark.asyncio +async def test_forward_to_llm_responses_stream_decodes_compressed_upstream_sse(monkeypatch): + raw_sse = b'data: {"type":"response.created"}\n\ndata: {"type":"response.completed"}\n\ndata: [DONE]\n\n' + + class FakeStreamResponse: + def raise_for_status(self): + return None + + def aiter_bytes(self): + response = httpx.Response( + 200, + headers={"content-encoding": "gzip"}, + stream=httpx.ByteStream(gzip.compress(raw_sse)), + ) + return response.aiter_bytes() + + class FakeStreamContext: + async def __aenter__(self): + return FakeStreamResponse() + + async def __aexit__(self, exc_type, exc, tb): + return False + + class FakeAsyncClient: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def stream(self, method, url, json, headers): + return FakeStreamContext() + + monkeypatch.setattr(httpx, "AsyncClient", FakeAsyncClient) + server = object.__new__(SkillClawAPIServer) + server.config = SkillClawConfig( + llm_api_base="http://upstream.test/v1", + llm_api_key="upstream-key", + llm_model_id="upstream-model", + llm_api_mode="responses", + ) + + chunks = [] + async for chunk in server._stream_llm_responses({"model": "skillclaw-model", "input": "hi", "stream": True}): + chunks.append(chunk) + + assert b"".join(chunks) == raw_sse + + +@pytest.mark.asyncio +async def test_stream_and_track_responses_records_before_completed_chunk_is_consumed(): + server = object.__new__(SkillClawAPIServer) + server._session_turns = {} + server._safe_create_task = lambda coro: None + recorded = {} + + completed_event = { + "type": "response.completed", + "response": { + "id": "resp_native", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "tracked ok", "annotations": []}], + } + ], + }, + } + + async def fake_stream(_body): + payload = ("data: " + json.dumps(completed_event) + "\n\n").encode() + yield payload[:20] + yield payload[20:] + yield b"data: [DONE]\n\n" + + def fake_record(session_id, request_body, response_payload, *, turn_type, injected_skills, session_done): + recorded["session_id"] = session_id + recorded["request_body"] = request_body + recorded["response_text"] = response_payload["output"][0]["content"][0]["text"] + recorded["turn_type"] = turn_type + recorded["injected_skills"] = injected_skills + recorded["session_done"] = session_done + + server._stream_llm_responses = fake_stream + server._record_responses_turn = fake_record + + stream = server._stream_and_track_responses( + {"model": "skillclaw-model", "instructions": "catalog", "stream": True}, + record_body={"model": "skillclaw-model", "instructions": "original instructions", "stream": True}, + session_id="codex-session-1", + turn_type="main", + injected_skills=["demo"], + session_done=False, + ) + first = await stream.__anext__() + second = await stream.__anext__() + assert first + second == ("data: " + json.dumps(completed_event) + "\n\n").encode() + assert recorded == { + "session_id": "codex-session-1", + "request_body": {"model": "skillclaw-model", "instructions": "original instructions", "stream": True}, + "response_text": "tracked ok", + "turn_type": "main", + "injected_skills": ["demo"], + "session_done": False, + } + await stream.aclose() + + +@pytest.mark.asyncio +async def test_stream_and_track_responses_records_output_from_stream_events_when_completed_lacks_output(): + server = object.__new__(SkillClawAPIServer) + recorded = {} + + events = [ + { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "in_progress", + "content": [], + }, + }, + { + "type": "response.output_text.delta", + "output_index": 0, + "content_index": 0, + "item_id": "msg_1", + "delta": "real ", + }, + { + "type": "response.output_text.delta", + "output_index": 0, + "content_index": 0, + "item_id": "msg_1", + "delta": "ok", + }, + { + "type": "response.output_text.done", + "output_index": 0, + "content_index": 0, + "item_id": "msg_1", + "text": "real ok", + }, + { + "type": "response.completed", + "response": { + "id": "resp_1", + "object": "response", + "status": "completed", + "model": "gpt-5.5", + }, + }, + ] + + async def fake_stream(_body): + for event in events: + yield ("event: " + event["type"] + "\n").encode() + yield ("data: " + json.dumps(event) + "\n\n").encode() + + def fake_record(session_id, request_body, response_payload, *, turn_type, injected_skills, session_done): + recorded["session_id"] = session_id + recorded["response_text"] = response_payload["output"][0]["content"][0]["text"] + recorded["turn_type"] = turn_type + + server._stream_llm_responses = fake_stream + server._record_responses_turn = fake_record + + chunks = [] + async for chunk in server._stream_and_track_responses( + {"model": "skillclaw-model", "input": "hi", "stream": True}, + session_id="codex-session-1", + turn_type="main", + injected_skills=[], + session_done=False, + ): + chunks.append(chunk) + + assert b"".join(chunks).startswith(b"event: response.output_item.added\n") + assert recorded == { + "session_id": "codex-session-1", + "response_text": "real ok", + "turn_type": "main", + } + + @pytest.mark.asyncio async def test_responses_endpoint_passthroughs_native_stream(): server = SkillClawAPIServer( diff --git a/tests/test_session_upload_trigger.py b/tests/test_session_upload_trigger.py index 11daf73..7265cb7 100644 --- a/tests/test_session_upload_trigger.py +++ b/tests/test_session_upload_trigger.py @@ -1,8 +1,14 @@ from __future__ import annotations +import httpx import pytest -from skillclaw.api_server import SkillClawAPIServer +from skillclaw.api_server import ( + _PROTOCOL_ANTHROPIC_MESSAGES, + SkillClawAPIServer, + _classify_raw_turn_kind, + _is_user_turn_boundary, +) from skillclaw.config import SkillClawConfig @@ -13,6 +19,7 @@ def _server_for_snapshot_tests() -> SkillClawAPIServer: sharing_session_upload_interval=2, evolve_server_url="http://evolve.test", ) + server._user_turn_counts = {} server._session_turns = { "session-a": [ {"turn_num": 1, "prompt_text": "one"}, @@ -92,3 +99,177 @@ async def fake_trigger(): await server._upload_session_snapshot_and_trigger("session-a", [{"turn_num": 2}]) assert calls == {"upload": 2, "trigger": 1} + + +def test_user_turn_cadence_counts_visible_turns_not_raw_turns() -> None: + server = _server_for_snapshot_tests() + queued = [] + + def fake_create_task(coro): + queued.append(coro) + return None + + server._safe_create_task = fake_create_task + server._session_turns["session-a"].append({"turn_num": 3, "raw_turn_kind": "tool_use"}) + + assert server._advance_user_turn_and_maybe_upload("session-a") == 1 + assert queued == [] + + assert server._advance_user_turn_and_maybe_upload("session-a") == 2 + assert len(queued) == 1 + + queued[0].close() + + +def test_claude_internal_turns_do_not_advance_user_turn_boundary() -> None: + assert ( + _classify_raw_turn_kind( + _PROTOCOL_ANTHROPIC_MESSAGES, + '{"title":"Weekly report"}', + [], + ) + == "session_title" + ) + assert ( + _classify_raw_turn_kind( + _PROTOCOL_ANTHROPIC_MESSAGES, + "", + [{"id": "call_1", "type": "function", "function": {"name": "Read", "arguments": "{}"}}], + ) + == "tool_use" + ) + assert not _is_user_turn_boundary("session_title") + assert not _is_user_turn_boundary("tool_use") + assert _is_user_turn_boundary(_classify_raw_turn_kind(_PROTOCOL_ANTHROPIC_MESSAGES, "final answer", [])) + + +def test_openclaw_and_hermes_main_turns_use_user_turn_upload_cadence() -> None: + server = _server_for_snapshot_tests() + captured = {} + + class DummyCoro: + def close(self): + return None + + def fake_create_task(coro): + return None + + def fake_upload_snapshot(session_id, turns): + captured["session_id"] = session_id + captured["turns"] = turns + return DummyCoro() + + server._safe_create_task = fake_create_task + server._upload_session_snapshot_and_trigger = fake_upload_snapshot + server._session_turns["session-a"] = [] + + for raw_turn_num in range(1, 4): + turn_record = { + "turn_num": raw_turn_num, + "raw_turn_kind": _classify_raw_turn_kind("", f"answer {raw_turn_num}", []), + "prompt_text": f"user {raw_turn_num}", + } + server._session_turns["session-a"].append(turn_record) + user_turn_num = server._next_user_turn_num("session-a") + turn_record["user_turn_num"] = user_turn_num + server._maybe_upload_session_snapshot("session-a", user_turn_num) + + assert server._user_turn_counts["session-a"] == 3 + assert captured["session_id"] == "session-a" + assert [t["turn_num"] for t in captured["turns"]] == [1, 2] + assert [t["user_turn_num"] for t in captured["turns"]] == [1, 2] + + +def test_skill_reload_polling_starts_only_in_poll_mode(monkeypatch) -> None: + server = object.__new__(SkillClawAPIServer) + server.config = SkillClawConfig(sharing_enabled=True, sharing_skill_reload_mode="poll") + server._skill_reload_task = None + server._skill_reload_interval_seconds = 30 + created = [] + + class FakeTask: + def done(self): + return False + + def add_done_callback(self, _callback): + return None + + def fake_create_task(coro): + created.append(coro) + return FakeTask() + + import asyncio + + monkeypatch.setattr(asyncio, "create_task", fake_create_task) + server._start_skill_reload_polling() + + assert len(created) == 1 + created[0].close() + + +def test_skill_reload_polling_does_not_start_when_disabled_or_callback(monkeypatch) -> None: + created = [] + + class FakeTask: + def done(self): + return False + + def fake_create_task(coro): + created.append(coro) + return FakeTask() + + import asyncio + + monkeypatch.setattr(asyncio, "create_task", fake_create_task) + for mode in ("off", "callback"): + server = object.__new__(SkillClawAPIServer) + server.config = SkillClawConfig(sharing_enabled=True, sharing_skill_reload_mode=mode) + server._skill_reload_task = None + server._skill_reload_interval_seconds = 30 + server._start_skill_reload_polling() + + server = object.__new__(SkillClawAPIServer) + server.config = SkillClawConfig(sharing_enabled=False, sharing_skill_reload_mode="poll") + server._skill_reload_task = None + server._skill_reload_interval_seconds = 30 + server._start_skill_reload_polling() + + assert created == [] + + +@pytest.mark.anyio +async def test_internal_reload_skills_endpoint_requires_auth_and_pulls(tmp_path) -> None: + server = SkillClawAPIServer( + SkillClawConfig( + proxy_api_key="secret", + record_enabled=False, + record_dir=str(tmp_path), + ) + ) + calls = {"pull": 0} + + async def fake_pull(skip_names=None): + assert skip_names is None + calls["pull"] += 1 + + class FakeSkillManager: + def get_all_skills(self): + return [{"name": "weekly-report"}, {"name": "demo"}] + + server._pull_skills_from_cloud = fake_pull + server.skill_manager = FakeSkillManager() + + client = httpx.AsyncClient(transport=httpx.ASGITransport(app=server.app), base_url="http://test") + try: + unauthorized = await client.post("/internal/reload-skills") + authorized = await client.post( + "/internal/reload-skills", + headers={"Authorization": "Bearer secret"}, + ) + finally: + await client.aclose() + + assert unauthorized.status_code == 401 + assert authorized.status_code == 200 + assert authorized.json() == {"ok": True, "skills": 2} + assert calls == {"pull": 1}