From fb0c992ed809c8119d5fee01622235c756a0d553 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Fri, 15 May 2026 22:20:13 -0700 Subject: [PATCH 1/7] feat(checkpoint): state migration registry, types, errors, builder surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements pipeline-utilities spec §10.12 (proposal 0014). - New errors: CheckpointStateMigrationMissing, CheckpointStateMigrationFailed. Both non-transient per §10.10. The missing-chain error carries from_version / to_version / registered_migrations_count / registry_description for actionable operator diagnostics. - New types: StateMigration (frozen dataclass — from_version, to_version, migrate callable) and MigrationRegistry (BFS chain resolution + ambiguity detection per §10.12.2). - Multi-shortest-path detection: when BFS finds a shortest path AND a second distinct path of equal length exists, the registry raises ValueError per the spec's ambiguous-chain rule. Resume surfaces this as CheckpointStateMigrationMissing with the ambiguity description in the payload. - State.schema_version: ClassVar[str] = '' (per spec §10.2's per-language carve-out). Empty-string sentinel; the framework reads type(state).schema_version at save time. - Checkpointer Protocol: supports_state_migration: ClassVar[bool] marker per §10.12.1. InMemoryCheckpointer: False (typed in- memory references can't expose a class-independent intermediate). SQLiteCheckpointer: True in JSON mode, False in pickle mode (pickle holds class identity and round-trips to typed instances; can't bridge versions). - GraphBuilder.with_state_migration / with_state_migrations thread a populated MigrationRegistry into CompiledGraph at compile time. - Resume-path routing (compiled.py): version mismatch → unsupported-backend check → registry lookup → chain application (with per-migration failure wrap) → final deserialization. The post-migration deserialization failure still surfaces as CheckpointRecordInvalid per §10.12.4; pre-migration version mismatch routes through the new two categories. Order matters; documented inline so a future reader doesn't swap it back. - Parent-state migration: same chain applied to each entry of parent_states in lockstep with the outer state per §10.12.2. Code comment records the spec-mandated equivalence so future contributors don't add per-parent metadata without a follow-on proposal. - Drop the CHECKPOINT_SCHEMA_VERSION = '1' constant: per Q1 spec answer, the old backend-internal record-shape role had no spec slot anyway. SQLiteCheckpointer no longer rejects records with non-default versions on load — that routing is now the engine's concern at resume time. Existing records carrying schema_version='1' get reinterpreted as user-facing v1 identifiers (single-user dev, no compat shim needed per Chris's note). --- src/openarmature/checkpoint/__init__.py | 9 +- .../checkpoint/backends/memory.py | 14 ++ .../checkpoint/backends/sqlite.py | 20 ++- src/openarmature/checkpoint/errors.py | 77 ++++++++- src/openarmature/checkpoint/migration.py | 163 ++++++++++++++++++ src/openarmature/checkpoint/protocol.py | 41 +++-- src/openarmature/graph/builder.py | 45 +++++ src/openarmature/graph/compiled.py | 143 ++++++++++++++- src/openarmature/graph/state.py | 11 ++ tests/unit/test_checkpoint.py | 38 ++-- 10 files changed, 512 insertions(+), 49 deletions(-) create mode 100644 src/openarmature/checkpoint/migration.py diff --git a/src/openarmature/checkpoint/__init__.py b/src/openarmature/checkpoint/__init__.py index 3acc363..27b1929 100644 --- a/src/openarmature/checkpoint/__init__.py +++ b/src/openarmature/checkpoint/__init__.py @@ -26,9 +26,11 @@ CheckpointNotFound, CheckpointRecordInvalid, CheckpointSaveFailed, + CheckpointStateMigrationFailed, + CheckpointStateMigrationMissing, ) +from .migration import MigrationRegistry, StateMigration from .protocol import ( - CHECKPOINT_SCHEMA_VERSION, Checkpointer, CheckpointFilter, CheckpointRecord, @@ -37,17 +39,20 @@ ) __all__ = [ - "CHECKPOINT_SCHEMA_VERSION", "CheckpointError", "CheckpointFilter", "CheckpointNotFound", "CheckpointRecord", "CheckpointRecordInvalid", "CheckpointSaveFailed", + "CheckpointStateMigrationFailed", + "CheckpointStateMigrationMissing", "CheckpointSummary", "Checkpointer", "InMemoryCheckpointer", + "MigrationRegistry", "NodePosition", "SQLiteCheckpointer", "SerializationMode", + "StateMigration", ] diff --git a/src/openarmature/checkpoint/backends/memory.py b/src/openarmature/checkpoint/backends/memory.py index 1c4347b..00b1432 100644 --- a/src/openarmature/checkpoint/backends/memory.py +++ b/src/openarmature/checkpoint/backends/memory.py @@ -12,6 +12,7 @@ import asyncio from collections.abc import Iterable +from typing import ClassVar from ..protocol import CheckpointFilter, CheckpointRecord, CheckpointSummary @@ -28,8 +29,21 @@ class InMemoryCheckpointer: Pydantic state instance the engine produces is what comes back from :meth:`load` — no serialization round-trip. (This is the feature: tests can assert on the saved state's identity.) + + **State-migration eligibility:** none. Per spec §10.12.1, a + backend supports migration only when it can expose a structural + intermediate form of the loaded state independent of the current + state class. This backend holds live typed instances by + reference, so a version mismatch on resume raises + ``CheckpointRecordInvalid`` rather than consulting the + migration registry. """ + # Per spec §10.12.1: in-memory storage holds live typed-state + # references, so there's no class-independent intermediate form + # the migration registry could consume. + supports_state_migration: ClassVar[bool] = False + def __init__(self) -> None: self._records: dict[str, CheckpointRecord] = {} self._lock = asyncio.Lock() diff --git a/src/openarmature/checkpoint/backends/sqlite.py b/src/openarmature/checkpoint/backends/sqlite.py index d8065ba..588562c 100644 --- a/src/openarmature/checkpoint/backends/sqlite.py +++ b/src/openarmature/checkpoint/backends/sqlite.py @@ -42,7 +42,6 @@ from ..errors import CheckpointRecordInvalid from ..protocol import ( - CHECKPOINT_SCHEMA_VERSION, CheckpointFilter, CheckpointRecord, CheckpointSummary, @@ -109,6 +108,13 @@ def __init__( self._serialization: SerializationMode = serialization self._lock = asyncio.Lock() self._initialized = False + # Per spec §10.12.1, a backend supports state migration only + # when it can expose a structural intermediate form of the + # loaded state that is independent of the current state + # class. JSON serialization satisfies this (loads to dicts); + # pickle holds class identity and round-trips to typed + # instances, so it cannot bridge a schema-version mismatch. + self.supports_state_migration: bool = serialization == "json" def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self._path) @@ -230,12 +236,12 @@ def _do() -> tuple[Any, ...] | None: schema_version, recorded_serialization, ) = row - if schema_version != CHECKPOINT_SCHEMA_VERSION: - raise CheckpointRecordInvalid( - invocation_id, - f"persisted schema_version={schema_version!r} does not match " - f"current {CHECKPOINT_SCHEMA_VERSION!r}", - ) + # Note: per spec §10.12 (proposal 0014), version mismatches + # are no longer rejected at the backend boundary. The engine + # routes mismatches through the migration registry on resume + # (CheckpointStateMigrationMissing if no chain, else applies + # the chain). The backend just round-trips the version + # identifier as opaque data. state = self._decode(state_blob, recorded_serialization, invocation_id) position_dicts = self._decode(positions_blob, recorded_serialization, invocation_id) parent_states = self._decode(parent_states_blob, recorded_serialization, invocation_id) diff --git a/src/openarmature/checkpoint/errors.py b/src/openarmature/checkpoint/errors.py index 380b97e..eaf7288 100644 --- a/src/openarmature/checkpoint/errors.py +++ b/src/openarmature/checkpoint/errors.py @@ -17,6 +17,8 @@ from __future__ import annotations +from typing import Any + class CheckpointError(Exception): """Base for all checkpoint errors. Each subclass carries a @@ -56,10 +58,17 @@ def __init__(self, invocation_id: str, cause: BaseException) -> None: class CheckpointRecordInvalid(CheckpointError): """Raised when ``Checkpointer.load(X)`` returns a record whose - schema is incompatible with the current graph (state shape - mismatch, missing required fields, or - ``schema_version`` mismatch). Non-transient — the persisted - record was written by an incompatible version of the engine.""" + schema is incompatible with the current graph: state shape + mismatch, missing required fields, OR a post-migration state + that fails to deserialize against the current state class (per + spec §10.12.4). Non-transient. + + Note: raw ``schema_version`` mismatches no longer route here. + They now flow through ``CheckpointStateMigrationMissing`` (no + chain registered) or ``CheckpointStateMigrationFailed`` (chain + application raised) per spec §10.10's three-way category + distinction. + """ category = "checkpoint_record_invalid" @@ -68,9 +77,69 @@ def __init__(self, invocation_id: str, message: str) -> None: self.invocation_id = invocation_id +class CheckpointStateMigrationMissing(CheckpointError): + """Raised on resume when the saved record's ``schema_version`` + does not match the current state class's ``schema_version`` AND + no chain of registered migrations bridges the two. Non-transient + per spec §10.10 — the user MUST register a migration (or pin + their state to the saved version) for the resume to succeed. + + Carries the saved-from / current-to versions and a description + of the registered migration set so the user can see what + migrations are available. + """ + + category = "checkpoint_state_migration_missing" + + from_version: str + to_version: str + registered_migrations_count: int + registry_description: str + + def __init__( + self, + *args: Any, + from_version: str, + to_version: str, + registered_migrations_count: int, + registry_description: str, + ) -> None: + super().__init__(*args) + self.from_version = from_version + self.to_version = to_version + self.registered_migrations_count = registered_migrations_count + self.registry_description = registry_description + + +class CheckpointStateMigrationFailed(CheckpointError): + """Raised on resume when a registered migration function raises + during chain application (per spec §10.12.2). The migration's + exception is preserved as ``__cause__``. Non-transient by + default: a buggy migration is deterministic, so retrying + without changing the migration code will not succeed. + """ + + category = "checkpoint_state_migration_failed" + + from_version: str + to_version: str + + def __init__( + self, + *args: Any, + from_version: str, + to_version: str, + ) -> None: + super().__init__(*args) + self.from_version = from_version + self.to_version = to_version + + __all__ = [ "CheckpointError", "CheckpointNotFound", "CheckpointRecordInvalid", "CheckpointSaveFailed", + "CheckpointStateMigrationFailed", + "CheckpointStateMigrationMissing", ] diff --git a/src/openarmature/checkpoint/migration.py b/src/openarmature/checkpoint/migration.py new file mode 100644 index 0000000..0c8685c --- /dev/null +++ b/src/openarmature/checkpoint/migration.py @@ -0,0 +1,163 @@ +"""State migration types and registry. + +Realizes pipeline-utilities §10.12 (proposal 0014). A +``StateMigration`` describes one edge in the migration graph; +``MigrationRegistry`` holds the ordered set and resolves chains +via BFS. Ambiguity (duplicate ``(from, to)`` pairs OR multiple +distinct shortest paths between the same source/sink) is a +configuration-style error per §10.12.1 / §10.12.2. +""" + +from __future__ import annotations + +from collections import deque +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class StateMigration: + """One edge in the migration graph. + + ``migrate`` receives the most-deserialized form the backend can + expose that is still independent of the current state class + (a plain ``dict`` for JSON-backed backends). It MUST return a + value of the same kind, suitable for the next migration in the + chain (or for final deserialization into the current state class). + + Migrations MUST be pure: deterministic, no I/O, no implicit + state. The framework does not police purity per spec §10.12.2 + ("the contract is documented, not policed"); violating it + risks non-deterministic resume. + """ + + from_version: str + to_version: str + migrate: Callable[[Any], Any] + + +class MigrationRegistry: + """Ordered set of registered migrations + BFS chain resolution. + + Registration-time invariants: + + - Two migrations with the same ``from_version`` AND + ``to_version`` raise ``ValueError`` (chain ambiguity per + §10.12.1). + - Two migrations with the same ``from_version`` and different + ``to_version`` are permitted (branched migration graph; + chain resolution picks a path). + + Resolution-time semantics (per §10.12.2): + + - BFS from ``record.schema_version`` to + ``current.schema_version``. BFS naturally finds the shortest + path. + - Empty registry on mismatch → no path → caller raises + ``CheckpointStateMigrationMissing``. + - Non-empty registry with no connecting path → same. + - Found a unique shortest path → return ordered list. + - Found multiple distinct shortest paths (same edge count, + different edge sequences) → raise ``ValueError`` per + §10.12.2's ambiguous-chain rule. Spec accepts load-time + detection. + """ + + def __init__(self) -> None: + self._migrations: dict[tuple[str, str], StateMigration] = {} + self._edges: dict[str, list[StateMigration]] = {} + + def register(self, migration: StateMigration) -> None: + key = (migration.from_version, migration.to_version) + if key in self._migrations: + raise ValueError( + f"duplicate state migration {migration.from_version!r}→" + f"{migration.to_version!r} registered; chain would be ambiguous" + ) + self._migrations[key] = migration + self._edges.setdefault(migration.from_version, []).append(migration) + + def __iter__(self) -> Iterator[StateMigration]: + return iter(self._migrations.values()) + + def __len__(self) -> int: + return len(self._migrations) + + def resolve_chain( + self, + from_version: str, + to_version: str, + ) -> list[StateMigration] | None: + """Return an ordered chain of migrations bridging the two + versions, or ``None`` if no chain exists. + + Raises ``ValueError`` if multiple distinct shortest paths + exist (ambiguous chain per §10.12.2). + """ + if from_version == to_version: + return [] + + # BFS that records every shortest-length path. If multiple + # paths share the minimum length, the chain is ambiguous. + # Standard BFS finds the shortest distance; the path-recording + # variant lets us detect ambiguity without a second pass. + # ``frontier`` items are (version, path_so_far). + frontier: deque[tuple[str, list[StateMigration]]] = deque() + frontier.append((from_version, [])) + shortest_paths: list[list[StateMigration]] = [] + shortest_length: int | None = None + # ``distances`` tracks the BFS layer at which each node was + # first seen. Frontier entries past the shortest_length layer + # are pruned. + distances: dict[str, int] = {from_version: 0} + + while frontier: + version, path = frontier.popleft() + depth = len(path) + # Stop expanding once we've moved past the shortest target. + if shortest_length is not None and depth >= shortest_length: + continue + for edge in self._edges.get(version, []): + next_version = edge.to_version + next_path = path + [edge] + if next_version == to_version: + if shortest_length is None: + shortest_length = len(next_path) + if len(next_path) == shortest_length: + shortest_paths.append(next_path) + continue + # Cycle-avoidance: a node revisited at the same or + # deeper BFS layer can't contribute to a strict- + # shortest path. Allow re-entry only when the new + # arrival is at the same layer as the first arrival + # (distinct shortest paths through the same node). + prior_depth = distances.get(next_version) + if prior_depth is not None and prior_depth < depth + 1: + continue + distances[next_version] = depth + 1 + frontier.append((next_version, next_path)) + + if not shortest_paths: + return None + if len(shortest_paths) > 1: + descriptions = [" → ".join([from_version, *(e.to_version for e in p)]) for p in shortest_paths] + raise ValueError( + f"ambiguous migration chain from {from_version!r} to " + f"{to_version!r}: multiple distinct shortest paths exist " + f"({descriptions}); register fewer migrations or pick a " + f"single canonical route" + ) + return shortest_paths[0] + + def describe(self) -> str: + """Human-readable description of the registered set, used + in the ``CheckpointStateMigrationMissing`` error payload. + Empty registry returns ``""``. + """ + if not self._migrations: + return "" + return "\n".join(f"{m.from_version} → {m.to_version}" for m in self._migrations.values()) + + +__all__ = ["MigrationRegistry", "StateMigration"] diff --git a/src/openarmature/checkpoint/protocol.py b/src/openarmature/checkpoint/protocol.py index 6ca11dd..59de72f 100644 --- a/src/openarmature/checkpoint/protocol.py +++ b/src/openarmature/checkpoint/protocol.py @@ -22,27 +22,19 @@ subgraph-internal nodes. Fan-out instance internal events do NOT produce records in the shipping version (atomic-restart contract). -The :data:`CHECKPOINT_SCHEMA_VERSION` constant is the single source -of truth for the persisted record shape — bump it whenever the field -set changes incompatibly so older saved records surface as -:class:`CheckpointRecordInvalid` on load rather than silently coercing -to a mismatched shape. +``CheckpointRecord.schema_version`` carries the user-facing +state-schema identifier per spec §10.2 (proposal 0014 repurposes +the field from the original backend-internal record-shape role). +The framework reads ``type(state).schema_version`` at save time; +on load, version mismatches route through the migration registry +(per §10.12) rather than a strict equality check. """ from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Protocol - -# Persisted record shape version. Bump when CheckpointRecord's field -# set or invariants change incompatibly — proposal 0009 (per-instance -# fan-out resume) is the concrete near-term candidate, since it -# populates ``fan_out_progress`` and the load path will need to -# distinguish v1 records (where it's None) from v2 records (where it -# carries instance-level progress data). Backends that reject mismatches -# MUST raise CheckpointRecordInvalid per spec §10.10. -CHECKPOINT_SCHEMA_VERSION = "1" +from typing import Any, ClassVar, Protocol # Spec: realizes pipeline-utilities §10.2 NodePosition. Field semantics @@ -108,7 +100,11 @@ class CheckpointRecord: completed_positions: tuple[NodePosition, ...] parent_states: tuple[Any, ...] last_saved_at: float - schema_version: str = CHECKPOINT_SCHEMA_VERSION + # Per spec §10.2 (proposal 0014): the user's state-schema + # version, read off ``type(state).schema_version`` at save time. + # Empty-string sentinel for state classes that don't declare a + # version; non-empty declares migration-eligibility. + schema_version: str = "" fan_out_progress: None = field(default=None) @@ -154,8 +150,20 @@ class Checkpointer(Protocol): access). Each operation MUST be thread-safe (Python) / task-coroutine-safe (asyncio); backends with synchronous I/O typically wrap their work in ``asyncio.to_thread`` or equivalent. + + ``supports_state_migration`` marks whether the backend can + expose a structural intermediate form of the loaded state (a + plain dict, JSON tree, or similar) that is independent of the + current state class. JSON-encoded backends naturally satisfy + this; backends that store live typed state instances or use + class-bound serialization (pickle) cannot. Per spec §10.12.1, + backends that cannot expose the intermediate MUST raise + ``CheckpointRecordInvalid`` on version mismatch even when + migrations are registered — the registry has no chance to bridge. """ + supports_state_migration: ClassVar[bool] = False + async def save(self, invocation_id: str, record: CheckpointRecord) -> None: """Persist ``record`` for ``invocation_id``. After return the record MUST be durable across process crashes for backends @@ -186,7 +194,6 @@ async def delete(self, invocation_id: str) -> None: __all__ = [ - "CHECKPOINT_SCHEMA_VERSION", "CheckpointFilter", "CheckpointRecord", "CheckpointSummary", diff --git a/src/openarmature/graph/builder.py b/src/openarmature/graph/builder.py index c930272..b1ba75e 100644 --- a/src/openarmature/graph/builder.py +++ b/src/openarmature/graph/builder.py @@ -14,6 +14,7 @@ from types import GenericAlias, UnionType from typing import Any, Self, cast, get_args, get_origin +from openarmature.checkpoint.migration import MigrationRegistry, StateMigration from openarmature.checkpoint.protocol import Checkpointer from .compiled import CompiledGraph @@ -51,6 +52,10 @@ def __init__(self, state_cls: type[StateT]) -> None: # Optional Checkpointer attached at compile time; ``None`` is # the spec §10.1.1 default-off behavior. self._checkpointer: Checkpointer | None = None + # State-migration registry per pipeline-utilities §10.12 + # (proposal 0014). Populated by ``with_state_migration(s)``; + # passed through to the compiled graph. + self._migration_registry: MigrationRegistry = MigrationRegistry() def add_node( self, @@ -251,6 +256,45 @@ def with_checkpointer(self, checkpointer: Checkpointer) -> Self: self._checkpointer = checkpointer return self + def with_state_migration( + self, + from_version: str, + to_version: str, + migrate: Callable[[Any], Any], + ) -> Self: + """Register one state migration per pipeline-utilities §10.12. + + On resume, when the saved record's ``schema_version`` does not + match the current state class's ``schema_version``, the engine + consults the registry for a chain that bridges the two and + applies it to the record's state (and to each entry in + ``parent_states``) before deserialization. + + Migrations MUST be pure: deterministic, no I/O, no implicit + state. The framework does not police purity (per §10.12.2), + but violating it risks non-deterministic resume. + + Raises ``ValueError`` at registration if the + ``(from_version, to_version)`` pair is already registered + (per §10.12.1 chain-ambiguity rule). + """ + self._migration_registry.register( + StateMigration( + from_version=from_version, + to_version=to_version, + migrate=migrate, + ) + ) + return self + + def with_state_migrations(self, *migrations: StateMigration) -> Self: + """Register multiple migrations in one call. Convenience over + ``with_state_migration``; each entry is registered through the + same path and obeys the same ambiguity rule.""" + for migration in migrations: + self._migration_registry.register(migration) + return self + def add_middleware(self, middleware: Middleware) -> Self: """Register a per-graph middleware applied to every node in this graph. @@ -343,6 +387,7 @@ def compile(self) -> CompiledGraph[StateT]: edges=edges_by_source, reducers=resolved, middleware=tuple(self._middleware), + migration_registry=self._migration_registry, ) if self._checkpointer is not None: compiled.attach_checkpointer(self._checkpointer) diff --git a/src/openarmature/graph/compiled.py b/src/openarmature/graph/compiled.py index 6234a42..30ab51c 100644 --- a/src/openarmature/graph/compiled.py +++ b/src/openarmature/graph/compiled.py @@ -28,6 +28,7 @@ import uuid from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field +from dataclasses import replace as dataclass_replace from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: @@ -48,9 +49,11 @@ CheckpointNotFound, CheckpointRecordInvalid, CheckpointSaveFailed, + CheckpointStateMigrationFailed, + CheckpointStateMigrationMissing, ) +from openarmature.checkpoint.migration import MigrationRegistry, StateMigration from openarmature.checkpoint.protocol import ( - CHECKPOINT_SCHEMA_VERSION, Checkpointer, CheckpointRecord, NodePosition, @@ -242,6 +245,27 @@ def _no_op_finalize(_edge_error: RuntimeGraphError | None) -> None: silently per proposal 0012 + fixture 013.""" +def _apply_migration_step( + migration: StateMigration, + value: Any, + label: str, +) -> Any: + """Apply one migration step to one value (outer state or one + parent-state entry). Wraps the user-supplied migration function's + raise as ``CheckpointStateMigrationFailed`` per spec §10.12.2. + The original exception rides ``__cause__``. + """ + try: + return migration.migrate(value) + except Exception as exc: + raise CheckpointStateMigrationFailed( + f"migration {migration.from_version!r}→{migration.to_version!r} " + f"raised while migrating {label}: {type(exc).__name__}: {exc}", + from_version=migration.from_version, + to_version=migration.to_version, + ) from exc + + @dataclass(frozen=True) class CompiledGraph[StateT: State]: """An immutable, executable graph produced by `GraphBuilder.compile()`. @@ -272,6 +296,11 @@ class CompiledGraph[StateT: State]: # the user can swap the registered Checkpointer via # ``attach_checkpointer``. ``None`` when no backend is registered. _checkpointer_slot: list[Checkpointer | None] = field(default_factory=lambda: [None]) + # State-migration registry (pipeline-utilities §10.12 / proposal + # 0014). Populated by ``GraphBuilder.with_state_migration(s)``; + # consulted on resume when the loaded record's ``schema_version`` + # does not match the current state class's ``schema_version``. + migration_registry: MigrationRegistry = field(default_factory=MigrationRegistry) # ------------------------------------------------------------------ # Observer registration (spec v0.6.0 §6) @@ -329,6 +358,85 @@ def checkpointer(self) -> Checkpointer | None: """Currently-registered Checkpointer, or ``None``.""" return self._checkpointer_slot[0] + # ------------------------------------------------------------------ + # State migration (pipeline-utilities §10.12 / proposal 0014) + # ------------------------------------------------------------------ + + async def _migrate_record( + self, + record: CheckpointRecord, + checkpointer: Checkpointer, + invocation_id: str, + current_schema_version: str, + ) -> CheckpointRecord: + """Resolve a migration chain for ``record`` and apply it. + + Returns the record with ``state`` + ``parent_states`` mapped + through the chain. Caller is responsible for the + post-migration deserialization step (§10.12.4): if the + migrated state cannot deserialize against the current state + class, the resulting failure surfaces as + ``CheckpointRecordInvalid``. + + Spec §10.12.2 says "parent states MUST be treated as carrying + the same ``schema_version`` as the outer record." We apply + the same chain to every entry in ``parent_states`` lockstep + with the outer state. Future per-parent versioning would + need a spec follow-on. + """ + # Eligibility check first per §10.12.1: backends that hold + # typed in-memory state or class-bound serialization cannot + # expose the class-independent intermediate the registry + # consumes. Mismatch + no eligibility → CheckpointRecordInvalid. + if not getattr(checkpointer, "supports_state_migration", False): + raise CheckpointRecordInvalid( + invocation_id, + f"persisted schema_version={record.schema_version!r} does not " + f"match current {current_schema_version!r}, and the active " + f"checkpointer ({type(checkpointer).__name__}) does not " + f"support state migration", + ) + + try: + chain = self.migration_registry.resolve_chain( + record.schema_version, + current_schema_version, + ) + except ValueError as exc: + # MigrationRegistry signals ambiguous chains (multiple + # distinct shortest paths) via ValueError. Spec §10.12.2 + # treats this as a configuration error — surface it + # promptly during the resume attempt. + raise CheckpointStateMigrationMissing( + str(exc), + from_version=record.schema_version, + to_version=current_schema_version, + registered_migrations_count=len(self.migration_registry), + registry_description=self.migration_registry.describe(), + ) from exc + + if chain is None: + raise CheckpointStateMigrationMissing( + f"no migration chain from {record.schema_version!r} to {current_schema_version!r}", + from_version=record.schema_version, + to_version=current_schema_version, + registered_migrations_count=len(self.migration_registry), + registry_description=self.migration_registry.describe(), + ) + + migrated_state: Any = record.state + migrated_parents: list[Any] = list(record.parent_states) + for migration in chain: + migrated_state = _apply_migration_step(migration, migrated_state, "state") + for i, parent in enumerate(migrated_parents): + migrated_parents[i] = _apply_migration_step(migration, parent, f"parent_states[{i}]") + + return dataclass_replace( + record, + state=migrated_state, + parent_states=tuple(migrated_parents), + ) + async def drain(self) -> None: """Await delivery of every observer event produced by prior invocations of this graph. @@ -433,11 +541,27 @@ async def invoke( record = await checkpointer.load(resume_invocation) if record is None: raise CheckpointNotFound(resume_invocation) - if record.schema_version != CHECKPOINT_SCHEMA_VERSION: - raise CheckpointRecordInvalid( + # Per spec §10.12 (proposal 0014): version-mismatch resume. + # Routing precedence (per §10.10 + §10.12.1): + # 1. unsupported backend → CheckpointRecordInvalid. + # Backends that hold typed in-memory state or + # class-bound serialization can't expose the + # class-independent intermediate the migration + # registry needs. + # 2. no chain in the registry → CheckpointStateMigrationMissing. + # Actionable: register a migration. + # 3. chain found but a migration raises → + # CheckpointStateMigrationFailed. + # 4. post-migration state fails to deserialize → + # CheckpointRecordInvalid (the §10.12.4 boundary). + # Order matters — do NOT swap eligibility and registry-lookup. + current_schema_version = self.state_cls.schema_version + if record.schema_version != current_schema_version: + record = await self._migrate_record( + record, + checkpointer, resume_invocation, - f"persisted schema_version={record.schema_version!r} " - f"does not match current {CHECKPOINT_SCHEMA_VERSION!r}", + current_schema_version, ) # The saved record's ``state`` is post-merge state at the # saving node's level (depth = len(parent_states)). For @@ -1387,7 +1511,14 @@ async def _maybe_save_checkpoint( # ``step`` field on each NodePosition is the canonical # within-invocation order. last_saved_at=time.time(), - schema_version=CHECKPOINT_SCHEMA_VERSION, + # Per spec §10.2 (proposal 0014): read the user's + # state-schema version off the state class at save time. + # Empty-string sentinel when the user hasn't declared + # one — those records are not migration-eligible until + # they declare a non-empty version (per §10.2). The + # runtime type of ``post_state`` is the authoritative + # source (subclasses MAY override the ClassVar). + schema_version=cast("type[State]", type(post_state)).schema_version, ) try: await checkpointer.save(context.invocation_id, record) diff --git a/src/openarmature/graph/state.py b/src/openarmature/graph/state.py index 746dd36..6ac9520 100644 --- a/src/openarmature/graph/state.py +++ b/src/openarmature/graph/state.py @@ -16,6 +16,7 @@ class S(State): """ from collections.abc import Mapping +from typing import ClassVar from pydantic import BaseModel, ConfigDict @@ -31,6 +32,16 @@ class State(BaseModel): # boundaries" intent. model_config = ConfigDict(frozen=True, extra="forbid") + # User-controlled state-schema version per pipeline-utilities + # spec §10.2 (proposal 0014). Empty-string sentinel for state + # classes that don't declare a version. The framework reads + # ``type(state).schema_version`` at save time and writes it + # onto the CheckpointRecord. Declaring a non-empty value opts + # the state class into the migration registry: on resume, + # records carrying a different version route through the + # registered migrations per spec §10.12. + schema_version: ClassVar[str] = "" + def field_reducers(state_cls: type[State]) -> Mapping[str, list[Reducer]]: """Return `{field_name: [declared reducers]}` for each field on `state_cls`. diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py index e9c9664..a76a792 100644 --- a/tests/unit/test_checkpoint.py +++ b/tests/unit/test_checkpoint.py @@ -14,13 +14,12 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, ClassVar import pytest from pydantic import Field from openarmature.checkpoint import ( - CHECKPOINT_SCHEMA_VERSION, Checkpointer, CheckpointFilter, CheckpointNotFound, @@ -92,7 +91,10 @@ def test_checkpoint_record_default_schema_version() -> None: parent_states=(), last_saved_at=0.0, ) - assert record.schema_version == CHECKPOINT_SCHEMA_VERSION + # Default schema_version is the empty-string sentinel per spec + # §10.2 (proposal 0014): records carry the user's state-schema + # version, which is "" until the state class declares one. + assert record.schema_version == "" assert record.fan_out_progress is None @@ -305,13 +307,18 @@ async def test_schema_version_round_trips(tmp_path: Path) -> None: await cp.save("i", record) loaded = await cp.load("i") assert loaded is not None - assert loaded.schema_version == CHECKPOINT_SCHEMA_VERSION - - -async def test_schema_version_mismatch_rejected_by_sqlite(tmp_path: Path) -> None: - """A persisted record with an unrecognized schema_version raises - CheckpointRecordInvalid on load — sentinel behavior so future - record-shape evolution doesn't silently coerce older records.""" + # Records round-trip the user-facing schema_version verbatim per + # spec §10.2. With no version declared on the saved state, the + # default sentinel is "". + assert loaded.schema_version == "" + + +async def test_schema_version_round_trips_through_sqlite_unchanged(tmp_path: Path) -> None: + """Per spec §10.12 (proposal 0014), the SQLite backend no longer + rejects records with non-default ``schema_version`` values — that + routing is now an engine concern at resume time. The backend + just round-trips the version identifier as opaque data so the + engine's migration registry has the chance to bridge it.""" cp = SQLiteCheckpointer(tmp_path / "ck.db") record = CheckpointRecord( invocation_id="i", @@ -320,11 +327,12 @@ async def test_schema_version_mismatch_rejected_by_sqlite(tmp_path: Path) -> Non completed_positions=(), parent_states=(), last_saved_at=1.0, - schema_version="999", # not the current version + schema_version="999", # an arbitrary user-facing identifier ) await cp.save("i", record) - with pytest.raises(CheckpointRecordInvalid): - await cp.load("i") + loaded = await cp.load("i") + assert loaded is not None + assert loaded.schema_version == "999" # --------------------------------------------------------------------------- @@ -424,6 +432,8 @@ class _AlwaysFailingCheckpointer: as :class:`CheckpointSaveFailed` and raises immediately to the caller of ``invoke()`` per the documented save-failure policy.""" + supports_state_migration: ClassVar[bool] = False + async def save(self, invocation_id: str, record: CheckpointRecord) -> None: raise RuntimeError("simulated backend failure") @@ -449,6 +459,8 @@ async def test_save_failure_raises_to_invoke_caller() -> None: class _CapturingCheckpointer: + supports_state_migration: ClassVar[bool] = False + def __init__(self) -> None: self.saves: list[CheckpointRecord] = [] From 63a5819b407600d1d7edc14f5bfa98df22f33f61 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Fri, 15 May 2026 22:33:21 -0700 Subject: [PATCH 2/7] test(conformance): drive 0014 state-migration fixtures 039-046 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New tests/conformance/test_state_migration.py drives all 8 spec state-migration fixtures end-to-end against the real engine (SQLite JSON-mode backend, real graph compile, real resume path). Harness pieces: - Migration mock library: add_new_field_default, add_v2_field, add_v3_field, identity_passthrough, raises_keyerror, should_not_run, irrelevant. Each fixture's migrate: resolves through the library. - _MigrationTrace wraps each mock to capture invocation order for the migrations_run / migration_count / single_migration_invocation / migration_order_matches_chain assertions. Consecutive duplicates collapse (fixture 043 runs each step once for outer + once for each parent under the lockstep ordering; the assertion is per-step, not per-entity). - _seed_record persists a checkpoint matching the fixture's seeded_record: block before invoke(resume_invocation=...) so the resume path has data to load. Harness/model adjustments: - StateSchema in tests/conformance/harness/directives.py gains optional schema_version (default '') and the required field knob (no-default Pydantic field, used by fixture 044's required_v2_field deserialization-failure case). - _DEFERRED_FIXTURES in test_fixture_parsing.py loses the 039-046 rows; the CasesFixture model parses them via permissive extras on CaseSpec. - Initial-state construction in the resume path uses state_cls.model_construct() so fixtures with required fields (044) can pass a placeholder past Pydantic's validator before the resume even starts; the engine loads state from the checkpoint, not from the placeholder. Protocol-attribute shape: - Checkpointer.supports_state_migration declared as bool = False (not ClassVar) so SQLiteCheckpointer can set the value per-instance in __init__ based on serialization mode. Backends with a static answer (InMemoryCheckpointer) override at the class level with bool = False — Pyright accepts either shape because Protocol attribute conformance ignores the ClassVar marker on subclasses. --- .../checkpoint/backends/memory.py | 9 +- src/openarmature/checkpoint/protocol.py | 13 +- tests/conformance/harness/directives.py | 11 + tests/conformance/test_fixture_parsing.py | 15 +- tests/conformance/test_state_migration.py | 400 ++++++++++++++++++ tests/unit/test_checkpoint.py | 6 +- 6 files changed, 435 insertions(+), 19 deletions(-) create mode 100644 tests/conformance/test_state_migration.py diff --git a/src/openarmature/checkpoint/backends/memory.py b/src/openarmature/checkpoint/backends/memory.py index 00b1432..1693264 100644 --- a/src/openarmature/checkpoint/backends/memory.py +++ b/src/openarmature/checkpoint/backends/memory.py @@ -12,7 +12,6 @@ import asyncio from collections.abc import Iterable -from typing import ClassVar from ..protocol import CheckpointFilter, CheckpointRecord, CheckpointSummary @@ -41,8 +40,12 @@ class InMemoryCheckpointer: # Per spec §10.12.1: in-memory storage holds live typed-state # references, so there's no class-independent intermediate form - # the migration registry could consume. - supports_state_migration: ClassVar[bool] = False + # the migration registry could consume. Declared at the class + # level (not as a per-instance attribute) since the answer is + # constructor-independent; the Protocol declaration in + # ``protocol.py`` types this as ``bool`` (not ``ClassVar[bool]``) + # so Pyright accepts a class-attribute override here. + supports_state_migration: bool = False def __init__(self) -> None: self._records: dict[str, CheckpointRecord] = {} diff --git a/src/openarmature/checkpoint/protocol.py b/src/openarmature/checkpoint/protocol.py index 59de72f..25ee7f1 100644 --- a/src/openarmature/checkpoint/protocol.py +++ b/src/openarmature/checkpoint/protocol.py @@ -34,7 +34,7 @@ from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, ClassVar, Protocol +from typing import Any, Protocol # Spec: realizes pipeline-utilities §10.2 NodePosition. Field semantics @@ -162,7 +162,16 @@ class Checkpointer(Protocol): migrations are registered — the registry has no chance to bridge. """ - supports_state_migration: ClassVar[bool] = False + # Declared as an instance attribute (not ``ClassVar``) so backends + # can compute it at construction time when the answer depends on + # constructor args. SQLiteCheckpointer is the concrete case: + # JSON-mode supports migration, pickle-mode doesn't, and the mode + # is a per-instance constructor arg. Backends with a static answer + # (InMemoryCheckpointer is always False) override at the class + # level with ``ClassVar[bool] = False``; pyright is happy with + # either shape because Protocol attribute conformance ignores the + # ClassVar marker on subclasses. + supports_state_migration: bool = False async def save(self, invocation_id: str, record: CheckpointRecord) -> None: """Persist ``record`` for ``invocation_id``. After return the diff --git a/tests/conformance/harness/directives.py b/tests/conformance/harness/directives.py index b2824f4..2a3b40f 100644 --- a/tests/conformance/harness/directives.py +++ b/tests/conformance/harness/directives.py @@ -56,16 +56,27 @@ class StateFieldSpec(_ForbidExtras): The ``alt_reducer`` knob exists only for ``graph-engine/007-compile-errors``'s ``conflicting_reducers`` case — fixtures intentionally declare two reducers on one field to verify the engine fails compile with the right category. + + The ``required`` knob (used by the state-migration deserialization- + failure fixture 044) marks a field as having no default — Pydantic's + natural "required" shape. The default-when-omitted falls through + via the ``default`` field above. """ type: str default: Any = None reducer: str | None = None alt_reducer: str | None = None + required: bool = False class StateSchema(_ForbidExtras): fields: dict[str, StateFieldSpec] + # User-facing state-schema version per pipeline-utilities §10.2 + # (proposal 0014). The state-migration fixtures (039-046) declare + # this on each case's ``state`` block; non-migration fixtures + # omit it (defaults to empty-string sentinel). + schema_version: str = "" # --------------------------------------------------------------------------- diff --git a/tests/conformance/test_fixture_parsing.py b/tests/conformance/test_fixture_parsing.py index c9a145a..128b185 100644 --- a/tests/conformance/test_fixture_parsing.py +++ b/tests/conformance/test_fixture_parsing.py @@ -40,17 +40,10 @@ def _id(case: tuple[str, Path]) -> str: "pipeline-utilities/036-parallel-branches-with-branch-middleware-retry": "0011 parallel branches (PR-5)", "pipeline-utilities/037-parallel-branches-determinism": "0011 parallel branches (PR-5)", "pipeline-utilities/038-parallel-branches-compose-with-fan-out": "0011 parallel branches (PR-5)", - # proposal 0014 — state migration (PR-4) - "pipeline-utilities/039-state-migration-additive-field": "0014 state migration (PR-4)", - "pipeline-utilities/040-state-migration-chain": "0014 state migration (PR-4)", - "pipeline-utilities/041-state-migration-missing": "0014 state migration (PR-4)", - "pipeline-utilities/042-state-migration-versions-match-no-op": "0014 state migration (PR-4)", - "pipeline-utilities/043-state-migration-parent-states-migrated": "0014 state migration (PR-4)", - "pipeline-utilities/044-state-migration-post-migration-deserialization-fails": ( - "0014 state migration (PR-4)" - ), - "pipeline-utilities/045-state-migration-no-path-in-registry": "0014 state migration (PR-4)", - "pipeline-utilities/046-state-migration-function-raises": "0014 state migration (PR-4)", + # proposal 0014's state-migration fixtures (039-046) were removed + # from this list as part of PR-4; the CasesFixture model already + # parses the seeded_record / migrations shape via its permissive + # extras (CaseSpec uses ``model_config = ConfigDict(extra="allow")``). # proposal 0015's llm-provider fixtures (009-020) were removed # from this list as part of PR-2; the typed harness parses the # content-block message shape via LlmCallSpec's permissive diff --git a/tests/conformance/test_state_migration.py b/tests/conformance/test_state_migration.py new file mode 100644 index 0000000..ad9851f --- /dev/null +++ b/tests/conformance/test_state_migration.py @@ -0,0 +1,400 @@ +"""Run every spec state-migration conformance fixture (039-046) end-to-end. + +The fixtures live under +``spec/pipeline-utilities/conformance/`` as ``cases`` shapes; each +case defines a state schema (with a ``schema_version``), an entry +node and edges, a ``seeded_record:`` describing a checkpoint that +was saved at a prior schema version, a ``migrations:`` list (each +naming one of the harness mock functions), and a ``resume`` block +specifying either an ``expected`` happy-path or an +``expected_error`` raise. + +The driver: + +1. Builds a State subclass via ``adapter.build_state_cls``, then + patches the generated class's ``schema_version`` ClassVar so + it matches the fixture's declared version. +2. Builds a minimal graph (entry node + edge to END) via the + existing adapter primitives. +3. Resolves each ``migrations[i].migrate`` name against the mock + library, wrapping every mock so the harness can count + invocations + record the ``v1->v2`` ordered list (for the + ``migrations_run`` / ``migration_count`` / + ``migration_order_matches_chain`` assertions). +4. Seeds a ``CheckpointRecord`` via the configured backend + (SQLite in JSON mode) using a stable seeded ``invocation_id``. +5. Calls ``invoke(resume_invocation=)`` and asserts + against ``resume.expected`` (final state, migrations_run, + invariants) OR ``resume.expected_error`` (category, carries, + cause). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, cast + +import pytest +import yaml + +from openarmature.checkpoint import ( + CheckpointError, + CheckpointRecord, + CheckpointRecordInvalid, + CheckpointStateMigrationFailed, + CheckpointStateMigrationMissing, + SQLiteCheckpointer, +) +from openarmature.graph import END, GraphBuilder, State + +from .adapter import build_state_cls + +CONFORMANCE_DIR = ( + Path(__file__).resolve().parents[2] / "openarmature-spec" / "spec" / "pipeline-utilities" / "conformance" +) + +_FIXTURE_RANGE = range(39, 47) + +# --------------------------------------------------------------------------- +# Migration mock library (harness-side; fixtures refer by name) +# --------------------------------------------------------------------------- + +# Wrapped mocks all receive a dict (JSON-mode SQLite hands the engine the +# structural intermediate form). Each returns a dict suitable as input +# for the next migration in the chain (or for final deserialization). + + +def _add_new_field_default(state: Any) -> Any: + out = dict(state) + out["new_field"] = "v2_default" + return out + + +def _add_v2_field(state: Any) -> Any: + out = dict(state) + out["v2_field"] = "v2_default" + return out + + +def _add_v3_field(state: Any) -> Any: + out = dict(state) + out["v3_field"] = "v3_default" + return out + + +def _identity_passthrough(state: Any) -> Any: + # Used by 044 to verify post-migration deserialization-failure + # routing: the migration runs cleanly but produces output that + # the v2 state class can't deserialize (missing a required field). + return dict(state) + + +def _raises_keyerror(_state: Any) -> Any: + # Used by 046 to verify CheckpointStateMigrationFailed routing. + raise KeyError("simulated buggy migration") + + +def _should_not_run(_state: Any) -> Any: + # Used by 042 (versions-match no-op) to verify the engine does + # NOT consult the migration registry when versions match. + raise AssertionError("fixture 042 invariant violated: should_not_run was called despite version match") + + +def _irrelevant(state: Any) -> Any: + # Used by 045 — migration is registered but the engine doesn't + # find a path to it. Returns input unchanged. + return dict(state) + + +_MOCK_LIBRARY: dict[str, Any] = { + "add_new_field_default": _add_new_field_default, + "add_v2_field": _add_v2_field, + "add_v3_field": _add_v3_field, + "identity_passthrough": _identity_passthrough, + "raises_keyerror": _raises_keyerror, + "should_not_run": _should_not_run, + "irrelevant": _irrelevant, +} + + +class _MigrationTrace: + """Captures the order migrations were invoked in, for the + ``migrations_run`` / ``migration_count`` / + ``migration_order_matches_chain`` fixture assertions.""" + + def __init__(self) -> None: + self.order: list[str] = [] + + def wrap(self, name: str, fn: Any, from_v: str, to_v: str) -> Any: + label = f"{from_v}->{to_v}" + + def _traced(state: Any) -> Any: + self.order.append(label) + return fn(state) + + _traced.__name__ = f"_traced_{name}" + return _traced + + +# --------------------------------------------------------------------------- +# Fixture discovery +# --------------------------------------------------------------------------- + + +def _fixture_paths() -> list[Path]: + out: list[Path] = [] + for p in sorted(CONFORMANCE_DIR.glob("[0-9][0-9][0-9]-*.yaml")): + try: + number = int(p.stem.split("-", 1)[0]) + except ValueError: + continue + if number in _FIXTURE_RANGE: + out.append(p) + return out + + +def _fixture_id(path: Path) -> str: + return path.stem + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_state_cls(state_spec: dict[str, Any], model_name: str) -> type[State]: + """Build a Pydantic State subclass from the fixture spec and stamp + its ``schema_version`` ClassVar with the declared version.""" + fields_spec = state_spec.get("fields", {}) + state_cls = build_state_cls(model_name, fields_spec) + # Stamp the per-fixture schema_version. The adapter's build_state_cls + # produces a fresh subclass via pydantic.create_model so it's safe + # to set the ClassVar after construction. + state_cls.schema_version = state_spec.get("schema_version", "") + return state_cls + + +async def _seed_record( + checkpointer: SQLiteCheckpointer, + invocation_id: str, + seeded: dict[str, Any], +) -> None: + """Persist a checkpoint record matching the fixture's + ``seeded_record:`` block. ``state`` and ``parent_states`` go + through as plain dicts (JSON-mode round-trip).""" + raw_positions: list[dict[str, Any]] = seeded.get("completed_positions", []) + from openarmature.checkpoint import NodePosition + + positions = tuple( + NodePosition( + namespace=tuple(p.get("namespace", [])), + node_name=p["node_name"], + step=p["step"], + attempt_index=p.get("attempt_index", 0), + fan_out_index=p.get("fan_out_index"), + ) + for p in raw_positions + ) + record = CheckpointRecord( + invocation_id=invocation_id, + correlation_id=seeded.get("correlation_id", "seeded-corr"), + state=seeded["state"], + completed_positions=positions, + parent_states=tuple(seeded.get("parent_states", [])), + last_saved_at=0.0, + schema_version=seeded.get("schema_version", ""), + ) + await checkpointer.save(invocation_id, record) + + +# --------------------------------------------------------------------------- +# Case runner +# --------------------------------------------------------------------------- + + +async def _run_one_case(case: dict[str, Any], tmp_path: Path) -> None: + """Run one fixture case end-to-end: build, seed, resume, assert.""" + state_cls = _build_state_cls(case["state"], model_name=f"Case_{case['name']}") + + # Minimal node body: apply ``update_pure`` to the state. + nodes_spec = case["nodes"] + edges_spec = case["edges"] + entry = case["entry"] + + builder = GraphBuilder(state_cls) + for node_name, node_spec in nodes_spec.items(): + update_pure = cast("dict[str, Any]", node_spec.get("update_pure", {})) + + async def _node_body(_s: State, _u: dict[str, Any] = update_pure) -> dict[str, Any]: + return _u + + builder.add_node(node_name, _node_body) + for edge in edges_spec: + target_raw = edge["to"] + target = END if target_raw == "END" else target_raw + builder.add_edge(edge["from"], target) + builder.set_entry(entry) + + # Register migrations + wrap each in the trace recorder. + trace = _MigrationTrace() + for m in case.get("migrations", []): + mock = _MOCK_LIBRARY[m["migrate"]] + wrapped = trace.wrap(m["migrate"], mock, m["from_version"], m["to_version"]) + builder.with_state_migration(m["from_version"], m["to_version"], wrapped) + + # Configure the SQLite backend in JSON mode (the migration-eligible + # backend per spec §10.12.1). One database file per case, isolated + # under tmp_path. + db_path = tmp_path / f"{case['name']}.db" + checkpointer = SQLiteCheckpointer(db_path, serialization="json") + builder.with_checkpointer(checkpointer) + compiled = builder.compile() + + # Seed the prior record. + invocation_id = f"seeded-{case['name']}" + if "seeded_record" in case: + await _seed_record(checkpointer, invocation_id, case["seeded_record"]) + + # Resume + assert. + resume = case["resume"] + # For resume-from-seeded cases the engine loads state from the + # checkpoint and never reads ``initial_state``; using + # ``model_construct`` skips Pydantic validation so fixtures with + # required fields (e.g., 044) can construct a placeholder without + # tripping the validator before the resume even starts. + initial_state = state_cls.model_construct() + raised: BaseException | None = None + final_state: Any = None + try: + if resume.get("from_seeded_record"): + final_state = await compiled.invoke(initial_state, resume_invocation=invocation_id) + else: + final_state = await compiled.invoke(initial_state) + except CheckpointError as exc: + raised = exc + + if "expected_error" in resume: + _assert_error(resume["expected_error"], resume.get("invariants", {}), raised) + elif "expected" in resume: + assert raised is None, f"expected success, got {raised!r}" + _assert_success(resume["expected"], resume.get("invariants", {}), final_state, trace) + + +def _assert_error( + expected_error: dict[str, Any], + invariants: dict[str, Any], + raised: BaseException | None, +) -> None: + assert raised is not None, ( + f"expected raise of category {expected_error.get('category')!r}, got no exception" + ) + actual_category = getattr(raised, "category", None) + assert actual_category == expected_error["category"], ( + f"expected category {expected_error['category']!r}, got {actual_category!r} ({raised!r})" + ) + for key, expected_value in expected_error.get("carries", {}).items(): + actual_attr = getattr(raised, key, None) + assert actual_attr == expected_value, f"expected {key}={expected_value!r}, got {actual_attr!r}" + cause_spec = expected_error.get("cause") + if cause_spec is not None: + cause = raised.__cause__ + assert cause is not None, "expected __cause__ to be populated" + assert type(cause).__name__ == cause_spec["exception_type"], ( + f"expected __cause__ type {cause_spec['exception_type']!r}, got {type(cause).__name__!r}" + ) + forbidden_categories = invariants.get("error_category_not", []) + for forbidden in forbidden_categories: + assert actual_category != forbidden, ( + f"invariant violated: error category {forbidden!r} forbidden but raised" + ) + + +def _assert_success( + expected: dict[str, Any], + invariants: dict[str, Any], + final_state: Any, + trace: _MigrationTrace, +) -> None: + expected_final_state = expected.get("final_state") + if expected_final_state is not None: + actual = final_state.model_dump() + for key, value in expected_final_state.items(): + assert actual.get(key) == value, f"final_state.{key}: expected {value!r}, got {actual.get(key)!r}" + + migrations_run = expected.get("migrations_run") + if migrations_run is not None: + # Per the spec's lockstep ordering, each migration step runs + # once for the outer state. The fixtures only seed parent_states + # on fixture 043; for the simpler fixtures, the order list + # tracks a single application of each step. fixture 043's + # lockstep semantics produce 2 invocations per migration step + # (once for outer, once for parent) — collapse the consecutive + # duplicates for the comparison so the order assertion stays + # "each step ran in chain order." + dedup_consecutive: list[str] = [] + for label in trace.order: + if not dedup_consecutive or dedup_consecutive[-1] != label: + dedup_consecutive.append(label) + assert dedup_consecutive == migrations_run, ( + f"migrations_run: expected {migrations_run!r}, got {dedup_consecutive!r} " + f"(raw invocation order: {trace.order!r})" + ) + + expected_count = invariants.get("migration_count") + if expected_count is not None: + # ``migration_count`` counts distinct migration steps in the + # chain (each step applied to outer + parents counts once per + # step). Collapse consecutive duplicates the same way. + dedup_consecutive_for_count: list[str] = [] + for label in trace.order: + if not dedup_consecutive_for_count or dedup_consecutive_for_count[-1] != label: + dedup_consecutive_for_count.append(label) + assert len(dedup_consecutive_for_count) == expected_count, ( + f"migration_count: expected {expected_count}, got " + f"{len(dedup_consecutive_for_count)} ({dedup_consecutive_for_count!r})" + ) + + if invariants.get("single_migration_invocation"): + # Each migration step ran once per state-or-parent-state entry. + # For fixture 039 there's no parent_states, so the dedup count + # equals the raw count and equals 1 (one migration step). + assert len(trace.order) == 1, ( + f"single_migration_invocation: expected 1 invocation, got {len(trace.order)} ({trace.order!r})" + ) + + if invariants.get("migration_order_matches_chain"): + # The chain ordering is captured implicitly by the migrations_run + # comparison above; this is a redundant invariant the fixture + # surfaces explicitly. No-op here since migrations_run carries + # the assertion. + pass + + +# --------------------------------------------------------------------------- +# Parametrized driver +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=_fixture_id) +async def test_state_migration_fixture(fixture_path: Path, tmp_path: Path) -> None: + """Parametrize across 039-046; each case in a fixture's ``cases`` + list runs as a sub-case under one parametrize id (matching how + test_checkpoint.py handles the cases-shape fixtures). + """ + with fixture_path.open() as f: + spec = cast("dict[str, Any]", yaml.safe_load(f)) + cases = spec.get("cases", []) + for case in cases: + case_name = case.get("name", "") + try: + await _run_one_case(case, tmp_path) + except AssertionError as exc: + raise AssertionError(f"case {case_name!r}: {exc}") from exc + + +# Make sure the imports for the error categories are reachable from +# the test module (pyright would flag them as unused otherwise; we +# reference them via the spec-error-category strings inside the +# fixtures, not by class identity, so the imports are load-bearing +# for the package-public API surface verification). +_ = CheckpointRecordInvalid, CheckpointStateMigrationFailed, CheckpointStateMigrationMissing diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py index a76a792..ae3fafb 100644 --- a/tests/unit/test_checkpoint.py +++ b/tests/unit/test_checkpoint.py @@ -14,7 +14,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, ClassVar +from typing import Any import pytest from pydantic import Field @@ -432,7 +432,7 @@ class _AlwaysFailingCheckpointer: as :class:`CheckpointSaveFailed` and raises immediately to the caller of ``invoke()`` per the documented save-failure policy.""" - supports_state_migration: ClassVar[bool] = False + supports_state_migration: bool = False async def save(self, invocation_id: str, record: CheckpointRecord) -> None: raise RuntimeError("simulated backend failure") @@ -459,7 +459,7 @@ async def test_save_failure_raises_to_invoke_caller() -> None: class _CapturingCheckpointer: - supports_state_migration: ClassVar[bool] = False + supports_state_migration: bool = False def __init__(self) -> None: self.saves: list[CheckpointRecord] = [] From fa7b96e2a12976257520178e1a2809e6aaf42e22 Mon Sep 17 00:00:00 2001 From: chris-colinsky Date: Fri, 15 May 2026 23:05:24 -0700 Subject: [PATCH 3/7] test(unit): state migration + docs(concepts) state migrations section + CHANGELOG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tests/unit/test_state_migration.py (16 tests) covers gaps the conformance fixtures don't exercise directly: BFS edge cases on the registry, multi-shortest-path ambiguity detection, GraphBuilder ergonomics (singular + plural registration, duplicate-pair ValueError), and error attribute carriage including __cause__ preservation on CheckpointStateMigrationFailed. docs/concepts/checkpointing.md gains a State migrations section covering: the State.schema_version declaration, the registration surface (with_state_migration / with_state_migrations), BFS chain resolution and ambiguity cases (duplicate edges + multi-shortest- path), the two new error categories and how they relate to CheckpointRecordInvalid (§10.12.4), backend support and why SQLite-pickle / InMemory aren't migration-eligible, the lockstep parent_states migration rule, and the migrations-MUST-be-pure contract. CHANGELOG.md gains two new Added entries (state-migration surface + Checkpointer.supports_state_migration Protocol attribute) plus a Changed entry documenting the CheckpointRecord.schema_version semantic shift and the CHECKPOINT_SCHEMA_VERSION constant removal. Pre-1.0 breaking change covered by the consolidated-release flag. --- CHANGELOG.md | 3 + docs/concepts/checkpointing.md | 119 ++++++++++++++++ tests/unit/test_state_migration.py | 212 +++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 tests/unit/test_state_migration.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ffc04d6..0cdfb8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The ### Added +- **State migration for checkpointed graphs (proposal 0014, introduced in spec v0.15.0).** Saved checkpoints whose `schema_version` doesn't match the current state class now route through a registered migration chain instead of failing on resume. Surface: `State.schema_version: ClassVar[str] = ""` (declare a non-empty value to opt in), `GraphBuilder.with_state_migration(from_version, to_version, migrate)` and `with_state_migrations(*migrations)` for registration, `StateMigration` and `MigrationRegistry` types exported from `openarmature.checkpoint`. Chain resolution is BFS over the registered edges; the shortest path wins. Multi-shortest-path ambiguity (e.g., a diamond `v1→v2→v4` + `v1→v3→v4`) surfaces as `CheckpointStateMigrationMissing` per spec §10.12.2's load-time-detection allowance. Two new error categories: `CheckpointStateMigrationMissing` (no chain bridges the versions, or chain ambiguous) and `CheckpointStateMigrationFailed` (a migration function raised). Both non-transient. Post-migration deserialization failures still route to `CheckpointRecordInvalid` per §10.12.4. The same chain applies to each entry in `parent_states` in lockstep with the outer state per §10.12.2. +- **`Checkpointer.supports_state_migration` Protocol attribute.** Marks whether a backend can expose the structural intermediate form (a plain dict, JSON tree) the migration registry consumes. `SQLiteCheckpointer(serialization="json")` opts in; `SQLiteCheckpointer(serialization="pickle")` and `InMemoryCheckpointer` opt out. On version mismatch against a non-migration-eligible backend the engine raises `CheckpointRecordInvalid` per spec §10.12.1. - **Prompt-management capability (proposal 0017, introduced in spec v0.15.0).** New `openarmature.prompts` subpackage. `PromptManager` composes one or more `PromptBackend`s, exposes `fetch` / `render` / `get`, applies the §8 fallback semantics (`prompt_store_unavailable` continues to the next backend; `prompt_not_found` stops the chain), and renders templates with Jinja2's `StrictUndefined` per §7. `Prompt` / `PromptResult` / `PromptGroup` are Pydantic models matching spec §3 / §4 / §9. Three error categories (`PromptNotFound`, `PromptRenderError`, `PromptStoreUnavailable`) with `PROMPT_TRANSIENT_CATEGORIES` exported for retry-middleware classifiers. `FilesystemPromptBackend` is the minimum local-filesystem reference backend (layout: `/