diff --git a/autogen/events/base_event.py b/autogen/events/base_event.py index 0569a4f5e7b..4308a5c83c0 100644 --- a/autogen/events/base_event.py +++ b/autogen/events/base_event.py @@ -5,7 +5,7 @@ from abc import ABC from collections.abc import Callable -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Any, ClassVar, Literal, Union from uuid import UUID, uuid4 from pydantic import BaseModel, Field, create_model @@ -31,6 +31,19 @@ def print(self, f: Callable[..., Any] | None = None) -> None: """ ... + _hooks: ClassVar[dict[type["BaseEvent"], list[Callable]]] = {} + + @classmethod + def register_hook(cls, event_type, func): + cls._hooks.setdefault(event_type, []).append(func) + + @classmethod + def trigger_hook(cls, event): + for base, funcs in cls._hooks.items(): + if isinstance(event, base): + for f in funcs: + f(event.content) + def camel2snake(name: str) -> str: return "".join(["_" + i.lower() if i.isupper() else i for i in name]).lstrip("_") diff --git a/test/events/test_base_event.py b/test/events/test_base_event.py index 2f7fb7a8d67..ee2ab8c9b12 100644 --- a/test/events/test_base_event.py +++ b/test/events/test_base_event.py @@ -71,3 +71,62 @@ class TestSingleContentParameterEvent(BaseEvent): model = TestSingleContentParameterEvent(**expected) assert model.model_dump() == expected + + +class TestBaseEventHooks: + def test_register_and_trigger_hook(self, TestEvent: type[BaseEvent], uuid: UUID) -> None: + captured = [] + + def hook(event: TestEvent) -> None: + captured.append(event.content) + + BaseEvent.register_hook(TestEvent, hook) + event = TestEvent(uuid=uuid, sender="alice", receiver="bob", content="hello") + + BaseEvent.trigger_hook(event) + + assert captured == ["hello"] + + def test_multiple_hooks(self, TestEvent: type[BaseEvent], uuid: UUID) -> None: + captured = [] + + def hook1(event: TestEvent) -> None: + captured.append("hook1:" + event.content) + + def hook2(event: TestEvent) -> None: + captured.append("hook2:" + event.content) + + BaseEvent.register_hook(TestEvent, hook1) + BaseEvent.register_hook(TestEvent, hook2) + + event = TestEvent(uuid=uuid, sender="alice", receiver="bob", content="hello") + BaseEvent.trigger_hook(event) + + assert captured == ["hook1:hello", "hook2:hello"] + + def test_hooks_are_isolated_by_event_type(self, uuid: UUID) -> None: + captured_a = [] + captured_b = [] + + @wrap_event + class EventAEvent(BaseEvent): + content: str + + @wrap_event + class EventBEvent(BaseEvent): + content: str + + def hook_a(event: EventAEvent) -> None: + captured_a.append("A:" + event.content) + + def hook_b(event: EventBEvent) -> None: + captured_b.append("B:" + event.content) + + BaseEvent.register_hook(EventAEvent, hook_a) + BaseEvent.register_hook(EventBEvent, hook_b) + + BaseEvent.trigger_hook(EventAEvent(uuid=uuid, content="foo")) + BaseEvent.trigger_hook(EventBEvent(uuid=uuid, content="bar")) + + assert captured_a == ["A:foo"] + assert captured_b == ["B:bar"]