diff --git a/ductor_bot/cli/process_registry.py b/ductor_bot/cli/process_registry.py index e089854e..b5a51f57 100644 --- a/ductor_bot/cli/process_registry.py +++ b/ductor_bot/cli/process_registry.py @@ -36,6 +36,7 @@ class ProcessRegistry: def __init__(self) -> None: self._processes: dict[int, list[TrackedProcess]] = {} self._aborted: set[int] = set() + self._aborted_topics: set[tuple[int, int | None]] = set() self._aborted_labels: set[tuple[int, str]] = set() self._interrupted: set[int] = set() # MED #9: serialize bulk kill operations (kill_for_task / kill_stale / @@ -102,6 +103,30 @@ async def kill_all(self, chat_id: int) -> int: return 0 return await _kill_processes(entries) + async def kill_by_chat_topic(self, chat_id: int, topic_id: int | None) -> int: + """Kill processes belonging to *chat_id* + *topic_id*. Returns count. + + Used by ``/stop`` so a stop in one topic does not affect another + topic in the same chat. ``topic_id=None`` matches processes + registered without a topic (e.g. private chats). + """ + async with self._kill_lock: + entries = self._processes.get(chat_id, []) + targets = [ + t for t in entries if t.topic_id == topic_id and t.process.returncode is None + ] + if not targets: + return 0 + self._aborted_topics.add((chat_id, topic_id)) + remaining = [ + t for t in entries if t.topic_id != topic_id or t.process.returncode is not None + ] + if remaining: + self._processes[chat_id] = remaining + else: + self._processes.pop(chat_id, None) + return await _kill_processes(targets) + async def kill_all_active(self) -> int: """Kill active processes across all chats. Returns total count killed.""" total = 0 @@ -117,6 +142,14 @@ def clear_abort(self, chat_id: int) -> None: """Clear the abort flag for *chat_id*.""" self._aborted.discard(chat_id) + def was_aborted_topic(self, chat_id: int, topic_id: int | None) -> bool: + """Check whether *(chat_id, topic_id)* has been aborted since last clear.""" + return (chat_id, topic_id) in self._aborted_topics + + def clear_topic_abort(self, chat_id: int, topic_id: int | None) -> None: + """Clear the topic-scoped abort flag (called after handling the abort).""" + self._aborted_topics.discard((chat_id, topic_id)) + def was_interrupted(self, chat_id: int) -> bool: """Check whether *chat_id* was soft-interrupted since last clear.""" return chat_id in self._interrupted diff --git a/ductor_bot/cli/service.py b/ductor_bot/cli/service.py index ad53ca86..0ab7020f 100644 --- a/ductor_bot/cli/service.py +++ b/ductor_bot/cli/service.py @@ -221,7 +221,9 @@ async def execute_streaming( # noqa: PLR0913 timeout_seconds=request.timeout_seconds, timeout_controller=request.timeout_controller, ): - if self._process_registry.was_aborted(request.chat_id): + if self._process_registry.was_aborted( + request.chat_id + ) or self._process_registry.was_aborted_topic(request.chat_id, request.topic_id): logger.info("Streaming aborted mid-stream chat=%d", request.chat_id) break text, result = await callbacks.dispatch(event) @@ -283,7 +285,9 @@ async def _handle_stream_fallback( init_session_id: str | None = None, ) -> AgentResponse: """Handle failed or incomplete streaming: use accumulated text or retry.""" - was_aborted = self._process_registry.was_aborted(request.chat_id) + was_aborted = self._process_registry.was_aborted( + request.chat_id + ) or self._process_registry.was_aborted_topic(request.chat_id, request.topic_id) logger.info( "Stream fallback: aborted=%s accumulated=%d init_sid=%s", was_aborted, diff --git a/ductor_bot/messenger/telegram/handlers.py b/ductor_bot/messenger/telegram/handlers.py index f87798db..bd3118e8 100644 --- a/ductor_bot/messenger/telegram/handlers.py +++ b/ductor_bot/messenger/telegram/handlers.py @@ -60,15 +60,16 @@ async def handle_abort( chat_id: int, message: Message, ) -> bool: - """Kill active CLI processes and send feedback. + """Kill active CLI processes in the current topic and send feedback. Returns True if handled, False if orchestrator not ready. """ if orchestrator is None: return False - killed = await orchestrator.abort(chat_id) - logger.info("Abort requested killed=%d", killed) + thread_id = get_thread_id(message) + killed = await orchestrator.abort(chat_id, topic_id=thread_id) + logger.info("Abort requested chat=%d topic=%s killed=%d", chat_id, thread_id, killed) text = stop_text(bool(killed), orchestrator.active_provider_name) await send_rich( bot, diff --git a/ductor_bot/orchestrator/core.py b/ductor_bot/orchestrator/core.py index 7bab3ae3..832e74d3 100644 --- a/ductor_bot/orchestrator/core.py +++ b/ductor_bot/orchestrator/core.py @@ -326,6 +326,7 @@ async def handle_message_streaming( # noqa: PLR0913 async def _handle_message_impl(self, dispatch: _MessageDispatch) -> OrchestratorResult: self._process_registry.clear_abort(dispatch.key.chat_id) + self._process_registry.clear_topic_abort(dispatch.key.chat_id, dispatch.key.topic_id) logger.info("Message received text=%s", dispatch.cmd[:80]) patterns = detect_suspicious_patterns(dispatch.text) @@ -464,8 +465,22 @@ async def reset_active_provider_session(self, key: SessionKey) -> str: logger.info("Active provider session reset provider=%s model=%s", provider, model) return provider - async def abort(self, chat_id: int) -> int: - """Kill all active CLI processes and background tasks for chat_id.""" + async def abort(self, chat_id: int, topic_id: int | None = None) -> int: + """Kill active CLI processes for *chat_id* (optionally scoped to *topic_id*). + + When ``topic_id`` is provided (``/stop`` from a specific topic), + only the foreground CLI processes registered under that + ``(chat_id, topic_id)`` pair are killed. Background tasks and + named sessions are left alone — they are not topic-tagged in + the current model and have their own management surfaces + (``/tasks``, ``/sessions``) so /stop should not double up. + + When ``topic_id`` is ``None`` (legacy callers / ``/stop_all``) + the chat-wide sweep runs as before: every process for the chat + plus every background task and named session. + """ + if topic_id is not None: + return await self._process_registry.kill_by_chat_topic(chat_id, topic_id) killed = await self._process_registry.kill_all(chat_id) if self._observers.background: killed += await self._observers.background.cancel_all(chat_id) diff --git a/ductor_bot/orchestrator/flows.py b/ductor_bot/orchestrator/flows.py index d1f041e9..a8859e9a 100644 --- a/ductor_bot/orchestrator/flows.py +++ b/ductor_bot/orchestrator/flows.py @@ -333,6 +333,7 @@ async def _maybe_recover_session( # noqa: PLR0913 _reg = orch._process_registry if ( _reg.was_aborted(key.chat_id) + or _reg.was_aborted_topic(key.chat_id, key.topic_id) or _reg.was_interrupted(key.chat_id) or not _needs_session_recovery(response) ): @@ -487,7 +488,11 @@ async def normal( # noqa: PLR0911 request, session, response = outcome.request, outcome.session, outcome.response session_recovered = outcome.session_recovered _reg = orch._process_registry - if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id): + if ( + _reg.was_aborted(key.chat_id) + or _reg.was_aborted_topic(key.chat_id, key.topic_id) + or _reg.was_interrupted(key.chat_id) + ): _reg.clear_interrupt(key.chat_id) await _preserve_session_from_response(orch, session, response, reason="abort") logger.info("Normal flow aborted/interrupted by user") @@ -567,7 +572,11 @@ async def _on_compact() -> None: return outcome.failed_result request, session, response = outcome.request, outcome.session, outcome.response _reg = orch._process_registry - if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id): + if ( + _reg.was_aborted(key.chat_id) + or _reg.was_aborted_topic(key.chat_id, key.topic_id) + or _reg.was_interrupted(key.chat_id) + ): _reg.clear_interrupt(key.chat_id) await _preserve_session_from_response(orch, session, response, reason="abort") logger.info("Streaming flow aborted/interrupted by user") @@ -735,7 +744,11 @@ async def named_session_flow( response = await orch._cli_service.execute(request) _reg = orch._process_registry - if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id): + if ( + _reg.was_aborted(key.chat_id) + or _reg.was_aborted_topic(key.chat_id, key.topic_id) + or _reg.was_interrupted(key.chat_id) + ): _reg.clear_interrupt(key.chat_id) ns.status = "idle" return OrchestratorResult(text="") @@ -799,7 +812,11 @@ async def _tagged_text_delta(chunk: str) -> None: ) _reg2 = orch._process_registry - if _reg2.was_aborted(key.chat_id) or _reg2.was_interrupted(key.chat_id): + if ( + _reg2.was_aborted(key.chat_id) + or _reg2.was_aborted_topic(key.chat_id, key.topic_id) + or _reg2.was_interrupted(key.chat_id) + ): _reg2.clear_interrupt(key.chat_id) ns.status = "idle" return OrchestratorResult(text="") diff --git a/tests/cli/test_process_registry.py b/tests/cli/test_process_registry.py index ff5725c8..20c3117a 100644 --- a/tests/cli/test_process_registry.py +++ b/tests/cli/test_process_registry.py @@ -179,6 +179,64 @@ def test_has_active_topic_id_ignores_exited() -> None: assert reg.has_active(1, topic_id=20) is True +class TestKillByChatTopicAbortMarker: + """Topic-scoped /stop marker behavior.""" + + async def test_sets_topic_abort_marker(self) -> None: + reg = ProcessRegistry() + proc = _mock_process(pid=80) + reg.register(chat_id=1, process=proc, label="main", topic_id=10) + + with patch( + "ductor_bot.cli.process_registry._kill_processes", + new_callable=AsyncMock, + return_value=1, + ): + killed = await reg.kill_by_chat_topic(1, 10) + + assert killed == 1 + assert reg.was_aborted_topic(1, 10) is True + assert reg.was_aborted(1) is False + + async def test_no_kill_no_marker(self) -> None: + reg = ProcessRegistry() + proc = _mock_process(pid=81) + reg.register(chat_id=1, process=proc, label="main", topic_id=10) + + killed = await reg.kill_by_chat_topic(1, 20) + + assert killed == 0 + assert reg.was_aborted_topic(1, 20) is False + assert reg.was_aborted_topic(1, 10) is False + + def test_clear_topic_abort_removes_marker(self) -> None: + reg = ProcessRegistry() + reg._aborted_topics.add((1, 10)) + + assert reg.was_aborted_topic(1, 10) is True + reg.clear_topic_abort(1, 10) + assert reg.was_aborted_topic(1, 10) is False + + async def test_topic_a_kill_does_not_kill_topic_b(self) -> None: + reg = ProcessRegistry() + proc_a = _mock_process(pid=82) + proc_b = _mock_process(pid=83) + reg.register(chat_id=1, process=proc_a, label="main", topic_id=10) + reg.register(chat_id=1, process=proc_b, label="main", topic_id=20) + + with patch( + "ductor_bot.cli.process_registry._kill_processes", + new_callable=AsyncMock, + return_value=1, + ): + killed = await reg.kill_by_chat_topic(1, 10) + + assert killed == 1 + assert reg.has_active(1, topic_id=10) is False + assert reg.has_active(1, topic_id=20) is True + assert proc_b.returncode is None + + async def test_kill_stale_handles_already_exited() -> None: reg = ProcessRegistry() proc = _mock_process(pid=40, returncode=0) diff --git a/tests/messenger/telegram/test_handlers.py b/tests/messenger/telegram/test_handlers.py index cde12329..52e4016b 100644 --- a/tests/messenger/telegram/test_handlers.py +++ b/tests/messenger/telegram/test_handlers.py @@ -50,7 +50,7 @@ async def test_abort_kills_processes_and_replies(self) -> None: msg = _make_message(chat_id=42) result = await handle_abort(orchestrator, bot, chat_id=42, message=msg) assert result is True - orchestrator.abort.assert_called_once_with(42) + orchestrator.abort.assert_called_once_with(42, topic_id=None) async def test_abort_no_orchestrator(self) -> None: from ductor_bot.messenger.telegram.handlers import handle_abort @@ -187,6 +187,19 @@ def test_none_username(self) -> None: class TestForumTopicPropagation: """Test that handlers extract and propagate thread_id.""" + @patch("ductor_bot.messenger.telegram.handlers.send_rich", new_callable=AsyncMock) + async def test_abort_entrypoint_passes_topic_id(self, _mock_send: AsyncMock) -> None: + from ductor_bot.messenger.telegram.handlers import handle_abort + + orchestrator = MagicMock() + orchestrator.abort = AsyncMock(return_value=1) + orchestrator.active_provider_name = "claude" + bot = MagicMock() + msg = _make_message(chat_id=42, topic_thread_id=99) + + await handle_abort(orchestrator, bot, chat_id=42, message=msg) + orchestrator.abort.assert_called_once_with(42, topic_id=99) + @patch("ductor_bot.messenger.telegram.handlers.send_rich", new_callable=AsyncMock) async def test_handle_abort_passes_thread_id(self, mock_send: AsyncMock) -> None: from ductor_bot.messenger.telegram.handlers import handle_abort diff --git a/tests/orchestrator/test_flows.py b/tests/orchestrator/test_flows.py index 9072ef64..bdac98b9 100644 --- a/tests/orchestrator/test_flows.py +++ b/tests/orchestrator/test_flows.py @@ -14,6 +14,7 @@ _finish_normal, _strip_ack_token, _update_session, + named_session_flow, normal, normal_streaming, ) @@ -697,6 +698,20 @@ async def test_normal_abort_skips_retry(orch: Orchestrator) -> None: assert mock_execute.call_count == 1 # No retry +async def test_topic_abort_skips_recovery(orch: Orchestrator) -> None: + """Topic-scoped /stop returns empty instead of session recovery.""" + key = SessionKey(chat_id=1, topic_id=42) + mock_execute = AsyncMock( + return_value=_mock_response(is_error=True, result="killed", returncode=-9), + ) + object.__setattr__(orch._cli_service, "execute", mock_execute) + orch._process_registry._aborted_topics.add((1, 42)) + + result = await normal(orch, key, "Hello") + assert result.text == "" + assert mock_execute.call_count == 1 + + async def test_streaming_abort_skips_retry(orch: Orchestrator) -> None: """When process is aborted (via /stop), normal_streaming() returns empty instead of retrying.""" mock_exec = AsyncMock(return_value=_mock_response()) @@ -726,6 +741,19 @@ async def test_normal_abort_discards_successful_response(orch: Orchestrator) -> assert result.text == "" +async def test_named_session_topic_abort(orch: Orchestrator) -> None: + """Named-session follow-up honors topic-scoped abort markers.""" + ns = orch._named_sessions.create(1, "claude", "opus", "Setup") + orch._named_sessions.update_after_response(1, ns.name, "sess-named") + mock_execute = AsyncMock(return_value=_mock_response(result="Agent replied")) + object.__setattr__(orch._cli_service, "execute", mock_execute) + orch._process_registry._aborted_topics.add((1, 42)) + + result = await named_session_flow(orch, SessionKey(chat_id=1, topic_id=42), ns.name, "Hello") + assert result.text == "" + assert ns.status == "idle" + + async def test_streaming_abort_discards_successful_response(orch: Orchestrator) -> None: """Even when streaming CLI responds successfully, abort flag causes empty result.""" mock_streaming = AsyncMock(