diff --git a/src/openarmature/graph/compiled.py b/src/openarmature/graph/compiled.py index 44d7463..73a023d 100644 --- a/src/openarmature/graph/compiled.py +++ b/src/openarmature/graph/compiled.py @@ -44,12 +44,18 @@ from openarmature.observability.correlation import ( _reset_active_dispatch, _reset_active_observers, + _reset_attempt_index, _reset_correlation_id, + _reset_fan_out_index, _reset_invocation_id, + _reset_namespace_prefix, _set_active_dispatch, _set_active_observers, + _set_attempt_index, _set_correlation_id, + _set_fan_out_index, _set_invocation_id, + _set_namespace_prefix, ) from .edges import END, ConditionalEdge, EndSentinel, StaticEdge @@ -558,51 +564,59 @@ async def innermost(s: Any) -> Mapping[str, Any]: attempt_index = attempt_counter[0] attempt_counter[0] += 1 - self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) - + # Calling-node identity for capability backends emitting + # from inside this attempt's scope (e.g., LLM provider's + # span hook). Per-attempt scope so retry middleware that + # re-enters innermost bumps the visible attempt_index. + attempt_token = _set_attempt_index(attempt_index) try: - partial = await node.run(s) - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=wrapped, - attempt_index=attempt_index, - ) - raise + self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) + + try: + partial = await node.run(s) + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + ) + raise + + try: + merged = _merge_partial(s, partial, self.reducers, current) + except (ReducerError, StateValidationError) as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + ) + raise - try: - merged = _merge_partial(s, partial, self.reducers, current) - except (ReducerError, StateValidationError) as e: self._dispatch_completed( context, current, namespace, step, s, - error=e, + post_state=merged, attempt_index=attempt_index, ) - raise - - self._dispatch_completed( - context, - current, - namespace, - step, - s, - post_state=merged, - attempt_index=attempt_index, - ) - # Return the partial (not the merged state) so middleware sees - # the partial-update shape per pipeline-utilities §2. The - # engine's canonical merge against the original state happens - # below, after the chain returns. - return partial + # Return the partial (not the merged state) so middleware sees + # the partial-update shape per pipeline-utilities §2. The + # engine's canonical merge against the original state happens + # below, after the chain returns. + return partial + finally: + _reset_attempt_index(attempt_token) chain: ChainCall = compose_chain( list(self.middleware) + list(node.middleware), @@ -612,12 +626,17 @@ async def innermost(s: Any) -> Mapping[str, Any]: # Spec observability §3 / Phase 6 LLM-span hook: capability # backends emitting from inside a node body (the # llm-provider span instrumentation in OpenAIProvider) need - # to find the observers active for THIS invocation. Set the - # ContextVar around the chain invocation; reset in - # ``try/finally`` so an exception escaping the chain still - # restores the prior value. + # to find the observers active for THIS invocation, which + # node is calling, and which fan-out instance (if any) the + # call belongs to. ``namespace_prefix`` and ``fan_out_index`` + # are set in this outer scope (per-node, not per-attempt); + # ``attempt_index`` is set inside ``innermost`` per attempt. + # All four reset in ``try/finally`` so an exception escaping + # the chain still restores the prior values. observers_token = _set_active_observers(context.full_observers()) dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event)) + namespace_token = _set_namespace_prefix(namespace) + fan_out_token = _set_fan_out_index(context.fan_out_index) try: try: final_partial = await chain(state) @@ -628,6 +647,8 @@ async def innermost(s: Any) -> Mapping[str, Any]: # the chain unrecovered. Wrap as NodeException per §4. raise NodeException(node_name=current, cause=e, recoverable_state=state) from e finally: + _reset_fan_out_index(fan_out_token) + _reset_namespace_prefix(namespace_token) _reset_active_dispatch(dispatch_token) _reset_active_observers(observers_token) # Engine's canonical merge uses the ORIGINAL state per §2: "the @@ -686,13 +707,18 @@ async def innermost(s: Any) -> Mapping[str, Any]: list(self.middleware) + list(node.middleware), innermost, ) - # Same active-observers scope as _step_function_node — parent - # middleware running before the descent should see the parent's - # observer set; the inner _invoke (called via ``node.run``) - # descends into its own context and sets a new scope from - # there. + # Same active-observers + calling-node scope as + # ``_step_function_node`` — parent middleware running before + # the descent should see the wrapper node's namespace + + # fan_out_index for any LLM-provider hook emissions. + # ``attempt_index`` defaults to 0 from the ContextVar; the + # subgraph wrapper has no engine-managed attempt counter + # (inner ``_step_function_node`` calls own their own). + namespace = context.namespace_prefix + (current,) observers_token = _set_active_observers(context.full_observers()) dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event)) + namespace_token = _set_namespace_prefix(namespace) + fan_out_token = _set_fan_out_index(context.fan_out_index) try: try: @@ -707,6 +733,8 @@ async def innermost(s: Any) -> Mapping[str, Any]: # preserved. raise NodeException(node_name=current, cause=e, recoverable_state=state) from e finally: + _reset_fan_out_index(fan_out_token) + _reset_namespace_prefix(namespace_token) _reset_active_dispatch(dispatch_token) _reset_active_observers(observers_token) return _merge_partial(state, final_partial, self.reducers, current) @@ -743,45 +771,49 @@ async def _step_fan_out_node( async def innermost(s: Any) -> Mapping[str, Any]: attempt_index = attempt_counter[0] attempt_counter[0] += 1 - self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) + attempt_token = _set_attempt_index(attempt_index) try: - partial = await node.run_with_context(s, context) - except RuntimeGraphError as e: - self._dispatch_completed( - context, current, namespace, step, s, error=e, attempt_index=attempt_index - ) - raise - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) + try: + partial = await node.run_with_context(s, context) + except RuntimeGraphError as e: + self._dispatch_completed( + context, current, namespace, step, s, error=e, attempt_index=attempt_index + ) + raise + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + ) + raise wrapped from e + + try: + merged = _merge_partial(s, partial, self.reducers, current) + except (ReducerError, StateValidationError) as e: + self._dispatch_completed( + context, current, namespace, step, s, error=e, attempt_index=attempt_index + ) + raise + self._dispatch_completed( context, current, namespace, step, s, - error=wrapped, + post_state=merged, attempt_index=attempt_index, ) - raise wrapped from e - - try: - merged = _merge_partial(s, partial, self.reducers, current) - except (ReducerError, StateValidationError) as e: - self._dispatch_completed( - context, current, namespace, step, s, error=e, attempt_index=attempt_index - ) - raise - - self._dispatch_completed( - context, - current, - namespace, - step, - s, - post_state=merged, - attempt_index=attempt_index, - ) - return partial + return partial + finally: + _reset_attempt_index(attempt_token) chain: ChainCall = compose_chain( list(self.middleware) + list(node.middleware), @@ -789,12 +821,18 @@ async def innermost(s: Any) -> Mapping[str, Any]: ) # Same observability §3 / LLM-span hook contract as - # _step_function_node: set the active observer set in scope - # around the chain invocation so capability backends emitting - # from inside the fan-out's parent dispatch (or any code - # running on its call stack) can find the observers. + # _step_function_node: set the active observer set, calling + # node identity, and dispatch scope around the chain + # invocation so capability backends emitting from inside the + # fan-out's parent dispatch (or any code running on its call + # stack) can find them. ``fan_out_index`` here is the parent + # context's view (the fan-out node from outside); per-instance + # values get set when the inner subgraph descends with the + # instance's index in its own context. observers_token = _set_active_observers(context.full_observers()) dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event)) + namespace_token = _set_namespace_prefix(namespace) + fan_out_token = _set_fan_out_index(context.fan_out_index) try: try: final_partial = await chain(state) @@ -803,6 +841,8 @@ async def innermost(s: Any) -> Mapping[str, Any]: except Exception as e: raise NodeException(node_name=current, cause=e, recoverable_state=state) from e finally: + _reset_fan_out_index(fan_out_token) + _reset_namespace_prefix(namespace_token) _reset_active_dispatch(dispatch_token) _reset_active_observers(observers_token) merged_outer = _merge_partial(state, final_partial, self.reducers, current) diff --git a/src/openarmature/llm/providers/openai.py b/src/openarmature/llm/providers/openai.py index 5f8fbd9..a4bcd49 100644 --- a/src/openarmature/llm/providers/openai.py +++ b/src/openarmature/llm/providers/openai.py @@ -47,7 +47,12 @@ from openarmature.graph.events import NodeEvent from openarmature.graph.state import State -from openarmature.observability.correlation import current_dispatch +from openarmature.observability.correlation import ( + current_attempt_index, + current_dispatch, + current_fan_out_index, + current_namespace_prefix, +) from ..errors import ( ProviderAuthentication, @@ -554,6 +559,15 @@ class _LlmEventState(State): LLM-span maps by it so concurrent ``complete()`` calls (e.g., fan-out instances each calling the provider) don't collide on a single sentinel-namespace key. + + ``calling_namespace_prefix``, ``calling_attempt_index``, and + ``calling_fan_out_index`` carry the calling node's identity so + the OTel observer can resolve the §5.5 "parent under calling + node" contract correctly under concurrent fan-out and retry. + Populated from the engine's ContextVars (set in + ``_step_*_node`` around node-body execution); fall back to + sentinel defaults (empty tuple, 0, ``None``) when the LLM + provider is called outside any node body. """ call_id: str @@ -571,6 +585,14 @@ class _LlmEventState(State): error_type: str | None = None error_message: str | None = None error_category: str | None = None + # Calling-node identity captured at dispatch time. The OTel + # observer reads these to look up the calling node's span in + # its (now-invocation_id-scoped) ``_open_spans`` map without relying on + # the OTel current-span context (which under concurrent fan-out + # can yield a sibling instance's span). + calling_namespace_prefix: tuple[str, ...] = () + calling_attempt_index: int = 0 + calling_fan_out_index: int | None = None def _make_llm_event( @@ -611,6 +633,9 @@ def _make_llm_event( error_type=error_type, error_message=error_message, error_category=error_category, + calling_namespace_prefix=current_namespace_prefix(), + calling_attempt_index=current_attempt_index(), + calling_fan_out_index=current_fan_out_index(), ) return NodeEvent( node_name="openarmature.llm.complete", diff --git a/src/openarmature/observability/__init__.py b/src/openarmature/observability/__init__.py index 5d22f2f..6a53ca4 100644 --- a/src/openarmature/observability/__init__.py +++ b/src/openarmature/observability/__init__.py @@ -22,14 +22,20 @@ from .correlation import ( current_active_observers, + current_attempt_index, current_correlation_id, current_dispatch, + current_fan_out_index, current_invocation_id, + current_namespace_prefix, ) __all__ = [ "current_active_observers", + "current_attempt_index", "current_correlation_id", "current_dispatch", + "current_fan_out_index", "current_invocation_id", + "current_namespace_prefix", ] diff --git a/src/openarmature/observability/correlation.py b/src/openarmature/observability/correlation.py index 45f8709..9b505a5 100644 --- a/src/openarmature/observability/correlation.py +++ b/src/openarmature/observability/correlation.py @@ -205,22 +205,116 @@ def _reset_active_dispatch( _active_dispatch_var.reset(token) +# --------------------------------------------------------------------------- +# Calling-node identity — for the OTel observer's §5.5 LLM-span parent +# attribution under concurrent fan-out + retry. The engine sets these +# ContextVars around node-body execution in ``_step_*_node``; capability +# code emitting ``NodeEvent``s from inside a node body (the LLM provider +# span hook) reads them to record which node the event originated from. +# +# Without these, the OTel observer falls back to ``opentelemetry.trace``'s +# current-span context to resolve the parent, which under concurrent +# fan-out can yield a sibling instance's span rather than the actual +# calling node. The §5.5 contract states the *outcome* (LLM span parents +# under the calling node); these ContextVars provide the *mechanism*. +# +# Defaults are baked into ContextVar construction so readers outside any +# node body (e.g., LLM ``complete`` called from a top-level harness) +# return the sentinel values directly without engine-side initialization. +# --------------------------------------------------------------------------- + + +_namespace_prefix_var: ContextVar[tuple[str, ...]] = ContextVar("openarmature.namespace_prefix", default=()) + + +def current_namespace_prefix() -> tuple[str, ...]: + """Return the namespace prefix of the node currently executing, + or the empty tuple outside any node body. + + The empty-tuple default makes top-level (outside-invocation) and + between-nodes (e.g., middleware bodies) calls fall back to + invocation-level parenting cleanly. + """ + return _namespace_prefix_var.get() + + +def _set_namespace_prefix(value: tuple[str, ...]) -> Token[tuple[str, ...]]: + """Set the calling node's namespace prefix. Internal — + engine-only; called inside ``_step_*_node`` around node-body + execution.""" + return _namespace_prefix_var.set(value) + + +def _reset_namespace_prefix(token: Token[tuple[str, ...]]) -> None: + _namespace_prefix_var.reset(token) + + +_fan_out_index_var: ContextVar[int | None] = ContextVar("openarmature.fan_out_index", default=None) + + +def current_fan_out_index() -> int | None: + """Return the fan_out_index of the node currently executing, or + ``None`` outside any fan-out instance body (top-level nodes, + subgraph dispatch, between nodes). + """ + return _fan_out_index_var.get() + + +def _set_fan_out_index(value: int | None) -> Token[int | None]: + """Set the calling node's fan_out_index. Internal — engine-only.""" + return _fan_out_index_var.set(value) + + +def _reset_fan_out_index(token: Token[int | None]) -> None: + _fan_out_index_var.reset(token) + + +_attempt_index_var: ContextVar[int] = ContextVar("openarmature.attempt_index", default=0) + + +def current_attempt_index() -> int: + """Return the attempt_index of the node currently executing, or + ``0`` outside any node body. Retry middleware bumps this per + attempt; the OTel observer uses it to disambiguate per-attempt + spans when an LLM call happens inside a retried node body. + """ + return _attempt_index_var.get() + + +def _set_attempt_index(value: int) -> Token[int]: + """Set the calling node's attempt_index. Internal — engine-only.""" + return _attempt_index_var.set(value) + + +def _reset_attempt_index(token: Token[int]) -> None: + _attempt_index_var.reset(token) + + __all__ = [ # Public surface — readable from anywhere within an invocation. "current_active_observers", + "current_attempt_index", "current_correlation_id", "current_dispatch", + "current_fan_out_index", "current_invocation_id", + "current_namespace_prefix", # Engine-internal lifecycle helpers — exported so the engine in # ``openarmature.graph.compiled`` can drive set/reset without # pyright's strict ``reportUnusedFunction`` flagging them as # dead. Underscore-prefixed; not part of the user-facing API. "_reset_active_dispatch", "_reset_active_observers", + "_reset_attempt_index", "_reset_correlation_id", + "_reset_fan_out_index", "_reset_invocation_id", + "_reset_namespace_prefix", "_set_active_dispatch", "_set_active_observers", + "_set_attempt_index", "_set_correlation_id", + "_set_fan_out_index", "_set_invocation_id", + "_set_namespace_prefix", ] diff --git a/src/openarmature/observability/otel/observer.py b/src/openarmature/observability/otel/observer.py index 5fc317e..d79bc4e 100644 --- a/src/openarmature/observability/otel/observer.py +++ b/src/openarmature/observability/otel/observer.py @@ -8,22 +8,42 @@ ``completed`` event it pops the span, applies §4.2 status mapping, and closes it. +**Per-invocation state isolation.** All internal span maps are +outer-keyed by ``invocation_id`` (per spec §5.1: each invocation has +a fresh framework-minted UUIDv4). A single observer can be safely +shared across concurrent invocations (e.g., an ASGI service running +``asyncio.gather([invoke(), invoke()])`` on one observer); each +invocation's spans live in their own sub-dict, lazy-allocated on +first event. The ``correlation_id`` is the cross-run join key (spec +§3.1) and is set as the ``openarmature.correlation_id`` attribute on +every span — it is *not* the state-scoping key, because resume runs +preserve the correlation_id and would (incorrectly) cause the +resumed run's spans to inherit the prior invocation's trace. + +**No cross-event OTel context tokens.** Parent spans are resolved from +the observer's own internal maps within a single event handler's +scope — never from ``opentelemetry.context.get_current()``. Spans are +opened with ``context=set_span_in_context(parent_span)`` directly +rather than ``attach()``-ing tokens that would have to be ``detach()`` +-ed on the matching completed event. This eliminates LIFO-violation +hazards under interleaved fan-out events and makes the observer +robust to dispatch ordering. + Subtree isolation lives in dedicated dicts rather than the leaf-span key: -- ``_subgraph_spans`` — synthetic subgraph dispatch spans (the - engine wrapper is transparent per fixture 013, but observability - §4.5 mandates a span). Keyed by namespace prefix. Open lazily on - the first deeper-namespace event, close when subsequent events - leave the prefix. -- ``_detached_roots`` — root spans for detached subgraphs (§4.4) - and per-instance detached fan-out roots. Each lives in its own - fresh ``trace_id``; the parent's dispatch span carries an OTel +- ``subgraph_spans`` — synthetic subgraph dispatch spans (the engine + wrapper is transparent per fixture 013, but observability §4.5 + mandates a span). Keyed by namespace prefix. Open lazily on the + first deeper-namespace event, close when subsequent events leave + the prefix. +- ``detached_roots`` — root spans for detached subgraphs (§4.4) and + per-instance detached fan-out roots. Each lives in its own fresh + ``trace_id``; the parent's dispatch span carries an OTel :class:`Link` to the detached trace. - ``_invocation_span`` — root invocation span keyed by - ``correlation_id``. Closed eagerly when a new correlation_id's - first event arrives (or explicitly via - :meth:`close_invocation` / :meth:`shutdown`). + ``invocation_id``. Closed via :meth:`close_invocation` / + :meth:`shutdown`. Spans are emitted through a **private** :class:`TracerProvider` constructed by this observer — never the OTel global. Per spec §6 @@ -35,11 +55,7 @@ Detached trace mode (§4.4) is implemented by minting a fresh :class:`SpanContext` with a new ``trace_id`` when entering a configured-detached subgraph or fan-out; the parent's dispatch span -carries an OTel :class:`Link` to the detached trace. Inner-event -parent resolution checks the per-fan-out-instance key -(``namespace[:1] + (str(fan_out_index),)``) before the generic -prefix scan, so per-instance detached roots win without depending -on attach-then-resolve ordering. +carries an OTel :class:`Link` to the detached trace. """ from __future__ import annotations @@ -49,7 +65,6 @@ from opentelemetry import context as otel_context from opentelemetry import trace as otel_trace -from opentelemetry.context import attach, detach from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.trace import ( @@ -70,13 +85,7 @@ # Span-stack key shape: ``(namespace, attempt_index, fan_out_index)`` # — these three fields uniquely identify any node attempt within an -# invocation. Trace_id was previously included in the key but that -# created a registration/lookup mismatch when the OTel current-span -# context changed between a span's started and completed events -# (e.g., detached fan-out instances opening detached roots in -# between). Detached sub-tree spans live in ``_detached_roots`` / -# ``_subgraph_spans`` separately, so the namespace alone doesn't -# collide here. +# invocation. _StackKey = tuple[tuple[str, ...], int, int | None] @@ -96,22 +105,32 @@ def _read_spec_version() -> str: def _empty_str_frozenset() -> frozenset[str]: """Typed empty frozenset factory for ``detached_subgraphs`` / - ``detached_fan_outs`` defaults. ``default_factory=frozenset`` - alone produces ``frozenset[Unknown]`` under pyright strict - mode; a named factory with the explicit return annotation - preserves the ``frozenset[str]`` typing without falling back to - a lambda.""" + ``detached_fan_outs`` defaults.""" return frozenset() @dataclass class _OpenSpan: - """An in-flight span paired with the OTel context token that pinned - its scope. The token is ``detach``ed when the span closes so the - OTel current-span context unwinds correctly.""" + """An in-flight span. No OTel context token: the new architecture + resolves parents from the observer's internal maps within a + single event handler's scope, so no token needs to live across + events.""" span: Span - token: object + + +@dataclass +class _InvState: + """Per-invocation span state. One instance per concurrent + invocation — the outer ``OTelObserver`` keys these by + ``invocation_id`` so concurrent invocations (and resumed runs of + the same correlation_id) don't collide.""" + + open_spans: dict[_StackKey, _OpenSpan] = field(default_factory=dict[_StackKey, _OpenSpan]) + open_llm_spans: dict[str, _OpenSpan] = field(default_factory=dict[str, _OpenSpan]) + subgraph_spans: dict[tuple[str, ...], _OpenSpan] = field(default_factory=dict[tuple[str, ...], _OpenSpan]) + detached_roots: dict[tuple[str, ...], _OpenSpan] = field(default_factory=dict[tuple[str, ...], _OpenSpan]) + fan_out_instance_root_prefixes: set[tuple[str, ...]] = field(default_factory=set[tuple[str, ...]]) @dataclass @@ -141,20 +160,10 @@ class OTelObserver: - ``spec_version`` — string surfaced as ``openarmature.graph.spec_version`` on the invocation span. - **Concurrency model.** A single ``OTelObserver`` instance is - SAFE for sequential invocations on the same graph (one - ``invoke()`` followed by another, with the observer reused - between). It is NOT safe to share an instance across CONCURRENT - invocations — the internal span state (``_open_spans``, - ``_subgraph_spans``, ``_detached_roots``, ``_invocation_span``) - is keyed without per-invocation scoping, so overlapping - namespaces collide and the close-prior-correlation_id logic in - ``_handle_started`` would close another in-flight invocation's - span. Recommended pattern: one observer per ``CompiledGraph`` - instance for sequential workloads; for ASGI / batch / parallel - invocation services, attach a fresh observer per invocation - (via ``invoke(observers=[...])``) until the Phase 6.1 - correlation_id-scoped state lands. + Safe to share across concurrent invocations and across + resumes of the same correlation_id — every internal span map is + outer-keyed by ``invocation_id``, and parent resolution stays + within a single event handler's scope. """ span_processor: SpanProcessor @@ -170,56 +179,16 @@ class OTelObserver: # Internal state, populated in __post_init__ and during invocation. _provider: TracerProvider = field(init=False, repr=False) _tracer: otel_trace.Tracer = field(init=False, repr=False) - _open_spans: dict[_StackKey, _OpenSpan] = field( - init=False, repr=False, default_factory=dict[_StackKey, _OpenSpan] - ) - # The invocation root span, opened on the first event of an - # invocation and closed when the matching outermost completed - # event arrives (or, in practice, when the engine's queue drains - # — the invocation span has no started/completed pair of its - # own, so we open it lazily and close it on a sentinel). + # Per-invocation_id span state — concurrent invocations on a + # shared observer each get their own ``_InvState`` so internal + # maps never collide. + _inv_states: dict[str, _InvState] = field(init=False, repr=False, default_factory=dict[str, _InvState]) + # Root invocation spans, keyed by invocation_id. Opened lazily on + # the first event for a new invocation_id; closed via + # ``close_invocation`` / ``shutdown``. _invocation_span: dict[str, _OpenSpan] = field( init=False, repr=False, default_factory=dict[str, _OpenSpan] ) - # Per-LLM-call span tracking, keyed by ``call_id`` (UUIDv4 - # minted in ``OpenAIProvider.complete`` per call). Concurrent - # ``complete()`` calls under fan-out instances would collide on - # a constant ``_LLM_NAMESPACE`` key in ``_open_spans``; the - # call_id-keyed dict disambiguates them. The deeper - # parent-attribution issue under concurrency (calling node's - # context not threaded through) is Phase 6.1 work. - _open_llm_spans: dict[str, _OpenSpan] = field( - init=False, repr=False, default_factory=dict[str, _OpenSpan] - ) - # Synthetic subgraph dispatch spans: the engine wrapper for - # ``add_subgraph_node`` is transparent (graph-engine fixture 013 - # — no started/completed events of its own), but observability - # §4.5 mandates a subgraph span. The OTel observer synthesizes - # one by detecting deeper-namespace events and opening an - # ancestor span for each new prefix; closes when subsequent - # events leave that prefix. - _subgraph_spans: dict[tuple[str, ...], _OpenSpan] = field( - init=False, repr=False, default_factory=dict[tuple[str, ...], _OpenSpan] - ) - # Per-namespace-prefix detached trace tracking. When a detached - # subgraph or fan-out instance enters, we mint a fresh trace and - # store the root span here so subsequent inner events at that - # prefix find the right parent. Keyed by namespace prefix - # (subgraph) or namespace_prefix + (str(fan_out_index),) (fan-out - # instance). The fan-out node's own span (in the parent trace) - # collects Links to each detached instance trace. - _detached_roots: dict[tuple[str, ...], _OpenSpan] = field( - init=False, repr=False, default_factory=dict[tuple[str, ...], _OpenSpan] - ) - # Subset of ``_detached_roots`` keys that represent per-instance - # fan-out roots — they're closed by ``_handle_completed`` on the - # fan-out node's own completion, NOT by ``_sync_subgraph_spans``. - # Using an explicit set rather than parsing the key (e.g., - # checking ``prefix[-1].isdigit()``) so node names that happen - # to be pure digits don't get misclassified. - _fan_out_instance_root_prefixes: set[tuple[str, ...]] = field( - init=False, repr=False, default_factory=set[tuple[str, ...]] - ) def __post_init__(self) -> None: # Private provider per spec §6 TracerProvider isolation — @@ -228,6 +197,18 @@ def __post_init__(self) -> None: self._provider.add_span_processor(self.span_processor) self._tracer = self._provider.get_tracer("openarmature") + # ------------------------------------------------------------------ + # Per-invocation state lookup + # ------------------------------------------------------------------ + + def _inv_state_for(self, invocation_id: str) -> _InvState: + """Get-or-create the state container for an invocation_id.""" + state = self._inv_states.get(invocation_id) + if state is None: + state = _InvState() + self._inv_states[invocation_id] = state + return state + # ------------------------------------------------------------------ # Observer protocol — async callable accepting a NodeEvent # ------------------------------------------------------------------ @@ -253,55 +234,60 @@ async def __call__(self, event: NodeEvent) -> None: def _handle_started(self, event: NodeEvent) -> None: """Open a span for this attempt, push onto the in-flight map.""" - from openarmature.observability.correlation import current_correlation_id + from openarmature.observability.correlation import ( + current_correlation_id, + current_invocation_id, + ) - # Lazily open the invocation span on the first event we see - # for this invocation. Detect "first event" by matching the - # correlation_id; the invocation span lives until either - # ``shutdown()`` runs OR a new correlation_id arrives (i.e., - # a new invocation starts on the same long-lived observer). - # The latter close path matters for shared observers reused - # across many invocations — without it the - # ``_invocation_span`` dict grows unbounded. + invocation_id = current_invocation_id() + if invocation_id is None: + return correlation_id = current_correlation_id() - if correlation_id is not None: - # New correlation_id → close prior invocation spans. - for prior_cid in list(self._invocation_span.keys()): - if prior_cid != correlation_id: - self._close_invocation_span(prior_cid) - if correlation_id not in self._invocation_span: - self._open_invocation_span(correlation_id, event) + inv_state = self._inv_state_for(invocation_id) + + # Lazily open the invocation span on the first event we see + # for this invocation_id. Per-invocation_id scoping means + # resumed runs of the same correlation_id (each with a fresh + # invocation_id per §5.1) get their own invocation span and + # therefore their own trace_id. + if invocation_id not in self._invocation_span: + self._open_invocation_span(invocation_id, correlation_id, event) # Synthesize subgraph dispatch spans for any ancestor namespace # prefix that doesn't have one yet (per observability §4.5). # Also closes subgraph spans we've left. - self._sync_subgraph_spans(event) + self._sync_subgraph_spans(inv_state, invocation_id, correlation_id, event) - parent_ctx = self._resolve_parent_context(event) + parent_ctx = self._resolve_parent_context(inv_state, invocation_id, event) span = self._tracer.start_span( name=event.node_name, context=cast("Any", parent_ctx), kind=SpanKind.INTERNAL, - attributes=self._node_attrs(event), + attributes=self._node_attrs(event, correlation_id), ) - token = attach(set_span_in_context(span)) - key = self._key_for(event) - self._open_spans[key] = _OpenSpan(span=span, token=token) + inv_state.open_spans[self._key_for(event)] = _OpenSpan(span=span) def _handle_completed(self, event: NodeEvent) -> None: """Close the matching span, applying §4.2 status mapping.""" + from openarmature.observability.correlation import current_invocation_id + + invocation_id = current_invocation_id() + if invocation_id is None: + return + inv_state = self._inv_states.get(invocation_id) + if inv_state is None: + return + # If this is the fan-out node's own completion AND the # fan-out is configured detached, close all per-instance # detached roots that this fan-out spawned. Done BEFORE the - # regular pop so the OTel current-span context is restored - # to the fan-out span's parent (otherwise inner instance - # roots would still be attached). + # regular pop so the close ordering is parents-after-children. if event.fan_out_index is None and event.namespace and event.namespace[0] in self.detached_fan_outs: - for key in list(self._detached_roots.keys()): + for key in list(inv_state.detached_roots.keys()): if len(key) > len(event.namespace) and key[: len(event.namespace)] == event.namespace: - self._close_detached_root(key) + self._close_detached_root(inv_state, key) key = self._key_for(event) - open_span = self._open_spans.pop(key, None) + open_span = inv_state.open_spans.pop(key, None) if open_span is None: # Started event was never delivered (e.g., observer was # attached mid-invocation). Nothing to close. @@ -314,24 +300,9 @@ def _handle_completed(self, event: NodeEvent) -> None: else: span.set_status(Status(StatusCode.OK)) span.end() - token = open_span.token - if token is not None: - try: - detach(cast("Any", token)) - except ValueError: - # Out-of-LIFO detach: under interleaved - # fan-out-instance events, ``completed`` for an - # earlier-started span can arrive while a - # later-started span's token is on top of the OTel - # context stack. The proper fix is Phase 6.1's - # "don't hold attach tokens across event - # boundaries"; the guard here keeps Phase 6.0 - # robust to the interleave without corrupting - # subsequent span attribution. - pass - # If this was a detached root, drop the root entry so a + # If this was a detached root prefix, drop the root entry so a # subsequent re-entry mints a fresh trace. - self._detached_roots.pop(event.namespace, None) + inv_state.detached_roots.pop(event.namespace, None) # ------------------------------------------------------------------ # Special-event paths @@ -342,9 +313,18 @@ def _emit_checkpoint_save_span(self, event: NodeEvent) -> None: zero-duration ``openarmature.checkpoint.save`` span attached to the most-recently-opened node span (the node whose completed event triggered the save).""" - parent_ctx = self._resolve_parent_context(event) - from openarmature.observability.correlation import current_correlation_id + from openarmature.observability.correlation import ( + current_correlation_id, + current_invocation_id, + ) + invocation_id = current_invocation_id() + if invocation_id is None: + return + inv_state = self._inv_states.get(invocation_id) + if inv_state is None: + return + parent_ctx = self._resolve_parent_context(inv_state, invocation_id, event) attrs: dict[str, Any] = { "openarmature.checkpoint.save_node": event.node_name, } @@ -361,28 +341,30 @@ def _emit_checkpoint_save_span(self, event: NodeEvent) -> None: span.end() def _handle_llm_event(self, event: NodeEvent) -> None: - """LLM provider span per spec §5.5 — parented to the node - span that invoked the provider.""" - # ``pre_state`` is a typed ``_LlmEventState`` Pydantic - # subclass — see - # ``openarmature.llm.providers.openai._LlmEventState``. We - # read attributes directly rather than treating the State - # as a dict, preserving the ``NodeEvent.pre_state: State`` - # contract for any other observers that consume the event. - # Lazy import to avoid the otel→llm package dependency at - # module load time. + """LLM provider span per spec §5.5 — parented to the calling + node's span via the calling-node identity carried on the + ``_LlmEventState`` payload (namespace_prefix + attempt_index + + fan_out_index). Lookup hits the per-invocation_id + ``open_spans`` so concurrent fan-out instances each find + their own calling node, not a sibling's.""" from openarmature.llm.providers.openai import _LlmEventState + from openarmature.observability.correlation import ( + current_correlation_id, + current_invocation_id, + ) if not isinstance(event.pre_state, _LlmEventState): # Defensive — callers other than the OpenAIProvider hook # shouldn't dispatch through the LLM_NAMESPACE sentinel. return + invocation_id = current_invocation_id() + if invocation_id is None: + return + inv_state = self._inv_state_for(invocation_id) payload = event.pre_state if event.phase == "started": - parent_ctx = self._current_span_context() + parent_ctx = self._resolve_llm_parent(inv_state, invocation_id, payload) attrs: dict[str, Any] = {"openarmature.llm.model": payload.model} - from openarmature.observability.correlation import current_correlation_id - cid = current_correlation_id() if cid is not None: attrs["openarmature.correlation_id"] = cid @@ -392,10 +374,9 @@ def _handle_llm_event(self, event: NodeEvent) -> None: kind=SpanKind.CLIENT, attributes=attrs, ) - token = attach(set_span_in_context(span)) - self._open_llm_spans[payload.call_id] = _OpenSpan(span=span, token=token) + inv_state.open_llm_spans[payload.call_id] = _OpenSpan(span=span) elif event.phase == "completed": - open_span = self._open_llm_spans.pop(payload.call_id, None) + open_span = inv_state.open_llm_spans.pop(payload.call_id, None) if open_span is None: return span = open_span.span @@ -419,88 +400,121 @@ def _handle_llm_event(self, event: NodeEvent) -> None: else: span.set_status(Status(StatusCode.OK)) span.end() - try: - detach(cast("Any", open_span.token)) - except ValueError: - # See ``_handle_completed`` for the rationale — - # out-of-LIFO detach under concurrent fan-out - # instance events is a known Phase 6.0 limitation; - # the proper fix is Phase 6.1's - # "don't hold attach tokens across event - # boundaries." - pass + + def _resolve_llm_parent( + self, + inv_state: _InvState, + invocation_id: str, + payload: Any, + ) -> object: + """Look up the calling node's span using the calling-node + identity carried on the LLM event payload, fall back through + subgraph dispatch / invocation span.""" + # 1. Direct match on the calling node's ``_StackKey``. + calling_key: _StackKey = ( + payload.calling_namespace_prefix, + payload.calling_attempt_index, + payload.calling_fan_out_index, + ) + calling = inv_state.open_spans.get(calling_key) + if calling is not None: + return set_span_in_context(calling.span) + # 2. Walk up the calling namespace prefix for a synthetic + # subgraph dispatch span at any ancestor — covers LLM + # calls from inside subgraph wrapper middleware. + prefix = payload.calling_namespace_prefix + for plen in range(len(prefix), 0, -1): + ancestor = prefix[:plen] + sg = inv_state.subgraph_spans.get(ancestor) + if sg is not None: + return set_span_in_context(sg.span) + dr = inv_state.detached_roots.get(ancestor) + if dr is not None: + return set_span_in_context(dr.span) + # 3. Invocation span — ``complete()`` called outside any + # node body but inside an ``invoke()``. + inv = self._invocation_span.get(invocation_id) + if inv is not None: + return set_span_in_context(inv.span) + # 4. No invocation in scope — return a fresh empty Context. + # The span will live in its own trace. + return otel_context.Context() # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ - def _open_invocation_span(self, correlation_id: str, event: NodeEvent) -> None: + def _open_invocation_span( + self, + invocation_id: str, + correlation_id: str | None, + event: NodeEvent, + ) -> None: """Open the root invocation span for a new invocation.""" - from openarmature.observability.correlation import current_invocation_id - - # The first event we receive carries the entry node's - # name; treat it as the invocation's entry_node attribute. - # ``invocation_id`` is set by the engine in ``invoke()`` via - # the ``current_invocation_id`` ContextVar; if the observer - # somehow runs outside an engine invocation (None) we omit - # the attribute rather than emitting a misleading sentinel. attrs: dict[str, Any] = { "openarmature.graph.entry_node": event.node_name, "openarmature.graph.spec_version": self.spec_version, - "openarmature.correlation_id": correlation_id, + "openarmature.invocation_id": invocation_id, } - invocation_id = current_invocation_id() - if invocation_id is not None: - attrs["openarmature.invocation_id"] = invocation_id + if correlation_id is not None: + attrs["openarmature.correlation_id"] = correlation_id span = self._tracer.start_span( name="openarmature.invocation", kind=SpanKind.INTERNAL, attributes=attrs, ) - token = attach(set_span_in_context(span)) - self._invocation_span[correlation_id] = _OpenSpan(span=span, token=token) + self._invocation_span[invocation_id] = _OpenSpan(span=span) def _key_for(self, event: NodeEvent) -> _StackKey: return (event.namespace, event.attempt_index, event.fan_out_index) - def _resolve_parent_context(self, event: NodeEvent) -> object: + def _resolve_parent_context( + self, + inv_state: _InvState, + invocation_id: str, + event: NodeEvent, + ) -> object: """Return the OTel context to use as the parent for this event's span. Walks namespace ancestors finding the - innermost-open subgraph or detached root span.""" + innermost-open subgraph or detached root span; falls back to + the invocation span.""" # 1a. Detached fan-out instance root — keyed by # ``namespace[:1] + (str(fan_out_index),)`` per # ``_open_detached_fan_out_instance_root``. Checked - # explicitly before the generic prefix scan so the - # parent attribution doesn't depend on the - # attach-then-resolve ordering of the surrounding - # ``_sync_subgraph_spans`` call. + # explicitly before the generic prefix scan. if event.fan_out_index is not None and event.namespace: instance_key = event.namespace[:1] + (str(event.fan_out_index),) - if instance_key in self._detached_roots: - root = self._detached_roots[instance_key] + root = inv_state.detached_roots.get(instance_key) + if root is not None: return set_span_in_context(root.span) # 1b. Detached subgraph root at any matching prefix wins # (highest precedence — events inside a detached subtree # always parent under the detached root, never bleed up). for prefix_len in range(len(event.namespace) - 1, -1, -1): prefix = event.namespace[:prefix_len] - if prefix in self._detached_roots: - root = self._detached_roots[prefix] + root = inv_state.detached_roots.get(prefix) + if root is not None: return set_span_in_context(root.span) # 2. Innermost synthetic subgraph span at any prefix. for prefix_len in range(len(event.namespace) - 1, 0, -1): prefix = event.namespace[:prefix_len] - if prefix in self._subgraph_spans: - sg = self._subgraph_spans[prefix] + sg = inv_state.subgraph_spans.get(prefix) + if sg is not None: return set_span_in_context(sg.span) - # 3. Otherwise, current OTel context (typically invocation span). - return self._current_span_context() - - def _current_span_context(self) -> object: - """Return the current OTel context.""" - return otel_context.get_current() - - def _sync_subgraph_spans(self, event: NodeEvent) -> None: + # 3. Otherwise, parent under the invocation span. + inv = self._invocation_span.get(invocation_id) + if inv is not None: + return set_span_in_context(inv.span) + # 4. No invocation in scope — fresh empty Context. + return otel_context.Context() + + def _sync_subgraph_spans( + self, + inv_state: _InvState, + invocation_id: str, + correlation_id: str | None, + event: NodeEvent, + ) -> None: """Open any synthetic subgraph dispatch spans we need (per observability §4.5: subgraph wrapper MUST emit a span); close any subgraph spans whose prefix is no longer an ancestor of @@ -514,13 +528,13 @@ def _sync_subgraph_spans(self, event: NodeEvent) -> None: namespace = event.namespace # 1. Close any open subgraph spans that aren't ancestors of # the current namespace — we've left those subgraphs. - for prefix in list(self._subgraph_spans.keys()): + for prefix in list(inv_state.subgraph_spans.keys()): if not (len(prefix) < len(namespace) and namespace[: len(prefix)] == prefix): - self._close_subgraph_span(prefix) + self._close_subgraph_span(inv_state, prefix) # 2. Same for detached subgraph roots — close ones we've # left. (Detached fan-out instance roots are NOT closed # here; they close on the fan-out's own completion.) - for prefix in list(self._detached_roots.keys()): + for prefix in list(inv_state.detached_roots.keys()): if ( len(prefix) < len(namespace) and namespace[: len(prefix)] == prefix @@ -531,101 +545,90 @@ def _sync_subgraph_spans(self, event: NodeEvent) -> None: # Detached fan-out instance roots: keyed by namespace + # (str(fan_out_index),); leave those alone here, they're # closed when the fan-out parent dispatch completes. - if prefix in self._fan_out_instance_root_prefixes: - # Closed by ``_handle_completed`` on the fan-out - # node's own completion, not here. + if prefix in inv_state.fan_out_instance_root_prefixes: continue if not (len(prefix) < len(namespace) and namespace[: len(prefix)] == prefix): - self._close_detached_root(prefix) + self._close_detached_root(inv_state, prefix) # 3. Open ancestor subgraph spans for any prefix that doesn't # have one yet. for depth in range(1, len(namespace)): prefix = namespace[:depth] - if prefix in self._subgraph_spans: + if prefix in inv_state.subgraph_spans: continue - if prefix in self._detached_roots: + if prefix in inv_state.detached_roots: continue # If this prefix's first segment is configured as a # detached subgraph, mint a fresh trace. if depth == 1 and prefix[0] in self.detached_subgraphs: - self._open_detached_subgraph_root(prefix) + self._open_detached_subgraph_root(inv_state, invocation_id, correlation_id, prefix) continue # If this is a fan-out instance namespace (event.fan_out_index # populated, prefix == namespace[:1]), and the fan-out # node is detached, open a per-instance detached root. if depth == 1 and event.fan_out_index is not None and prefix[0] in self.detached_fan_outs: - self._open_detached_fan_out_instance_root(prefix, event) + self._open_detached_fan_out_instance_root(inv_state, correlation_id, prefix, event) continue - self._open_subgraph_span(prefix) - - def _open_subgraph_span(self, prefix: tuple[str, ...]) -> None: + self._open_subgraph_span(inv_state, invocation_id, correlation_id, prefix) + + def _open_subgraph_span( + self, + inv_state: _InvState, + invocation_id: str, + correlation_id: str | None, + prefix: tuple[str, ...], + ) -> None: """Open a synthetic subgraph dispatch span for the given namespace prefix. Parent is the next-outer subgraph span (or the invocation span if depth-1).""" - from openarmature.observability.correlation import current_correlation_id - - parent_ctx = self._current_span_context() # Walk up looking for the nearest enclosing subgraph or # detached root. + parent_ctx: object = otel_context.Context() for plen in range(len(prefix) - 1, 0, -1): outer = prefix[:plen] - if outer in self._subgraph_spans: - parent_ctx = set_span_in_context(self._subgraph_spans[outer].span) + sg = inv_state.subgraph_spans.get(outer) + if sg is not None: + parent_ctx = set_span_in_context(sg.span) break - if outer in self._detached_roots: - parent_ctx = set_span_in_context(self._detached_roots[outer].span) + dr = inv_state.detached_roots.get(outer) + if dr is not None: + parent_ctx = set_span_in_context(dr.span) break + else: + inv = self._invocation_span.get(invocation_id) + if inv is not None: + parent_ctx = set_span_in_context(inv.span) attrs: dict[str, Any] = { "openarmature.node.name": prefix[-1], "openarmature.subgraph.name": prefix[-1], } - cid = current_correlation_id() - if cid is not None: - attrs["openarmature.correlation_id"] = cid + if correlation_id is not None: + attrs["openarmature.correlation_id"] = correlation_id span = self._tracer.start_span( name=prefix[-1], context=cast("Any", parent_ctx), kind=SpanKind.INTERNAL, attributes=attrs, ) - token = attach(set_span_in_context(span)) - self._subgraph_spans[prefix] = _OpenSpan(span=span, token=token) + inv_state.subgraph_spans[prefix] = _OpenSpan(span=span) - def _close_subgraph_span(self, prefix: tuple[str, ...]) -> None: - open_span = self._subgraph_spans.pop(prefix, None) + def _close_subgraph_span(self, inv_state: _InvState, prefix: tuple[str, ...]) -> None: + open_span = inv_state.subgraph_spans.pop(prefix, None) if open_span is None: return open_span.span.set_status(Status(StatusCode.OK)) open_span.span.end() - # Mirror ``_close_detached_root``: the attach token MUST be - # detached or the OTel current-span context stays pinned to - # a closed span and corrupts parent/child for subsequent - # spans. The cross-context guard handles the case where the - # token was created in a different OTel context (e.g., - # subgraph open/close straddles a detached descent). - if open_span.token is not None: - try: - detach(cast("Any", open_span.token)) - except ValueError: - # Token was created in a different OTel context — - # cross-context detach raises here. The span has - # ended; the leaked context entry is cosmetic and - # unwinds when the worker task exits. - pass - - def _open_detached_subgraph_root(self, prefix: tuple[str, ...]) -> None: + + def _open_detached_subgraph_root( + self, + inv_state: _InvState, + invocation_id: str, + correlation_id: str | None, + prefix: tuple[str, ...], + ) -> None: """Mint a fresh trace for a detached subgraph entry. The detached root span lives in the new trace; the parent trace's dispatch span (synthesized at the same prefix BUT in the - parent trace) carries an OTel Link to this root. - - Implementation: we open BOTH a parent-trace dispatch span - (with the Link) AND a detached-trace root span (the actual - parent for inner events). The dispatch span ends at sync - time when we leave the prefix; the root span ends when its - children finish.""" - from openarmature.observability.correlation import current_correlation_id - + parent trace) carries an OTel Link to this root.""" # 1. Mint the new trace_id + root span_id NOW so the # parent's Link target matches the detached root's # SpanContext exactly. @@ -639,25 +642,27 @@ def _open_detached_subgraph_root(self, prefix: tuple[str, ...]) -> None: trace_flags=TraceFlags(TraceFlags.SAMPLED), ) - # 2. Open the dispatch span in the parent trace. Carries a - # Link pointing at the detached root's SpanContext. - cid = current_correlation_id() + # 2. Open the dispatch span in the parent trace. Parent of + # the dispatch span is the invocation span (or whatever + # was already in scope) per the per-invocation map. + parent_ctx_for_dispatch: object = otel_context.Context() + inv = self._invocation_span.get(invocation_id) + if inv is not None: + parent_ctx_for_dispatch = set_span_in_context(inv.span) attrs_parent: dict[str, Any] = { "openarmature.node.name": prefix[-1], "openarmature.subgraph.name": prefix[-1], } - if cid is not None: - attrs_parent["openarmature.correlation_id"] = cid + if correlation_id is not None: + attrs_parent["openarmature.correlation_id"] = correlation_id parent_dispatch = self._tracer.start_span( name=prefix[-1], - context=cast("Any", self._current_span_context()), + context=cast("Any", parent_ctx_for_dispatch), kind=SpanKind.INTERNAL, links=[Link(detached_sc)], attributes=attrs_parent, ) - # Track in _subgraph_spans so the sync routine closes it on - # leaving the prefix. - self._subgraph_spans[prefix] = _OpenSpan(span=parent_dispatch, token=None) + inv_state.subgraph_spans[prefix] = _OpenSpan(span=parent_dispatch) # 3. Open the detached root span — parented to the synthetic # detached SpanContext so OTel uses the new trace_id. @@ -672,17 +677,20 @@ def _open_detached_subgraph_root(self, prefix: tuple[str, ...]) -> None: kind=SpanKind.INTERNAL, attributes=attrs_root, ) - token = attach(set_span_in_context(detached_root)) - self._detached_roots[prefix] = _OpenSpan(span=detached_root, token=token) - - def _open_detached_fan_out_instance_root(self, prefix: tuple[str, ...], event: NodeEvent) -> None: + inv_state.detached_roots[prefix] = _OpenSpan(span=detached_root) + + def _open_detached_fan_out_instance_root( + self, + inv_state: _InvState, + correlation_id: str | None, + prefix: tuple[str, ...], + event: NodeEvent, + ) -> None: """Per-instance detached root for a configured-detached fan-out. Each instance gets its own trace_id; the fan-out node's span (in the parent trace, already open via the engine's started event) accumulates Links — one per instance.""" - from openarmature.observability.correlation import current_correlation_id - gen = RandomIdGenerator() detached_trace_id = gen.generate_trace_id() detached_root_span_id = gen.generate_span_id() @@ -694,9 +702,14 @@ def _open_detached_fan_out_instance_root(self, prefix: tuple[str, ...], event: N ) # Find the fan-out node's already-open span in the parent - # trace and add a Link to the detached root. - fan_out_key = self._fan_out_node_span_key(prefix) - fan_out_open = self._open_spans.get(fan_out_key) + # trace and add a Link to the detached root. Retry middleware + # wrapping the fan-out bumps its attempt_index, so the span + # sits at ``(prefix, N, None)`` for the in-flight attempt N + # — scan for any entry at ``prefix`` with + # ``fan_out_index is None`` rather than hardcoding the key. + # Only one such entry is open at a time (retry opens and + # closes within a single attempt's lifecycle). + fan_out_open = self._find_fan_out_node_span(inv_state, prefix) if fan_out_open is not None: fan_out_open.span.add_link(detached_sc) @@ -704,75 +717,57 @@ def _open_detached_fan_out_instance_root(self, prefix: tuple[str, ...], event: N detached_parent_ctx = otel_trace.set_span_in_context( NonRecordingSpan(detached_sc), otel_context.Context() ) - cid = current_correlation_id() attrs: dict[str, Any] = { "openarmature.node.name": prefix[-1], "openarmature.fan_out.parent_node_name": prefix[-1], "openarmature.node.fan_out_index": event.fan_out_index, } - if cid is not None: - attrs["openarmature.correlation_id"] = cid + if correlation_id is not None: + attrs["openarmature.correlation_id"] = correlation_id instance_root = self._tracer.start_span( name=prefix[-1], context=cast("Any", detached_parent_ctx), kind=SpanKind.INTERNAL, attributes=attrs, ) - token = attach(set_span_in_context(instance_root)) # Key by prefix + (str(fan_out_index),) so per-instance - # roots stay distinct. Track separately in - # ``_fan_out_instance_root_prefixes`` so - # ``_sync_subgraph_spans`` can identify them by membership - # (rather than by parsing the key shape). + # roots stay distinct. instance_key = prefix + (str(event.fan_out_index),) - self._detached_roots[instance_key] = _OpenSpan(span=instance_root, token=token) - self._fan_out_instance_root_prefixes.add(instance_key) + inv_state.detached_roots[instance_key] = _OpenSpan(span=instance_root) + inv_state.fan_out_instance_root_prefixes.add(instance_key) - def _close_detached_root(self, prefix: tuple[str, ...]) -> None: - self._fan_out_instance_root_prefixes.discard(prefix) - open_span = self._detached_roots.pop(prefix, None) + def _close_detached_root(self, inv_state: _InvState, prefix: tuple[str, ...]) -> None: + inv_state.fan_out_instance_root_prefixes.discard(prefix) + open_span = inv_state.detached_roots.pop(prefix, None) if open_span is None: return open_span.span.set_status(Status(StatusCode.OK)) open_span.span.end() - if open_span.token is not None: - try: - detach(cast("Any", open_span.token)) - except ValueError: - # Cross-context detach (token created in a different - # OTel context) — ignore. The span has ended; the - # context entry leaks cosmetically. - pass @staticmethod def _drain_open_span(open_span: _OpenSpan) -> None: """Close an open span as an orphan during shutdown: OK - status, end, and try-detach the token. No paired completed - event will arrive, so we don't have an error category to - record. Cross-context detach is swallowed (the worker that - owns the context eventually unwinds).""" + status, end. No paired completed event will arrive, so we + don't have an error category to record.""" open_span.span.set_status(Status(StatusCode.OK)) open_span.span.end() - if open_span.token is not None: - try: - detach(cast("Any", open_span.token)) - except ValueError: - # Cross-context detach during shutdown — token - # created in a different OTel context. Span has - # ended; ignore. - pass - - def _fan_out_node_span_key(self, prefix: tuple[str, ...]) -> _StackKey: - """Build the lookup key for a fan-out node's own span (the - parent dispatch span). Fan-out node has no attempt_index ≠ 0 - and no fan_out_index — those fields belong to its inner - instances.""" - return (prefix, 0, None) - - def _node_attrs(self, event: NodeEvent) -> dict[str, Any]: - """Build the §5 attribute set for a node span.""" - from openarmature.observability.correlation import current_correlation_id + def _find_fan_out_node_span(self, inv_state: _InvState, prefix: tuple[str, ...]) -> _OpenSpan | None: + """Find the currently-open fan-out node's parent dispatch + span at ``prefix`` regardless of ``attempt_index``. Under + retry middleware wrapping the fan-out, the in-flight + attempt's span lives at ``(prefix, attempt_index, None)``; + only one such entry is open at a time (retry opens and + closes within each attempt's lifecycle), so a scan finds it + unambiguously.""" + for key, open_span in inv_state.open_spans.items(): + ns, _attempt, fan_idx = key + if ns == prefix and fan_idx is None: + return open_span + return None + + def _node_attrs(self, event: NodeEvent, correlation_id: str | None) -> dict[str, Any]: + """Build the §5 attribute set for a node span.""" attrs: dict[str, Any] = { "openarmature.node.name": event.node_name, "openarmature.node.namespace": list(event.namespace), @@ -781,41 +776,67 @@ def _node_attrs(self, event: NodeEvent) -> dict[str, Any]: } if event.fan_out_index is not None: attrs["openarmature.node.fan_out_index"] = event.fan_out_index - cid = current_correlation_id() - if cid is not None: - attrs["openarmature.correlation_id"] = cid + if correlation_id is not None: + attrs["openarmature.correlation_id"] = correlation_id return attrs # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ - def close_invocation(self, correlation_id: str) -> None: - """Public lifecycle hook: close the invocation span for - ``correlation_id``. Idempotent — calling twice (or for a - correlation_id with no open span) is a no-op. - - Long-lived observers shared across many invocations - automatically close prior invocation spans on the first - event of a new invocation (see ``_handle_started``). This - method is for callers who want explicit control without - driving a follow-on invocation, e.g.:: - - await graph.invoke(state, correlation_id=cid) - await graph.drain() - otel_observer.close_invocation(cid) + def close_invocation(self, invocation_id: str) -> None: + """Close the invocation span for ``invocation_id`` and drain + the per-invocation state. Idempotent — calling twice (or for + an invocation_id with no open span) is a no-op. + + Drains any still-open spans in the per-invocation state in + child→parent order (LLM spans → leaf spans → detached roots + → subgraph dispatch → invocation). + + **Sourcing the invocation_id.** ``CompiledGraph.invoke()`` + does not currently return the invocation_id, and the + ``current_invocation_id`` ContextVar is reset before control + returns to the caller. The practical use case for this + method is test code that captures the invocation_id from + inside a node body (or middleware / observer callback), + debugging scenarios, and integration code that has the id + from a checkpoint record's ``invocation_id`` field. + + For typical production lifecycle on long-lived observers, + prefer :meth:`shutdown` — it drains every in-flight + invocation in one call without needing to track ids + externally. A first-class engine-level signal that lets + observers auto-drain per-invocation state on completion is + tracked as Phase 6.1+ follow-up work in + ``openarmature-coord/docs/phase-6-1-conformance-fillin.md``. """ - self._close_invocation_span(correlation_id) - - def _close_invocation_span(self, correlation_id: str) -> None: - """End and remove the invocation span for ``correlation_id``. - - The OTel context token captured when the span opened was - created in the worker task's context — we don't try to - ``detach()`` it (cross-context detach raises ValueError; - the leaked context entry is cosmetic since the worker - eventually exits).""" - open_span = self._invocation_span.pop(correlation_id, None) + inv_state = self._inv_states.pop(invocation_id, None) + if inv_state is not None: + self._drain_inv_state(inv_state) + self._close_invocation_span(invocation_id) + + def _drain_inv_state(self, inv_state: _InvState) -> None: + """Close any still-open spans in a per-invocation state + container in child→parent order. LLM spans (deepest leaves) + → leaf node spans (sorted deepest-first by namespace) → + detached roots → subgraph dispatch spans. Matches the + ordering used in ``shutdown``.""" + for call_id in list(inv_state.open_llm_spans.keys()): + open_span = inv_state.open_llm_spans.pop(call_id, None) + if open_span is not None: + self._drain_open_span(open_span) + for key in sorted(inv_state.open_spans.keys(), key=lambda k: -len(k[0])): + open_span = inv_state.open_spans.pop(key, None) + if open_span is not None: + self._drain_open_span(open_span) + for prefix in sorted(inv_state.detached_roots.keys(), key=lambda k: -len(k)): + self._close_detached_root(inv_state, prefix) + for prefix in sorted(inv_state.subgraph_spans.keys(), key=lambda k: -len(k)): + self._close_subgraph_span(inv_state, prefix) + + def _close_invocation_span(self, invocation_id: str) -> None: + """End and remove the invocation span for ``invocation_id``.""" + open_span = self._invocation_span.pop(invocation_id, None) if open_span is None: return # Status defaults to OK for completed invocations; if the @@ -826,30 +847,16 @@ def _close_invocation_span(self, correlation_id: str) -> None: open_span.span.end() def shutdown(self) -> None: - """Close any still-open spans and shut down the underlying - provider. Walks state maps in child→parent order — LLM - spans (deepest leaves) before leaf node spans before - subgraph dispatch / detached roots before invocation spans - — and within each depth-bearing map sorts by namespace - depth (deepest first) so children end before their - parents. ``_open_llm_spans`` (call_id keys) and - ``_invocation_span`` (one entry per correlation_id) carry - no internal nesting and drain in insertion order. - Idempotent.""" - for call_id in list(self._open_llm_spans.keys()): - open_span = self._open_llm_spans.pop(call_id, None) - if open_span is not None: - self._drain_open_span(open_span) - for key in sorted(self._open_spans.keys(), key=lambda k: -len(k[0])): - open_span = self._open_spans.pop(key, None) - if open_span is not None: - self._drain_open_span(open_span) - for prefix in sorted(self._detached_roots.keys(), key=lambda k: -len(k)): - self._close_detached_root(prefix) - for prefix in sorted(self._subgraph_spans.keys(), key=lambda k: -len(k)): - self._close_subgraph_span(prefix) - for cid in list(self._invocation_span.keys()): - self._close_invocation_span(cid) + """Close any still-open spans across all in-flight + invocations and shut down the underlying provider. Each + per-invocation state is drained in child→parent order (LLM + spans → leaf spans → detached roots → subgraph dispatch); + invocation spans drain last. Idempotent.""" + for invocation_id in list(self._inv_states.keys()): + inv_state = self._inv_states.pop(invocation_id) + self._drain_inv_state(inv_state) + for invocation_id in list(self._invocation_span.keys()): + self._close_invocation_span(invocation_id) self._provider.shutdown() diff --git a/tests/unit/test_observability_otel.py b/tests/unit/test_observability_otel.py index 4cf474c..19b53a5 100644 --- a/tests/unit/test_observability_otel.py +++ b/tests/unit/test_observability_otel.py @@ -28,8 +28,10 @@ # Skip the entire module if otel extras aren't installed. pytest.importorskip("opentelemetry.sdk.trace") +from typing import cast + from opentelemetry import trace as otel_trace -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( InMemorySpanExporter, @@ -365,3 +367,390 @@ def test_install_log_bridge_is_idempotent() -> None: # not leak state into others. root.handlers[:] = prior_handlers root.filters[:] = prior_filters + + +# --------------------------------------------------------------------------- +# Phase 6.1: concurrency-safe state scoping + §5.5 calling-node attribution +# --------------------------------------------------------------------------- + + +async def test_shared_observer_concurrent_invocations_dont_collide() -> None: + """A single observer shared across concurrent invocations MUST + keep their span trees isolated. Per spec §5.1 each invocation + has its own ``invocation_id`` and therefore its own + ``trace_id``; with shared internal state keyed by + ``invocation_id`` the observer no longer collides on overlapping + namespaces, no longer closes another in-flight invocation's span + on a new event, and produces N distinct trace_ids for N + concurrent invocations on the same compiled graph.""" + import asyncio + + exporter = InMemorySpanExporter() + observer = OTelObserver(span_processor=SimpleSpanProcessor(exporter)) + g = ( + GraphBuilder(_LinearState) + .add_node("node_a", _node_a) + .add_node("node_b", _node_b) + .add_edge("node_a", "node_b") + .add_edge("node_b", END) + .set_entry("node_a") + .compile() + ) + g.attach_observer(observer) + + n = 5 + results = await asyncio.gather(*[g.invoke(_LinearState()) for _ in range(n)]) + await g.drain() + observer.shutdown() + assert len(results) == n + + spans = exporter.get_finished_spans() + invocation_spans = [s for s in spans if s.name == "openarmature.invocation"] + assert len(invocation_spans) == n, ( + f"expected one invocation span per concurrent invocation; got {len(invocation_spans)}" + ) + # Each invocation has its own trace_id. + trace_ids: set[int] = set() + for s in invocation_spans: + assert s.context is not None + trace_ids.add(s.context.trace_id) + assert len(trace_ids) == n, ( + f"each concurrent invocation MUST have its own trace_id; got {len(trace_ids)} for {n} invocations" + ) + # Every span in the export belongs to one of those trace_ids + # (no orphans pointing at a stale trace). + for s in spans: + assert s.context is not None + assert s.context.trace_id in trace_ids, ( + f"span {s.name!r} carries unknown trace_id {s.context.trace_id}" + ) + # Each trace has the expected node count: one invocation span + + # node_a + node_b = 3 spans. + by_trace: dict[int, list[str]] = {tid: [] for tid in trace_ids} + for s in spans: + assert s.context is not None + by_trace[s.context.trace_id].append(s.name) + for tid, names_list in by_trace.items(): + names = sorted(names_list) + assert names == ["node_a", "node_b", "openarmature.invocation"], ( + f"trace {tid:x} span set MUST be exactly the invocation + node_a + node_b; got {names}" + ) + + +async def test_concurrent_fan_out_no_lifo_violation() -> None: + """Regression check: under fan-out with multiple concurrent + instances, started/completed events for different instances + interleave on the observer's call queue. The Phase 6.0 + architecture used cross-event ``opentelemetry.context.attach`` + tokens that produced LIFO violations on out-of-order detach + (suppressed by try/except guards in round-4 / round-7). Phase + 6.1 derives parents from internal maps within a single event + handler — no tokens cross event boundaries — so the underlying + hazard goes away. This test drives a fan-out with three + instances and asserts the run completes without the warnings + that the suppressed guards would have produced.""" + import warnings + + class _ParentState(State): + items: list[int] = [] + results: list[int] = [] + + class _ChildState(State): + item: int = 0 + out: int = 0 + + async def _double(s: _ChildState) -> dict[str, int]: + # Yield to give other instances a chance to interleave their + # started/completed events on the observer queue. + import asyncio + + await asyncio.sleep(0) + return {"out": s.item * 2} + + inner = ( + GraphBuilder(_ChildState) + .add_node("double", _double) + .add_edge("double", END) + .set_entry("double") + .compile() + ) + parent = ( + GraphBuilder(_ParentState) + .add_fan_out_node( + "fan", + subgraph=inner, + collect_field="out", + target_field="results", + items_field="items", + item_field="item", + concurrency=3, + ) + .add_edge("fan", END) + .set_entry("fan") + ) + exporter = InMemorySpanExporter() + observer = OTelObserver(span_processor=SimpleSpanProcessor(exporter)) + compiled = parent.compile() + compiled.attach_observer(observer) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = await compiled.invoke(_ParentState(items=[1, 2, 3, 4, 5])) + await compiled.drain() + observer.shutdown() + + assert result.results == [2, 4, 6, 8, 10] + # Sanity: per-instance node spans landed (one ``double`` span + # per item, all sharing the same trace_id since the fan-out is + # not configured detached). + spans = exporter.get_finished_spans() + double_spans = [s for s in spans if s.name == "double"] + assert len(double_spans) == 5, f"expected 5 per-instance node spans; got {len(double_spans)}" + + +async def test_concurrent_fan_out_llm_spans_parent_under_calling_instance() -> None: + """Spec §5.5 under concurrent fan-out: each instance's + ``openarmature.llm.complete`` span MUST parent under that + instance's calling node, not a sibling instance's. The Phase 6.1 + calling-node identity (namespace_prefix + attempt_index + + fan_out_index threaded via ContextVar onto the LLM event + payload) is what makes this attribution correct.""" + import asyncio + + import httpx + + from openarmature.llm.messages import UserMessage + from openarmature.llm.providers.openai import OpenAIProvider + + def _ok(_req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "id": "x", + "object": "chat.completion", + "created": 0, + "model": "m", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + ) + + provider = OpenAIProvider( + base_url="http://test", + model="m", + api_key="k", + transport=httpx.MockTransport(_ok), + ) + + class _ParentState(State): + items: list[int] = [] + outs: list[str] = [] + + class _ChildState(State): + item: int = 0 + out: str = "" + + async def _ask(s: _ChildState) -> dict[str, str]: + # Yield first so peer instances can interleave. + await asyncio.sleep(0) + resp = await provider.complete([UserMessage(content=str(s.item))]) + return {"out": str(resp.message.content or "")} + + inner = GraphBuilder(_ChildState).add_node("ask", _ask).add_edge("ask", END).set_entry("ask").compile() + parent = ( + GraphBuilder(_ParentState) + .add_fan_out_node( + "fan", + subgraph=inner, + collect_field="out", + target_field="outs", + items_field="items", + item_field="item", + concurrency=4, + ) + .add_edge("fan", END) + .set_entry("fan") + ) + exporter = InMemorySpanExporter() + observer = OTelObserver(span_processor=SimpleSpanProcessor(exporter)) + compiled = parent.compile() + compiled.attach_observer(observer) + + n = 4 + try: + await compiled.invoke(_ParentState(items=list(range(n)))) + await compiled.drain() + finally: + await provider.aclose() + observer.shutdown() + + spans = exporter.get_finished_spans() + by_id: dict[int, ReadableSpan] = {} + for s in spans: + assert s.context is not None + by_id[s.context.span_id] = s + llm_spans = [s for s in spans if s.name == "openarmature.llm.complete"] + ask_spans = [s for s in spans if s.name == "ask"] + assert len(llm_spans) == n, f"expected one LLM span per instance; got {len(llm_spans)}" + assert len(ask_spans) == n, f"expected one ``ask`` span per instance; got {len(ask_spans)}" + + # Build a map from fan_out_index → ask span_id (each instance's + # node carries its own ``openarmature.node.fan_out_index`` attribute). + ask_by_index: dict[int, int] = {} + for s in ask_spans: + assert s.context is not None and s.attributes is not None + idx_attr = s.attributes["openarmature.node.fan_out_index"] + assert isinstance(idx_attr, int) + ask_by_index[idx_attr] = s.context.span_id + assert set(ask_by_index.keys()) == set(range(n)) + + # For each LLM span, confirm the parent span_id is one of the + # ``ask`` spans (calling instance's node), not a sibling + # fan-out instance's span. + parented_ask_ids: set[int] = set() + for llm in llm_spans: + assert llm.parent is not None, "LLM span MUST have a parent" + parent_span = by_id.get(llm.parent.span_id) + assert parent_span is not None, f"LLM span parent_id {llm.parent.span_id} not in exported set" + assert parent_span.name == "ask", ( + f"LLM span MUST parent under ``ask`` (the calling node), got {parent_span.name!r}" + ) + parented_ask_ids.add(llm.parent.span_id) + + # Every LLM span parents under a UNIQUE ``ask`` span — i.e., no + # collision where two LLM calls attributed to the same instance. + assert len(parented_ask_ids) == n, ( + f"each LLM call MUST parent under its own calling instance; " + f"got {len(parented_ask_ids)} distinct parents for {n} calls" + ) + + +async def test_llm_call_inside_retried_node_parents_per_attempt() -> None: + """Spec §5.5 under retry: when an LLM ``complete()`` call + happens inside a node body wrapped with retry middleware, each + attempt's LLM span MUST parent under THAT attempt's node span, + not a hardcoded ``attempt_index=0``. Phase 6.1's + ``current_attempt_index`` ContextVar (set inside the per-attempt + ``innermost`` scope) is what makes this work.""" + import httpx + + from openarmature.graph.middleware import RetryMiddleware + from openarmature.llm.errors import ProviderRateLimit + from openarmature.llm.messages import UserMessage + from openarmature.llm.providers.openai import OpenAIProvider + + def _ok(_req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "id": "x", + "object": "chat.completion", + "created": 0, + "model": "m", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }, + ) + + provider = OpenAIProvider( + base_url="http://test", + model="m", + api_key="k", + transport=httpx.MockTransport(_ok), + ) + + class _S(State): + attempts: int = 0 + + # Mutable counter so the node body can observe its own attempt + # index and decide whether to fail. Two failures + one success. + flaky_state = {"calls": 0} + + async def _flaky(s: _S) -> dict[str, int]: + flaky_state["calls"] += 1 + # Always issue an LLM call BEFORE the conditional raise so a + # span fires for every attempt, including the failing ones. + await provider.complete([UserMessage(content="hi")]) + if flaky_state["calls"] < 3: + raise ProviderRateLimit("transient") + return {"attempts": flaky_state["calls"]} + + g = ( + GraphBuilder(_S) + .add_node("flaky", _flaky, middleware=[RetryMiddleware(max_attempts=3, backoff=lambda _i: 0.0)]) + .add_edge("flaky", END) + .set_entry("flaky") + .compile() + ) + exporter = InMemorySpanExporter() + observer = OTelObserver(span_processor=SimpleSpanProcessor(exporter)) + g.attach_observer(observer) + + try: + result = await g.invoke(_S()) + await g.drain() + finally: + await provider.aclose() + observer.shutdown() + + assert result.attempts == 3 + spans = exporter.get_finished_spans() + by_id: dict[int, ReadableSpan] = {} + for s in spans: + assert s.context is not None + by_id[s.context.span_id] = s + + # Three ``flaky`` spans (one per attempt), three LLM spans. + flaky_spans = [s for s in spans if s.name == "flaky"] + llm_spans = [s for s in spans if s.name == "openarmature.llm.complete"] + assert len(flaky_spans) == 3, f"expected 3 attempt spans; got {len(flaky_spans)}" + assert len(llm_spans) == 3, f"expected 3 LLM spans; got {len(llm_spans)}" + + # Map attempt_index → flaky span_id. + flaky_by_attempt: dict[int, int] = {} + for s in flaky_spans: + assert s.context is not None and s.attributes is not None + idx = s.attributes["openarmature.node.attempt_index"] + assert isinstance(idx, int) + flaky_by_attempt[idx] = s.context.span_id + assert set(flaky_by_attempt.keys()) == {0, 1, 2} + + # Every LLM span MUST parent under one of the ``flaky`` spans + # (NOT under the invocation span, which would mean + # attempt_index=0 was hardcoded and the lookup fell through). + flaky_span_ids = set(flaky_by_attempt.values()) + parented_under: set[int] = set() + for llm in llm_spans: + assert llm.parent is not None, "LLM span MUST have a parent" + parented_under.add(llm.parent.span_id) + assert parented_under <= flaky_span_ids, ( + f"every LLM span MUST parent under an attempt's ``flaky`` span; " + f"got LLM parents {parented_under} not all in flaky set {flaky_span_ids}" + ) + # And the THREE LLM spans parent under THREE DISTINCT ``flaky`` + # spans — one per attempt — proving the calling_attempt_index + # threading actually disambiguates per-attempt. + assert len(parented_under) == 3, ( + f"each attempt's LLM call MUST parent under its OWN attempt's span; " + f"got {len(parented_under)} distinct parents for 3 LLM calls" + ) + # Spot-check: every attempt is represented. + parented_attempts: set[int] = set() + for pid in parented_under: + attrs = by_id[pid].attributes + assert attrs is not None + idx = cast("int", attrs["openarmature.node.attempt_index"]) + parented_attempts.add(idx) + assert parented_attempts == {0, 1, 2}