Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions ductor_bot/cli/process_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ProcessRegistry:
def __init__(self) -> None:
self._processes: dict[int, list[TrackedProcess]] = {}
self._aborted: set[int] = set()
self._aborted_topics: set[tuple[int, int | None]] = set()
self._aborted_labels: set[tuple[int, str]] = set()
self._interrupted: set[int] = set()
# MED #9: serialize bulk kill operations (kill_for_task / kill_stale /
Expand Down Expand Up @@ -102,6 +103,30 @@ async def kill_all(self, chat_id: int) -> int:
return 0
return await _kill_processes(entries)

async def kill_by_chat_topic(self, chat_id: int, topic_id: int | None) -> int:
"""Kill processes belonging to *chat_id* + *topic_id*. Returns count.

Used by ``/stop`` so a stop in one topic does not affect another
topic in the same chat. ``topic_id=None`` matches processes
registered without a topic (e.g. private chats).
"""
async with self._kill_lock:
entries = self._processes.get(chat_id, [])
targets = [
t for t in entries if t.topic_id == topic_id and t.process.returncode is None
]
if not targets:
return 0
self._aborted_topics.add((chat_id, topic_id))
remaining = [
t for t in entries if t.topic_id != topic_id or t.process.returncode is not None
]
if remaining:
self._processes[chat_id] = remaining
else:
self._processes.pop(chat_id, None)
return await _kill_processes(targets)

async def kill_all_active(self) -> int:
"""Kill active processes across all chats. Returns total count killed."""
total = 0
Expand All @@ -117,6 +142,14 @@ def clear_abort(self, chat_id: int) -> None:
"""Clear the abort flag for *chat_id*."""
self._aborted.discard(chat_id)

def was_aborted_topic(self, chat_id: int, topic_id: int | None) -> bool:
"""Check whether *(chat_id, topic_id)* has been aborted since last clear."""
return (chat_id, topic_id) in self._aborted_topics

def clear_topic_abort(self, chat_id: int, topic_id: int | None) -> None:
"""Clear the topic-scoped abort flag (called after handling the abort)."""
self._aborted_topics.discard((chat_id, topic_id))

def was_interrupted(self, chat_id: int) -> bool:
"""Check whether *chat_id* was soft-interrupted since last clear."""
return chat_id in self._interrupted
Expand Down
8 changes: 6 additions & 2 deletions ductor_bot/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ async def execute_streaming( # noqa: PLR0913
timeout_seconds=request.timeout_seconds,
timeout_controller=request.timeout_controller,
):
if self._process_registry.was_aborted(request.chat_id):
if self._process_registry.was_aborted(
request.chat_id
) or self._process_registry.was_aborted_topic(request.chat_id, request.topic_id):
logger.info("Streaming aborted mid-stream chat=%d", request.chat_id)
break
text, result = await callbacks.dispatch(event)
Expand Down Expand Up @@ -283,7 +285,9 @@ async def _handle_stream_fallback(
init_session_id: str | None = None,
) -> AgentResponse:
"""Handle failed or incomplete streaming: use accumulated text or retry."""
was_aborted = self._process_registry.was_aborted(request.chat_id)
was_aborted = self._process_registry.was_aborted(
request.chat_id
) or self._process_registry.was_aborted_topic(request.chat_id, request.topic_id)
logger.info(
"Stream fallback: aborted=%s accumulated=%d init_sid=%s",
was_aborted,
Expand Down
7 changes: 4 additions & 3 deletions ductor_bot/messenger/telegram/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ async def handle_abort(
chat_id: int,
message: Message,
) -> bool:
"""Kill active CLI processes and send feedback.
"""Kill active CLI processes in the current topic and send feedback.

Returns True if handled, False if orchestrator not ready.
"""
if orchestrator is None:
return False

killed = await orchestrator.abort(chat_id)
logger.info("Abort requested killed=%d", killed)
thread_id = get_thread_id(message)
killed = await orchestrator.abort(chat_id, topic_id=thread_id)
logger.info("Abort requested chat=%d topic=%s killed=%d", chat_id, thread_id, killed)
text = stop_text(bool(killed), orchestrator.active_provider_name)
await send_rich(
bot,
Expand Down
19 changes: 17 additions & 2 deletions ductor_bot/orchestrator/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ async def handle_message_streaming( # noqa: PLR0913

async def _handle_message_impl(self, dispatch: _MessageDispatch) -> OrchestratorResult:
self._process_registry.clear_abort(dispatch.key.chat_id)
self._process_registry.clear_topic_abort(dispatch.key.chat_id, dispatch.key.topic_id)
logger.info("Message received text=%s", dispatch.cmd[:80])

patterns = detect_suspicious_patterns(dispatch.text)
Expand Down Expand Up @@ -464,8 +465,22 @@ async def reset_active_provider_session(self, key: SessionKey) -> str:
logger.info("Active provider session reset provider=%s model=%s", provider, model)
return provider

async def abort(self, chat_id: int) -> int:
"""Kill all active CLI processes and background tasks for chat_id."""
async def abort(self, chat_id: int, topic_id: int | None = None) -> int:
"""Kill active CLI processes for *chat_id* (optionally scoped to *topic_id*).

When ``topic_id`` is provided (``/stop`` from a specific topic),
only the foreground CLI processes registered under that
``(chat_id, topic_id)`` pair are killed. Background tasks and
named sessions are left alone — they are not topic-tagged in
the current model and have their own management surfaces
(``/tasks``, ``/sessions``) so /stop should not double up.

When ``topic_id`` is ``None`` (legacy callers / ``/stop_all``)
the chat-wide sweep runs as before: every process for the chat
plus every background task and named session.
"""
if topic_id is not None:
return await self._process_registry.kill_by_chat_topic(chat_id, topic_id)
killed = await self._process_registry.kill_all(chat_id)
if self._observers.background:
killed += await self._observers.background.cancel_all(chat_id)
Expand Down
25 changes: 21 additions & 4 deletions ductor_bot/orchestrator/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ async def _maybe_recover_session( # noqa: PLR0913
_reg = orch._process_registry
if (
_reg.was_aborted(key.chat_id)
or _reg.was_aborted_topic(key.chat_id, key.topic_id)
or _reg.was_interrupted(key.chat_id)
or not _needs_session_recovery(response)
):
Expand Down Expand Up @@ -487,7 +488,11 @@ async def normal( # noqa: PLR0911
request, session, response = outcome.request, outcome.session, outcome.response
session_recovered = outcome.session_recovered
_reg = orch._process_registry
if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id):
if (
_reg.was_aborted(key.chat_id)
or _reg.was_aborted_topic(key.chat_id, key.topic_id)
or _reg.was_interrupted(key.chat_id)
):
_reg.clear_interrupt(key.chat_id)
await _preserve_session_from_response(orch, session, response, reason="abort")
logger.info("Normal flow aborted/interrupted by user")
Expand Down Expand Up @@ -567,7 +572,11 @@ async def _on_compact() -> None:
return outcome.failed_result
request, session, response = outcome.request, outcome.session, outcome.response
_reg = orch._process_registry
if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id):
if (
_reg.was_aborted(key.chat_id)
or _reg.was_aborted_topic(key.chat_id, key.topic_id)
or _reg.was_interrupted(key.chat_id)
):
_reg.clear_interrupt(key.chat_id)
await _preserve_session_from_response(orch, session, response, reason="abort")
logger.info("Streaming flow aborted/interrupted by user")
Expand Down Expand Up @@ -735,7 +744,11 @@ async def named_session_flow(
response = await orch._cli_service.execute(request)

_reg = orch._process_registry
if _reg.was_aborted(key.chat_id) or _reg.was_interrupted(key.chat_id):
if (
_reg.was_aborted(key.chat_id)
or _reg.was_aborted_topic(key.chat_id, key.topic_id)
or _reg.was_interrupted(key.chat_id)
):
_reg.clear_interrupt(key.chat_id)
ns.status = "idle"
return OrchestratorResult(text="")
Expand Down Expand Up @@ -799,7 +812,11 @@ async def _tagged_text_delta(chunk: str) -> None:
)

_reg2 = orch._process_registry
if _reg2.was_aborted(key.chat_id) or _reg2.was_interrupted(key.chat_id):
if (
_reg2.was_aborted(key.chat_id)
or _reg2.was_aborted_topic(key.chat_id, key.topic_id)
or _reg2.was_interrupted(key.chat_id)
):
_reg2.clear_interrupt(key.chat_id)
ns.status = "idle"
return OrchestratorResult(text="")
Expand Down
58 changes: 58 additions & 0 deletions tests/cli/test_process_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,64 @@ def test_has_active_topic_id_ignores_exited() -> None:
assert reg.has_active(1, topic_id=20) is True


class TestKillByChatTopicAbortMarker:
"""Topic-scoped /stop marker behavior."""

async def test_sets_topic_abort_marker(self) -> None:
reg = ProcessRegistry()
proc = _mock_process(pid=80)
reg.register(chat_id=1, process=proc, label="main", topic_id=10)

with patch(
"ductor_bot.cli.process_registry._kill_processes",
new_callable=AsyncMock,
return_value=1,
):
killed = await reg.kill_by_chat_topic(1, 10)

assert killed == 1
assert reg.was_aborted_topic(1, 10) is True
assert reg.was_aborted(1) is False

async def test_no_kill_no_marker(self) -> None:
reg = ProcessRegistry()
proc = _mock_process(pid=81)
reg.register(chat_id=1, process=proc, label="main", topic_id=10)

killed = await reg.kill_by_chat_topic(1, 20)

assert killed == 0
assert reg.was_aborted_topic(1, 20) is False
assert reg.was_aborted_topic(1, 10) is False

def test_clear_topic_abort_removes_marker(self) -> None:
reg = ProcessRegistry()
reg._aborted_topics.add((1, 10))

assert reg.was_aborted_topic(1, 10) is True
reg.clear_topic_abort(1, 10)
assert reg.was_aborted_topic(1, 10) is False

async def test_topic_a_kill_does_not_kill_topic_b(self) -> None:
reg = ProcessRegistry()
proc_a = _mock_process(pid=82)
proc_b = _mock_process(pid=83)
reg.register(chat_id=1, process=proc_a, label="main", topic_id=10)
reg.register(chat_id=1, process=proc_b, label="main", topic_id=20)

with patch(
"ductor_bot.cli.process_registry._kill_processes",
new_callable=AsyncMock,
return_value=1,
):
killed = await reg.kill_by_chat_topic(1, 10)

assert killed == 1
assert reg.has_active(1, topic_id=10) is False
assert reg.has_active(1, topic_id=20) is True
assert proc_b.returncode is None


async def test_kill_stale_handles_already_exited() -> None:
reg = ProcessRegistry()
proc = _mock_process(pid=40, returncode=0)
Expand Down
15 changes: 14 additions & 1 deletion tests/messenger/telegram/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def test_abort_kills_processes_and_replies(self) -> None:
msg = _make_message(chat_id=42)
result = await handle_abort(orchestrator, bot, chat_id=42, message=msg)
assert result is True
orchestrator.abort.assert_called_once_with(42)
orchestrator.abort.assert_called_once_with(42, topic_id=None)

async def test_abort_no_orchestrator(self) -> None:
from ductor_bot.messenger.telegram.handlers import handle_abort
Expand Down Expand Up @@ -187,6 +187,19 @@ def test_none_username(self) -> None:
class TestForumTopicPropagation:
"""Test that handlers extract and propagate thread_id."""

@patch("ductor_bot.messenger.telegram.handlers.send_rich", new_callable=AsyncMock)
async def test_abort_entrypoint_passes_topic_id(self, _mock_send: AsyncMock) -> None:
from ductor_bot.messenger.telegram.handlers import handle_abort

orchestrator = MagicMock()
orchestrator.abort = AsyncMock(return_value=1)
orchestrator.active_provider_name = "claude"
bot = MagicMock()
msg = _make_message(chat_id=42, topic_thread_id=99)

await handle_abort(orchestrator, bot, chat_id=42, message=msg)
orchestrator.abort.assert_called_once_with(42, topic_id=99)

@patch("ductor_bot.messenger.telegram.handlers.send_rich", new_callable=AsyncMock)
async def test_handle_abort_passes_thread_id(self, mock_send: AsyncMock) -> None:
from ductor_bot.messenger.telegram.handlers import handle_abort
Expand Down
28 changes: 28 additions & 0 deletions tests/orchestrator/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_finish_normal,
_strip_ack_token,
_update_session,
named_session_flow,
normal,
normal_streaming,
)
Expand Down Expand Up @@ -697,6 +698,20 @@ async def test_normal_abort_skips_retry(orch: Orchestrator) -> None:
assert mock_execute.call_count == 1 # No retry


async def test_topic_abort_skips_recovery(orch: Orchestrator) -> None:
"""Topic-scoped /stop returns empty instead of session recovery."""
key = SessionKey(chat_id=1, topic_id=42)
mock_execute = AsyncMock(
return_value=_mock_response(is_error=True, result="killed", returncode=-9),
)
object.__setattr__(orch._cli_service, "execute", mock_execute)
orch._process_registry._aborted_topics.add((1, 42))

result = await normal(orch, key, "Hello")
assert result.text == ""
assert mock_execute.call_count == 1


async def test_streaming_abort_skips_retry(orch: Orchestrator) -> None:
"""When process is aborted (via /stop), normal_streaming() returns empty instead of retrying."""
mock_exec = AsyncMock(return_value=_mock_response())
Expand Down Expand Up @@ -726,6 +741,19 @@ async def test_normal_abort_discards_successful_response(orch: Orchestrator) ->
assert result.text == ""


async def test_named_session_topic_abort(orch: Orchestrator) -> None:
"""Named-session follow-up honors topic-scoped abort markers."""
ns = orch._named_sessions.create(1, "claude", "opus", "Setup")
orch._named_sessions.update_after_response(1, ns.name, "sess-named")
mock_execute = AsyncMock(return_value=_mock_response(result="Agent replied"))
object.__setattr__(orch._cli_service, "execute", mock_execute)
orch._process_registry._aborted_topics.add((1, 42))

result = await named_session_flow(orch, SessionKey(chat_id=1, topic_id=42), ns.name, "Hello")
assert result.text == ""
assert ns.status == "idle"


async def test_streaming_abort_discards_successful_response(orch: Orchestrator) -> None:
"""Even when streaming CLI responds successfully, abort flag causes empty result."""
mock_streaming = AsyncMock(
Expand Down