Skip to content

Commit 0184cfc

Browse files
lapc506claude
andauthored
feat(hooks): add hook system with PreToolUse, PostToolUse, Stop, OnError events (#61) (#79)
HookRegistry with priority-ordered hooks that can BLOCK execution (unlike middleware which wraps). Hooks fire at granular points: before/after tool calls, on session stop, on error. First BLOCK verdict wins; failing hooks are logged but don't block. Modified args propagated through ALLOW results. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ec013af commit 0184cfc

File tree

2 files changed

+329
-0
lines changed

2 files changed

+329
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Hook system for granular execution control.
2+
3+
Hooks complement middleware: while middleware wraps the full request lifecycle,
4+
hooks fire at specific points (before/after tool calls, on session stop, on error)
5+
and can BLOCK execution by returning HookVerdict.BLOCK.
6+
7+
Usage:
8+
registry = HookRegistry()
9+
10+
async def validate_ph_range(ctx: HookContext) -> HookResult:
11+
if ctx.tool_name == "set_ph" and ctx.tool_args.get("value", 7) < 5.5:
12+
return HookResult(verdict=HookVerdict.BLOCK, reason="pH too low")
13+
return HookResult()
14+
15+
registry.register(HookEvent.PRE_TOOL_USE, validate_ph_range)
16+
17+
result = await registry.run(HookContext(
18+
event=HookEvent.PRE_TOOL_USE,
19+
tool_name="set_ph",
20+
tool_args={"value": 4.0},
21+
))
22+
assert result.verdict == HookVerdict.BLOCK
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import logging
28+
from collections import defaultdict
29+
from collections.abc import Awaitable, Callable
30+
from enum import Enum
31+
from typing import Any
32+
33+
from pydantic import BaseModel
34+
35+
logger = logging.getLogger(__name__)
36+
37+
38+
class HookEvent(str, Enum):
39+
PRE_TOOL_USE = "pre_tool_use"
40+
POST_TOOL_USE = "post_tool_use"
41+
SESSION_STOP = "session_stop"
42+
ON_ERROR = "on_error"
43+
44+
45+
class HookVerdict(str, Enum):
46+
ALLOW = "allow"
47+
BLOCK = "block"
48+
49+
50+
class HookContext(BaseModel):
51+
event: HookEvent
52+
session_id: str | None = None
53+
persona_id: str | None = None
54+
tool_name: str | None = None
55+
tool_args: dict[str, Any] = {}
56+
tool_result_output: str | None = None
57+
tool_result_success: bool | None = None
58+
error: str | None = None
59+
60+
61+
class HookResult(BaseModel):
62+
verdict: HookVerdict = HookVerdict.ALLOW
63+
reason: str | None = None
64+
modified_args: dict[str, Any] | None = None
65+
66+
67+
Hook = Callable[[HookContext], Awaitable[HookResult]]
68+
69+
70+
class HookRegistry:
71+
"""Registry for execution hooks. First BLOCK verdict wins."""
72+
73+
def __init__(self) -> None:
74+
self._hooks: dict[HookEvent, list[tuple[int, Hook]]] = defaultdict(list)
75+
76+
def register(
77+
self, event: HookEvent, hook: Hook, *, priority: int = 0,
78+
) -> None:
79+
self._hooks[event].append((priority, hook))
80+
self._hooks[event].sort(key=lambda x: x[0])
81+
82+
def unregister(self, event: HookEvent, hook: Hook) -> None:
83+
self._hooks[event] = [
84+
(p, h) for p, h in self._hooks[event] if h is not hook
85+
]
86+
87+
async def run(self, context: HookContext) -> HookResult:
88+
hooks = self._hooks.get(context.event, [])
89+
last_modified_args: dict[str, Any] | None = None
90+
for _priority, hook in hooks:
91+
try:
92+
result = await hook(context)
93+
if result.verdict == HookVerdict.BLOCK:
94+
logger.info(
95+
"Hook blocked %s: %s", context.event.value, result.reason,
96+
)
97+
return result
98+
if result.modified_args is not None:
99+
last_modified_args = result.modified_args
100+
except Exception:
101+
logger.exception("Hook failed for %s", context.event.value)
102+
return HookResult(modified_args=last_modified_args)
103+
104+
def count(self, event: HookEvent | None = None) -> int:
105+
if event is not None:
106+
return len(self._hooks[event])
107+
return sum(len(hooks) for hooks in self._hooks.values())

tests/unit/test_hooks.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
from __future__ import annotations
2+
3+
from agentic_core.application.hooks import (
4+
Hook,
5+
HookContext,
6+
HookEvent,
7+
HookRegistry,
8+
HookResult,
9+
HookVerdict,
10+
)
11+
12+
13+
async def _allow_hook(ctx: HookContext) -> HookResult:
14+
return HookResult()
15+
16+
17+
async def _block_hook(ctx: HookContext) -> HookResult:
18+
return HookResult(verdict=HookVerdict.BLOCK, reason="blocked by test")
19+
20+
21+
async def _modify_args_hook(ctx: HookContext) -> HookResult:
22+
return HookResult(modified_args={"sanitized": True})
23+
24+
25+
async def _failing_hook(ctx: HookContext) -> HookResult:
26+
raise RuntimeError("hook crashed")
27+
28+
29+
# --- Registry basics ---
30+
31+
32+
def test_empty_registry_count():
33+
reg = HookRegistry()
34+
assert reg.count() == 0
35+
assert reg.count(HookEvent.PRE_TOOL_USE) == 0
36+
37+
38+
def test_register_and_count():
39+
reg = HookRegistry()
40+
reg.register(HookEvent.PRE_TOOL_USE, _allow_hook)
41+
reg.register(HookEvent.POST_TOOL_USE, _allow_hook)
42+
assert reg.count() == 2
43+
assert reg.count(HookEvent.PRE_TOOL_USE) == 1
44+
45+
46+
def test_unregister():
47+
reg = HookRegistry()
48+
reg.register(HookEvent.PRE_TOOL_USE, _allow_hook)
49+
reg.unregister(HookEvent.PRE_TOOL_USE, _allow_hook)
50+
assert reg.count(HookEvent.PRE_TOOL_USE) == 0
51+
52+
53+
# --- Run behavior ---
54+
55+
56+
async def test_run_no_hooks_allows():
57+
reg = HookRegistry()
58+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE, tool_name="test")
59+
result = await reg.run(ctx)
60+
assert result.verdict == HookVerdict.ALLOW
61+
62+
63+
async def test_run_allow_hook():
64+
reg = HookRegistry()
65+
reg.register(HookEvent.PRE_TOOL_USE, _allow_hook)
66+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE, tool_name="test")
67+
result = await reg.run(ctx)
68+
assert result.verdict == HookVerdict.ALLOW
69+
70+
71+
async def test_run_block_hook():
72+
reg = HookRegistry()
73+
reg.register(HookEvent.PRE_TOOL_USE, _block_hook)
74+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE, tool_name="set_ph")
75+
result = await reg.run(ctx)
76+
assert result.verdict == HookVerdict.BLOCK
77+
assert result.reason == "blocked by test"
78+
79+
80+
async def test_block_stops_chain():
81+
"""First BLOCK wins; subsequent hooks don't run."""
82+
call_count = 0
83+
84+
async def counting_hook(ctx: HookContext) -> HookResult:
85+
nonlocal call_count
86+
call_count += 1
87+
return HookResult()
88+
89+
reg = HookRegistry()
90+
reg.register(HookEvent.PRE_TOOL_USE, _block_hook, priority=0)
91+
reg.register(HookEvent.PRE_TOOL_USE, counting_hook, priority=1)
92+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE)
93+
result = await reg.run(ctx)
94+
assert result.verdict == HookVerdict.BLOCK
95+
assert call_count == 0
96+
97+
98+
async def test_allow_then_block():
99+
reg = HookRegistry()
100+
reg.register(HookEvent.PRE_TOOL_USE, _allow_hook, priority=0)
101+
reg.register(HookEvent.PRE_TOOL_USE, _block_hook, priority=1)
102+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE)
103+
result = await reg.run(ctx)
104+
assert result.verdict == HookVerdict.BLOCK
105+
106+
107+
async def test_priority_ordering():
108+
"""Lower priority number runs first."""
109+
order: list[str] = []
110+
111+
async def hook_a(ctx: HookContext) -> HookResult:
112+
order.append("a")
113+
return HookResult()
114+
115+
async def hook_b(ctx: HookContext) -> HookResult:
116+
order.append("b")
117+
return HookResult()
118+
119+
reg = HookRegistry()
120+
reg.register(HookEvent.PRE_TOOL_USE, hook_b, priority=10)
121+
reg.register(HookEvent.PRE_TOOL_USE, hook_a, priority=1)
122+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE)
123+
await reg.run(ctx)
124+
assert order == ["a", "b"]
125+
126+
127+
async def test_failing_hook_does_not_block():
128+
reg = HookRegistry()
129+
reg.register(HookEvent.PRE_TOOL_USE, _failing_hook)
130+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE)
131+
result = await reg.run(ctx)
132+
assert result.verdict == HookVerdict.ALLOW
133+
134+
135+
async def test_modify_args_returned():
136+
reg = HookRegistry()
137+
reg.register(HookEvent.PRE_TOOL_USE, _modify_args_hook)
138+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE, tool_args={"raw": True})
139+
result = await reg.run(ctx)
140+
assert result.verdict == HookVerdict.ALLOW
141+
assert result.modified_args == {"sanitized": True}
142+
143+
144+
# --- Hook events ---
145+
146+
147+
async def test_post_tool_use_event():
148+
reg = HookRegistry()
149+
reg.register(HookEvent.POST_TOOL_USE, _allow_hook)
150+
ctx = HookContext(
151+
event=HookEvent.POST_TOOL_USE,
152+
tool_name="search",
153+
tool_result_success=True,
154+
tool_result_output="found 3 results",
155+
)
156+
result = await reg.run(ctx)
157+
assert result.verdict == HookVerdict.ALLOW
158+
159+
160+
async def test_session_stop_event():
161+
reg = HookRegistry()
162+
reg.register(HookEvent.SESSION_STOP, _allow_hook)
163+
ctx = HookContext(
164+
event=HookEvent.SESSION_STOP,
165+
session_id="sess_123",
166+
)
167+
result = await reg.run(ctx)
168+
assert result.verdict == HookVerdict.ALLOW
169+
170+
171+
async def test_on_error_event():
172+
reg = HookRegistry()
173+
reg.register(HookEvent.ON_ERROR, _allow_hook)
174+
ctx = HookContext(
175+
event=HookEvent.ON_ERROR,
176+
error="Connection timeout",
177+
)
178+
result = await reg.run(ctx)
179+
assert result.verdict == HookVerdict.ALLOW
180+
181+
182+
async def test_hooks_scoped_to_event():
183+
"""Hooks registered for one event don't fire for another."""
184+
reg = HookRegistry()
185+
reg.register(HookEvent.PRE_TOOL_USE, _block_hook)
186+
ctx = HookContext(event=HookEvent.POST_TOOL_USE)
187+
result = await reg.run(ctx)
188+
assert result.verdict == HookVerdict.ALLOW
189+
190+
191+
# --- HookContext ---
192+
193+
194+
def test_hook_context_defaults():
195+
ctx = HookContext(event=HookEvent.PRE_TOOL_USE)
196+
assert ctx.session_id is None
197+
assert ctx.persona_id is None
198+
assert ctx.tool_name is None
199+
assert ctx.tool_args == {}
200+
assert ctx.error is None
201+
202+
203+
def test_hook_context_full():
204+
ctx = HookContext(
205+
event=HookEvent.PRE_TOOL_USE,
206+
session_id="s1",
207+
persona_id="support-agent",
208+
tool_name="mcp_stripe_create_charge",
209+
tool_args={"amount": 1000},
210+
)
211+
assert ctx.tool_name == "mcp_stripe_create_charge"
212+
assert ctx.tool_args["amount"] == 1000
213+
214+
215+
# --- HookResult ---
216+
217+
218+
def test_hook_result_defaults():
219+
r = HookResult()
220+
assert r.verdict == HookVerdict.ALLOW
221+
assert r.reason is None
222+
assert r.modified_args is None

0 commit comments

Comments
 (0)