Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/concepts/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ the framework, or jump to whichever concept you need.
- [Composition: conditional edges, subgraphs, projection](composition.md):
routing decisions, encapsulated sub-pipelines, the parent ↔ subgraph
data seam.
- [Middleware](middleware.md): wrap node dispatch with retries,
timing, logging, error transformation; per-node, per-graph,
per-branch, and per-fan-out-instance registration.
- [Fan-out](fan-out.md): running the same subgraph many times in
parallel, results merged back deterministically.
- [Parallel branches](parallel-branches.md): dispatching M
Expand Down
211 changes: 211 additions & 0 deletions docs/concepts/middleware.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Middleware

Middleware wraps the dispatch of a single node. The shape is an async
callable `(state, next) -> partial_update`. Anything you want to happen
around a node, without changing the node itself, lives here: retries,
timing, structured logging, request enrichment, error transformation,
circuit-breaking.

```python
from collections.abc import Mapping
from typing import Any

from openarmature.graph import Middleware, NextCall


class LogAround:
async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]:
print("before")
partial = await next_(state)
print("after")
return partial


_: Middleware = LogAround() # structural conformance check
```

`next` invokes the next layer of the chain (or the wrapped node, at
the innermost end) and returns the partial update from that layer.
Code before `await next(state)` is the pre-node phase (runs on the way
in); code after is the post-node phase (runs on the way out).

## Four registration sites

You can attach middleware at four places. The same `Middleware` shape
works in all of them.

**Per-node**, on a single function node:

```python
builder.add_node("fetch", fetch_fn, middleware=[RetryMiddleware()])
```

**Per-graph**, applied to every node in the graph:

```python
builder.add_middleware(TimingMiddleware(node_name="...", on_complete=record))
```

**Per-branch**, on a single branch of a parallel-branches node:

```python
from openarmature.graph import BranchSpec

branches = {
"sentiment": BranchSpec(
subgraph=sentiment_subgraph,
middleware=(RetryMiddleware(),),
),
"topic": BranchSpec(subgraph=topic_subgraph),
}
builder.add_parallel_branches_node("classify", branches=branches)
```

The branch middleware wraps the whole branch dispatch as one call. A
retry on a branch retries the entire branch from scratch, not an
individual node inside it.

**Per-fan-out-instance**, on the instance dispatch inside a fan-out
node:

```python
builder.add_fan_out_node(
"summarize",
subgraph=summarize_subgraph,
items_field="articles",
item_field="article",
collect_field="article",
target_field="summaries",
instance_middleware=[RetryMiddleware()],
)
```

A retry here retries one instance, not the whole fan-out.

## Composition order

When a node has middleware from multiple sites, per-graph composes
*outside* per-node. The runtime chain at a single function node is:

```
[per_graph_outer_to_inner...] → [per_node_outer_to_inner...] → node
```

The first middleware in `builder.add_middleware()` calls is the
outermost layer; the last is closest to the node. Same rule for
per-node: list order is outer-to-inner.

## The subgraph boundary

Middleware does not cross into a subgraph. The parent's middleware
wraps the `SubgraphNode` dispatch as a single atomic call, and the
subgraph's own middleware (configured on the child builder) wraps the
child's internal nodes independently.

In practical terms: a `RetryMiddleware` on a subgraph-as-node retries
the whole child graph from its entry. A `RetryMiddleware` inside the
child retries one of its individual nodes.

## Error semantics

An exception raised by `next(state)` propagates up through `await
next(state)`. Middleware may:

- **Re-raise**: the simplest case. Don't catch, let it bubble.
- **Catch and recover**: catch the exception and return a partial
update of your own. The rest of the chain continues as if the node
had returned that partial update normally.
- **Catch and transform**: catch one exception type, raise a different
one. The new exception propagates up.
- **Call `next` more than once**: this is what retry middleware does.

A middleware MUST NOT mutate the input `state` object in place. To
hand a transformed state down the chain, pass a new state instance to
`next(...)`.

## Built-in: RetryMiddleware

```python
from openarmature.graph import RetryMiddleware, exponential_jitter_backoff


async def on_retry(exc: Exception, attempt: int) -> None:
log.warning("retrying after %r (attempt %d)", exc, attempt)


retry = RetryMiddleware(
max_attempts=3,
backoff=exponential_jitter_backoff,
on_retry=on_retry,
)
```

Four plug points, all optional:

- **`max_attempts`** is the total attempt count including the first
call. `1` disables retry. Default `3`.
- **`classifier`** is a predicate `(exception, state) -> bool`.
The default (`default_classifier`, importable from
`openarmature.graph`) treats any exception with a `category`
attribute matching the project's `TRANSIENT_CATEGORIES` set as
transient. To retry on additional types, write a classifier that
delegates to `default_classifier` and falls back to your own check.
- **`backoff`** is a callable `(attempt_index) -> seconds`. The default
is exponential with jitter (base 1s, cap 30s, full jitter).
`deterministic_backoff(seconds)` is provided for tests.
- **`on_retry`** is an optional async callback `(exception, attempt)
-> None`. Fires before each sleep. Useful for emitting a structured
"about to retry" event.

A retry's attempt counter propagates as a context variable to every
node event emitted from within the retry, including nodes inside
subgraphs and branches that the retry wraps transitively. So an
observer logging a retried node sees `attempt=1`, `attempt=2`, etc. on
the inner events.

## Built-in: TimingMiddleware

```python
from openarmature.graph import TimingMiddleware, TimingRecord


async def record(rec: TimingRecord) -> None:
metrics.histogram("node_duration_ms", rec.duration_ms, tags={
"node": rec.node_name,
"outcome": rec.outcome,
})


builder.add_node(
"fetch",
fetch_fn,
middleware=[TimingMiddleware(node_name="fetch", on_complete=record)],
)
```

`TimingMiddleware` records the wrapped chain's duration with a
monotonic clock and delivers a `TimingRecord` to your async callback.
The record includes `node_name`, `duration_ms`, `outcome` (`"success"`
or `"exception"`), and `exception_category` (the failing exception's
`category` attribute when present).

Two implementation details worth knowing:

- The callback fires **inline** before the chain's result returns.
Slow callbacks add to the apparent node duration. Keep them fast
(queue work, defer I/O).
- The clock is injectable per instance via the `clock` kwarg.
Test fixtures use this to supply a deterministic stub without
globally patching `time.monotonic` (which would also distort
asyncio's scheduling).

## Related

- [Parallel branches](parallel-branches.md): per-branch middleware
and its interaction with parent-graph middleware.
- [Fan-out](fan-out.md): `instance_middleware` and how it composes
with parent and node-level layers.
- [LLMs](llms.md): how transient-classification flows from provider
errors into `RetryMiddleware`'s default classifier.
- [Observability](observability.md): observer events emitted around
middleware-wrapped nodes carry the retry attempt index.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ nav:
- State and reducers: concepts/state-and-reducers.md
- Graphs: concepts/graphs.md
- Composition: concepts/composition.md
- Middleware: concepts/middleware.md
- Fan-out: concepts/fan-out.md
- Parallel branches: concepts/parallel-branches.md
- LLMs: concepts/llms.md
Expand Down
2 changes: 1 addition & 1 deletion src/openarmature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""OpenArmature workflow framework for LLM pipelines and tool-calling agents."""
"""OpenArmature: workflow framework for LLM pipelines and tool-calling agents."""

__version__ = "0.6.0"
__spec_version__ = "0.16.1"
2 changes: 1 addition & 1 deletion src/openarmature/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# internal + fan-out nodes per §10.3.
# - Resume via ``invoke(resume_invocation=...)`` restores per §10.4.

"""openarmature.checkpoint checkpointing capability.
"""openarmature.checkpoint: checkpointing capability.

Public surface: the typed :class:`Checkpointer` Protocol,
:class:`CheckpointRecord` / :class:`NodePosition` /
Expand Down
15 changes: 13 additions & 2 deletions src/openarmature/checkpoint/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""In-memory Checkpointer.

Keeps records in a Python ``dict`` keyed by ``invocation_id``. NOT
durable across process crashes useful for tests, short-lived runs,
durable across process crashes; useful for tests, short-lived runs,
and development. Accepts any state shape (the dict holds the
:class:`CheckpointRecord` directly; nothing is serialized).
"""
Expand All @@ -26,7 +26,7 @@ class InMemoryCheckpointer:

**State shape:** any. The record is held by reference, so the
Pydantic state instance the engine produces is what comes back
from :meth:`load` no serialization round-trip. (This is the
from :meth:`load`; no serialization round-trip. (This is the
feature: tests can assert on the saved state's identity.)

**State-migration eligibility:** none. Per spec §10.12.1, a
Expand All @@ -52,14 +52,23 @@ def __init__(self) -> None:
self._lock = asyncio.Lock()

async def save(self, invocation_id: str, record: CheckpointRecord) -> None:
"""Store ``record`` under ``invocation_id``, replacing any
previous record for the same id. Not durable across process
restarts."""
async with self._lock:
self._records[invocation_id] = record

async def load(self, invocation_id: str) -> CheckpointRecord | None:
"""Return the saved record for ``invocation_id`` or ``None``
if nothing has been saved under that id."""
async with self._lock:
return self._records.get(invocation_id)

async def list(self, filter: CheckpointFilter | None = None) -> Iterable[CheckpointSummary]:
"""Enumerate stored invocations as :class:`CheckpointSummary`
rows. With ``filter.correlation_id`` set, restricts the
results to invocations carrying that correlation id;
otherwise returns all rows."""
async with self._lock:
records = list(self._records.values())
summaries = [
Expand All @@ -76,6 +85,8 @@ async def list(self, filter: CheckpointFilter | None = None) -> Iterable[Checkpo
return [s for s in summaries if s.correlation_id == filter.correlation_id]

async def delete(self, invocation_id: str) -> None:
"""Remove the record for ``invocation_id``. No-op when nothing
is saved under that id (no error)."""
async with self._lock:
self._records.pop(invocation_id, None)

Expand Down
27 changes: 22 additions & 5 deletions src/openarmature/checkpoint/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@

Persists records to a SQLite database with WAL mode enabled. Durable
across process crashes within a single host. One row per
``invocation_id`` (upsert retention overwritten on every save).
``invocation_id`` (upsert retention; overwritten on every save).

**Serialization knobs:**

- ``"pickle"`` (default) accepts any pickleable state shape.
- ``"pickle"`` (default): accepts any pickleable state shape.
Python-only on the read side; a TypeScript reimplementation cannot
decode pickle blobs.
- ``"json"`` accepts only JSON-native state shapes (Pydantic
- ``"json"``: accepts only JSON-native state shapes (Pydantic
``model_dump(mode="json")`` output). Cross-language portable; if
the user wants to read python-written records from a TypeScript
consumer (or vice versa), this is the choice.

Choose deliberately at construction time; the same database file
MUST be read with the same serialization mode it was written with
MUST be read with the same serialization mode it was written with;
mismatches surface as :class:`CheckpointRecordInvalid` on
:meth:`load`.

Expand Down Expand Up @@ -87,7 +87,7 @@ def _to_json_native(obj: Any) -> Any:
class SQLiteCheckpointer:
"""SQLite Checkpointer with WAL-mode durability.

**Retention:** upsert one row per ``invocation_id``, overwritten
**Retention:** upsert; one row per ``invocation_id``, overwritten
on every save. Saved records are NOT historical: only the most
recent save for any given ``invocation_id`` is retained.

Expand Down Expand Up @@ -167,6 +167,11 @@ def _decode(self, blob: bytes, recorded_mode: str, invocation_id: str) -> Any:
# ------------------------------------------------------------------

async def save(self, invocation_id: str, record: CheckpointRecord) -> None:
"""Upsert ``record`` under ``invocation_id``. The state,
completed positions, and parent-state stack are serialized via
the configured :class:`SerializationMode` and written in a
single statement. Writes are durable on return (WAL mode,
per-write fsync at the SQLite layer)."""
await self._ensure_initialized()
state_blob = self._encode(record.state)
positions_blob = self._encode([asdict(p) for p in record.completed_positions])
Expand Down Expand Up @@ -207,6 +212,11 @@ def _do() -> None:
await asyncio.to_thread(_do)

async def load(self, invocation_id: str) -> CheckpointRecord | None:
"""Return the saved record for ``invocation_id`` or ``None``
when no row exists. The serialization mode stored with the
row is used to decode the blobs back, so a database written
with one mode can still be loaded after the backend has been
reconfigured."""
await self._ensure_initialized()

def _do() -> tuple[Any, ...] | None:
Expand Down Expand Up @@ -266,6 +276,11 @@ def _do() -> tuple[Any, ...] | None:
)

async def list(self, filter: CheckpointFilter | None = None) -> Iterable[CheckpointSummary]:
"""Enumerate saved invocations as :class:`CheckpointSummary`
rows, ordered by ``last_saved_at`` ascending. With
``filter.correlation_id`` set the SQL query is constrained at
the database (indexed lookup); without a filter the full
table is returned."""
await self._ensure_initialized()

def _do() -> list[tuple[Any, ...]]:
Expand Down Expand Up @@ -308,6 +323,8 @@ def _do() -> list[tuple[Any, ...]]:
return summaries

async def delete(self, invocation_id: str) -> None:
"""Remove the row for ``invocation_id``. No-op when no row
exists (no error). The delete is durable on return."""
await self._ensure_initialized()

def _do() -> None:
Expand Down
Loading