diff --git a/CHANGELOG.md b/CHANGELOG.md index 79b2e51..3385c0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The ### Added +- **Parallel branches (proposal 0011, introduced in spec v0.11.0; attempt-index propagation clarified in spec v0.16.1).** New `GraphBuilder.add_parallel_branches_node(name, *, branches, error_policy, errors_field, middleware)` surface dispatches M heterogeneous compiled subgraphs concurrently per pipeline-utilities §11. `BranchSpec` (subgraph + inputs/outputs projection + branch middleware) and `ParallelBranchesNode` types exported from `openarmature.graph`. Branch insertion order determines fan-in merge order regardless of completion timing (§11.8). Two error policies: `"fail_fast"` raises `ParallelBranchesBranchFailed` (a `NodeException` subtype) with `branch_name`, original cause as `__cause__`, and `recoverable_state` carrying the parent's pre-dispatch snapshot — no buffered branch contributions are visible (§11.5 buffer-and-apply). `"collect"` records per-branch failures in an optional `errors_field` (each record carries `branch_name` + `category` + implementation-defined extras) and continues. Two new error categories: `ParallelBranchesNoBranches` (compile time, empty branches map) and `ParallelBranchesBranchFailed` (runtime, fail_fast branch raise). +- **`NodeEvent.branch_name: str | None` (proposal 0011 / graph-engine §6).** Populated on events from nodes inside a parallel-branches branch, absent outside. Independent of `fan_out_index` — both may be present simultaneously when a branch contains a fan-out (or a fan-out instance contains a parallel-branches node). The combined `(namespace, branch_name, fan_out_index, attempt_index, phase)` tuple is the event-source uniqueness key. +- **`openarmature.branch_name` OTel span attribute.** Mirrors the existing `openarmature.node.fan_out_index`. Emitted on synthesized inner-node spans when `branch_name` is populated on the event. The two attributes coexist on inner nodes of a fan-out-inside-a-branch composition. +- **Attempt-index ContextVar propagation through transitive retry (graph-engine §6 v0.16.1).** Retry middleware now sets the `attempt_index` ContextVar before each `next` call; the engine reads `current_attempt_index()` when emitting events. This makes retry semantics symmetric across direct (per-node middleware) and transitive (instance / branch / fan-out instance_middleware) wrapping — events from inner nodes of a subgraph the retry re-invokes carry the wrapping retry's counter, not a freshly-zeroed inner counter. Innermost-wins precedence falls out of Python's ContextVar set/reset token stack. Pre-existing node-level retry behavior is unchanged. - **State migration for checkpointed graphs (proposal 0014, introduced in spec v0.15.0; refined by proposal 0018 in spec v0.16.0).** Saved checkpoints whose `schema_version` doesn't match the current state class now route through a registered migration chain instead of failing on resume. Surface: `State.schema_version: ClassVar[str] = ""` (declare a non-empty value to opt in), `GraphBuilder.with_state_migration(from_version, to_version, migrate)` and `with_state_migrations(*migrations)` for registration, `StateMigration` and `MigrationRegistry` types exported from `openarmature.checkpoint`. Chain resolution is BFS over the registered edges; the shortest path wins. Three new error categories: `CheckpointStateMigrationChainAmbiguous` (proposal 0018: duplicate `(from, to)` pair at registration time, or multiple distinct shortest paths between the saved and current versions at resume time), `CheckpointStateMigrationMissing` (no chain bridges the versions), and `CheckpointStateMigrationFailed` (a migration function raised). All non-transient. Post-migration deserialization failures still route to `CheckpointRecordInvalid` per §10.12.4. The same chain applies to each entry in `parent_states` in lockstep with the outer state per §10.12.2. Routing precedence per §10.10 (v0.16.0): chain-ambiguous → missing → failed → record-invalid. - **`Checkpointer.supports_state_migration` Protocol attribute.** Marks whether a backend can expose the structural intermediate form (a plain dict, JSON tree) the migration registry consumes. `SQLiteCheckpointer(serialization="json")` opts in; `SQLiteCheckpointer(serialization="pickle")` and `InMemoryCheckpointer` opt out. On version mismatch against a non-migration-eligible backend the engine raises `CheckpointRecordInvalid` per spec §10.12.1. - **`openarmature.checkpoint.migrate` OTel span (proposal 0014 §6 cross-ref).** Versioned resumes whose migration chain runs emit a zero-duration `openarmature.checkpoint.migrate` span on the OTel observer, parented under the invocation root span. Attributes: `openarmature.checkpoint.migrate.from_version`, `openarmature.checkpoint.migrate.to_version` (the final target), `openarmature.checkpoint.migrate.chain_length`. The §10.12.3 fast path (versions match, registry not consulted) emits no span. Engine-side: a synthetic `checkpoint_migrated` observer phase carries a `_MigrationSummary` payload from `_migrate_record` through to the OTel observer; the new phase is gated off default subscriptions (observers opt in explicitly via `phases={..., "checkpoint_migrated"}`). @@ -25,13 +29,13 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The ### Changed -- **Pinned spec version: 0.10.0 → 0.16.0.** Adopts the skip-ahead governance principle: the submodule jumps across v0.11.0–v0.16.0 (proposals 0009, 0011, 0014, 0015, 0016, 0017, 0018) in one bump. Only the surfaces introduced by proposals 0014–0017 are implemented in the batch's release; fixtures from 0011 are deferred-skip in the conformance suite and unmark with PR-5. +- **Pinned spec version: 0.10.0 → 0.16.1.** Adopts the skip-ahead governance principle: the submodule jumps across v0.11.0–v0.16.1 (proposals 0009, 0011, 0014, 0015, 0016, 0017, 0018) in one bump. All five proposals (0011, 0014, 0015, 0016, 0017) are implemented in the batch's release; the v0.16.1 clarification of attempt-index propagation through transitive retry middleware lands with the proposal 0011 implementation. - **`CheckpointRecord.schema_version` semantic shift (proposal 0014).** Previously a backend-internal record-shape version (`CHECKPOINT_SCHEMA_VERSION = "1"` constant), now the user-facing state-schema version per spec §10.2. The framework reads `type(state).schema_version` at save time. Pre-PR-4 records carrying `"1"` are reinterpreted as user-facing v1 identifiers; users with such records either declare `schema_version="1"` on their state class or discard the pre-PR-4 records. `SQLiteCheckpointer` no longer rejects records with non-default `schema_version` at the backend boundary; version-mismatch routing is now an engine concern at resume time. The `CHECKPOINT_SCHEMA_VERSION` module constant is removed; future record-shape evolution can add backend-private metadata fields if needed. - **`NodeEvent.pre_state` typed `Any` (was `State`).** Required by the new `checkpoint_migrated` phase which carries a `_MigrationSummary` payload rather than a `State` instance. Observer authors who type-narrowed `pre_state` to `State` should treat it as `Any` and narrow per-phase (e.g., `if event.phase == "completed": ...`). The `checkpoint_saved` phase already carried a State-flavored shape (not necessarily a typed `State` subclass instance), so this widens the declared type to match runtime reality rather than introducing a new constraint. ### Notes -- **Release gate: do not tag until all of {0011, 0014, 0015, 0016, 0017} are merged.** This batch implements one proposal per PR and lands a consolidated release after the fifth PR. Cutting a release tag before the batch is complete would ship a partial spec implementation against the v0.15.0 pin. +- **Release gate cleared with PR-5 (proposal 0011).** All five proposals in the batch ({0011, 0014, 0015, 0016, 0017}) are now implemented. Tag the consolidated release once this PR merges. - **Pre-1.0 MINOR.** Existing free-form callers (no `response_schema`) see no behavior change — the new field defaults to `None`, the wire body omits `response_format`, and `Response.parsed` remains absent. ## [0.5.0] — 2026-05-10 diff --git a/docs/concepts/index.md b/docs/concepts/index.md index aa63ccb..a55a722 100644 --- a/docs/concepts/index.md +++ b/docs/concepts/index.md @@ -12,6 +12,9 @@ the framework, or jump to whichever concept you need. data seam. - [Fan-out](fan-out.md): running the same subgraph many times in parallel, results merged back deterministically. +- [Parallel branches](parallel-branches.md): dispatching M + heterogeneous subgraphs concurrently with per-branch state schemas + and middleware. - [LLMs](llms.md): how LLM calls fit into nodes, structured output, routing on parsed fields, errors at the LLM boundary. - [Observability](observability.md): node-boundary hooks, OTel mapping, diff --git a/docs/concepts/observability.md b/docs/concepts/observability.md index 3b40c4f..f6aa23e 100644 --- a/docs/concepts/observability.md +++ b/docs/concepts/observability.md @@ -78,6 +78,7 @@ class NodeEvent: attempt_index: int = 0 fan_out_index: int | None = None fan_out_config: FanOutEventConfig | None = None + branch_name: str | None = None ``` A walk-through: @@ -130,7 +131,11 @@ A walk-through: `len(parent_states) == len(namespace) - 1`. - **`attempt_index`**: 0-based retry attempt counter. `0` for nodes - not wrapped by retry middleware; `1+` for retries. + not wrapped by retry middleware; `1+` for retries. Retry middleware + may wrap transitively — a retry on a [parallel-branches + branch](parallel-branches.md) or fan-out `instance_middleware` + re-runs the whole subgraph; events from inner nodes carry the + wrapping retry's attempt counter. - **`fan_out_index`**: 0-based per-instance index for events inside a fan-out instance; `None` outside. @@ -140,6 +145,17 @@ A walk-through: `item_count` / `concurrency` / `error_policy` / `parent_node_name`. `None` on every other event. +- **`branch_name`**: populated on events from nodes inside a + [parallel-branches branch](parallel-branches.md), carrying the + branch's name as declared on the dispatcher. `None` outside. + Independent of `fan_out_index` — both may be present simultaneously + when a parallel-branches branch contains a fan-out (or a fan-out + instance contains a parallel-branches node). The combination + `(namespace, branch_name, fan_out_index, attempt_index, phase)` + uniquely identifies each event source. On the OTel mapping + side, an `openarmature.branch_name` span attribute is added in + parallel to the existing `openarmature.node.fan_out_index`. + ## Routing errors and the completed event When a conditional edge raises or returns an invalid target: diff --git a/docs/concepts/parallel-branches.md b/docs/concepts/parallel-branches.md new file mode 100644 index 0000000..b6e94f6 --- /dev/null +++ b/docs/concepts/parallel-branches.md @@ -0,0 +1,151 @@ +# Parallel branches + +Dispatch M heterogeneous subgraphs concurrently, projected outputs +merged back into the parent via the parent's reducers in branch +insertion order. + +Sibling to [fan-out](fan-out.md) (same `for each thing, do work in +parallel` shape), but the *thing* is different per branch: a research +subgraph, a categorize subgraph, a sentiment subgraph — each with its +own state schema, its own middleware, its own observer events — +running in parallel and joining their results into one parent state. + +## When to reach for parallel branches + +The signal: a fixed set of named operations, each with its own +behavior and state schema, that don't depend on each other. Three +classifiers running independently against the same input. A research +step, a translate step, and a fact-check step that all want the +parent's prompt. M is known at build time and small (typically 2–6), +and each branch is its own subgraph because each has its own +internal pipeline worth modelling separately. + +Fan-out is the right pick when you have N similar pieces of work, +N depends on runtime state, and the work is the same across instances. +Parallel branches is the right pick when M is a small fixed set of +different operations that happen to run concurrently. + +## The shape + +```python +from openarmature.graph import BranchSpec, GraphBuilder + +builder.add_parallel_branches_node( + "dispatcher", + branches={ + "research": BranchSpec( + subgraph=research_subgraph, # CompiledGraph[ResearchState] + inputs={"question": "prompt"}, # subgraph_field -> parent_field + outputs={"facts": "facts"}, # parent_field -> subgraph_field + ), + "translate": BranchSpec( + subgraph=translate_subgraph, # CompiledGraph[TranslateState] + inputs={"source": "prompt"}, + outputs={"translation": "translated"}, + ), + "fact_check": BranchSpec( + subgraph=fact_check_subgraph, # CompiledGraph[FactCheckState] + inputs={"claim": "prompt"}, + outputs={"verdict": "verdict"}, + ), + }, + error_policy="fail_fast", # or "collect" +) +``` + +Each branch's `subgraph` is a compiled graph; `inputs` and `outputs` +mirror the explicit projection shape from +[composition](composition.md#explicitmapping-declarative). The +branches dict's key is the branch name — used as the branch identity +on observer events (see [observability](observability.md)) and in +the per-branch error records that `error_policy: "collect"` +produces. + +## Per-branch state, inputs and outputs + +Each branch runs its own subgraph against its own state — heterogeneous +schemas are explicit. Subgraph fields named in `inputs` are seeded +from the parent's corresponding field at branch entry; other subgraph +fields take their schema defaults. At branch exit, only the parent +fields named in `outputs` receive contributions; the rest of the +branch's final state is discarded. + +When two branches contribute to the same parent field, the parent's +reducer for that field applies both values in **branch insertion +order** — first the branch declared first in the `branches` dict, +then the next, and so on. This is deterministic regardless of which +branch's inner work finishes first. + +## Error policy + +- **`"fail_fast"`** (default): the first branch failure cancels + the in-flight siblings and propagates as + `ParallelBranchesBranchFailed` (a `NodeException` subtype) carrying + the failing `branch_name` and the original cause as `__cause__`. + `recoverable_state` is the parent's snapshot at the moment the + dispatcher entered — **no buffered branch contributions are + applied**, including those of branches that successfully completed + before the failure. Buffer-and-apply semantics: contributions are + held until every branch finishes, then either all apply (success) + or none apply (fail_fast failure). +- **`"collect"`**: every branch runs to completion. Successful + branches' contributions merge in insertion order; failed branches' + `outputs` projections do NOT fire (their named parent fields stay + at their defaults). If you declare `errors_field` on the dispatcher, + each failed branch produces a record with at minimum + `{"branch_name": , "category": }` appended to that + parent list field; the implementation may include additional keys + (message, cause_type) and tests should match by the spec-mandated + keys rather than strict equality. + +## Branch middleware + +Each `BranchSpec` accepts a `middleware` tuple — middlewares that +wrap that branch's whole subgraph invocation as a unit. Retry +middleware on a branch retries the **whole branch**: a fresh +subgraph invocation each time, fresh inner-node execution. The +wrapping retry's attempt counter propagates to events emitted from +inner nodes (per graph-engine §6 v0.16.1), so observer events +inside the branch correctly show `attempt_index` ticking across +retries. + +Branch middleware is independent across branches — branch A may +have `[retry, timing]`; branch B may have `[]`; branch C may have +some custom breaker. Each branch's chain composes in isolation. + +## Composition with other constructs + +Parallel branches compose with the rest of the engine the way +subgraphs and fan-outs do: + +- A branch's subgraph can itself contain a fan-out node — inner-node + events inside that fan-out carry **both** `branch_name` (this + branch) and `fan_out_index` (the instance within this branch). + The two fields are independent. +- The parallel-branches node itself can be invoked from inside a + fan-out instance — inner events then carry the outer fan-out's + `fan_out_index` and the inner branch's `branch_name`. +- Per-graph and per-node middleware on the parallel-branches node + wrap the dispatcher as a single unit — one `started` event before + dispatch begins, one `completed` event after all branches finish + and fan-in lands. The parent's retry middleware retries the **whole + parallel-branches node**, not individual branches. + +## Resume semantics + +Parallel-branches nodes use the same **atomic restart** model as +fan-out (per spec §10.7): if a checkpoint resume lands on a +parallel-branches node, all branches re-dispatch from scratch. +Per-branch progress is not individually persisted in v1. + +## When parallel branches is NOT the right shape + +- **Not the same as N copies of one subgraph.** If you want "run + this subgraph for each item in a list," reach for + [fan-out](fan-out.md). +- **Not a router.** A router is a conditional-edge pattern — pick + one branch based on state. Parallel branches runs *all* branches + concurrently. +- **Not a coordinator.** Branches don't communicate with each other + during execution; if branch B's work depends on branch A's + output, you want a linear pipeline (A → B), not parallel branches. diff --git a/openarmature-spec b/openarmature-spec index bdfe13a..19b3e0c 160000 --- a/openarmature-spec +++ b/openarmature-spec @@ -1 +1 @@ -Subproject commit bdfe13ad3bbeb83401eedb2b6e9ab51a15619ad3 +Subproject commit 19b3e0c81480c1c974e8520322ddf5ba7abc8286 diff --git a/pyproject.toml b/pyproject.toml index fd33d16..a5c222e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ Repository = "https://github.com/LunarCommand/openarmature-python" Specification = "https://github.com/LunarCommand/openarmature-spec" [tool.openarmature] -spec_version = "0.16.0" +spec_version = "0.16.1" [dependency-groups] dev = [ diff --git a/src/openarmature/__init__.py b/src/openarmature/__init__.py index feef202..2496083 100644 --- a/src/openarmature/__init__.py +++ b/src/openarmature/__init__.py @@ -1,4 +1,4 @@ """OpenArmature — workflow framework for LLM pipelines and tool-calling agents.""" __version__ = "0.5.0" -__spec_version__ = "0.16.0" +__spec_version__ = "0.16.1" diff --git a/src/openarmature/graph/__init__.py b/src/openarmature/graph/__init__.py index 85c2cb0..47b4675 100644 --- a/src/openarmature/graph/__init__.py +++ b/src/openarmature/graph/__init__.py @@ -27,6 +27,8 @@ MultipleOutgoingEdges, NoDeclaredEntry, NodeException, + ParallelBranchesBranchFailed, + ParallelBranchesNoBranches, ReducerError, RoutingError, RuntimeGraphError, @@ -45,6 +47,7 @@ ) from .nodes import FunctionNode, Node from .observer import Observer, RemoveHandle, SubscribedObserver +from .parallel_branches import BranchSpec, ParallelBranchesNode from .projection import ExplicitMapping, FieldNameMatching, ProjectionStrategy from .reducers import Reducer, append, last_write_wins, merge from .state import State @@ -79,6 +82,10 @@ "NodeException", "NoDeclaredEntry", "Observer", + "ParallelBranchesBranchFailed", + "ParallelBranchesNoBranches", + "ParallelBranchesNode", + "BranchSpec", "ProjectionStrategy", "Reducer", "ReducerError", diff --git a/src/openarmature/graph/builder.py b/src/openarmature/graph/builder.py index be5b8dd..8fe7129 100644 --- a/src/openarmature/graph/builder.py +++ b/src/openarmature/graph/builder.py @@ -12,7 +12,7 @@ from collections.abc import Awaitable, Callable, Iterable, Mapping from types import GenericAlias, UnionType -from typing import Any, Self, cast, get_args, get_origin +from typing import Any, Literal, Self, cast, get_args, get_origin from openarmature.checkpoint.errors import CheckpointStateMigrationChainAmbiguous from openarmature.checkpoint.migration import MigrationRegistry, StateMigration @@ -28,11 +28,13 @@ MappingReferencesUndeclaredField, MultipleOutgoingEdges, NoDeclaredEntry, + ParallelBranchesNoBranches, UnreachableNode, ) from .fan_out import ConcurrencyResolver, CountResolver, FanOutConfig, FanOutNode from .middleware import Middleware from .nodes import FunctionNode, Node +from .parallel_branches import BranchSpec, ParallelBranchesNode from .projection import FieldNameMatching, ProjectionStrategy from .reducers import Reducer from .state import State, field_reducers, resolve_reducer @@ -244,6 +246,87 @@ def add_fan_out_node[ChildT: State]( self._nodes[name] = fan_out return self + def add_parallel_branches_node( + self, + name: str, + *, + branches: Mapping[str, BranchSpec[Any]], + error_policy: Literal["fail_fast", "collect"] = "fail_fast", + errors_field: str | None = None, + middleware: Iterable[Middleware] | None = None, + ) -> Self: + """Register a parallel-branches node per pipeline-utilities §11. + + ``branches`` is a mapping from non-empty branch name to a + :class:`BranchSpec`. Insertion order is preserved and is + the dispatch + merge order per §11.8. + + Validates at registration: + + - ``branches`` non-empty (raises ``ParallelBranchesNoBranches``). + - Each branch name is a non-empty string (raises ``ValueError``). + - Each branch's ``inputs`` / ``outputs`` refer only to declared + fields on the (parent, branch-subgraph) state schemas + (raises ``MappingReferencesUndeclaredField``). + - ``errors_field`` (when set) is a declared parent-state field. + """ + if name in self._nodes: + raise ValueError(f"node {name!r} already declared") + if not branches: + raise ParallelBranchesNoBranches(node_name=name) + + parent_fields = self.state_cls.model_fields + if errors_field is not None and errors_field not in parent_fields: + raise MappingReferencesUndeclaredField( + direction="parallel_branches.errors_field", + side="parent", + field_name=errors_field, + ) + + for branch_name, spec in branches.items(): + if not branch_name: + raise ValueError(f"parallel-branches node {name!r}: branch_name MUST be non-empty") + sub_fields = spec.subgraph.state_cls.model_fields + for sub_field, parent_field in spec.inputs.items(): + if sub_field not in sub_fields: + raise MappingReferencesUndeclaredField( + direction=f"parallel_branches.{branch_name}.inputs", + side="subgraph", + field_name=sub_field, + ) + if parent_field not in parent_fields: + raise MappingReferencesUndeclaredField( + direction=f"parallel_branches.{branch_name}.inputs", + side="parent", + field_name=parent_field, + ) + for parent_field, sub_field in spec.outputs.items(): + if parent_field not in parent_fields: + raise MappingReferencesUndeclaredField( + direction=f"parallel_branches.{branch_name}.outputs", + side="parent", + field_name=parent_field, + ) + if sub_field not in sub_fields: + raise MappingReferencesUndeclaredField( + direction=f"parallel_branches.{branch_name}.outputs", + side="subgraph", + field_name=sub_field, + ) + + pb: Node[StateT] = cast( + "Node[StateT]", + ParallelBranchesNode[StateT]( + name=name, + branches=dict(branches), + error_policy=error_policy, + errors_field=errors_field, + middleware=tuple(middleware) if middleware is not None else (), + ), + ) + self._nodes[name] = pb + return self + def with_checkpointer(self, checkpointer: Checkpointer) -> Self: """Register a Checkpointer for the compiled graph. diff --git a/src/openarmature/graph/compiled.py b/src/openarmature/graph/compiled.py index 245af6c..648f952 100644 --- a/src/openarmature/graph/compiled.py +++ b/src/openarmature/graph/compiled.py @@ -62,7 +62,6 @@ from openarmature.observability.correlation import ( _reset_active_dispatch, _reset_active_observers, - _reset_attempt_index, _reset_correlation_id, _reset_fan_out_index, _reset_invocation_id, @@ -70,12 +69,12 @@ _set_active_dispatch, _set_active_observer_span, _set_active_observers, - _set_attempt_index, _set_correlation_id, _set_fan_out_index, _set_invocation_id, _set_namespace_prefix, current_active_observer_span, + current_attempt_index, ) from .edges import END, ConditionalEdge, EndSentinel, StaticEdge @@ -167,6 +166,14 @@ def _merge_partial[StateT: State]( ``ReducerError`` and schema failures as ``StateValidationError``. """ + # Lazy import to avoid a textual cycle (parallel_branches has a + # TYPE_CHECKING back-reference to this module). _MultiContribution + # is the sentinel ParallelBranchesNode uses when multiple branches + # write the same parent field — each value flows through the + # parent's reducer in branch insertion order per spec §11.4 + + # §11.8. + from .parallel_branches import _MultiContribution # noqa: PLC0415 + new_values = prior.model_dump() for field_name, partial_value in partial.items(): reducer = reducers.get(field_name) @@ -175,7 +182,17 @@ def _merge_partial[StateT: State]( new_values[field_name] = partial_value continue try: - new_values[field_name] = reducer(new_values[field_name], partial_value) + if isinstance(partial_value, _MultiContribution): + # Per pipeline-utilities §11.4: multi-branch + # contributions to one parent field apply in branch + # insertion order via the parent's reducer. Fold + # each value in sequence. + acc = new_values[field_name] + for v in partial_value.values: + acc = reducer(acc, v) + new_values[field_name] = acc + else: + new_values[field_name] = reducer(new_values[field_name], partial_value) except Exception as e: raise ReducerError( field_name=field_name, @@ -778,6 +795,7 @@ async def _invoke( # to this module). Function-scope import is cheap once # cached; this branch fires once per fan-out step. from .fan_out import FanOutNode # noqa: PLC0415 + from .parallel_branches import ParallelBranchesNode # noqa: PLC0415 if isinstance(node, FanOutNode): # Fan-out nodes are recognized as a distinct node type @@ -787,6 +805,13 @@ async def _invoke( # concurrency lives inside the FanOutNode itself. fn_node = cast("FanOutNode[StateT, State]", node) step_result = await self._step_fan_out_node(fn_node, current, state, context) + elif isinstance(node, ParallelBranchesNode): + # Parallel-branches nodes are recognized as a distinct + # node type per pipeline-utilities §11. Dispatched + # through ``_step_parallel_branches_node`` which wraps + # the whole dispatch as one parent unit (per §11.6) — + # M heterogeneous subgraphs run concurrently inside. + step_result = await self._step_parallel_branches_node(node, current, state, context) elif isinstance(node, SubgraphNode): # Subgraph wrappers are transparent to the observer protocol # (per fixture 013): no event is dispatched for the wrapper @@ -920,72 +945,72 @@ async def innermost(s: Any) -> Mapping[str, Any]: # the original `category` attribute (timing's # exception_category, retry's classifier). The engine wraps # any exception that escapes the chain, OUTSIDE this layer. - attempt_index = attempt_counter[0] attempt_counter[0] += 1 - # Calling-node identity for capability backends emitting - # from inside this attempt's scope (e.g., LLM provider's - # span hook). Per-attempt scope so retry middleware that - # re-enters innermost bumps the visible attempt_index. - attempt_token = _set_attempt_index(attempt_index) + # Per graph-engine §6 (clarified in v0.16.1): event + # emission reads ``attempt_index`` from the ContextVar set + # by any enclosing retry middleware — direct (per-node + # MW) or transitive (instance / branch MW on a subgraph + # the retry re-invokes). The engine itself no longer + # writes the var; innermost-wins precedence falls out of + # Python's ContextVar token-stack semantics. + attempt_index = current_attempt_index() + + self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) + + # Splice the observer-published span (if any) into the + # OTel context so logs emitted from the FIRST line of + # the node body — before any ``await`` — pick up the + # right trace_id/span_id via OTel's LoggingHandler. + # Detach in ``finally`` so retries / merge / completed + # dispatch don't run with the span still active, and + # clear ``current_active_observer_span`` to ``None`` so + # the next dispatch that raises or early-returns from + # ``prepare_sync`` can't reveal this node's span as a + # stale value to the engine's read. + otel_token = _attach_active_observer_span() try: - self._dispatch_started(context, current, namespace, step, s, attempt_index=attempt_index) - - # Splice the observer-published span (if any) into the - # OTel context so logs emitted from the FIRST line of - # the node body — before any ``await`` — pick up the - # right trace_id/span_id via OTel's LoggingHandler. - # Detach in ``finally`` so retries / merge / completed - # dispatch don't run with the span still active, and - # clear ``current_active_observer_span`` to ``None`` so - # the next dispatch that raises or early-returns from - # ``prepare_sync`` can't reveal this node's span as a - # stale value to the engine's read. - otel_token = _attach_active_observer_span() - try: - try: - partial = await node.run(s) - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=wrapped, - attempt_index=attempt_index, - ) - raise - finally: - _detach_active_observer_span(otel_token) - _set_active_observer_span(None) - try: - merged = _merge_partial(s, partial, self.reducers, current) - except (ReducerError, StateValidationError) as e: + partial = await node.run(s) + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) self._dispatch_completed( context, current, namespace, step, s, - error=e, + error=wrapped, attempt_index=attempt_index, ) raise - - # Defer the success-case completed dispatch to - # ``finalize_completed`` per proposal-0012; just - # record the info for the outer scope. - deferred_info[0] = (attempt_index, cast("StateT", s), cast("StateT", merged)) - # Return the partial (not the merged state) so middleware sees - # the partial-update shape per pipeline-utilities §2. The - # engine's canonical merge against the original state happens - # below, after the chain returns. - return partial finally: - _reset_attempt_index(attempt_token) + _detach_active_observer_span(otel_token) + _set_active_observer_span(None) + + try: + merged = _merge_partial(s, partial, self.reducers, current) + except (ReducerError, StateValidationError) as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + ) + raise + + # Defer the success-case completed dispatch to + # ``finalize_completed`` per proposal-0012; just + # record the info for the outer scope. + deferred_info[0] = (attempt_index, cast("StateT", s), cast("StateT", merged)) + # Return the partial (not the merged state) so middleware sees + # the partial-update shape per pipeline-utilities §2. The + # engine's canonical merge against the original state happens + # below, after the chain returns. + return partial chain: ChainCall = compose_chain( list(self.middleware) + list(node.middleware), @@ -1029,23 +1054,27 @@ async def innermost(s: Any) -> Mapping[str, Any]: merged_outer = _merge_partial(state, final_partial, self.reducers, current) # Spec §10.3: save fires once the canonical merge succeeds — # the LAST attempt's index is what gets recorded (retries - # don't multiply saves). attempt_counter[0] is one past the - # final attempt; ``max(0, ... - 1)`` covers the - # short-circuit case where middleware returns a partial - # without ever invoking ``next()`` (counter stays at 0, - # subtracting 1 would yield an invalid -1). + # don't multiply saves). Per graph-engine §6 v0.16.1, the + # recorded value is the wrapping retry MW's attempt counter + # (which the inner-node events also reflected via the + # ContextVar). ``deferred_info[0]`` captures that value at + # the moment of the successful merge, sourced from + # ``current_attempt_index()``. When middleware short- + # circuited without invoking ``next()``, ``deferred_info[0]`` + # is None and the save records attempt_index=0. + info = deferred_info[0] + saved_attempt = info[0] if info is not None else 0 await self._maybe_save_checkpoint( context, node_name=current, namespace=namespace, step=step, - attempt_index=max(0, attempt_counter[0] - 1), + attempt_index=saved_attempt, post_state=merged_outer, ) # Build the deferred-dispatch closure for the success-case # completed event. ``_invoke`` calls this after edge eval. - info = deferred_info[0] if info is None: # Middleware short-circuited without invoking ``next`` — # no started/completed pair fired. Edge errors after this @@ -1288,72 +1317,43 @@ async def _step_fan_out_node( deferred_info: list[tuple[int, StateT, StateT] | None] = [None] async def innermost(s: Any) -> Mapping[str, Any]: - attempt_index = attempt_counter[0] attempt_counter[0] += 1 - attempt_token = _set_attempt_index(attempt_index) - try: - self._dispatch_started( - context, - current, - namespace, - step, - s, - attempt_index=attempt_index, - fan_out_config=fan_out_event_config, - ) - # Same OTel attach pattern as ``_step_function_node``'s - # ``innermost`` — splice the observer-published span - # into the OTel context so logs emitted from inside - # the fan-out node's own scope (middleware bodies, - # the dispatch machinery) carry the right - # trace_id/span_id. Per-instance bodies get their own - # attach inside their ``_step_function_node`` - # innermost when the recursive invocation hits leaf - # nodes. ``finally`` clears the ContextVar so a later - # dispatch whose ``prepare_sync`` raises or early- - # returns can't reveal this fan-out's span as a stale - # value to the engine's read. - otel_token = _attach_active_observer_span() - try: - try: - partial = await node.run_with_context( - s, - context, - pre_resolved_count=item_count, - pre_resolved_concurrency=(concurrency_resolved,), - ) - except RuntimeGraphError as e: - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=e, - attempt_index=attempt_index, - fan_out_config=fan_out_event_config, - ) - raise - except Exception as e: - wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) - self._dispatch_completed( - context, - current, - namespace, - step, - s, - error=wrapped, - attempt_index=attempt_index, - fan_out_config=fan_out_event_config, - ) - raise wrapped from e - finally: - _detach_active_observer_span(otel_token) - _set_active_observer_span(None) + # Read from ContextVar — see ``_step_function_node``'s + # ``innermost`` comment on the v0.16.1 attempt-index + # propagation rule. + attempt_index = current_attempt_index() + self._dispatch_started( + context, + current, + namespace, + step, + s, + attempt_index=attempt_index, + fan_out_config=fan_out_event_config, + ) + # Same OTel attach pattern as ``_step_function_node``'s + # ``innermost`` — splice the observer-published span + # into the OTel context so logs emitted from inside + # the fan-out node's own scope (middleware bodies, + # the dispatch machinery) carry the right + # trace_id/span_id. Per-instance bodies get their own + # attach inside their ``_step_function_node`` + # innermost when the recursive invocation hits leaf + # nodes. ``finally`` clears the ContextVar so a later + # dispatch whose ``prepare_sync`` raises or early- + # returns can't reveal this fan-out's span as a stale + # value to the engine's read. + otel_token = _attach_active_observer_span() + try: try: - merged = _merge_partial(s, partial, self.reducers, current) - except (ReducerError, StateValidationError) as e: + partial = await node.run_with_context( + s, + context, + pre_resolved_count=item_count, + pre_resolved_concurrency=(concurrency_resolved,), + ) + except RuntimeGraphError as e: self._dispatch_completed( context, current, @@ -1365,13 +1365,42 @@ async def innermost(s: Any) -> Mapping[str, Any]: fan_out_config=fan_out_event_config, ) raise - - # Defer the success-case completed dispatch per - # proposal-0012; record the info for the outer scope. - deferred_info[0] = (attempt_index, cast("StateT", s), cast("StateT", merged)) - return partial + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + fan_out_config=fan_out_event_config, + ) + raise wrapped from e finally: - _reset_attempt_index(attempt_token) + _detach_active_observer_span(otel_token) + _set_active_observer_span(None) + + try: + merged = _merge_partial(s, partial, self.reducers, current) + except (ReducerError, StateValidationError) as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + fan_out_config=fan_out_event_config, + ) + raise + + # Defer the success-case completed dispatch per + # proposal-0012; record the info for the outer scope. + deferred_info[0] = (attempt_index, cast("StateT", s), cast("StateT", merged)) + return partial chain: ChainCall = compose_chain( list(self.middleware) + list(node.middleware), @@ -1408,19 +1437,23 @@ async def innermost(s: Any) -> Mapping[str, Any]: # one record once the fan-out as a whole has finished and # results have merged back. Per-instance internal saves are # gated off by the fan-out instance descent setting - # ``checkpointer=None`` on the inner context. ``max(0, ...)`` - # guards against the short-circuit case (middleware returns a - # partial without ever invoking ``next()``). + # ``checkpointer=None`` on the inner context. Per graph-engine + # §6 v0.16.1: the saved attempt_index reflects the wrapping + # retry MW's counter (sourced from ``deferred_info[0]`` which + # captured ``current_attempt_index()`` at the moment of the + # successful merge). Short-circuit case (middleware returned + # without invoking ``next``) records attempt_index=0. + info = deferred_info[0] + saved_attempt = info[0] if info is not None else 0 await self._maybe_save_checkpoint( context, node_name=current, namespace=namespace, step=step, - attempt_index=max(0, attempt_counter[0] - 1), + attempt_index=saved_attempt, post_state=merged_outer, ) - info = deferred_info[0] if info is None: return _StepResult(state=merged_outer, finalize_completed=_no_op_finalize) final_attempt_index, final_pre_state, final_merged = info @@ -1451,6 +1484,159 @@ def finalize_completed(edge_error: RuntimeGraphError | None) -> None: return _StepResult(state=merged_outer, finalize_completed=finalize_completed) + async def _step_parallel_branches_node( + self, + node: Any, # ParallelBranchesNode[StateT] — lazy import keeps the + # textual cycle off the module graph (``parallel_branches`` has a + # TYPE_CHECKING back-reference to this module). + current: str, + state: StateT, + context: _InvocationContext, + ) -> _StepResult[StateT]: + """Run one parallel-branches-as-node step through the parent's + middleware chain. + + Per pipeline-utilities §11.6: the parent's per-graph + + per-node middleware wraps the parallel-branches dispatch + as a SINGLE unit — one started event before dispatch + begins, one completed event after all branches complete + and fan-in is done. Per-branch internal events come from + the branches' subgraph executions and carry ``branch_name`` + per graph-engine §6. + + Mirrors ``_step_fan_out_node`` minus the eager + count/concurrency resolution (parallel branches has no + callable resolvers — the branch set is static at compile + time). + """ + step = context.take_step() + namespace = context.namespace_prefix + (current,) + attempt_counter: list[int] = [0] + deferred_info: list[tuple[int, StateT, StateT] | None] = [None] + + async def innermost(s: Any) -> Mapping[str, Any]: + attempt_counter[0] += 1 + # Read from ContextVar — see ``_step_function_node``'s + # ``innermost`` for the v0.16.1 propagation rule. + attempt_index = current_attempt_index() + + self._dispatch_started( + context, + current, + namespace, + step, + s, + attempt_index=attempt_index, + ) + otel_token = _attach_active_observer_span() + try: + try: + partial = await node.run_with_context(s, context) + except RuntimeGraphError as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + ) + raise + except Exception as e: + wrapped = NodeException(node_name=current, cause=e, recoverable_state=s) + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=wrapped, + attempt_index=attempt_index, + ) + raise wrapped from e + finally: + _detach_active_observer_span(otel_token) + _set_active_observer_span(None) + + try: + merged = _merge_partial(s, partial, self.reducers, current) + except (ReducerError, StateValidationError) as e: + self._dispatch_completed( + context, + current, + namespace, + step, + s, + error=e, + attempt_index=attempt_index, + ) + raise + + deferred_info[0] = (attempt_index, cast("StateT", s), cast("StateT", merged)) + return partial + + chain: ChainCall = compose_chain( + list(self.middleware) + list(node.middleware), + innermost, + ) + + observers_token = _set_active_observers(context.full_observers()) + dispatch_token = _set_active_dispatch(lambda event: _dispatch(context, event)) + namespace_token = _set_namespace_prefix(namespace) + fan_out_token = _set_fan_out_index(context.fan_out_index) + try: + try: + final_partial = await chain(state) + except RuntimeGraphError: + raise + except Exception as e: + raise NodeException(node_name=current, cause=e, recoverable_state=state) from e + finally: + _reset_fan_out_index(fan_out_token) + _reset_namespace_prefix(namespace_token) + _reset_active_dispatch(dispatch_token) + _reset_active_observers(observers_token) + merged_outer = _merge_partial(state, final_partial, self.reducers, current) + info = deferred_info[0] + saved_attempt = info[0] if info is not None else 0 + await self._maybe_save_checkpoint( + context, + node_name=current, + namespace=namespace, + step=step, + attempt_index=saved_attempt, + post_state=merged_outer, + ) + + if info is None: + return _StepResult(state=merged_outer, finalize_completed=_no_op_finalize) + final_attempt_index, final_pre_state, final_merged = info + + def finalize_completed(edge_error: RuntimeGraphError | None) -> None: + if edge_error is None: + self._dispatch_completed( + context, + current, + namespace, + step, + final_pre_state, + post_state=final_merged, + attempt_index=final_attempt_index, + ) + else: + self._dispatch_completed( + context, + current, + namespace, + step, + final_pre_state, + error=edge_error, + attempt_index=final_attempt_index, + ) + + return _StepResult(state=merged_outer, finalize_completed=finalize_completed) + @staticmethod def _dispatch_started( context: _InvocationContext, @@ -1462,6 +1648,13 @@ def _dispatch_started( attempt_index: int = 0, fan_out_config: FanOutEventConfig | None = None, ) -> None: + # Per graph-engine §6 + pipeline-utilities §11: read the + # active branch_name (set by ParallelBranchesNode inside + # each branch's task ``copy_context``) and stamp it on + # every event emitted from inside the branch. Outside any + # branch, current_branch_name() returns None. + from openarmature.observability.correlation import current_branch_name # noqa: PLC0415 + _dispatch( context, NodeEvent( @@ -1476,6 +1669,7 @@ def _dispatch_started( attempt_index=attempt_index, fan_out_index=context.fan_out_index, fan_out_config=fan_out_config, + branch_name=current_branch_name(), ), ) @@ -1492,6 +1686,8 @@ def _dispatch_completed( attempt_index: int = 0, fan_out_config: FanOutEventConfig | None = None, ) -> None: + from openarmature.observability.correlation import current_branch_name # noqa: PLC0415 + _dispatch( context, NodeEvent( @@ -1506,6 +1702,7 @@ def _dispatch_completed( attempt_index=attempt_index, fan_out_index=context.fan_out_index, fan_out_config=fan_out_config, + branch_name=current_branch_name(), ), ) diff --git a/src/openarmature/graph/errors.py b/src/openarmature/graph/errors.py index c9f9b99..ae07894 100644 --- a/src/openarmature/graph/errors.py +++ b/src/openarmature/graph/errors.py @@ -105,6 +105,18 @@ def __init__(self, node_name: str, field_name: str) -> None: self.field_name = field_name +class ParallelBranchesNoBranches(CompileError): + """Raised at registration when a parallel-branches node's + ``branches`` mapping is empty. Per pipeline-utilities §11.9 + / proposal 0011. Non-transient.""" + + category = "parallel_branches_no_branches" + + def __init__(self, node_name: str) -> None: + super().__init__(f"parallel-branches node {node_name!r}: branches mapping is empty") + self.node_name = node_name + + # ===== Runtime errors ===== @@ -127,6 +139,46 @@ def __init__(self, node_name: str, cause: BaseException, recoverable_state: Any) self.__cause__ = cause +class ParallelBranchesBranchFailed(NodeException): + """Raised when a branch's subgraph raises under + ``error_policy: 'fail_fast'``. Per pipeline-utilities §11.9 / + proposal 0011. + + Subtype of :class:`NodeException` (per §11.9: "a + ``node_exception`` subtype attached at the parallel-branches + node's level"). The existing NodeException-classifier path + handles transient classification from ``__cause__`` per §6.1 + — non-transient by default, inherits transient classification + from the wrapped exception. + + Carries ``branch_name`` as a structured field per §11.9; the + inner exception rides ``__cause__``. + """ + + category = "parallel_branches_branch_failed" + + branch_name: str + + def __init__( + self, + node_name: str, + cause: BaseException, + recoverable_state: Any, + *, + branch_name: str, + ) -> None: + # NodeException's __init__ formats the message; override + # the message format to surface the branch identity. + super().__init__(node_name, cause, recoverable_state) + self.branch_name = branch_name + # Rewrite the inherited message so trace UIs / logs see + # the branch context up front. + self.args = ( + f"parallel-branches node {node_name!r}: " + f"branch {branch_name!r} raised {type(cause).__name__}: {cause}", + ) + + class EdgeException(RuntimeGraphError): category = "edge_exception" source_node: str diff --git a/src/openarmature/graph/events.py b/src/openarmature/graph/events.py index cc05de9..498cba5 100644 --- a/src/openarmature/graph/events.py +++ b/src/openarmature/graph/events.py @@ -115,6 +115,15 @@ class NodeEvent: - ``fan_out_config`` carries resolved fan-out configuration on events from a fan-out NODE itself. See :class:`FanOutEventConfig`. ``None`` on every other event. + - ``branch_name`` is the non-empty string name of the + parallel-branches branch this event came from. ``None`` for + nodes outside any branch. Per graph-engine §6 / pipeline- + utilities §11, the combination of ``namespace``, + ``branch_name``, ``fan_out_index``, ``attempt_index``, and + ``phase`` jointly uniquely identifies an event source. + ``branch_name`` and ``fan_out_index`` are independent — both + MAY be present when a branch's subgraph contains a fan-out + (or a fan-out instance contains a parallel-branches node). Invariants: @@ -172,6 +181,16 @@ class NodeEvent: attempt_index: int = 0 fan_out_index: int | None = None fan_out_config: FanOutEventConfig | None = None + # Per pipeline-utilities §11 / graph-engine §6 (proposal 0011): + # optional non-empty string populated only on events from nodes + # that execute inside a parallel-branches branch. The + # combination of ``namespace``, ``branch_name``, + # ``fan_out_index``, ``attempt_index``, and ``phase`` jointly + # uniquely identifies an event source. ``branch_name`` and + # ``fan_out_index`` are independent; both MAY be present + # simultaneously when a branch's subgraph contains a fan-out + # (and vice versa). + branch_name: str | None = None __all__ = ["FanOutEventConfig", "NodeEvent"] diff --git a/src/openarmature/graph/middleware/retry.py b/src/openarmature/graph/middleware/retry.py index 204804d..2dc9337 100644 --- a/src/openarmature/graph/middleware/retry.py +++ b/src/openarmature/graph/middleware/retry.py @@ -23,6 +23,7 @@ from typing import Any from openarmature.llm.errors import TRANSIENT_CATEGORIES +from openarmature.observability.correlation import _reset_attempt_index, _set_attempt_index from ._core import NextCall @@ -128,18 +129,30 @@ def __init__( async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]: attempt = 0 while True: + # Spec graph-engine §6 (clarified in v0.16.1): the wrapping + # retry's attempt counter MUST propagate to events emitted + # from any inner node the retry re-invokes — including + # nodes inside subgraph / branch / fan-out-instance + # invocations the retry wraps transitively. Set on entry, + # reset on exit; Python's ContextVar token stack gives + # innermost-wins precedence for free when retry middlewares + # nest. + token = _set_attempt_index(attempt) try: - return await next_(state) - except Exception as exc: - # Spec §6.1: cancellation propagates by virtue of - # `CancelledError` extending `BaseException`, not - # `Exception` — it never enters this branch in Python. - if attempt + 1 >= self.max_attempts or not self.classifier(exc, state): - raise - if self.on_retry is not None: - await self.on_retry(exc, attempt) - await asyncio.sleep(self.backoff(attempt)) - attempt += 1 + try: + return await next_(state) + except Exception as exc: + # Spec §6.1: cancellation propagates by virtue of + # `CancelledError` extending `BaseException`, not + # `Exception` — it never enters this branch in Python. + if attempt + 1 >= self.max_attempts or not self.classifier(exc, state): + raise + if self.on_retry is not None: + await self.on_retry(exc, attempt) + await asyncio.sleep(self.backoff(attempt)) + attempt += 1 + finally: + _reset_attempt_index(token) __all__ = [ diff --git a/src/openarmature/graph/observer.py b/src/openarmature/graph/observer.py index f36bf2a..104ad5e 100644 --- a/src/openarmature/graph/observer.py +++ b/src/openarmature/graph/observer.py @@ -360,6 +360,50 @@ def descend_into_fan_out_instance( resume_invocation=self.resume_invocation, ) + def descend_into_parallel_branch( + self, + parallel_branches_node_name: str, + parent_state: State, + sub_attached: tuple[SubscribedObserver, ...], + ) -> _InvocationContext: + """Build the context for one parallel-branches branch's + subgraph invocation. + + Per pipeline-utilities §11.6 the parallel-branches node looks + to outer middleware like a single dispatch; inner-branch + events come from the branch's subgraph execution. Stamps the + namespace prefix with the parallel-branches node name so + inner events nest under it (mirrors + ``descend_into_fan_out_instance``'s namespace stamping). + + Branch identity lives on the + ``observability.correlation._branch_name_var`` ContextVar + rather than on the descend context — set inside the + branch's task closure so ``copy_context`` inherits it + through the subgraph's execution. + + Per §11.9 / §10.7 atomic-restart: drops the checkpointer + and pending_resume_states (a crash mid-dispatch re-runs the + whole parallel-branches node from scratch on resume; the + branches' inner saves wouldn't be useful). + """ + return _InvocationContext( + queue=self.queue, + graph_attached=self.graph_attached + sub_attached, + invocation_scoped=self.invocation_scoped, + step_counter=self.step_counter, + namespace_prefix=self.namespace_prefix + (parallel_branches_node_name,), + parent_states_prefix=self.parent_states_prefix + (parent_state,), + fan_out_index=self.fan_out_index, + invocation_id=self.invocation_id, + correlation_id=self.correlation_id, + checkpointer=None, + completed_positions=self.completed_positions, + resume_skip_set=frozenset(), + pending_resume_states={}, + resume_invocation=self.resume_invocation, + ) + def take_step(self) -> int: """Atomically (single-threaded asyncio) read-and-increment the shared step counter. Returns the value to assign to the just- diff --git a/src/openarmature/graph/parallel_branches.py b/src/openarmature/graph/parallel_branches.py new file mode 100644 index 0000000..3ec9411 --- /dev/null +++ b/src/openarmature/graph/parallel_branches.py @@ -0,0 +1,365 @@ +# Spec: realizes pipeline-utilities §11 (parallel branches). + +"""Parallel branches — concurrent dispatch of M heterogeneous compiled subgraphs. + +Counterpart to :mod:`.fan_out`. Fan-out is data-driven (N items, +one subgraph, instantiated N times); parallel branches is +topology-driven (M heterogeneous compiled subgraphs declared +statically, run concurrently within a single parent invocation). + +Each branch's :class:`BranchSpec` carries its own compiled +subgraph (with potentially different state schema, middleware, +topology), its own ``inputs`` / ``outputs`` projection mappings, +and its own optional ``middleware`` wrapping the whole branch +invocation as a unit (§11.7). + +Buffer-then-apply semantics per §11.4: contributions are +collected during dispatch and merged deterministically once at +node completion, using the parent's reducer for each output +field. Branch insertion order determines both dispatch order +(§11.8) and merge tie-breaking when two branches write the same +parent field. + +Error policies per §11.5: + +- ``fail_fast``: first failure cancels still-running branches; + the buffered contributions are discarded; the parallel-branches + node raises ``ParallelBranchesBranchFailed`` with the failing + branch's exception as ``__cause__``. ``recoverable_state`` + equals the parent state at the moment the node entered. +- ``collect``: all branches run to completion; successful + branches' contributions merge; failed branches' errors land in + the optional ``errors_field``. +""" + +from __future__ import annotations + +import asyncio +import contextvars +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal + +from openarmature.observability.correlation import ( + _reset_branch_name, + _set_branch_name, +) + +from .errors import ParallelBranchesBranchFailed +from .middleware import ChainCall, Middleware, compose_chain +from .state import State + +if TYPE_CHECKING: + from .compiled import CompiledGraph + from .observer import _InvocationContext + +_log = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BranchSpec[ChildT: State]: + """One entry in a :class:`ParallelBranchesNode`'s branch mapping. + + Branches are heterogeneous: each spec MAY reference a different + compiled subgraph with a different state schema. ``inputs`` / + ``outputs`` follow the same shape as subgraph projection + mappings (proposal 0002). + + Validation lives on the builder side + (``GraphBuilder.add_parallel_branches_node``): + ``mapping_references_undeclared_field`` for inputs/outputs + referencing undeclared fields; ``parallel_branches_no_branches`` + for empty ``branches`` maps; ``ValueError`` for empty-string + branch names. + """ + + subgraph: CompiledGraph[ChildT] + inputs: Mapping[str, str] = field(default_factory=dict[str, str]) + outputs: Mapping[str, str] = field(default_factory=dict[str, str]) + middleware: tuple[Middleware, ...] = () + + +@dataclass(frozen=True) +class ParallelBranchesNode[ParentT: State]: + """A node that dispatches M heterogeneous compiled subgraphs + concurrently per spec §11. + + The Node Protocol contract requires ``name``, ``middleware``, + and ``run``. Like :class:`FanOutNode`, the engine recognizes + this type in ``_invoke`` and calls ``run_with_context`` so the + dispatcher has access to the invocation context for + observer-attribution + namespace descent. ``run`` exists for + Protocol conformance only and raises if anyone calls it + directly. + """ + + name: str + branches: dict[str, BranchSpec[Any]] + error_policy: Literal["fail_fast", "collect"] = "fail_fast" + errors_field: str | None = None + middleware: tuple[Middleware, ...] = () + + async def run(self, state: ParentT) -> Mapping[str, Any]: + del state + raise RuntimeError( + "ParallelBranchesNode is dispatched by the graph engine; if you're " + "seeing this, you've likely instantiated it outside an engine " + "context (e.g., calling node.run(state) directly instead of " + "compiled.invoke)." + ) + + async def run_with_context( + self, + state: ParentT, + context: _InvocationContext, + ) -> Mapping[str, Any]: + """Execute the parallel-branches dispatch and return the + merged partial update. + + Snapshot parent state, project per-branch initial states, + dispatch M branches concurrently in insertion order, then + either fail-fast on first error (cancelling the rest) or + run to completion and merge per the configured error policy. + """ + # ``contributions`` is the buffer per §11.4 — keyed by branch + # name, holds each successful branch's projected outputs + # (parent_field -> exit_value mapping) until the dispatch + # completes and the dispatcher applies them in insertion + # order via the parent's reducers. + contributions: dict[str, Mapping[str, Any]] = {} + # ``errors`` collects per-branch failures under ``collect``; + # under ``fail_fast`` the dispatcher raises before this is + # consulted. + errors: list[dict[str, str]] = [] + + async def run_branch(branch_name: str, spec: BranchSpec[Any]) -> Mapping[str, Any]: + # Set the branch_name ContextVar inside the branch's + # task scope so the OTel observer sees it on every + # inner-node span. The task copies the current context + # at spawn time, so this set happens inside the spawned + # task body — not in the dispatcher loop. + token = _set_branch_name(branch_name) + try: + # Per §11.2 projection in: subgraph fields not in + # ``inputs`` use the subgraph's declared defaults; + # named subgraph fields are initialized from the + # corresponding parent field. + parent_dump = state.model_dump() + init: dict[str, Any] = {} + for sub_field, parent_field in spec.inputs.items(): + init[sub_field] = parent_dump[parent_field] + initial = spec.subgraph.state_cls(**init) + + child_context = context.descend_into_parallel_branch( + parallel_branches_node_name=self.name, + parent_state=state, + sub_attached=tuple(spec.subgraph._attached_observers), # noqa: SLF001 + ) + + async def innermost(s: Any) -> Mapping[str, Any]: + final_branch_state = await spec.subgraph._invoke(s, child_context) # noqa: SLF001 + # Per §11.4 projection out: only fields named + # in ``outputs`` contribute back to parent + # state; unnamed subgraph fields are discarded. + return { + parent_field: getattr(final_branch_state, sub_field) + for parent_field, sub_field in spec.outputs.items() + } + + chain: ChainCall = compose_chain(spec.middleware, innermost) + return await chain(initial) + finally: + _reset_branch_name(token) + + # Spawn one task per branch, in insertion order. Per §11.8 + # the dispatch order is the branches dict's insertion order; + # ``started`` events from the inner subgraphs interleave + # arbitrarily but the branch-level dispatch ordering is + # deterministic. + ctx = contextvars.copy_context() + tasks: list[tuple[str, asyncio.Task[Mapping[str, Any]]]] = [] + for branch_name, spec in self.branches.items(): + task = asyncio.create_task( + run_branch(branch_name, spec), + context=ctx.copy(), + ) + tasks.append((branch_name, task)) + + if self.error_policy == "fail_fast": + return await self._fail_fast(state, tasks, contributions) + return await self._collect(state, tasks, contributions, errors) + + async def _fail_fast( + self, + parent_state: Any, + tasks: list[tuple[str, asyncio.Task[Mapping[str, Any]]]], + contributions: dict[str, Mapping[str, Any]], + ) -> Mapping[str, Any]: + """Fail-fast policy per spec §11.5. + + Wait for all branches; on first failure, cancel the rest + and raise ``ParallelBranchesBranchFailed`` with the failing + branch's exception as ``__cause__``. Buffered contributions + are discarded (collect-then-apply means they never reached + parent state). ``recoverable_state`` equals the parent + state at the moment the node entered. + """ + task_map = {t: name for name, t in tasks} + try: + done, pending = await asyncio.wait( + [t for _, t in tasks], + return_when=asyncio.FIRST_EXCEPTION, + ) + except BaseException: + # Defensive: if the dispatcher itself is cancelled, + # drain the children before propagating. + for _, t in tasks: + t.cancel() + await asyncio.gather(*(t for _, t in tasks), return_exceptions=True) + raise + + # Find the first task that raised; cancel the rest. + failed_name: str | None = None + failed_cause: BaseException | None = None + for t in done: + if t.exception() is not None: + failed_name = task_map[t] + failed_cause = t.exception() + break + + if failed_cause is None: + # All tasks finished without raising (the wait can + # return early when the last task finishes successfully). + # Buffer the contributions in branch insertion order. + for name, t in tasks: + contributions[name] = t.result() + return self._merge_contributions(contributions) + + # Cancel remaining + any pending; drain to absorb + # CancelledError so it doesn't propagate as unhandled. + for t in pending: + t.cancel() + # Subtle case worth flagging: a second task may race past + # the cancellation point with its own raised exception (a + # near-simultaneous failure). ``return_exceptions=True`` + # absorbs both the CancelledErrors AND that second exception + # into the gather's return list. We discard them silently — + # the raise is committed to the FIRST failure observed + # above; logging stragglers at DEBUG helps post-mortem + # analysis without muddying the raise contract. + drained = await asyncio.gather( + *(t for _, t in tasks if not t.done()), + return_exceptions=True, + ) + for residual in drained: + if isinstance(residual, BaseException) and not isinstance(residual, asyncio.CancelledError): + _log.debug( + "parallel-branches node %r: post-cancellation residual exception " + "(discarded; raise is committed to %r): %r", + self.name, + failed_name, + residual, + ) + + raise ParallelBranchesBranchFailed( + self.name, + failed_cause, + parent_state, + branch_name=failed_name or "", + ) from failed_cause + + async def _collect( + self, + parent_state: Any, + tasks: list[tuple[str, asyncio.Task[Mapping[str, Any]]]], + contributions: dict[str, Mapping[str, Any]], + errors: list[dict[str, str]], + ) -> Mapping[str, Any]: + """Collect policy per spec §11.5. + + All branches run to completion regardless of individual + failures. Successful branches' contributions go to the + buffer; failed branches' errors land in ``errors_field`` + (when configured). The node returns normally. + """ + del parent_state + results = await asyncio.gather( + *(t for _, t in tasks), + return_exceptions=True, + ) + for (name, _task), result in zip(tasks, results, strict=True): + if isinstance(result, BaseException): + errors.append( + { + "branch_name": name, + "category": getattr(result, "category", type(result).__name__), + "message": str(result), + "cause_type": type( + result.__cause__ if result.__cause__ is not None else result + ).__name__, + } + ) + else: + contributions[name] = result + partial = dict(self._merge_contributions(contributions)) + if self.errors_field is not None: + partial[self.errors_field] = errors + return partial + + def _merge_contributions( + self, + contributions: dict[str, Mapping[str, Any]], + ) -> dict[str, Any]: + """Flatten per-branch contributions into a single partial. + + Per §11.4 + §11.8: contributions apply in branch insertion + order, using each parent field's reducer. The actual reducer + application happens at ``_merge_partial`` in compiled.py + when the engine merges this partial into parent state. Here + we just flatten the per-branch contributions into a dict + of ``parent_field -> [values in branch insertion order]`` + when multiple branches write the same field, OR + ``parent_field -> value`` when only one branch writes it. + + Returning multi-value lists lets ``_merge_partial`` route + each value through the parent's reducer in order. + """ + # First pass: detect fields written by multiple branches. + field_contributors: dict[str, list[Any]] = {} + for branch_name in self.branches.keys(): + if branch_name not in contributions: + continue + for parent_field, value in contributions[branch_name].items(): + field_contributors.setdefault(parent_field, []).append(value) + + partial: dict[str, Any] = {} + for parent_field, values in field_contributors.items(): + if len(values) == 1: + partial[parent_field] = values[0] + else: + # Multi-branch contributions to the same field: the + # parent reducer applies in branch insertion order. + # The engine's _merge_partial sees a list and routes + # each entry through the parent's reducer; we lift + # the multi-write case via a sentinel marker the + # engine recognizes. + partial[parent_field] = _MultiContribution(values=tuple(values)) + return partial + + +@dataclass(frozen=True) +class _MultiContribution: + """Sentinel for ``_merge_partial`` indicating that multiple + branches contributed to the same parent field. The engine + applies the parent's reducer to each value in sequence, + preserving branch insertion order per §11.8. + """ + + values: tuple[Any, ...] + + +__all__ = [ + "BranchSpec", + "ParallelBranchesNode", +] diff --git a/src/openarmature/observability/correlation.py b/src/openarmature/observability/correlation.py index 3b35881..5bbf078 100644 --- a/src/openarmature/observability/correlation.py +++ b/src/openarmature/observability/correlation.py @@ -272,6 +272,34 @@ def _reset_fan_out_index(token: Token[int | None]) -> None: _fan_out_index_var.reset(token) +# Per pipeline-utilities §11 / proposal 0011: when a node runs +# inside a parallel-branches branch, this carries the branch's +# name. Mirrors ``_fan_out_index_var`` — set by the parallel- +# branches dispatcher inside each branch's task ``copy_context``, +# inherited through the branch's subgraph execution, read by the +# OTel observer to populate ``openarmature.branch_name`` on +# inner-node spans. +_branch_name_var: ContextVar[str | None] = ContextVar("openarmature.branch_name", default=None) + + +def current_branch_name() -> str | None: + """Return the branch_name of the node currently executing, or + ``None`` outside any parallel-branches branch body (top-level + nodes, subgraph dispatch, fan-out instance bodies that aren't + themselves inside a branch). + """ + return _branch_name_var.get() + + +def _set_branch_name(value: str | None) -> Token[str | None]: + """Set the calling node's branch_name. Internal — engine-only.""" + return _branch_name_var.set(value) + + +def _reset_branch_name(token: Token[str | None]) -> None: + _branch_name_var.reset(token) + + _attempt_index_var: ContextVar[int] = ContextVar("openarmature.attempt_index", default=0) @@ -364,6 +392,7 @@ def _reset_active_observer_span(token: Token[object | None]) -> None: "current_active_observer_span", "current_active_observers", "current_attempt_index", + "current_branch_name", "current_correlation_id", "current_dispatch", "current_fan_out_index", @@ -377,6 +406,7 @@ def _reset_active_observer_span(token: Token[object | None]) -> None: "_reset_active_observer_span", "_reset_active_observers", "_reset_attempt_index", + "_reset_branch_name", "_reset_correlation_id", "_reset_fan_out_index", "_reset_invocation_id", @@ -385,6 +415,7 @@ def _reset_active_observer_span(token: Token[object | None]) -> None: "_set_active_observer_span", "_set_active_observers", "_set_attempt_index", + "_set_branch_name", "_set_correlation_id", "_set_fan_out_index", "_set_invocation_id", diff --git a/src/openarmature/observability/otel/observer.py b/src/openarmature/observability/otel/observer.py index 3e25f80..774af10 100644 --- a/src/openarmature/observability/otel/observer.py +++ b/src/openarmature/observability/otel/observer.py @@ -1066,6 +1066,14 @@ def _node_attrs(self, event: NodeEvent, correlation_id: str | None) -> dict[str, } if event.fan_out_index is not None: attrs["openarmature.node.fan_out_index"] = event.fan_out_index + # Per pipeline-utilities §11 / proposal 0011: surface + # branch_name on every inner-node span within a + # parallel-branches branch. Independent of fan_out_index — + # both MAY be present when a branch's subgraph contains a + # fan-out, and the spec §6 uniqueness invariant treats them + # as independent identification slots. + if event.branch_name is not None: + attrs["openarmature.branch_name"] = event.branch_name if correlation_id is not None: attrs["openarmature.correlation_id"] = correlation_id # Per spec §5.4 + proposal 0013 (v0.10.0): fan-out node spans diff --git a/tests/conformance/adapter.py b/tests/conformance/adapter.py index b51b31d..e5c1ba0 100644 --- a/tests/conformance/adapter.py +++ b/tests/conformance/adapter.py @@ -10,6 +10,7 @@ from __future__ import annotations +import asyncio import copy from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field @@ -19,12 +20,14 @@ from openarmature.graph import ( END, + BranchSpec, CompiledGraph, EndSentinel, ExplicitMapping, FanOutNode, FieldNameMatching, GraphBuilder, + ParallelBranchesNode, ProjectionStrategy, Reducer, State, @@ -56,6 +59,14 @@ def _parse_type(s: str) -> Any: return float if s == "bool": return bool + # Unparameterized container types — parallel-branches fixtures + # 034/035/037 use ``dict`` and ``list`` as state-field types + # for accumulator slots (branch_errors, merged_dict, collected_labels) + # where the element shape is heterogeneous across branches. + if s == "dict": + return dict[str, Any] + if s == "list": + return list[dict[str, Any]] if s.startswith("list<") and s.endswith(">"): return list[_parse_type(s[5:-1])] if s.startswith("dict<") and s.endswith(">"): @@ -357,6 +368,23 @@ async def fn(_state: Any) -> Mapping[str, Any]: return fn +def _wrap_with_sleep( + fn: Callable[[Any], Awaitable[Mapping[str, Any]]], + sleep_ms: int, +) -> Callable[[Any], Awaitable[Mapping[str, Any]]]: + # ``sleep_ms`` companion modifier on a NodeSpec — sleep that many + # milliseconds before the wrapped body fires. Used by parallel-branches + # fixtures 033 (slow third branch for fail-fast cancellation) and 037 + # (randomized completion timing to verify insertion-order determinism). + delay = sleep_ms / 1000.0 + + async def fn_with_sleep(state: Any) -> Mapping[str, Any]: + await asyncio.sleep(delay) + return await fn(state) + + return fn_with_sleep + + @dataclass(frozen=True) class _TracingFanOutNode(FanOutNode[State, State]): """Conformance helper: a FanOutNode that appends its name to a shared @@ -383,6 +411,24 @@ async def run_with_context( ) +@dataclass(frozen=True) +class _TracingParallelBranchesNode(ParallelBranchesNode[State]): + """Conformance helper: a ParallelBranchesNode that appends its name + to the shared trace list once when the engine runs it. The + parallel-branches dispatcher itself counts as one engine step from + the parent's POV per §11.6, mirroring the fan-out tracing wrapper.""" + + trace_list: list[str] = field(default_factory=list[str]) + + async def run_with_context( + self, + state: State, + context: _InvocationContext, + ) -> Mapping[str, Any]: + self.trace_list.append(self.name) + return await super().run_with_context(state, context) + + @dataclass(frozen=True) class _TracingSubgraphNode(SubgraphNode[State, State]): """Conformance helper: a SubgraphNode that appends its name to a shared @@ -457,6 +503,7 @@ def build_graph( node_middleware: Mapping[str, Sequence[Any]] | None = None, graph_middleware: Sequence[Any] | None = None, fan_out_instance_middleware: Mapping[str, Sequence[Any]] | None = None, + parallel_branches_branch_middleware: Mapping[str, Mapping[str, Sequence[Any]]] | None = None, ) -> BuiltGraph: """Translate a graph-shaped fixture block into a `BuiltGraph`. @@ -486,6 +533,7 @@ def build_graph( subgraphs = subgraphs or {} node_middleware = node_middleware or {} fan_out_instance_middleware = fan_out_instance_middleware or {} + parallel_branches_branch_middleware = parallel_branches_branch_middleware or {} for mw in graph_middleware or (): builder.add_middleware(mw) @@ -505,7 +553,8 @@ def build_graph( trace_list=trace, middleware=per_node_mw, ) - elif "fan_out" in node_spec: + continue + if "fan_out" in node_spec: _add_fan_out_node( builder, node_name, @@ -514,55 +563,47 @@ def build_graph( trace, instance_middleware=fan_out_instance_middleware.get(node_name, ()), ) - elif "raises" in node_spec: - builder.add_node( + continue + if "parallel_branches" in node_spec: + _add_parallel_branches_node( + builder, node_name, - _make_raising_fn(node_name, node_spec["raises"], trace), - middleware=per_node_mw, + node_spec["parallel_branches"], + subgraphs, + trace, + branch_middleware=parallel_branches_branch_middleware.get(node_name, {}), ) + continue + + body: Callable[[Any], Awaitable[Mapping[str, Any]]] + if "raises" in node_spec: + body = _make_raising_fn(node_name, node_spec["raises"], trace) elif "flaky" in node_spec: - builder.add_node( - node_name, - _make_flaky_fn(node_name, node_spec["flaky"], trace), - middleware=per_node_mw, - ) + body = _make_flaky_fn(node_name, node_spec["flaky"], trace) elif "flaky_by_index" in node_spec: - builder.add_node( - node_name, - _make_flaky_by_index_fn(node_name, node_spec["flaky_by_index"], trace), - middleware=per_node_mw, - ) + body = _make_flaky_by_index_fn(node_name, node_spec["flaky_by_index"], trace) elif "flaky_instance_only" in node_spec: - builder.add_node( - node_name, - _make_flaky_instance_only_fn(node_name, node_spec["flaky_instance_only"], trace), - middleware=per_node_mw, - ) + body = _make_flaky_instance_only_fn(node_name, node_spec["flaky_instance_only"], trace) elif "update" in node_spec: - builder.add_node( - node_name, - _make_update_fn(node_name, node_spec["update"], trace), - middleware=per_node_mw, - ) + body = _make_update_fn(node_name, node_spec["update"], trace) elif "update_pure" in node_spec: - builder.add_node( - node_name, - _make_pure_update_fn(node_name, node_spec["update_pure"], trace), - middleware=per_node_mw, - ) + body = _make_pure_update_fn(node_name, node_spec["update_pure"], trace) elif "update_from_field" in node_spec: - builder.add_node( - node_name, - _make_update_from_field_fn(node_name, node_spec["update_from_field"], trace), - middleware=per_node_mw, - ) + body = _make_update_from_field_fn(node_name, node_spec["update_from_field"], trace) else: raise ValueError( f"node {node_name!r} has no recognized directive " "(update / update_pure / update_from_field / raises / flaky / " - "flaky_by_index / flaky_instance_only / fan_out / subgraph)" + "flaky_by_index / flaky_instance_only / fan_out / parallel_branches / " + "subgraph)" ) + sleep_ms = node_spec.get("sleep_ms") + if sleep_ms is not None: + body = _wrap_with_sleep(body, int(sleep_ms)) + + builder.add_node(node_name, body, middleware=per_node_mw) + for edge_spec in spec.get("edges", []): source = edge_spec["from"] if "to" in edge_spec: @@ -623,6 +664,8 @@ def _record_event(event: NodeEvent) -> dict[str, Any]: rec["error"] = event.error.category if event.fan_out_index is not None: rec["fan_out_index"] = event.fan_out_index + if event.branch_name is not None: + rec["branch_name"] = event.branch_name return rec @@ -734,6 +777,57 @@ def _add_fan_out_node( ) +def _add_parallel_branches_node( + builder: GraphBuilder[Any], + node_name: str, + cfg: Mapping[str, Any], + subgraphs: Mapping[str, CompiledGraph[State]], + trace: list[str], + *, + branch_middleware: Mapping[str, Sequence[Any]], +) -> None: + """Translate a fixture's ``parallel_branches:`` block into a + ``builder.add_parallel_branches_node`` call. + + Each branch's ``subgraph`` name resolves against the shared + ``subgraphs`` registry (built from the fixture's top-level + ``subgraphs:`` block). ``branch_middleware`` maps branch-name to a + pre-translated middleware list; the test driver populates it from + each branch's ``middleware:`` block. + """ + branches_cfg = cast("dict[str, dict[str, Any]]", cfg["branches"]) + branches: dict[str, BranchSpec[Any]] = {} + for branch_name, branch_cfg in branches_cfg.items(): + sub_compiled = subgraphs[branch_cfg["subgraph"]] + branches[branch_name] = BranchSpec( + subgraph=sub_compiled, + inputs=dict(branch_cfg.get("inputs") or {}), + outputs=dict(branch_cfg.get("outputs") or {}), + middleware=tuple(branch_middleware.get(branch_name, ())), + ) + + builder.add_parallel_branches_node( + node_name, + branches=branches, + error_policy=cfg.get("error_policy", "fail_fast"), + errors_field=cfg.get("errors_field"), + ) + + # Swap the registered node for a tracing variant so the + # conformance trace records the dispatcher as one engine step. The + # builder's validation has already run; we only replace the stored + # Node instance. + original = cast("ParallelBranchesNode[State]", builder._nodes[node_name]) + builder._nodes[node_name] = _TracingParallelBranchesNode( + name=original.name, + branches=original.branches, + error_policy=original.error_policy, + errors_field=original.errors_field, + middleware=original.middleware, + trace_list=trace, + ) + + def _resolve_callable_int_resolver(cfg: Mapping[str, Any]) -> Callable[[Any], int]: """Build a state-reader callable from a fixture's callable config. diff --git a/tests/conformance/harness/directives.py b/tests/conformance/harness/directives.py index 2a3b40f..e7e58e6 100644 --- a/tests/conformance/harness/directives.py +++ b/tests/conformance/harness/directives.py @@ -235,6 +235,33 @@ class FanOutSpec(_AllowExtras): instance_middleware: list[MiddlewareSpec] | None = None +class ParallelBranchSpec(_AllowExtras): + """One entry inside a ``parallel_branches.branches`` mapping. + + Permissive on extras because fixtures may carry extra knobs + (e.g., per-branch annotations the harness ignores). + """ + + subgraph: str + inputs: dict[str, str] | None = None + outputs: dict[str, str] | None = None + middleware: list[MiddlewareSpec] | None = None + + +class ParallelBranchesSpec(_AllowExtras): + """``parallel_branches:`` block on a NodeSpec (pipeline-utilities §11). + + Mirrors :class:`FanOutSpec` but topology-driven: M heterogeneous + branches, each referencing a different compiled subgraph by name + against the case's top-level ``subgraphs:`` block. Branch insertion + order is preserved per §11.8. + """ + + branches: dict[str, ParallelBranchSpec] + error_policy: Literal["fail_fast", "collect"] | None = None + errors_field: str | None = None + + class CallsLlmSpec(_AllowExtras): """LLM-using node: sends ``messages`` to the harness's mock provider and stores the response (assistant content) in ``stores_response_in``. @@ -294,6 +321,7 @@ class NodeSpec(_ForbidExtras): raises: str | None = None subgraph: str | None = None fan_out: FanOutSpec | None = None + parallel_branches: ParallelBranchesSpec | None = None flaky: FlakySpec | None = None flaky_by_index: FlakyByIndexSpec | None = None flaky_per_index: FlakyPerIndexSpec | None = None @@ -309,6 +337,13 @@ class NodeSpec(_ForbidExtras): also_emits_via_global_tracer: GlobalTracerSpec | None = None # Pair with ``raises`` to specify the error category (graph-engine §4). error_category: str | None = None + # Parallel-branches fixtures (033, 037): the node sleeps this many + # milliseconds before its update fires. Used to force deterministic + # branch-completion ordering (037 — different branches finish at + # different wall-clock times yet final state must be insertion-order + # deterministic per §11.8) and to slow a third branch so fail-fast + # cancellation has time to land before it finishes (033). + sleep_ms: int | None = None _PRIMARY_FIELDS = ( "update", @@ -318,6 +353,7 @@ class NodeSpec(_ForbidExtras): "raises", "subgraph", "fan_out", + "parallel_branches", "flaky", "flaky_by_index", "flaky_per_index", diff --git a/tests/conformance/harness/expectations.py b/tests/conformance/harness/expectations.py index 5d1b2b7..c3812b8 100644 --- a/tests/conformance/harness/expectations.py +++ b/tests/conformance/harness/expectations.py @@ -141,6 +141,12 @@ class PipelineUtilitiesExpected(_ForbidExtras): # - dict[recorder_name, list[record]] when multiple recorders (001). # - list[record] flat when a single recorder. trace_records: Any = None + # Parallel-branches fixtures (032-038). On fail_fast, + # ``recoverable_state`` carries the pre-entry parent state + # snapshot per spec §11.5; the harness asserts it equals the + # ``recoverable_state`` attached to the raised + # ``ParallelBranchesBranchFailed``. + recoverable_state: dict[str, Any] | None = None # --------------------------------------------------------------------------- @@ -201,6 +207,7 @@ class ObservabilityExpected(_ForbidExtras): "timing_records", "trace_records", "expected_observer_event", + "recoverable_state", } ) _OBSERVABILITY_KEYS = frozenset( diff --git a/tests/conformance/test_conformance.py b/tests/conformance/test_conformance.py index 9f79246..3d2d902 100644 --- a/tests/conformance/test_conformance.py +++ b/tests/conformance/test_conformance.py @@ -70,9 +70,8 @@ def _fixture_id(path: Path) -> str: # passes." Each subsequent PR drops its own rows as it lands the # underlying support. _DEFERRED_FIXTURES: dict[str, str] = { - # proposal 0011 — parallel branches; adds ``branch_name`` to - # NodeEvent (PR-5 of the batch) - "021-observer-branch-name": "0011 parallel branches (PR-5)", + # proposal 0011 — parallel branches; fixture 021 (``branch_name`` + # field on NodeEvent) runs through this driver as of PR-5. } @@ -319,6 +318,9 @@ async def _run_runtime_case(spec: Mapping[str, Any], fixture_id: str) -> None: f"delivery_order mismatch: actual={delivery}, expected={expected_delivery}" ) + if "observer_event_invariants" in expected: + _check_event_invariants(expected["observer_event_invariants"], observer_fixtures) + # 018 — registering an observer with an empty `phases` set raises at # registration time per spec §6. if expected.get("empty_phases_raises_at_registration"): @@ -353,6 +355,52 @@ async def _run_runtime_case(spec: Mapping[str, Any], fixture_id: str) -> None: # --------------------------------------------------------------------------- +def _check_event_invariants( + invariants: Mapping[str, Any], + observer_fixtures: Mapping[str, ObserverFixture], +) -> None: + """Verify ``observer_event_invariants`` block contents against the + captured observer events. Each named invariant has a recognized + shape used by one or more fixtures (021 parallel-branches + branch_name on inner-node events). + + Fixtures consult the first observer's events as the canonical + stream; multi-observer fixtures author their assertions against the + `observer_events` block instead. + """ + if not observer_fixtures: + return + first_obs = next(iter(observer_fixtures.values())) + events = first_obs.events + + no_branch_name_cfg = cast( + "dict[str, Any] | None", + invariants.get("outermost_events_have_no_branch_name"), + ) + if no_branch_name_cfg is not None: + node_names = cast("list[str]", no_branch_name_cfg.get("nodes") or []) + for ev in events: + if ev["node_name"] in node_names: + assert "branch_name" not in ev, ( + f"outermost node {ev['node_name']!r} event MUST NOT carry " + f"branch_name; got {ev.get('branch_name')!r}" + ) + + inner_branch_cfg = cast( + "dict[str, str] | None", + invariants.get("inner_events_carry_correct_branch_name"), + ) + if inner_branch_cfg is not None: + for ev in events: + expected_branch = inner_branch_cfg.get(ev["node_name"]) + if expected_branch is None: + continue + assert ev.get("branch_name") == expected_branch, ( + f"inner node {ev['node_name']!r} expected branch_name={expected_branch!r}, " + f"got {ev.get('branch_name')!r}" + ) + + async def _run_fixture_020(spec: Mapping[str, Any]) -> None: cases = cast("list[dict[str, Any]]", spec["cases"]) for case in cases: diff --git a/tests/conformance/test_fixture_parsing.py b/tests/conformance/test_fixture_parsing.py index 128b185..0e9d4d3 100644 --- a/tests/conformance/test_fixture_parsing.py +++ b/tests/conformance/test_fixture_parsing.py @@ -31,15 +31,10 @@ def _id(case: tuple[str, Path]) -> str: # branch_name) to succeed; those shapes ship with their respective PRs. # Keyed by the test ID format ``/``. _DEFERRED_FIXTURES: dict[str, str] = { - # proposal 0011 — parallel branches (PR-5) - "graph-engine/021-observer-branch-name": "0011 parallel branches (PR-5)", - "pipeline-utilities/032-parallel-branches-basic": "0011 parallel branches (PR-5)", - "pipeline-utilities/033-parallel-branches-fail-fast": "0011 parallel branches (PR-5)", - "pipeline-utilities/034-parallel-branches-collect": "0011 parallel branches (PR-5)", - "pipeline-utilities/035-parallel-branches-different-state-schemas": "0011 parallel branches (PR-5)", - "pipeline-utilities/036-parallel-branches-with-branch-middleware-retry": "0011 parallel branches (PR-5)", - "pipeline-utilities/037-parallel-branches-determinism": "0011 parallel branches (PR-5)", - "pipeline-utilities/038-parallel-branches-compose-with-fan-out": "0011 parallel branches (PR-5)", + # proposal 0011's parallel-branches fixtures (032-038 + + # graph-engine/021) were removed from this list as part of + # PR-5; the typed harness parses the parallel_branches: + # node shape via the new ParallelBranchesSpec directive model. # proposal 0014's state-migration fixtures (039-046) were removed # from this list as part of PR-4; the CasesFixture model already # parses the seeded_record / migrations shape via its permissive diff --git a/tests/conformance/test_pipeline_utilities.py b/tests/conformance/test_pipeline_utilities.py index 877dd18..7be9b35 100644 --- a/tests/conformance/test_pipeline_utilities.py +++ b/tests/conformance/test_pipeline_utilities.py @@ -24,6 +24,7 @@ from openarmature.graph import ( NodeException, + ParallelBranchesBranchFailed, RuntimeGraphError, ) from openarmature.graph.middleware import ( @@ -35,7 +36,7 @@ deterministic_backoff, ) -from .adapter import build_graph +from .adapter import ObserverFixture, build_graph, make_observer_fn from .middleware_seam import ( ErrorRaiserMiddleware, ErrorRecoveryMiddleware, @@ -69,8 +70,12 @@ def _load(path: Path) -> dict[str, Any]: # Phase 3 target: fan-out (proposal 0005 PU side) covers fixtures 017-023. -# Phase 5 will pick up the checkpointing fixtures (024-031). -_PHASE_3_LAST = 23 +# Phase 5 will pick up the checkpointing fixtures (024-031). PR-5 +# (proposal 0011) drives fixtures 032-038 through this same harness. +# State-migration fixtures 039-047 run via a dedicated runner +# (``test_state_migration.py``); they need a separate driver because +# the `cases:` shape carries seeded-record + migrations + resume blocks. +_LAST_DRIVEN_FIXTURE = 38 def _fixture_paths() -> list[Path]: @@ -81,7 +86,7 @@ def _fixture_paths() -> list[Path]: number = int(p.stem.split("-", 1)[0]) except ValueError: continue - if number <= _PHASE_3_LAST: + if number <= _LAST_DRIVEN_FIXTURE: out.append(p) return out @@ -96,23 +101,24 @@ def _fixture_id(path: Path) -> str: # passes." Each subsequent PR drops its own rows as it lands the # underlying support. _DEFERRED_FIXTURES: dict[str, str] = { - # proposal 0011 — parallel branches (PR-5 of the batch) - "032-parallel-branches-basic": "0011 parallel branches (PR-5)", - "033-parallel-branches-fail-fast": "0011 parallel branches (PR-5)", - "034-parallel-branches-collect": "0011 parallel branches (PR-5)", - "035-parallel-branches-different-state-schemas": "0011 parallel branches (PR-5)", - "036-parallel-branches-with-branch-middleware-retry": "0011 parallel branches (PR-5)", - "037-parallel-branches-determinism": "0011 parallel branches (PR-5)", - "038-parallel-branches-compose-with-fan-out": "0011 parallel branches (PR-5)", - # proposal 0014 — state migration (PR-4 of the batch) - "039-state-migration-additive-field": "0014 state migration (PR-4)", - "040-state-migration-chain": "0014 state migration (PR-4)", - "041-state-migration-missing": "0014 state migration (PR-4)", - "042-state-migration-versions-match-no-op": "0014 state migration (PR-4)", - "043-state-migration-parent-states-migrated": "0014 state migration (PR-4)", - "044-state-migration-post-migration-deserialization-fails": "0014 state migration (PR-4)", - "045-state-migration-no-path-in-registry": "0014 state migration (PR-4)", - "046-state-migration-function-raises": "0014 state migration (PR-4)", + # proposal 0011 — parallel branches (PR-5 of the batch) — driven + # by the harness as of this PR; the 8 fixtures (032-038 + + # graph-engine/021) parse + run through the engine. + # proposal 0014 — state migration (PR-4 of the batch) — driven + # by ``test_state_migration.py`` (a separate runner that handles + # the cases-shape seeded_record + migrations + resume blocks). + # Checkpointing fixtures (024-031, proposal 0008) — driven by + # ``test_checkpoint.py`` because their cases-shape carries + # ``first_run_expected_error`` + ``resume:`` blocks that this + # driver doesn't recognize. + "024-checkpoint-save-on-every-completed-event": "checkpointing (test_checkpoint.py)", + "025-checkpoint-resume-from-completed-position": "checkpointing (test_checkpoint.py)", + "026-checkpoint-record-shape": "checkpointing (test_checkpoint.py)", + "027-checkpoint-attempt-index-resets-on-resume": "checkpointing (test_checkpoint.py)", + "028-checkpoint-fan-out-atomic-restart": "checkpointing (test_checkpoint.py)", + "029-checkpoint-subgraph-resume": "checkpointing (test_checkpoint.py)", + "030-checkpoint-not-found": "checkpointing (test_checkpoint.py)", + "031-checkpoint-correlation-id-preserved-across-resume": "checkpointing (test_checkpoint.py)", } @@ -266,6 +272,35 @@ def _translate_middleware_block( return graph_mw, node_mw +def _translate_parallel_branches_branch_middleware( + spec: Mapping[str, Any], + sinks: CaptureSinks, + clock: Callable[[], float] | None = None, +) -> dict[str, dict[str, list[Middleware]]]: + """Walk ``spec.nodes`` for parallel_branches blocks with per-branch + ``middleware:`` and translate each into a list of Middleware + instances. Returned map is keyed by parallel-branches node name + then branch name (per spec §11.7 branch middleware) and consumed by + build_graph's ``parallel_branches_branch_middleware`` kwarg.""" + out: dict[str, dict[str, list[Middleware]]] = {} + nodes = cast("dict[str, dict[str, Any]]", spec.get("nodes") or {}) + for node_name, node_spec in nodes.items(): + pb_cfg_raw = node_spec.get("parallel_branches") + if not isinstance(pb_cfg_raw, dict): + continue + pb_cfg = cast("dict[str, Any]", pb_cfg_raw) + branches_cfg = cast("dict[str, dict[str, Any]]", pb_cfg.get("branches") or {}) + per_branch: dict[str, list[Middleware]] = {} + for branch_name, branch_cfg in branches_cfg.items(): + entries = cast("list[dict[str, Any]]", branch_cfg.get("middleware") or []) + if not entries: + continue + per_branch[branch_name] = [_build_middleware(cfg, sinks, clock) for cfg in entries] + if per_branch: + out[node_name] = per_branch + return out + + def _translate_fan_out_instance_middleware( spec: Mapping[str, Any], sinks: CaptureSinks, @@ -387,6 +422,8 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> # fixtures 020-022 use one or two named subgraph blocks # (``subgraph:``, ``subgraph_with_idx:``) at the top level so the # fan-out config can pick which one to dispatch to per case. + # Parallel-branches fixtures (032-038) use a plural ``subgraphs:`` + # block — a dict mapping subgraph-name to graph-spec. subgraphs: dict[str, Any] = {} for sub_key in ("subgraph", "subgraph_with_idx"): sub_spec = spec.get(sub_key) @@ -401,6 +438,24 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> node_middleware=sub_node_mw, ) subgraphs[sub_spec["name"]] = sub_built.builder.compile() + plural_subgraphs = cast("dict[str, dict[str, Any]] | None", spec.get("subgraphs")) or {} + for sub_name, sub_spec in plural_subgraphs.items(): + sub_graph_mw, sub_node_mw = _translate_middleware_block(sub_spec.get("middleware"), sinks, clock) + # Pass ``subgraphs=subgraphs`` so a subgraph that itself contains + # a fan_out / parallel_branches dispatch (fixture 038) can resolve + # the inner subgraph against entries already compiled in earlier + # iterations of this loop. The fixture's authoring order MUST put + # dependencies before dependents (the spec author's responsibility). + sub_built = build_graph( + sub_spec, + subgraphs=subgraphs, + model_name=f"{sub_name.title()}State", + graph_middleware=sub_graph_mw, + node_middleware=sub_node_mw, + ) + subgraphs[sub_name] = sub_built.builder.compile() + + branch_middleware = _translate_parallel_branches_branch_middleware(spec, sinks, clock) expected = cast("dict[str, Any]", spec.get("expected") or {}) run_count = cast("int", spec.get("run_count", 1)) @@ -417,6 +472,7 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> graph_middleware=graph_mw, node_middleware=node_mw, fan_out_instance_middleware=fan_out_inst_mw, + parallel_branches_branch_middleware=branch_middleware, ) compiled = built.builder.compile() initial = built.initial_state(spec.get("initial_state", {})) @@ -426,8 +482,27 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> assert excinfo.value.category == expected_err["category"] if "message" in expected_err and isinstance(excinfo.value, NodeException): assert str(excinfo.value.__cause__) == expected_err["message"] + if "cause_message" in expected_err and isinstance(excinfo.value, NodeException): + # ``cause_message`` is the original cause text — the + # leaf of the __cause__ chain. For parallel-branches + # fail_fast, the chain is: + # ParallelBranchesBranchFailed -> NodeException (branch's inner node) -> RuntimeError("...") + # Walk to the deepest non-None __cause__ before + # comparing. + leaf: BaseException = excinfo.value + while leaf.__cause__ is not None: + leaf = leaf.__cause__ + assert str(leaf) == expected_err["cause_message"] + if "branch_name" in expected_err and isinstance(excinfo.value, ParallelBranchesBranchFailed): + assert excinfo.value.branch_name == expected_err["branch_name"] + # ``recoverable_state`` may live nested under ``expected_error`` + # (legacy fan-out shape) or as a sibling under ``expected`` (per + # spec §11.5 for parallel-branches fail_fast fixtures). Both + # carry the same buffer-and-apply invariant. if "recoverable_state" in expected_err and isinstance(excinfo.value, NodeException): assert excinfo.value.recoverable_state.model_dump() == expected_err["recoverable_state"] + if "recoverable_state" in expected and isinstance(excinfo.value, NodeException): + assert excinfo.value.recoverable_state.model_dump() == expected["recoverable_state"] # Some error fixtures still attach trace_records assertions for # what fired before the failure. _check_trace_records( @@ -440,6 +515,7 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> # stateful middleware (retry counters etc.) doesn't leak across runs. final_states: list[dict[str, Any]] = [] traces: list[list[str]] = [] + observer_fixtures: dict[str, ObserverFixture] = {} for run_idx in range(run_count): run_sinks = sinks if run_count == 1 else CaptureSinks() run_graph_mw, run_node_mw = ( @@ -458,19 +534,39 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> graph_middleware=run_graph_mw, node_middleware=run_node_mw, fan_out_instance_middleware=run_fan_out_inst_mw, + parallel_branches_branch_middleware=branch_middleware, ) run_compiled = run_built.builder.compile() run_initial = run_built.initial_state(spec.get("initial_state", {})) + # Observers — graph-attached only (parallel-branches fixtures + # 036/037/038 use ``attach: graph, target: outer``). We rebuild + # the observer set fresh per run so capture lists don't bleed + # across runs in determinism fixtures. + run_observer_fixtures: dict[str, ObserverFixture] = {} + run_delivery: list[tuple[str, int, str]] = [] + for o in spec.get("observers", []): + phases_list = o.get("phases") + phases = frozenset(phases_list) if phases_list is not None else None + ofx = ObserverFixture( + name=o["name"], + attach=o["attach"], + target=o["target"], + behavior=o["behavior"], + phases=phases, + ) + run_observer_fixtures[ofx.name] = ofx + obs = make_observer_fn(ofx, run_delivery) + if ofx.attach == "graph" and ofx.target == "outer": + run_compiled.attach_observer(obs, phases=phases) run_final = await run_compiled.invoke(run_initial) await run_compiled.drain() final_states.append(run_final.model_dump()) traces.append(list(run_built.trace)) - del run_idx # quiet pyright unused-name + if run_idx == 0: + observer_fixtures = run_observer_fixtures if "final_state" in expected: - assert final_states[0] == expected["final_state"], ( - f"final_state mismatch: actual={final_states[0]}, expected={expected['final_state']}" - ) + _assert_final_state(final_states[0], expected["final_state"], spec) if "execution_order" in expected: assert traces[0] == expected["execution_order"], ( f"execution_order mismatch: actual={traces[0]}, expected={expected['execution_order']}" @@ -487,6 +583,13 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> sinks, ) + if "observer_event_invariants" in expected: + _check_parallel_branches_invariants( + cast("Mapping[str, Any]", expected["observer_event_invariants"]), + observer_fixtures, + spec, + ) + # Timing record assertions. if "timing_records" in expected: # Two shapes per Phase 0 typed harness: dict-of-lists OR a flat list. @@ -511,6 +614,148 @@ async def _run_one(spec: Mapping[str, Any], monkeypatch: pytest.MonkeyPatch) -> pass +def _collect_parallel_branches_errors_fields(spec: Mapping[str, Any]) -> set[str]: + """Return the set of parent-state field names used as + ``errors_field`` on any parallel_branches node in ``spec``. + + Per spec §11.1 ``errors_field`` carries an implementation-defined + record shape; the spec only mandates ``branch_name`` + category. The + engine's record carries additional engine-defined keys (``message``, + ``cause_type``). Fixtures asserting against ``errors_field`` records + use subset semantics — assert the spec-mandated keys are present + with the expected values, ignore the rest. + """ + out: set[str] = set() + nodes = cast("dict[str, dict[str, Any]]", spec.get("nodes") or {}) + for node_spec in nodes.values(): + pb_cfg = cast("dict[str, Any] | None", node_spec.get("parallel_branches")) + if pb_cfg is None: + continue + field_name = pb_cfg.get("errors_field") + if isinstance(field_name, str): + out.add(field_name) + return out + + +def _assert_final_state( + actual: Mapping[str, Any], + expected: Mapping[str, Any], + spec: Mapping[str, Any], +) -> None: + """Compare ``actual`` vs ``expected`` final state. Strict equality + everywhere except for parallel-branches ``errors_field`` records, + which compare per-element via subset semantics.""" + errors_fields = _collect_parallel_branches_errors_fields(spec) + assert set(actual.keys()) == set(expected.keys()), ( + f"final_state key mismatch: actual={set(actual.keys())}, expected={set(expected.keys())}" + ) + for key, expected_val in expected.items(): + actual_val = actual[key] + if key in errors_fields and isinstance(expected_val, list) and isinstance(actual_val, list): + actual_list = cast("list[Any]", actual_val) + expected_list = cast("list[Any]", expected_val) + actual_len = len(actual_list) + expected_len = len(expected_list) + assert actual_len == expected_len, ( + f"final_state[{key!r}] length mismatch: actual={actual_len}, expected={expected_len}" + ) + for actual_rec, expected_rec in zip(actual_list, expected_list, strict=True): + if not isinstance(expected_rec, dict) or not isinstance(actual_rec, dict): + assert actual_rec == expected_rec + continue + actual_dict = cast("dict[str, Any]", actual_rec) + expected_dict = cast("dict[str, Any]", expected_rec) + for sub_key, sub_val in expected_dict.items(): + assert sub_key in actual_dict, ( + f"final_state[{key!r}] record missing key {sub_key!r}: actual={actual_dict}" + ) + actual_sub = actual_dict[sub_key] + assert actual_sub == sub_val, ( + f"final_state[{key!r}].{sub_key} mismatch: actual={actual_sub}, expected={sub_val}" + ) + continue + assert actual_val == expected_val, ( + f"final_state[{key!r}] mismatch: actual={actual_val}, expected={expected_val}" + ) + + +def _check_parallel_branches_invariants( + invariants: Mapping[str, Any], + observer_fixtures: Mapping[str, ObserverFixture], + spec: Mapping[str, Any], +) -> None: + """Verify parallel-branches observer-event invariants for fixtures + 036 (branch-middleware retry), 037 (determinism), 038 (compose with + fan-out). Each invariant name maps to one of the recognized shapes + below; an unknown name is skipped (forward-compat with new + fixtures the harness hasn't been taught yet). + """ + if not observer_fixtures: + return + obs = next(iter(observer_fixtures.values())) + events = obs.events + + started_events = [ev for ev in events if ev["phase"] == "started"] + + # 037 — branches' started events fire in branches insertion order + # regardless of their inner-node completion timing. + expected_order = invariants.get("branch_started_event_order") + if isinstance(expected_order, list): + seen_order: list[str] = [] + for ev in started_events: + branch = ev.get("branch_name") + if branch is None: + continue + if branch in seen_order: + continue + seen_order.append(branch) + assert seen_order == expected_order, ( + f"branch_started_event_order mismatch: actual={seen_order}, expected={expected_order}" + ) + + # 036 — per-branch attempt_index sequence on each branch's inner + # node. Authors per-branch via ``_inner_attempt_indices_seen``. + for key, expected_attempts in invariants.items(): + if not key.endswith("_inner_attempt_indices_seen"): + continue + branch_name = key.removesuffix("_inner_attempt_indices_seen") + attempts = [ev["attempt_index"] for ev in started_events if ev.get("branch_name") == branch_name] + assert attempts == expected_attempts, ( + f"{key} mismatch: actual={attempts}, expected={expected_attempts}" + ) + + # 038 — composition with fan-out invariants. + if invariants.get("fan_out_inner_events_carry_both_branch_name_and_fan_out_index"): + fan_out_events = [ev for ev in events if "fan_out_index" in ev] + assert fan_out_events, "expected inner-node events carrying fan_out_index, got none" + for ev in fan_out_events: + assert "branch_name" in ev, f"fan-out inner event missing branch_name: {ev}" + fan_out_branch = invariants.get("fan_out_inner_branch_name_seen") + if isinstance(fan_out_branch, str): + fan_out_branch_names = {ev.get("branch_name") for ev in events if "fan_out_index" in ev} + assert fan_out_branch in fan_out_branch_names, ( + f"fan-out inner events expected to carry branch_name={fan_out_branch!r}; " + f"saw branch_names={fan_out_branch_names}" + ) + expected_indices_raw = invariants.get("fan_out_inner_fan_out_indices_seen") + if isinstance(expected_indices_raw, list): + expected_indices = cast("list[int]", expected_indices_raw) + seen_indices = sorted({ev["fan_out_index"] for ev in events if "fan_out_index" in ev}) + assert seen_indices == sorted(expected_indices), ( + f"fan_out_inner_fan_out_indices_seen mismatch: actual={seen_indices}, " + f"expected={sorted(expected_indices)}" + ) + if invariants.get("plain_inner_events_carry_branch_name_but_no_fan_out_index"): + plain_branch = invariants.get("plain_inner_branch_name_seen") + if isinstance(plain_branch, str): + plain_events = [ + ev for ev in events if ev.get("branch_name") == plain_branch and "fan_out_index" not in ev + ] + assert plain_events, ( + f"expected branch_name={plain_branch!r} inner events without fan_out_index; got none" + ) + + def _check_trace_records( expected_recs: Mapping[str, list[Mapping[str, Any]]] | None, sinks: CaptureSinks, diff --git a/tests/test_smoke.py b/tests/test_smoke.py index f6fc125..6f6c0ff 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -9,7 +9,7 @@ def test_package_versions() -> None: assert openarmature.__version__ == "0.5.0" - assert openarmature.__spec_version__ == "0.16.0" + assert openarmature.__spec_version__ == "0.16.1" def test_spec_version_matches_pyproject() -> None: diff --git a/tests/unit/test_parallel_branches.py b/tests/unit/test_parallel_branches.py new file mode 100644 index 0000000..8ee72cb --- /dev/null +++ b/tests/unit/test_parallel_branches.py @@ -0,0 +1,480 @@ +"""Unit tests for the parallel-branches runtime (pipeline-utilities §11). + +Covers spec corner cases the conformance fixtures exercise only +implicitly: + +- compile-time empty-branches rejection +- compile-time empty-branch-name rejection +- compile-time projection validation (inputs/outputs reference declared + fields on the right side of the projection direction) +- compile-time ``errors_field`` validation +- fail_fast: first failure cancels siblings; recoverable_state is the + parent's pre-dispatch snapshot +- collect: per-branch errors recorded; successful branches' projections + merge in branch insertion order +- branch insertion order determines fan-in merge order regardless of + completion timing +- single-branch ``contributions`` are written through the parent's + reducer (the ``_MultiContribution`` sentinel only fires for multi- + branch fields) +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from typing import Annotated, Any + +import pytest + +from openarmature.graph import ( + END, + BranchSpec, + CompiledGraph, + GraphBuilder, + MappingReferencesUndeclaredField, + ParallelBranchesBranchFailed, + ParallelBranchesNoBranches, + State, + append, + merge, +) + +# --------------------------------------------------------------------------- +# Shared schemas + helpers +# --------------------------------------------------------------------------- + + +class AlphaState(State): + a_out: int = 0 + + +class BetaState(State): + b_out: int = 0 + + +class GammaState(State): + c_out: int = 0 + + +class ParentState(State): + alpha_result: int = 0 + beta_result: int = 0 + gamma_result: int = 0 + + +def _build_alpha_succeeds() -> CompiledGraph[AlphaState]: + async def a(_state: AlphaState) -> Mapping[str, Any]: + return {"a_out": 1} + + return GraphBuilder(AlphaState).set_entry("a").add_node("a", a).add_edge("a", END).compile() + + +def _build_beta_succeeds() -> CompiledGraph[BetaState]: + async def b(_state: BetaState) -> Mapping[str, Any]: + return {"b_out": 2} + + return GraphBuilder(BetaState).set_entry("b").add_node("b", b).add_edge("b", END).compile() + + +def _build_beta_raises(message: str) -> CompiledGraph[BetaState]: + async def b(_state: BetaState) -> Mapping[str, Any]: + raise RuntimeError(message) + + return GraphBuilder(BetaState).set_entry("b").add_node("b", b).add_edge("b", END).compile() + + +def _build_gamma_succeeds() -> CompiledGraph[GammaState]: + async def c(_state: GammaState) -> Mapping[str, Any]: + return {"c_out": 3} + + return GraphBuilder(GammaState).set_entry("c").add_node("c", c).add_edge("c", END).compile() + + +# --------------------------------------------------------------------------- +# Compile-time validation +# --------------------------------------------------------------------------- + + +def test_empty_branches_raises_at_compile_time() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(ParallelBranchesNoBranches) as excinfo: + builder.add_parallel_branches_node("dispatcher", branches={}) + assert excinfo.value.category == "parallel_branches_no_branches" + + +def test_empty_branch_name_raises_at_compile_time() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(ValueError) as excinfo: + builder.add_parallel_branches_node( + "dispatcher", + branches={ + "": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + }, + ) + assert "non-empty" in str(excinfo.value) + + +def test_outputs_references_undeclared_parent_field() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(MappingReferencesUndeclaredField) as excinfo: + builder.add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"missing_parent_field": "a_out"}, + ), + }, + ) + assert excinfo.value.side == "parent" + + +def test_outputs_references_undeclared_subgraph_field() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(MappingReferencesUndeclaredField) as excinfo: + builder.add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "missing_sub_field"}, + ), + }, + ) + assert excinfo.value.side == "subgraph" + + +def test_inputs_references_undeclared_parent_field() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(MappingReferencesUndeclaredField) as excinfo: + builder.add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + inputs={"a_out": "missing_parent_field"}, + outputs={"alpha_result": "a_out"}, + ), + }, + ) + assert excinfo.value.side == "parent" + + +def test_errors_field_references_undeclared_parent_field() -> None: + builder: GraphBuilder[ParentState] = GraphBuilder(ParentState) + with pytest.raises(MappingReferencesUndeclaredField) as excinfo: + builder.add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + }, + error_policy="collect", + errors_field="not_declared", + ) + assert excinfo.value.side == "parent" + + +# --------------------------------------------------------------------------- +# Runtime — happy path +# --------------------------------------------------------------------------- + + +async def test_three_heterogeneous_branches_merge_to_parent() -> None: + compiled = ( + GraphBuilder(ParentState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + "beta": BranchSpec( + subgraph=_build_beta_succeeds(), + outputs={"beta_result": "b_out"}, + ), + "gamma": BranchSpec( + subgraph=_build_gamma_succeeds(), + outputs={"gamma_result": "c_out"}, + ), + }, + ) + .add_edge("dispatcher", END) + .compile() + ) + final = await compiled.invoke(ParentState()) + await compiled.drain() + assert final.alpha_result == 1 + assert final.beta_result == 2 + assert final.gamma_result == 3 + + +# --------------------------------------------------------------------------- +# fail_fast policy +# --------------------------------------------------------------------------- + + +async def test_fail_fast_raises_branch_failed_with_branch_name() -> None: + compiled = ( + GraphBuilder(ParentState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + "beta": BranchSpec( + subgraph=_build_beta_raises("boom"), + outputs={"beta_result": "b_out"}, + ), + }, + error_policy="fail_fast", + ) + .add_edge("dispatcher", END) + .compile() + ) + with pytest.raises(ParallelBranchesBranchFailed) as excinfo: + await compiled.invoke(ParentState()) + await compiled.drain() + assert excinfo.value.branch_name == "beta" + # __cause__ chain: ParallelBranchesBranchFailed -> NodeException("b") -> RuntimeError("boom") + inner = excinfo.value.__cause__ + assert inner is not None + leaf: BaseException = inner + while leaf.__cause__ is not None: + leaf = leaf.__cause__ + assert str(leaf) == "boom" + + +async def test_fail_fast_recoverable_state_drops_buffered_contributions() -> None: + # Per spec §11.5: on fail_fast, NO branch contributions are visible + # in recoverable_state, including the first branch's successful + # work (the buffer-and-apply semantic). + compiled = ( + GraphBuilder(ParentState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + # Slow successful branch — its result must NOT land in + # recoverable_state even though its inner-node update + # may complete before fail-fast cancellation propagates. + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + "beta": BranchSpec( + subgraph=_build_beta_raises("boom"), + outputs={"beta_result": "b_out"}, + ), + }, + error_policy="fail_fast", + ) + .add_edge("dispatcher", END) + .compile() + ) + with pytest.raises(ParallelBranchesBranchFailed) as excinfo: + await compiled.invoke(ParentState()) + await compiled.drain() + snapshot = excinfo.value.recoverable_state.model_dump() + # All defaults — alpha's contribution is NOT applied even though + # its branch may have completed before cancellation landed. + assert snapshot == {"alpha_result": 0, "beta_result": 0, "gamma_result": 0} + + +# --------------------------------------------------------------------------- +# collect policy +# --------------------------------------------------------------------------- + + +class ParentWithErrors(State): + alpha_result: int = 0 + beta_result: int = 0 + gamma_result: int = 0 + branch_errors: Annotated[list[dict[str, Any]], append] = [] + + +async def test_collect_records_branch_failures_in_errors_field() -> None: + compiled = ( + GraphBuilder(ParentWithErrors) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + "beta": BranchSpec( + subgraph=_build_beta_raises("boom"), + outputs={"beta_result": "b_out"}, + ), + "gamma": BranchSpec( + subgraph=_build_gamma_succeeds(), + outputs={"gamma_result": "c_out"}, + ), + }, + error_policy="collect", + errors_field="branch_errors", + ) + .add_edge("dispatcher", END) + .compile() + ) + final = await compiled.invoke(ParentWithErrors()) + await compiled.drain() + # Successful branches' contributions land. + assert final.alpha_result == 1 + assert final.gamma_result == 3 + # Failed branch's outputs do NOT fire — beta_result stays at default. + assert final.beta_result == 0 + # One error record for beta carrying the spec-mandated keys. + assert len(final.branch_errors) == 1 + rec = final.branch_errors[0] + assert rec["branch_name"] == "beta" + assert rec["category"] == "node_exception" + + +# --------------------------------------------------------------------------- +# Determinism — insertion order is the merge order +# --------------------------------------------------------------------------- + + +class MergedDictState(State): + merged: Annotated[dict[str, Any], merge] = {} + + +def _build_writer(delay_s: float, value: str) -> CompiledGraph[MergedDictState]: + """Subgraph that sleeps then writes ``{key: value}`` to ``merged``.""" + + async def write(_state: MergedDictState) -> Mapping[str, Any]: + await asyncio.sleep(delay_s) + return {"merged": {"key": value}} + + return ( + GraphBuilder(MergedDictState) + .set_entry("write") + .add_node("write", write) + .add_edge("write", END) + .compile() + ) + + +async def test_branch_fan_in_order_follows_insertion_order_not_completion() -> None: + # Per spec §11.8: when two branches write the same parent field, + # the parent's reducer applies them in branch INSERTION order + # regardless of which branch finishes first. We give alpha (first + # in insertion order) a deliberately long delay and beta (second) + # a short one, so completion order is beta-then-alpha. The merge + # reducer applies alpha first, then beta — beta's value overrides. + compiled = ( + GraphBuilder(MergedDictState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_writer(0.05, "alpha_value"), + outputs={"merged": "merged"}, + ), + "beta": BranchSpec( + subgraph=_build_writer(0.005, "beta_value"), + outputs={"merged": "merged"}, + ), + }, + ) + .add_edge("dispatcher", END) + .compile() + ) + final = await compiled.invoke(MergedDictState()) + await compiled.drain() + # beta wrote after alpha (per insertion-order fan-in), so beta's + # value wins the merge for ``key``. + assert final.merged == {"key": "beta_value"} + + +# --------------------------------------------------------------------------- +# Single-branch field write (no _MultiContribution sentinel firing) +# --------------------------------------------------------------------------- + + +async def test_single_branch_field_writes_through_reducer_normally() -> None: + # When only one branch contributes to a given parent field, the + # value should flow through the parent's reducer as a plain value, + # not as a _MultiContribution sentinel — the fan-in code only + # synthesizes the sentinel for fields touched by multiple branches. + compiled = ( + GraphBuilder(ParentState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_succeeds(), + outputs={"alpha_result": "a_out"}, + ), + }, + ) + .add_edge("dispatcher", END) + .compile() + ) + final = await compiled.invoke(ParentState()) + await compiled.drain() + assert final.alpha_result == 1 + + +# --------------------------------------------------------------------------- +# Fail_fast cancellation drain — second failure absorbed silently +# --------------------------------------------------------------------------- + + +def _build_alpha_raises(message: str) -> CompiledGraph[AlphaState]: + async def a(_state: AlphaState) -> Mapping[str, Any]: + raise RuntimeError(message) + + return GraphBuilder(AlphaState).set_entry("a").add_node("a", a).add_edge("a", END).compile() + + +async def test_fail_fast_cancellation_drain_absorbs_residual_exceptions() -> None: + # Per spec §11.5 + Q5 cancellation-drain note: under fail_fast, + # the raise is committed to the FIRST failure observed; later + # tasks may race past the cancellation point with their own + # exceptions, but those are absorbed silently by the drain + # ``gather(*, return_exceptions=True)``. No second exception + # surfaces to the caller. + compiled = ( + GraphBuilder(ParentState) + .set_entry("dispatcher") + .add_parallel_branches_node( + "dispatcher", + branches={ + "alpha": BranchSpec( + subgraph=_build_alpha_raises("first"), + outputs={"alpha_result": "a_out"}, + ), + "beta": BranchSpec( + subgraph=_build_beta_raises("second"), + outputs={"beta_result": "b_out"}, + ), + }, + error_policy="fail_fast", + ) + .add_edge("dispatcher", END) + .compile() + ) + with pytest.raises(ParallelBranchesBranchFailed) as excinfo: + await compiled.invoke(ParentState()) + await compiled.drain() + # The raise commits to the first observed failure; the dispatcher + # picks deterministically from the FIRST_EXCEPTION wait — one of + # the two branches surfaces. + assert excinfo.value.branch_name in {"alpha", "beta"}