Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 118 additions & 78 deletions src/openarmature/graph/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,18 @@
from openarmature.observability.correlation import (
_reset_active_dispatch,
_reset_active_observers,
_reset_attempt_index,
_reset_correlation_id,
_reset_fan_out_index,
_reset_invocation_id,
_reset_namespace_prefix,
_set_active_dispatch,
_set_active_observers,
_set_attempt_index,
_set_correlation_id,
_set_fan_out_index,
_set_invocation_id,
_set_namespace_prefix,
)

from .edges import END, ConditionalEdge, EndSentinel, StaticEdge
Expand Down Expand Up @@ -558,51 +564,59 @@ async def innermost(s: Any) -> Mapping[str, Any]:
attempt_index = attempt_counter[0]
attempt_counter[0] += 1

self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index)

# Calling-node identity for capability backends emitting
# from inside this attempt's scope (e.g., LLM provider's
# span hook). Per-attempt scope so retry middleware that
# re-enters innermost bumps the visible attempt_index.
attempt_token = _set_attempt_index(attempt_index)
try:
partial = await node.run(s)
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
)
raise
self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index)

try:
partial = await node.run(s)
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
)
raise

try:
merged = _merge_partial(s, partial, self.reducers, current)
except (ReducerError, StateValidationError) as e:
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=e,
attempt_index=attempt_index,
)
raise

try:
merged = _merge_partial(s, partial, self.reducers, current)
except (ReducerError, StateValidationError) as e:
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=e,
post_state=merged,
attempt_index=attempt_index,
)
raise

self._dispatch_completed(
context,
current,
namespace,
step,
s,
post_state=merged,
attempt_index=attempt_index,
)
# Return the partial (not the merged state) so middleware sees
# the partial-update shape per pipeline-utilities §2. The
# engine's canonical merge against the original state happens
# below, after the chain returns.
return partial
# Return the partial (not the merged state) so middleware sees
# the partial-update shape per pipeline-utilities §2. The
# engine's canonical merge against the original state happens
# below, after the chain returns.
return partial
finally:
_reset_attempt_index(attempt_token)

chain: ChainCall = compose_chain(
list(self.middleware) + list(node.middleware),
Expand All @@ -612,12 +626,17 @@ async def innermost(s: Any) -> Mapping[str, Any]:
# Spec observability §3 / Phase 6 LLM-span hook: capability
# backends emitting from inside a node body (the
# llm-provider span instrumentation in OpenAIProvider) need
# to find the observers active for THIS invocation. Set the
# ContextVar around the chain invocation; reset in
# ``try/finally`` so an exception escaping the chain still
# restores the prior value.
# to find the observers active for THIS invocation, which
# node is calling, and which fan-out instance (if any) the
# call belongs to. ``namespace_prefix`` and ``fan_out_index``
# are set in this outer scope (per-node, not per-attempt);
# ``attempt_index`` is set inside ``innermost`` per attempt.
# All four reset in ``try/finally`` so an exception escaping
# the chain still restores the prior values.
observers_token = _set_active_observers(context.full_observers())
dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event))
namespace_token = _set_namespace_prefix(namespace)
fan_out_token = _set_fan_out_index(context.fan_out_index)
try:
try:
final_partial = await chain(state)
Expand All @@ -628,6 +647,8 @@ async def innermost(s: Any) -> Mapping[str, Any]:
# the chain unrecovered. Wrap as NodeException per §4.
raise NodeException(node_name=current, cause=e, recoverable_state=state) from e
finally:
_reset_fan_out_index(fan_out_token)
_reset_namespace_prefix(namespace_token)
_reset_active_dispatch(dispatch_token)
_reset_active_observers(observers_token)
# Engine's canonical merge uses the ORIGINAL state per §2: "the
Expand Down Expand Up @@ -686,13 +707,18 @@ async def innermost(s: Any) -> Mapping[str, Any]:
list(self.middleware) + list(node.middleware),
innermost,
)
# Same active-observers scope as _step_function_node — parent
# middleware running before the descent should see the parent's
# observer set; the inner _invoke (called via ``node.run``)
# descends into its own context and sets a new scope from
# there.
# Same active-observers + calling-node scope as
# ``_step_function_node`` — parent middleware running before
# the descent should see the wrapper node's namespace +
# fan_out_index for any LLM-provider hook emissions.
# ``attempt_index`` defaults to 0 from the ContextVar; the
# subgraph wrapper has no engine-managed attempt counter
# (inner ``_step_function_node`` calls own their own).
namespace = context.namespace_prefix + (current,)
observers_token = _set_active_observers(context.full_observers())
dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event))
namespace_token = _set_namespace_prefix(namespace)
fan_out_token = _set_fan_out_index(context.fan_out_index)

try:
try:
Expand All @@ -707,6 +733,8 @@ async def innermost(s: Any) -> Mapping[str, Any]:
# preserved.
raise NodeException(node_name=current, cause=e, recoverable_state=state) from e
finally:
_reset_fan_out_index(fan_out_token)
_reset_namespace_prefix(namespace_token)
_reset_active_dispatch(dispatch_token)
_reset_active_observers(observers_token)
return _merge_partial(state, final_partial, self.reducers, current)
Expand Down Expand Up @@ -743,58 +771,68 @@ async def _step_fan_out_node(
async def innermost(s: Any) -> Mapping[str, Any]:
attempt_index = attempt_counter[0]
attempt_counter[0] += 1
self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index)
attempt_token = _set_attempt_index(attempt_index)
try:
partial = await node.run_with_context(s, context)
except RuntimeGraphError as e:
self._dispatch_completed(
context, current, namespace, step, s, error=e, attempt_index=attempt_index
)
raise
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index)
try:
partial = await node.run_with_context(s, context)
except RuntimeGraphError as e:
self._dispatch_completed(
context, current, namespace, step, s, error=e, attempt_index=attempt_index
)
raise
except Exception as e:
wrapped = NodeException(node_name=current, cause=e, recoverable_state=s)
self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
attempt_index=attempt_index,
)
raise wrapped from e

try:
merged = _merge_partial(s, partial, self.reducers, current)
except (ReducerError, StateValidationError) as e:
self._dispatch_completed(
context, current, namespace, step, s, error=e, attempt_index=attempt_index
)
raise

self._dispatch_completed(
context,
current,
namespace,
step,
s,
error=wrapped,
post_state=merged,
attempt_index=attempt_index,
)
raise wrapped from e

try:
merged = _merge_partial(s, partial, self.reducers, current)
except (ReducerError, StateValidationError) as e:
self._dispatch_completed(
context, current, namespace, step, s, error=e, attempt_index=attempt_index
)
raise

self._dispatch_completed(
context,
current,
namespace,
step,
s,
post_state=merged,
attempt_index=attempt_index,
)
return partial
return partial
finally:
_reset_attempt_index(attempt_token)

chain: ChainCall = compose_chain(
list(self.middleware) + list(node.middleware),
innermost,
)

# Same observability §3 / LLM-span hook contract as
# _step_function_node: set the active observer set in scope
# around the chain invocation so capability backends emitting
# from inside the fan-out's parent dispatch (or any code
# running on its call stack) can find the observers.
# _step_function_node: set the active observer set, calling
# node identity, and dispatch scope around the chain
# invocation so capability backends emitting from inside the
# fan-out's parent dispatch (or any code running on its call
# stack) can find them. ``fan_out_index`` here is the parent
# context's view (the fan-out node from outside); per-instance
# values get set when the inner subgraph descends with the
# instance's index in its own context.
observers_token = _set_active_observers(context.full_observers())
dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event))
namespace_token = _set_namespace_prefix(namespace)
fan_out_token = _set_fan_out_index(context.fan_out_index)
try:
try:
final_partial = await chain(state)
Expand All @@ -803,6 +841,8 @@ async def innermost(s: Any) -> Mapping[str, Any]:
except Exception as e:
raise NodeException(node_name=current, cause=e, recoverable_state=state) from e
finally:
_reset_fan_out_index(fan_out_token)
_reset_namespace_prefix(namespace_token)
_reset_active_dispatch(dispatch_token)
_reset_active_observers(observers_token)
merged_outer = _merge_partial(state, final_partial, self.reducers, current)
Expand Down
27 changes: 26 additions & 1 deletion src/openarmature/llm/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@

from openarmature.graph.events import NodeEvent
from openarmature.graph.state import State
from openarmature.observability.correlation import current_dispatch
from openarmature.observability.correlation import (
current_attempt_index,
current_dispatch,
current_fan_out_index,
current_namespace_prefix,
)

from ..errors import (
ProviderAuthentication,
Expand Down Expand Up @@ -554,6 +559,15 @@ class _LlmEventState(State):
LLM-span maps by it so concurrent ``complete()`` calls (e.g.,
fan-out instances each calling the provider) don't collide on
a single sentinel-namespace key.

``calling_namespace_prefix``, ``calling_attempt_index``, and
``calling_fan_out_index`` carry the calling node's identity so
the OTel observer can resolve the §5.5 "parent under calling
node" contract correctly under concurrent fan-out and retry.
Populated from the engine's ContextVars (set in
``_step_*_node`` around node-body execution); fall back to
sentinel defaults (empty tuple, 0, ``None``) when the LLM
provider is called outside any node body.
"""

call_id: str
Expand All @@ -571,6 +585,14 @@ class _LlmEventState(State):
error_type: str | None = None
error_message: str | None = None
error_category: str | None = None
# Calling-node identity captured at dispatch time. The OTel
# observer reads these to look up the calling node's span in
# its (now-invocation_id-scoped) ``_open_spans`` map without relying on
# the OTel current-span context (which under concurrent fan-out
# can yield a sibling instance's span).
calling_namespace_prefix: tuple[str, ...] = ()
calling_attempt_index: int = 0
calling_fan_out_index: int | None = None


def _make_llm_event(
Expand Down Expand Up @@ -611,6 +633,9 @@ def _make_llm_event(
error_type=error_type,
error_message=error_message,
error_category=error_category,
calling_namespace_prefix=current_namespace_prefix(),
calling_attempt_index=current_attempt_index(),
calling_fan_out_index=current_fan_out_index(),
)
return NodeEvent(
node_name="openarmature.llm.complete",
Expand Down
6 changes: 6 additions & 0 deletions src/openarmature/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@

from .correlation import (
current_active_observers,
current_attempt_index,
current_correlation_id,
current_dispatch,
current_fan_out_index,
current_invocation_id,
current_namespace_prefix,
)

__all__ = [
"current_active_observers",
"current_attempt_index",
"current_correlation_id",
"current_dispatch",
"current_fan_out_index",
"current_invocation_id",
"current_namespace_prefix",
]
Loading
Loading