diff --git a/src/openarmature/graph/compiled.py b/src/openarmature/graph/compiled.py index c830e16..2d24567 100644 --- a/src/openarmature/graph/compiled.py +++ b/src/openarmature/graph/compiled.py @@ -64,12 +64,14 @@ _reset_invocation_id, _reset_namespace_prefix, _set_active_dispatch, + _set_active_observer_span, _set_active_observers, _set_attempt_index, _set_correlation_id, _set_fan_out_index, _set_invocation_id, _set_namespace_prefix, + current_active_observer_span, ) from .edges import END, ConditionalEdge, EndSentinel, StaticEdge @@ -99,6 +101,54 @@ from .state import State from .subgraph import SubgraphNode +# Try-import OpenTelemetry attach primitives so the engine can splice an +# observer-published span into the OTel context for the duration of a +# node body. The engine treats the span value opaquely (writes by an +# observer's ``prepare_sync``, reads via ``current_active_observer_span``) +# and only touches OTel when both: (a) the extras are installed, and +# (b) an observer actually published a span. Installs without ``[otel]`` +# get a no-op attach/detach pair; the observer ContextVar stays +# ``None`` and nothing changes. +# +# The names are bound to ``None`` in the except branch so pyright +# narrows correctly at call sites (``if _otel_attach is None: ...``) +# rather than flagging "possibly unbound." +try: + from opentelemetry.context import attach as _otel_attach + from opentelemetry.context import detach as _otel_detach + from opentelemetry.trace.propagation import set_span_in_context as _otel_set_span_in_context +except ImportError: # pragma: no cover — exercised only in non-otel installs + _otel_attach = None # type: ignore[assignment] + _otel_detach = None # type: ignore[assignment] + _otel_set_span_in_context = None # type: ignore[assignment] + + +def _attach_active_observer_span() -> object | None: + """Read ``current_active_observer_span``; if an observer published + one and OTel is installed, attach the span into the OTel context + so that any logs emitted from the next user-code scope (a node + body) pick up the right ``trace_id``/``span_id`` via OTel's + ``LoggingHandler``. + + Returns the OTel context token to hand back to + :func:`_detach_active_observer_span` in ``finally``, or ``None`` + if no attach happened (no observer, no OTel, or both). + """ + if _otel_attach is None or _otel_set_span_in_context is None: + return None + span = current_active_observer_span() + if span is None: + return None + return _otel_attach(_otel_set_span_in_context(cast("Any", span))) + + +def _detach_active_observer_span(token: object | None) -> None: + """Pair to :func:`_attach_active_observer_span`. No-op when no + attach was performed (token is ``None``).""" + if token is None or _otel_detach is None: + return + _otel_detach(cast("Any", token)) + def _merge_partial[StateT: State]( prior: StateT, @@ -690,20 +740,35 @@ async def innermost(s: Any) -> Mapping[str, Any]: try: self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) + # Splice the observer-published span (if any) into the + # OTel context so logs emitted from the FIRST line of + # the node body — before any ``await`` — pick up the + # right trace_id/span_id via OTel's LoggingHandler. + # Detach in ``finally`` so retries / merge / completed + # dispatch don't run with the span still active, and + # clear ``current_active_observer_span`` to ``None`` so + # the next dispatch that raises or early-returns from + # ``prepare_sync`` can't reveal this node's span as a + # stale value to the engine's read. + otel_token = _attach_active_observer_span() try: - partial = await node.run(s) - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=wrapped, - attempt_index=attempt_index, - ) - raise + try: + partial = await node.run(s) + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + ) + raise + finally: + _detach_active_observer_span(otel_token) + _set_active_observer_span(None) try: merged = _merge_partial(s, partial, self.reducers, current) @@ -1045,38 +1110,55 @@ async def innermost(s: Any) -> Mapping[str, Any]: attempt_index=attempt_index, fan_out_config=fan_out_event_config, ) + # Same OTel attach pattern as ``_step_function_node``'s + # ``innermost`` — splice the observer-published span + # into the OTel context so logs emitted from inside + # the fan-out node's own scope (middleware bodies, + # the dispatch machinery) carry the right + # trace_id/span_id. Per-instance bodies get their own + # attach inside their ``_step_function_node`` + # innermost when the recursive invocation hits leaf + # nodes. ``finally`` clears the ContextVar so a later + # dispatch whose ``prepare_sync`` raises or early- + # returns can't reveal this fan-out's span as a stale + # value to the engine's read. + otel_token = _attach_active_observer_span() try: - partial = await node.run_with_context( - s, - context, - pre_resolved_count=item_count, - pre_resolved_concurrency=(concurrency_resolved,), - ) - except RuntimeGraphError as e: - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=e, - attempt_index=attempt_index, - fan_out_config=fan_out_event_config, - ) - raise - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=wrapped, - attempt_index=attempt_index, - fan_out_config=fan_out_event_config, - ) - raise wrapped from e + try: + partial = await node.run_with_context( + s, + context, + pre_resolved_count=item_count, + pre_resolved_concurrency=(concurrency_resolved,), + ) + except RuntimeGraphError as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + fan_out_config=fan_out_event_config, + ) + raise + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + fan_out_config=fan_out_event_config, + ) + raise wrapped from e + finally: + _detach_active_observer_span(otel_token) + _set_active_observer_span(None) try: merged = _merge_partial(s, partial, self.reducers, current) diff --git a/src/openarmature/graph/observer.py b/src/openarmature/graph/observer.py index cbcf9ea..3cb1924 100644 --- a/src/openarmature/graph/observer.py +++ b/src/openarmature/graph/observer.py @@ -28,6 +28,7 @@ from __future__ import annotations import asyncio +import inspect import warnings from collections.abc import Iterable from dataclasses import dataclass, field @@ -60,6 +61,25 @@ async def log_observer(event: NodeEvent) -> None: The event parameter is positional-only (`event, /`) so structural conformance doesn't pin you to that name — any of `event`, `_event`, `e`, etc. matches. + + Optional ``prepare_sync`` extension + ----------------------------------- + An observer MAY additionally define a synchronous method:: + + def prepare_sync(self, event: NodeEvent, /) -> None: ... + + that the engine calls IN THE ENGINE TASK, BEFORE queueing the + event for the async ``__call__``. This exists for observers that + need to set up state — e.g., open a span and stash a handle in + a ContextVar — that the engine itself must read synchronously + before running the node body (otherwise logs emitted on the + first line of the body wouldn't see the right span). + + ``prepare_sync`` is **opt-in via ``hasattr``** — no subclass or + Protocol method required. Observers that don't define it skip + the synchronous prep entirely; observers that do define it run + only for ``"started"``-phase events, errors warned not propagated + (same isolation contract as the async path per spec §6). """ async def __call__(self, event: NodeEvent, /) -> None: ... @@ -344,12 +364,85 @@ def take_step(self) -> int: def _dispatch(context: _InvocationContext, event: NodeEvent) -> None: """Enqueue a node event for the delivery worker. + For ``"started"``-phase events, also call any subscribed observer's + optional ``prepare_sync(event)`` synchronously — in the engine task, + BEFORE queueing — so observers that need to publish per-event state + the engine itself reads in the same engine-task scope (e.g., the + OTel observer setting ``current_active_observer_span`` for the + engine to attach into the OTel context) can do so before the node + body runs. + + Phase-gated forwarding: ``prepare_sync`` only fires when ``"started"`` + is in the subscribed observer's ``phases`` set, mirroring how the + async ``deliver_loop`` filters dispatch. A user who explicitly + subscribes only to ``{"completed"}`` doesn't get the synchronous + prep — the wrapper acts as a uniform phase shield across both + sync prep and async dispatch. + + Errors from ``prepare_sync`` follow the same isolation contract as + the async path per spec §6: don't propagate, don't break siblings, + don't block the queueing or subsequent events. Reported via + ``warnings.warn``. + No-op when no observers exist for this depth — avoids paying the queue overhead for graphs that don't observe anything. """ observers = context.full_observers() if not observers: return + if event.phase == "started": + for subscribed in observers: + if "started" not in subscribed.phases: + continue + prepare_sync = getattr(subscribed.observer, "prepare_sync", None) + if prepare_sync is None: + continue + try: + result = prepare_sync(event) + except Exception as e: + warnings.warn( + f"observer prepare_sync raised {type(e).__name__}: {e}", + stacklevel=2, + ) + continue + if inspect.isawaitable(result): + # ``prepare_sync`` is opt-in via ``hasattr`` (not a + # Protocol method) so pyright can't catch a user's + # ``async def prepare_sync`` signature drift up front. + # The call here would silently return an unawaited + # coroutine — the prep work wouldn't run AND Python + # would emit a delayed "coroutine was never awaited" + # warning at GC time. Close the awaitable to suppress + # that secondary noise and surface the misconfiguration + # via our own explicit warn so it fails loudly at the + # call site. ``getattr`` rather than ``hasattr``+method + # access keeps pyright's strict-mode happy on the + # ``Awaitable`` type (``.close`` lives on + # ``Coroutine``, not the broader ``Awaitable``). + close_method = getattr(result, "close", None) + if close_method is not None: + try: + close_method() + except Exception as close_error: + # Cleanup is best-effort: a raise here MUST NOT + # propagate or block sibling observers. Surface + # via ``warnings.warn`` so the swallow is at + # least observable if it ever fires (CodeQL + # py/empty-except clears on this surface too). + warnings.warn( + f"observer prepare_sync close cleanup raised " + f"{type(close_error).__name__}: {close_error}", + stacklevel=2, + ) + warnings.warn( + f"observer prepare_sync returned an awaitable " + f"({type(result).__name__}); prepare_sync MUST be sync " + f"(define as `def`, not `async def`). The returned " + f"awaitable will not be awaited and is NOT guaranteed " + f"to complete before the node body starts; log " + f"correlation may miss this node's span.", + stacklevel=2, + ) context.queue.put_nowait(_QueuedItem(event=event, observers=observers)) diff --git a/src/openarmature/observability/correlation.py b/src/openarmature/observability/correlation.py index 9b505a5..d8a714b 100644 --- a/src/openarmature/observability/correlation.py +++ b/src/openarmature/observability/correlation.py @@ -290,8 +290,75 @@ def _reset_attempt_index(token: Token[int]) -> None: _attempt_index_var.reset(token) +# --------------------------------------------------------------------------- +# Active observer span — for engine-side OTel context attach inside +# ``innermost``. Populated synchronously by an observer's ``prepare_sync`` +# hook BEFORE the engine queues the started event; read by ``innermost`` +# AFTER ``_dispatch_started`` returns to attach the span into the OTel +# context for the duration of the node-body chain. +# +# Inverted directionality vs. the engine→observer ContextVars above: +# this one flows observer→engine. The producer (an opt-in observer's +# ``prepare_sync``) and the consumer (``innermost``) both run in the +# engine task, so the same-task ContextVar contract holds — last-writer- +# wins is fine in practice (charter §6 says "one OTelObserver per +# private provider," so multi-OTelObserver attach is rare). +# +# Typed as ``object | None`` rather than ``Span | None`` so the base +# package stays free of an OpenTelemetry import. The OTel observer +# writes ``Span`` instances; the engine treats the value opaquely and +# delegates the actual attach to a try-imported OTel helper. +# --------------------------------------------------------------------------- + + +_active_observer_span_var: ContextVar[object | None] = ContextVar( + "openarmature.active_observer_span", default=None +) + + +def current_active_observer_span() -> object | None: + """Return the active observer-side span for the current node body, + or ``None`` if no observer published one (no opt-in observer with + ``prepare_sync`` is attached, or this is being called outside a + node-body scope). + + Engine-readable handle to the span an opt-in observer's + ``prepare_sync`` created synchronously during dispatch. The engine's + ``innermost`` reads this AFTER ``_dispatch_started`` returns and + attaches the span into the OTel context (via a try-imported OTel + helper) so that any logs emitted FROM INSIDE the node body — even + on the first line, before any ``await`` — pick up the span's + trace_id/span_id via OTel's ``LoggingHandler``. + + Lifecycle: the value is ``None`` outside a node-body scope (between + dispatches, during merge, during completed-event dispatch). The + engine's ``innermost`` clears it to ``None`` in its ``finally`` + block right after the OTel detach — so a subsequent + ``prepare_sync`` that raises or early-returns can't reveal a stale + span from a previous node when the engine reads. + + Backend coupling note: typed as ``object | None`` so this primitive + works in installs without the ``[otel]`` extras. OTel observers + write OpenTelemetry ``Span`` instances; the engine treats the + value opaquely. + """ + return _active_observer_span_var.get() + + +def _set_active_observer_span(value: object | None) -> Token[object | None]: + """Set the active observer span. Internal — observers' ``prepare_sync`` + implementations call this synchronously before returning so the + engine's ``innermost`` reads the right value when it attaches.""" + return _active_observer_span_var.set(value) + + +def _reset_active_observer_span(token: Token[object | None]) -> None: + _active_observer_span_var.reset(token) + + __all__ = [ # Public surface — readable from anywhere within an invocation. + "current_active_observer_span", "current_active_observers", "current_attempt_index", "current_correlation_id", @@ -304,6 +371,7 @@ def _reset_attempt_index(token: Token[int]) -> None: # pyright's strict ``reportUnusedFunction`` flagging them as # dead. Underscore-prefixed; not part of the user-facing API. "_reset_active_dispatch", + "_reset_active_observer_span", "_reset_active_observers", "_reset_attempt_index", "_reset_correlation_id", @@ -311,6 +379,7 @@ def _reset_attempt_index(token: Token[int]) -> None: "_reset_invocation_id", "_reset_namespace_prefix", "_set_active_dispatch", + "_set_active_observer_span", "_set_active_observers", "_set_attempt_index", "_set_correlation_id", diff --git a/src/openarmature/observability/otel/observer.py b/src/openarmature/observability/otel/observer.py index 1d65fa0..a1487fb 100644 --- a/src/openarmature/observability/otel/observer.py +++ b/src/openarmature/observability/otel/observer.py @@ -243,7 +243,12 @@ async def __call__(self, event: NodeEvent) -> None: self._emit_checkpoint_save_span(event) return if event.phase == "started": - self._handle_started(event) + # Idempotent — short-circuits inside ``_open_started_span`` + # if ``prepare_sync`` already opened the span synchronously + # in the engine task. Falls through (and opens the span) + # for observers attached after the engine started, or for + # test paths that bypass ``prepare_sync``. + self._open_started_span(event) elif event.phase == "completed": self._handle_completed(event) @@ -251,8 +256,54 @@ async def __call__(self, event: NodeEvent) -> None: # Started / completed pairing # ------------------------------------------------------------------ - def _handle_started(self, event: NodeEvent) -> None: - """Open a span for this attempt, push onto the in-flight map.""" + def prepare_sync(self, event: NodeEvent) -> None: + """Synchronous engine-task entry point: open the span for this + attempt AND publish it via ``current_active_observer_span`` so + the engine's ``innermost`` can attach it into the OTel context + before the node body runs. + + Called by ``_dispatch`` BEFORE ``queue.put_nowait`` for + ``"started"``-phase events. The async ``__call__`` later sees + the span already in ``inv_state.open_spans`` and short-circuits. + + Skipped for non-``"started"`` phases and for the LLM sentinel + namespace — only graph-node started events participate in the + engine-side attach. Errors don't leak: ``_dispatch`` wraps this + call in try/except + ``warnings.warn`` matching the async path. + """ + if event.phase != "started" or event.namespace == _LLM_NAMESPACE: + return + from openarmature.observability.correlation import ( + _set_active_observer_span, + current_invocation_id, + ) + + self._open_started_span(event) + invocation_id = current_invocation_id() + if invocation_id is None: + return + inv_state = self._inv_states.get(invocation_id) + if inv_state is None: + return + open_span = inv_state.open_spans.get(self._key_for(event)) + if open_span is None: + return + # Publish the span to the engine via the ContextVar. Discard + # the Token — last-writer-wins is the documented contract + # (next ``prepare_sync`` overwrites; task-local context dies + # with the invocation task). + _set_active_observer_span(open_span.span) + + def _open_started_span(self, event: NodeEvent) -> None: + """Sync core: create the span + mutate ``inv_state.open_spans``. + + Idempotent — short-circuits if a span already exists for this + event's ``_StackKey``. That covers the common case where + ``prepare_sync`` opened the span synchronously in the engine + task and the async ``__call__`` later re-fires for the same + event; the second call becomes a true no-op rather than + opening a duplicate span. + """ from openarmature.observability.correlation import ( current_correlation_id, current_invocation_id, @@ -261,8 +312,13 @@ def _handle_started(self, event: NodeEvent) -> None: invocation_id = current_invocation_id() if invocation_id is None: return - correlation_id = current_correlation_id() inv_state = self._inv_state_for(invocation_id) + # Idempotency: a span already exists for this attempt — likely + # opened by ``prepare_sync`` in the engine task. No-op to avoid + # duplicates. + if self._key_for(event) in inv_state.open_spans: + return + correlation_id = current_correlation_id() # Lazily open the invocation span on the first event we see # for this invocation_id. Per-invocation_id scoping means @@ -579,8 +635,8 @@ def _sync_subgraph_spans( any subgraph spans whose prefix is no longer an ancestor of the current event's namespace. - Called from ``_handle_started`` BEFORE opening the leaf node - span. Detached-mode entries (subgraph or fan-out instance) + Called from ``_open_started_span`` BEFORE opening the leaf + node span. Detached-mode entries (subgraph or fan-out instance) are registered as detached roots so their inner spans live in a fresh trace. """ diff --git a/tests/conformance/test_observability.py b/tests/conformance/test_observability.py index 5b20f54..13539d1 100644 --- a/tests/conformance/test_observability.py +++ b/tests/conformance/test_observability.py @@ -16,24 +16,15 @@ - **009-correlation-id-cross-cutting** (Phase 6.0) — every span carries ``openarmature.correlation_id``; back-to-back invocations get distinct UUIDv4s. +- **010-log-correlation** (PR-C.3) — log records emitted from + inside node bodies pick up the active node span's + ``trace_id``/``span_id`` via the engine-side + ``prepare_sync`` → OTel context attach pipeline; both nested + and detached-trace cases. - **011-determinism** (PR-C) — deterministic span content (hierarchy, names, status, attributes minus the canonical non-deterministic-by-design list) is identical across runs. -Deferred: - -- **004-routing-error-attribution** — needs the proposal-0012 - ordering swap (completed dispatch after edge eval) so the - preceding node's ``completed`` event carries the routing-error - status. Lands in PR-C.1 once v0.9.0 ships. -- **006-fan-out-instance-attribution** — needs non-detached - fan-out per-instance dispatch span synthesis + ``FanOutConfig`` - metadata surfacing. Lands in PR-C.2. -- **010-log-correlation** — needs the synchronous observer prep - hook (``prepare_sync``) so the engine task can attach the - observer's span to OTel context for the duration of node-body - execution. Lands in PR-C.3. - Per-fixture wiring notes live in ``docs/phase-6-1-conformance-fillin.md``. """ @@ -76,19 +67,13 @@ "007-otel-retry-attempt-spans", "008-otel-detached-trace-mode", "009-otel-correlation-id-cross-cutting", + "010-otel-log-correlation", "011-otel-determinism", } ) -_DEFERRED_FIXTURES: dict[str, str] = { - "010-otel-log-correlation": ( - "Needs synchronous observer prep hook (prepare_sync) so the engine task can " - "attach the observer's span to OTel context for the duration of node-body " - "execution — observer span creation runs on the worker task today and isn't " - "available synchronously after _dispatch_started. Lands in PR-C.3." - ), -} +_DEFERRED_FIXTURES: dict[str, str] = {} # UUIDv4 canonical form: xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx (where y in {8,9,a,b}). @@ -143,6 +128,8 @@ async def test_observability_fixture(fixture_path: Path) -> None: await _run_fixture_008(spec) elif fixture_id == "009-otel-correlation-id-cross-cutting": await _run_fixture_009(spec) + elif fixture_id == "010-otel-log-correlation": + await _run_fixture_010(spec) elif fixture_id == "011-otel-determinism": await _run_fixture_011(spec) else: @@ -1427,3 +1414,323 @@ async def delete(self, invocation_id: str) -> None: f"original and resumed runs MUST produce DIFFERENT trace_ids " f"(per §10.4 step 4 + §5.1); got {len(trace_ids)} distinct trace_ids" ) + + +# --------------------------------------------------------------------------- +# Fixture 010 — log correlation (PR-C.3) +# +# Two sub-cases. Both build the graph by hand rather than going through the +# adapter — fixture 010's ``emits_log:`` directive isn't an adapter primitive +# (the adapter recognizes ``update_pure``, ``subgraph``, etc., and silently +# ignores anything else), and the sub-cases are small enough that hand-built +# python is clearer than threading a new directive through the adapter. +# --------------------------------------------------------------------------- + + +def _setup_isolated_log_bridge() -> tuple[Any, Any, Any]: + """Spin up an OTel ``LoggerProvider`` + ``InMemoryLogRecordExporter`` and + install the log bridge against the root logger, snapshotting the prior + log state so the caller can restore it in ``finally`` (the bridge mutates + process-global ``logging`` state — handlers, factory). + + Returns ``(exporter, provider, restore_state)`` where ``restore_state`` + is a snapshot to pass to :func:`_restore_log_state`. + """ + import logging as _logging # noqa: PLC0415 + + from opentelemetry.sdk._logs import LoggerProvider # noqa: PLC0415 + from opentelemetry.sdk._logs.export import ( # noqa: PLC0415 + InMemoryLogRecordExporter, + SimpleLogRecordProcessor, + ) + + from openarmature.observability.otel import install_log_bridge # noqa: PLC0415 + + root = _logging.getLogger() + snapshot = (list(root.handlers), list(root.filters), _logging.getLogRecordFactory()) + + exporter = InMemoryLogRecordExporter() + provider = LoggerProvider() + provider.add_log_record_processor(SimpleLogRecordProcessor(exporter)) + install_log_bridge(provider) + return exporter, provider, snapshot + + +def _restore_log_state(snapshot: Any) -> None: + """Pair to :func:`_setup_isolated_log_bridge` — restores the root logger's + handler list, filters, and ``LogRecord`` factory to the snapshot taken + before ``install_log_bridge`` ran.""" + import logging as _logging # noqa: PLC0415 + + handlers, filters, factory = snapshot + root = _logging.getLogger() + root.handlers[:] = handlers + root.filters[:] = filters + _logging.setLogRecordFactory(factory) + + +def _enable_test_logger_at_info() -> tuple[Any, int]: + """Bring the fixture-010 test logger up to ``INFO`` so YAML's + ``level: INFO`` records actually flow through Python's logger-level + filter to the bridge handler. Returns ``(logger, prior_level)`` to + pair with a restore in ``finally``.""" + import logging as _logging # noqa: PLC0415 + + test_logger = _logging.getLogger("openarmature.test.fixture_010") + prior_level = test_logger.level + test_logger.setLevel(_logging.INFO) + return test_logger, prior_level + + +async def _run_fixture_010(spec: Mapping[str, Any]) -> None: + """Two sub-cases: nested-trace log correlation (single graph, all logs + share the parent trace_id) and detached-subgraph log correlation + (logs across the detached boundary carry distinct trace_ids but the + same correlation_id).""" + cases = cast("list[dict[str, Any]]", spec["cases"]) + for case in cases: + case_name = cast("str", case["name"]) + try: + await _run_fixture_010_case(case) + except AssertionError as e: + raise AssertionError(f"case {case_name!r}: {e}") from e + + +async def _run_fixture_010_case(case: Mapping[str, Any]) -> None: + case_name = cast("str", case["name"]) + if case_name == "log_records_carry_trace_span_correlation_ids": + await _run_fixture_010_nested_trace(case) + elif case_name == "detached_subgraph_log_uses_detached_trace_id_keeps_correlation_id": + await _run_fixture_010_detached(case) + else: + raise AssertionError(f"unknown fixture 010 sub-case: {case_name!r}") + + +async def _run_fixture_010_nested_trace(case: Mapping[str, Any]) -> None: + """Sub-case 1: 2 nodes ``a`` → ``b``, both emit logs from the FIRST line + of their body. The log bridge MUST report all logs in the parent + trace_id, with each log's span_id matching the active node span at + emission, and all carrying the invocation's correlation_id.""" + from openarmature.graph import END, GraphBuilder, State # noqa: PLC0415 + + nodes_spec = cast("dict[str, Any]", case["nodes"]) + correlation_id = cast("str", case["caller_correlation_id"]) + # Spec YAML is the single source of truth for the log bodies; derive + # them up front rather than hard-coding so a fixture rename doesn't + # silently break the driver's record filtering. + node_emit_messages: dict[str, str] = { + name: cast("str", cast("dict[str, Any]", nodes_spec[name])["emits_log"]["message"]) + for name in nodes_spec + } + + class _S(State): + x: int = 0 + + test_logger, prior_level = _enable_test_logger_at_info() + + def _make_body(node_name: str) -> Any: + spec = cast("dict[str, Any]", nodes_spec[node_name]) + emit_msg = cast("str", spec["emits_log"]["message"]) + update = cast("dict[str, Any]", spec["update_pure"]) + + async def body(_s: _S) -> dict[str, Any]: + # FIRST line, before any await — the load-bearing case + # the engine attach via ``prepare_sync`` exists to cover. + test_logger.info(emit_msg) + return dict(update) + + return body + + builder = GraphBuilder(_S) + for node_name in nodes_spec: + builder.add_node(node_name, _make_body(node_name)) + for edge in cast("list[dict[str, Any]]", case["edges"]): + from_node = cast("str", edge["from"]) + to = edge["to"] + builder.add_edge(from_node, END if to == "END" else cast("str", to)) + builder.set_entry(cast("str", case["entry"])) + compiled = builder.compile() + + observer, span_exporter = _build_observer() + log_exporter, log_provider, snapshot = _setup_isolated_log_bridge() + try: + compiled.attach_observer(observer) + await compiled.invoke(_S(), correlation_id=correlation_id) + await compiled.drain() + observer.shutdown() + log_provider.force_flush() + + records = log_exporter.get_finished_logs() + # Filter to OUR test loggers so concurrent test setup noise + # doesn't contaminate the assertions. Expected message set + # comes from the spec YAML, not hard-coded strings. + expected_messages = set(node_emit_messages.values()) + ours = [r for r in records if str(r.log_record.body) in expected_messages] + assert len(ours) == 2, ( + f"expected 2 log records (one per node body); got {len(ours)}: " + f"{[str(r.log_record.body) for r in ours]}" + ) + + # Group by body for predictable lookup, indexing by the spec's + # emit-message values. + by_body = {str(r.log_record.body): r for r in ours} + a_log = by_body[node_emit_messages["a"]] + b_log = by_body[node_emit_messages["b"]] + + # Invariant: all_logs_same_trace_id. + trace_ids = {a_log.log_record.trace_id, b_log.log_record.trace_id} + assert len(trace_ids) == 1, f"all logs MUST share a trace_id (single nested trace); got {trace_ids}" + + # Invariant: log_span_ids_match_active_span_at_emission. + spans = span_exporter.get_finished_spans() + node_span_ids: dict[str, int] = {} + for s in spans: + if s.name in {"a", "b"}: + node_span_ids[s.name] = s.context.span_id + assert a_log.log_record.span_id == node_span_ids["a"], ( + f"node-a log MUST carry node-a span's span_id; " + f"got log span_id={a_log.log_record.span_id}, span={node_span_ids['a']}" + ) + assert b_log.log_record.span_id == node_span_ids["b"], ( + f"node-b log MUST carry node-b span's span_id; " + f"got log span_id={b_log.log_record.span_id}, span={node_span_ids['b']}" + ) + + # Invariant: all_logs_carry_correlation_id. + for r in ours: + attrs = dict(r.log_record.attributes or {}) + assert attrs.get("openarmature.correlation_id") == correlation_id, ( + f"every log MUST carry openarmature.correlation_id={correlation_id!r}; " + f"got {attrs.get('openarmature.correlation_id')!r}" + ) + finally: + test_logger.setLevel(prior_level) + _restore_log_state(snapshot) + + +async def _run_fixture_010_detached(case: Mapping[str, Any]) -> None: + """Sub-case 2: outer invocation has a detached subgraph. Logs emitted + inside the detached subgraph carry the DETACHED trace's trace_id — + NOT the parent's — while the correlation_id flows unchanged across + the boundary.""" + from openarmature.graph import END, GraphBuilder, State # noqa: PLC0415 + + correlation_id = cast("str", case["caller_correlation_id"]) + sub_specs = cast("dict[str, Any]", case["subgraphs"]) + inner_spec = cast("dict[str, Any]", sub_specs["detached_inner"]) + outer_nodes = cast("dict[str, Any]", case["nodes"]) + + # Detached subgraph identity → wrapper-node-name translation, same + # convention as fixture 008. The fixture YAML lists subgraph identities + # in ``detached_subgraphs:``; OTelObserver keys on the wrapper node's + # name in the parent graph. + detached_identities = set(cast("list[str]", case.get("detached_subgraphs") or [])) + wrapper_names: set[str] = set() + for wrapper_name, node_spec in outer_nodes.items(): + sub_id = cast("dict[str, Any]", node_spec).get("subgraph") + if isinstance(sub_id, str) and sub_id in detached_identities: + wrapper_names.add(wrapper_name) + detached_subgraphs = frozenset(wrapper_names) + + test_logger, prior_level = _enable_test_logger_at_info() + + # Inner subgraph (detached_inner): 1 node ``inner`` with + # ``update_pure: {y: 1}`` + ``emits_log: "inside detached subgraph"``. + class _Inner(State): + y: int = 0 + + inner_node_spec = cast("dict[str, Any]", inner_spec["nodes"]["inner"]) + inner_emit = cast("str", inner_node_spec["emits_log"]["message"]) + inner_update = cast("dict[str, Any]", inner_node_spec["update_pure"]) + + async def _inner_body(_s: _Inner) -> dict[str, Any]: + test_logger.info(inner_emit) + return dict(inner_update) + + inner_compiled = ( + GraphBuilder(_Inner) + .add_node("inner", _inner_body) + .add_edge("inner", END) + .set_entry("inner") + .compile() + ) + + # Outer graph: ``outer_dispatch`` is a SubgraphNode wrapper around + # ``inner_compiled`` AND emits a log "before subgraph dispatch". + # SubgraphNode wrappers don't get ``prepare_sync`` per spec — the + # outer log is emitted via per-node middleware that fires inside + # the wrapper's chain. Without an attached span at wrapper scope, + # the outer log's trace_id is OTel's "no active span" sentinel + # (0); the inner log's trace_id is the detached trace's. The + # invariant ``log_trace_ids_differ_when_detached`` holds either + # way. + class _Outer(State): + z: int = 0 + + outer_node_spec = cast("dict[str, Any]", outer_nodes["outer_dispatch"]) + outer_emit = cast("str", outer_node_spec["emits_log"]["message"]) + + async def _outer_log_middleware(s: Any, next_call: Any) -> Mapping[str, Any]: + test_logger.info(outer_emit) + return cast("Mapping[str, Any]", await next_call(s)) + + outer_compiled = ( + GraphBuilder(_Outer) + .add_subgraph_node("outer_dispatch", inner_compiled, middleware=[_outer_log_middleware]) + .add_edge("outer_dispatch", END) + .set_entry("outer_dispatch") + .compile() + ) + + observer, _span_exporter = _build_observer_with_detached(detached_subgraphs) + log_exporter, log_provider, snapshot = _setup_isolated_log_bridge() + try: + outer_compiled.attach_observer(observer) + await outer_compiled.invoke(_Outer(), correlation_id=correlation_id) + await outer_compiled.drain() + observer.shutdown() + log_provider.force_flush() + + records = log_exporter.get_finished_logs() + ours = [r for r in records if str(r.log_record.body) in {outer_emit, inner_emit}] + assert len(ours) == 2, ( + f"expected 2 log records (outer + inner); got {len(ours)}: " + f"{[str(r.log_record.body) for r in ours]}" + ) + + by_body = {str(r.log_record.body): r for r in ours} + outer_log = by_body[outer_emit] + inner_log = by_body[inner_emit] + + # Invariant: log_trace_ids_differ_when_detached. + assert outer_log.log_record.trace_id != inner_log.log_record.trace_id, ( + f"detached-subgraph log MUST carry the detached trace's trace_id, " + f"DIFFERENT from the parent log; both got {outer_log.log_record.trace_id}" + ) + + # Invariant: all_logs_carry_correlation_id. + for r in ours: + attrs = dict(r.log_record.attributes or {}) + assert attrs.get("openarmature.correlation_id") == correlation_id, ( + f"every log MUST carry openarmature.correlation_id={correlation_id!r}; " + f"got {attrs.get('openarmature.correlation_id')!r}" + ) + finally: + test_logger.setLevel(prior_level) + _restore_log_state(snapshot) + + +def _build_observer_with_detached(detached_subgraphs: frozenset[str]) -> tuple[OTelObserver, Any]: + """Variant of :func:`_build_observer` that takes a detached_subgraphs + set — needed for fixture 010 sub-case 2.""" + from opentelemetry.sdk.trace.export import SimpleSpanProcessor # noqa: PLC0415 + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( # noqa: PLC0415 + InMemorySpanExporter, + ) + + exporter = InMemorySpanExporter() + observer = OTelObserver( + span_processor=SimpleSpanProcessor(exporter), + detached_subgraphs=detached_subgraphs, + ) + return observer, exporter diff --git a/tests/unit/test_observability_otel.py b/tests/unit/test_observability_otel.py index 20e1474..0187423 100644 --- a/tests/unit/test_observability_otel.py +++ b/tests/unit/test_observability_otel.py @@ -29,7 +29,7 @@ # Skip the entire module if otel extras aren't installed. pytest.importorskip("opentelemetry.sdk.trace") -from typing import cast +from typing import Any, cast from opentelemetry import trace as otel_trace from opentelemetry.sdk.trace import ReadableSpan, TracerProvider @@ -862,3 +862,106 @@ async def _flaky(s: _S) -> dict[str, int]: idx = cast("int", attrs["openarmature.node.attempt_index"]) parented_attempts.add(idx) assert parented_attempts == {0, 1, 2} + + +async def test_log_on_first_line_of_node_body_carries_node_span() -> None: + """The load-bearing case ``prepare_sync`` exists to fix. + + Without ``prepare_sync``, the engine queues the started event for + async dispatch, then enters the node body — by the time the OTel + observer's ``__call__`` opens the span on the worker task, the + node body has already executed (or is mid-await). A log emitted + on the FIRST line of the body, before any ``await``, would not + see the observer's span via OTel ``get_current()``. + + With ``prepare_sync``, the observer creates the span synchronously + in the engine task BEFORE queueing, publishes it via + ``current_active_observer_span``, and the engine attaches it to + the OTel context around the node body. The first-line log picks + up the right ``trace_id``/``span_id``. + + This test exists in unit/ (not just buried in the conformance + fixture 010 driver) so a failure here jumps straight to + ``prepare_sync``-related changes during a regression hunt. + """ + from opentelemetry.sdk._logs import LoggerProvider + from opentelemetry.sdk._logs.export import ( + InMemoryLogRecordExporter, + SimpleLogRecordProcessor, + ) + + test_logger = logging.getLogger("openarmature.test.first_line_log") + + class _S(State): + x: int = 0 + + async def first_line_log_node(_s: _S) -> dict[str, Any]: + # FIRST line, before any ``await`` — without ``prepare_sync`` + # in the engine task, OTel ``get_current()`` would return an + # invalid span here and the log would have ``trace_id=0`` / + # ``span_id=0``. + test_logger.info("emitted before any await") + return {"x": 1} + + span_exporter = InMemorySpanExporter() + observer = OTelObserver(span_processor=SimpleSpanProcessor(span_exporter)) + log_exporter = InMemoryLogRecordExporter() + log_provider = LoggerProvider() + log_provider.add_log_record_processor(SimpleLogRecordProcessor(log_exporter)) + + # Snapshot prior log state so this test doesn't bleed into others + # — install_log_bridge mutates process-global ``logging`` state. + root = logging.getLogger() + prior_handlers = list(root.handlers) + prior_filters = list(root.filters) + prior_factory = logging.getLogRecordFactory() + prior_test_level = test_logger.level + test_logger.setLevel(logging.INFO) + + try: + install_log_bridge(log_provider) + g = ( + GraphBuilder(_S) + .add_node("node_a", first_line_log_node) + .add_edge("node_a", END) + .set_entry("node_a") + .compile() + ) + g.attach_observer(observer) + await g.invoke(_S(), correlation_id="first-line-test") + await g.drain() + observer.shutdown() + log_provider.force_flush() + + records = log_exporter.get_finished_logs() + ours = [r for r in records if str(r.log_record.body) == "emitted before any await"] + assert len(ours) == 1, ( + f"expected exactly one log record; got {len(ours)}: {[str(r.log_record.body) for r in records]}" + ) + log_record = ours[0].log_record + + spans = span_exporter.get_finished_spans() + node_a_spans = [s for s in spans if s.name == "node_a"] + assert len(node_a_spans) == 1, f"expected one node_a span; got {len(node_a_spans)}" + node_a_span = node_a_spans[0] + assert node_a_span.context is not None + node_span_id = node_a_span.context.span_id + node_trace_id = node_a_span.context.trace_id + + # Load-bearing: the prepare_sync hook attached the observer + # span synchronously so the first-line log saw it via OTel + # ``get_current()``. + assert log_record.span_id == node_span_id, ( + f"first-line log MUST carry node_a span's span_id " + f"(prepare_sync attaches the span synchronously in the engine task); " + f"got log span_id={log_record.span_id}, node span_id={node_span_id}" + ) + assert log_record.trace_id == node_trace_id, ( + f"first-line log MUST carry node_a span's trace_id; " + f"got log trace_id={log_record.trace_id}, node trace_id={node_trace_id}" + ) + finally: + root.handlers[:] = prior_handlers + root.filters[:] = prior_filters + logging.setLogRecordFactory(prior_factory) + test_logger.setLevel(prior_test_level)