Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Gradata/src/gradata/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,14 @@ def memory(self):
return self._memory_manager

def close(self):
"""Cleanup: re-encrypt database if encryption is enabled."""
"""Cleanup: drain EventBus and re-encrypt database if encryption is enabled."""
bus = getattr(self, "bus", None)
if bus is not None:
try:
bus.close()
except Exception:
import logging as _l
_l.getLogger(__name__).exception("EventBus close failed")
if self._encryption_key:
from gradata._encryption import close_encrypted_db

Expand Down
84 changes: 81 additions & 3 deletions Gradata/src/gradata/events_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

from __future__ import annotations

import atexit
import logging
import threading
import weakref
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any
Expand All @@ -26,18 +28,60 @@
MAX_LISTENERS_PER_EVENT = 50
MAX_ASYNC_WORKERS = 4

# Track all live EventBus instances so atexit can drain them. WeakSet so
# normal close() / GC doesn't keep them alive forever.
_LIVE_BUSES: weakref.WeakSet[EventBus] = weakref.WeakSet()
_ATEXIT_REGISTERED = False
_ATEXIT_LOCK = threading.Lock()


def _drain_all_buses_atexit() -> None:
"""Close any EventBus that survived to interpreter shutdown."""
for bus in list(_LIVE_BUSES):
try:
bus.close(timeout=2.0)
except Exception:
logger.exception("EventBus atexit drain failed for %r", bus)


def _ensure_atexit_registered() -> None:
global _ATEXIT_REGISTERED
if _ATEXIT_REGISTERED:
return
with _ATEXIT_LOCK:
if _ATEXIT_REGISTERED:
return
atexit.register(_drain_all_buses_atexit)
_ATEXIT_REGISTERED = True


class EventBus:
"""In-process event bus with bounded listeners, thread safety, and thread pool."""
"""In-process event bus with bounded listeners, thread safety, and thread pool.

Lifecycle:
bus = EventBus()
bus.on("evt", handler)
bus.emit("evt", payload)
bus.close() # explicit shutdown — drains async work, rejects new

Workers are atexit-registered so background threads cannot outlive the
process even if a caller forgets to close().
"""

def __init__(self) -> None:
self.listeners: dict[str, list[tuple[Callable, bool]]] = defaultdict(list)
self._pool = ThreadPoolExecutor(max_workers=MAX_ASYNC_WORKERS, thread_name_prefix="gradata-bus")
self._lock = threading.Lock()
self._closed = False
_LIVE_BUSES.add(self)
_ensure_atexit_registered()

def on(self, event: str, handler: Callable, async_handler: bool = False) -> None:
"""Subscribe *handler* to *event*. Deduplicates and bounds per event."""
with self._lock:
if self._closed:
logger.warning("EventBus.on() on closed bus; ignoring %r", event)
return
entries = self.listeners[event]
if any(h is handler for h, _ in entries):
return
Expand All @@ -53,15 +97,49 @@ def off(self, event: str, handler: Callable) -> None:
self.listeners[event] = [(h, a) for h, a in entries if h is not handler]

def emit(self, event: str, payload: Any = None) -> None:
"""Emit *event* with *payload*. Errors are logged, never raised."""
"""Emit *event* with *payload*. Errors are logged, never raised.

After close(), emit() is a no-op (logged at DEBUG). This prevents
late-shutdown handlers from raising RuntimeError on the dead pool.
"""
with self._lock:
if self._closed:
logger.debug("EventBus.emit(%r) after close — dropped", event)
return
handlers = list(self.listeners.get(event, []))
for handler, is_async in handlers:
if is_async:
self._pool.submit(self._safe_call, handler, payload)
try:
self._pool.submit(self._safe_call, handler, payload)
except RuntimeError:
# Pool was shut down between the lock check and submit.
logger.debug("EventBus async submit after shutdown — dropped")
else:
self._safe_call(handler, payload)

def close(self, timeout: float | None = None) -> None:
"""Drain async handlers and reject further work. Idempotent.

Subsequent emit() / on() calls become no-ops.
"""
with self._lock:
if self._closed:
return
self._closed = True
self.listeners.clear()
pool = self._pool
pool.shutdown(wait=True, cancel_futures=False)
if timeout is not None:
# Best-effort: ThreadPoolExecutor has no per-call timeout, but
# workers should already be drained by shutdown(wait=True). If
# any thread is still alive after the wait, log it.
for t in threading.enumerate():
if t.name.startswith("gradata-bus") and t.is_alive():
t.join(timeout=timeout)

# Backwards compat alias.
shutdown = close

@staticmethod
def _safe_call(handler: Callable, payload: Any) -> None:
try:
Expand Down
95 changes: 95 additions & 0 deletions Gradata/tests/test_eventbus_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""EventBus lifecycle and thread-safety regressions."""

from __future__ import annotations

import gc
import os
import threading
import time

from gradata import Brain
from gradata.events_bus import EventBus


def _bus_threads() -> list[threading.Thread]:
return [t for t in threading.enumerate() if t.name.startswith("gradata-bus")]


def test_subscribe_unsubscribe_under_concurrent_load() -> None:
bus = EventBus()
calls = 0
calls_lock = threading.Lock()

def handler(payload: object) -> None:
nonlocal calls
with calls_lock:
calls += 1

def worker() -> None:
for _ in range(200):
bus.on("evt", handler)
bus.emit("evt", {})
bus.off("evt", handler)

threads = [threading.Thread(target=worker) for _ in range(16)]
for thread in threads:
thread.start()
for thread in threads:
thread.join(timeout=5)

assert all(not thread.is_alive() for thread in threads)
assert calls >= 1
bus.close()
assert _bus_threads() == []


def test_eventbus_close_waits_for_executor_and_rejects_late_work() -> None:
bus = EventBus()
finished = threading.Event()
late_calls = 0

def async_handler(payload: object) -> None:
time.sleep(0.01)
finished.set()

def late_handler(payload: object) -> None:
nonlocal late_calls
late_calls += 1

bus.on("evt", async_handler, async_handler=True)
bus.emit("evt", {})
bus.close()
bus.on("evt", late_handler)
bus.emit("evt", {})

assert finished.is_set()
assert late_calls == 0
assert _bus_threads() == []


def test_brain_close_cleans_eventbus_executor_across_many_cycles(tmp_path) -> None:
before = threading.active_count()

for idx in range(100):
brain_dir = tmp_path / f"brain-{idx}"
os.environ["BRAIN_DIR"] = str(brain_dir)
brain = Brain.init(
brain_dir,
name=f"Lifecycle {idx}",
domain="Testing",
embedding="local",
interactive=False,
)
done = threading.Event()
brain.bus.on("evt", lambda payload, done=done: done.set(), async_handler=True)
brain.bus.emit("evt", {})
brain.close()
assert done.is_set()

gc.collect()
deadline = time.time() + 5
while _bus_threads() and time.time() < deadline:
time.sleep(0.01)

assert _bus_threads() == []
assert threading.active_count() <= before + 2
120 changes: 120 additions & 0 deletions Gradata/tests/test_layer_enforcement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Architecture-layer import guard for src/gradata."""

from __future__ import annotations

import ast
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1] / "src" / "gradata"

LAYER_2_ROOT = {"brain.py", "cli.py", "daemon.py", "mcp_server.py"}
LAYER_2_DIRS = {"middleware", "integrations"}
LAYER_1_DIRS = {"enhancements", "rules"}

ALLOWED_UPWARD_IMPORTS = {
("__init__.py", "gradata.brain"): "PUBLIC BARREL: documented top-level Brain export.",
("__init__.py", "gradata.enhancements.self_improvement"): "PUBLIC BARREL: graduate / parse_lessons / format_lessons / update_confidence are documented public helpers.",
("_core.py", "gradata.enhancements.behavioral_extractor"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.causal_chains"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.dedup"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.diff_engine"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.edit_classifier"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.instruction_cache"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.meta_rules"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.meta_rules_storage"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.metrics"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.pattern_extractor"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.pattern_integration"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.rule_canary"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.self_healing"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.self_improvement"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.self_improvement._confidence"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_core.py", "gradata.enhancements.similarity"): "DEFERRED: _core is delegated Brain behavior; moving it is >50 lines.",
("_mine_transcripts.py", "gradata.brain"): "LAZY-IMPORT-OK: CLI commit path opens Brain only when writing mined events.",
("_mine_transcripts.py", "gradata.enhancements.meta_rules_storage"): "LAZY-IMPORT-OK: CLI graduation bridge imports storage only on commit.",
("_scoped_brain.py", "gradata.rules.rule_engine"): "LAZY-IMPORT-OK: scoped rule injection imports ranking only when injecting.",
("contrib/patterns/evaluator.py", "gradata.rules.rule_context"): "LAZY-IMPORT-OK: graduated-rule adapter imports context on demand.",
("contrib/patterns/guardrails.py", "gradata.rules.rule_context"): "LAZY-IMPORT-OK: graduated-rule adapter imports context on demand.",
("contrib/patterns/orchestrator.py", "gradata.rules.scope"): "LAZY-IMPORT-OK: request classification imports scope on demand.",
("contrib/patterns/reflection.py", "gradata.rules.rule_context"): "LAZY-IMPORT-OK: graduated-rule adapter imports context on demand.",
}


def _layer_for(path: Path) -> int | None:
rel = path.relative_to(ROOT)
parts = rel.parts
if len(parts) == 1 and parts[0] in LAYER_2_ROOT:
return 2
if parts[0] in LAYER_2_DIRS:
return 2
if parts[0] in LAYER_1_DIRS:
return 1
if len(parts) >= 2 and parts[:2] == ("contrib", "patterns"):
return 0
if len(parts) == 1 and parts[0].startswith("_"):
return 0
return None


def _module_path(module: str) -> Path | None:
if not module.startswith("gradata"):
return None
parts = module.split(".")[1:]
if not parts:
return ROOT / "__init__.py"
module_file = ROOT.joinpath(*parts).with_suffix(".py")
if module_file.exists():
return module_file
package_init = ROOT.joinpath(*parts) / "__init__.py"
if package_init.exists():
return package_init
return None


def _inside_type_checking(node: ast.AST) -> bool:
parent = getattr(node, "_parent", None)
while parent is not None:
if isinstance(parent, ast.If) and ast.unparse(parent.test) == "TYPE_CHECKING":
return True
parent = getattr(parent, "_parent", None)
return False


def _absolute_imports(tree: ast.AST) -> list[tuple[int, str]]:
imports: list[tuple[int, str]] = []
for node in ast.walk(tree):
if _inside_type_checking(node):
continue
if isinstance(node, ast.Import):
imports.extend((node.lineno, alias.name) for alias in node.names)
elif isinstance(node, ast.ImportFrom) and node.level == 0 and node.module:
imports.append((node.lineno, node.module))
return imports


def test_no_unclassified_upward_layer_imports() -> None:
failures: list[str] = []

for path in sorted(ROOT.rglob("*.py")):
source_layer = _layer_for(path)
if source_layer is None:
continue

tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
for parent in ast.walk(tree):
for child in ast.iter_child_nodes(parent):
child._parent = parent # type: ignore[attr-defined]

for line, module in _absolute_imports(tree):
target = _module_path(module)
if target is None:
continue
target_layer = _layer_for(target)
if target_layer is None or target_layer <= source_layer:
continue

rel = path.relative_to(ROOT).as_posix()
if (rel, module) not in ALLOWED_UPWARD_IMPORTS:
failures.append(f"{rel}:{line} L{source_layer}->L{target_layer} imports {module}")

assert failures == []
Loading
Loading