Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 126 additions & 44 deletions src/openarmature/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@
_reset_invocation_id,
_reset_namespace_prefix,
_set_active_dispatch,
_set_active_observer_span,
_set_active_observers,
_set_attempt_index,
_set_correlation_id,
_set_fan_out_index,
_set_invocation_id,
_set_namespace_prefix,
current_active_observer_span,
)

from .edges import END, ConditionalEdge, EndSentinel, StaticEdge
Expand Down Expand Up @@ -99,6 +101,54 @@
from .state import State
from .subgraph import SubgraphNode

# Try-import OpenTelemetry attach primitives so the engine can splice an
# observer-published span into the OTel context for the duration of a
# node body. The engine treats the span value opaquely (writes by an
# observer's ``prepare_sync``, reads via ``current_active_observer_span``)
# and only touches OTel when both: (a) the extras are installed, and
# (b) an observer actually published a span. Installs without ``[otel]``
# get a no-op attach/detach pair; the observer ContextVar stays
# ``None`` and nothing changes.
#
# The names are bound to ``None`` in the except branch so pyright
# narrows correctly at call sites (``if _otel_attach is None: ...``)
# rather than flagging "possibly unbound."
try:
from opentelemetry.context import attach as _otel_attach
from opentelemetry.context import detach as _otel_detach
from opentelemetry.trace.propagation import set_span_in_context as _otel_set_span_in_context
except ImportError: # pragma: no cover — exercised only in non-otel installs
_otel_attach = None # type: ignore[assignment]
_otel_detach = None # type: ignore[assignment]
_otel_set_span_in_context = None # type: ignore[assignment]


def _attach_active_observer_span() -> object | None:
"""Read ``current_active_observer_span``; if an observer published
one and OTel is installed, attach the span into the OTel context
so that any logs emitted from the next user-code scope (a node
body) pick up the right ``trace_id``/``span_id`` via OTel's
``LoggingHandler``.

Returns the OTel context token to hand back to
:func:`_detach_active_observer_span` in ``finally``, or ``None``
if no attach happened (no observer, no OTel, or both).
"""
if _otel_attach is None or _otel_set_span_in_context is None:
return None
span = current_active_observer_span()
if span is None:
return None
return _otel_attach(_otel_set_span_in_context(cast("Any", span)))


def _detach_active_observer_span(token: object | None) -> None:
"""Pair to :func:`_attach_active_observer_span`. No-op when no
attach was performed (token is ``None``)."""
if token is None or _otel_detach is None:
return
_otel_detach(cast("Any", token))


def _merge_partial[StateT: State](
prior: StateT,
Expand Down Expand Up @@ -690,20 +740,35 @@ async def innermost(s: Any) -> Mapping[str, Any]:
try:
self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index)

# Splice the observer-published span (if any) into the
# OTel context so logs emitted from the FIRST line of
# the node body — before any ``await`` — pick up the
# right trace_id/span_id via OTel's LoggingHandler.
# Detach in ``finally`` so retries / merge / completed
# dispatch don't run with the span still active, and
# clear ``current_active_observer_span`` to ``None`` so
# the next dispatch that raises or early-returns from
# ``prepare_sync`` can't reveal this node's span as a
# stale value to the engine's read.
otel_token = _attach_active_observer_span()
try:
partial = await node.run(s)
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
)
raise
try:
partial = await node.run(s)
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
)
raise
finally:
_detach_active_observer_span(otel_token)
Comment thread
chris-colinsky marked this conversation as resolved.
_set_active_observer_span(None)

try:
merged = _merge_partial(s, partial, self.reducers, current)
Expand Down Expand Up @@ -1045,38 +1110,55 @@ async def innermost(s: Any) -> Mapping[str, Any]:
attempt_index=attempt_index,
fan_out_config=fan_out_event_config,
)
# Same OTel attach pattern as ``_step_function_node``'s
# ``innermost`` — splice the observer-published span
# into the OTel context so logs emitted from inside
# the fan-out node's own scope (middleware bodies,
# the dispatch machinery) carry the right
# trace_id/span_id. Per-instance bodies get their own
# attach inside their ``_step_function_node``
# innermost when the recursive invocation hits leaf
# nodes. ``finally`` clears the ContextVar so a later
# dispatch whose ``prepare_sync`` raises or early-
# returns can't reveal this fan-out's span as a stale
# value to the engine's read.
otel_token = _attach_active_observer_span()
try:
partial = await node.run_with_context(
s,
context,
pre_resolved_count=item_count,
pre_resolved_concurrency=(concurrency_resolved,),
)
except RuntimeGraphError as e:
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=e,
attempt_index=attempt_index,
fan_out_config=fan_out_event_config,
)
raise
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
fan_out_config=fan_out_event_config,
)
raise wrapped from e
try:
partial = await node.run_with_context(
s,
context,
pre_resolved_count=item_count,
pre_resolved_concurrency=(concurrency_resolved,),
)
except RuntimeGraphError as e:
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=e,
attempt_index=attempt_index,
fan_out_config=fan_out_event_config,
)
raise
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
fan_out_config=fan_out_event_config,
)
raise wrapped from e
finally:
_detach_active_observer_span(otel_token)
Comment thread
chris-colinsky marked this conversation as resolved.
_set_active_observer_span(None)

try:
merged = _merge_partial(s, partial, self.reducers, current)
Expand Down
93 changes: 93 additions & 0 deletions src/openarmature/graph/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from __future__ import annotations

import asyncio
import inspect
import warnings
from collections.abc import Iterable
from dataclasses import dataclass, field
Expand Down Expand Up @@ -60,6 +61,25 @@ async def log_observer(event: NodeEvent) -> None:
The event parameter is positional-only (`event, /`) so structural
conformance doesn't pin you to that name — any of `event`, `_event`,
`e`, etc. matches.

Optional ``prepare_sync`` extension
-----------------------------------
An observer MAY additionally define a synchronous method::

def prepare_sync(self, event: NodeEvent, /) -> None: ...

that the engine calls IN THE ENGINE TASK, BEFORE queueing the
event for the async ``__call__``. This exists for observers that
need to set up state — e.g., open a span and stash a handle in
a ContextVar — that the engine itself must read synchronously
before running the node body (otherwise logs emitted on the
first line of the body wouldn't see the right span).

``prepare_sync`` is **opt-in via ``hasattr``** — no subclass or
Protocol method required. Observers that don't define it skip
the synchronous prep entirely; observers that do define it run
only for ``"started"``-phase events, errors warned not propagated
(same isolation contract as the async path per spec §6).
"""

async def __call__(self, event: NodeEvent, /) -> None: ...
Expand Down Expand Up @@ -344,12 +364,85 @@ def take_step(self) -> int:
def _dispatch(context: _InvocationContext, event: NodeEvent) -> None:
"""Enqueue a node event for the delivery worker.

For ``"started"``-phase events, also call any subscribed observer's
optional ``prepare_sync(event)`` synchronously — in the engine task,
BEFORE queueing — so observers that need to publish per-event state
the engine itself reads in the same engine-task scope (e.g., the
OTel observer setting ``current_active_observer_span`` for the
engine to attach into the OTel context) can do so before the node
body runs.

Phase-gated forwarding: ``prepare_sync`` only fires when ``"started"``
is in the subscribed observer's ``phases`` set, mirroring how the
async ``deliver_loop`` filters dispatch. A user who explicitly
subscribes only to ``{"completed"}`` doesn't get the synchronous
prep — the wrapper acts as a uniform phase shield across both
sync prep and async dispatch.

Errors from ``prepare_sync`` follow the same isolation contract as
the async path per spec §6: don't propagate, don't break siblings,
don't block the queueing or subsequent events. Reported via
``warnings.warn``.

No-op when no observers exist for this depth — avoids paying the queue
overhead for graphs that don't observe anything.
"""
observers = context.full_observers()
if not observers:
return
if event.phase == "started":
for subscribed in observers:
if "started" not in subscribed.phases:
continue
prepare_sync = getattr(subscribed.observer, "prepare_sync", None)
if prepare_sync is None:
continue
try:
result = prepare_sync(event)
except Exception as e:
warnings.warn(
f"observer prepare_sync raised {type(e).__name__}: {e}",
stacklevel=2,
)
continue
if inspect.isawaitable(result):
# ``prepare_sync`` is opt-in via ``hasattr`` (not a
# Protocol method) so pyright can't catch a user's
# ``async def prepare_sync`` signature drift up front.
# The call here would silently return an unawaited
# coroutine — the prep work wouldn't run AND Python
# would emit a delayed "coroutine was never awaited"
# warning at GC time. Close the awaitable to suppress
# that secondary noise and surface the misconfiguration
# via our own explicit warn so it fails loudly at the
# call site. ``getattr`` rather than ``hasattr``+method
# access keeps pyright's strict-mode happy on the
# ``Awaitable`` type (``.close`` lives on
# ``Coroutine``, not the broader ``Awaitable``).
close_method = getattr(result, "close", None)
if close_method is not None:
try:
close_method()
except Exception as close_error:
# Cleanup is best-effort: a raise here MUST NOT
# propagate or block sibling observers. Surface
# via ``warnings.warn`` so the swallow is at
# least observable if it ever fires (CodeQL
# py/empty-except clears on this surface too).
warnings.warn(
f"observer prepare_sync close cleanup raised "
f"{type(close_error).__name__}: {close_error}",
stacklevel=2,
)
warnings.warn(
f"observer prepare_sync returned an awaitable "
f"({type(result).__name__}); prepare_sync MUST be sync "
f"(define as `def`, not `async def`). The returned "
f"awaitable will not be awaited and is NOT guaranteed "
f"to complete before the node body starts; log "
f"correlation may miss this node's span.",
stacklevel=2,
)
context.queue.put_nowait(_QueuedItem(event=event, observers=observers))


Expand Down
Loading
Loading