Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/concepts/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ the framework, or jump to whichever concept you need.
- [Composition: conditional edges, subgraphs, projection](composition.md):
routing decisions, encapsulated sub-pipelines, the parent ↔ subgraph
data seam.
- [Middleware](middleware.md): wrap node dispatch with retries,
timing, logging, error transformation; per-node, per-graph,
per-branch, and per-fan-out-instance registration.
- [Fan-out](fan-out.md): running the same subgraph many times in
parallel, results merged back deterministically.
- [Parallel branches](parallel-branches.md): dispatching M
Expand Down
211 changes: 211 additions & 0 deletions docs/concepts/middleware.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Middleware

Middleware wraps the dispatch of a single node. The shape is an async
callable `(state, next) -> partial_update`. Anything you want to happen
around a node, without changing the node itself, lives here: retries,
timing, structured logging, request enrichment, error transformation,
circuit-breaking.

```python
from collections.abc import Mapping
from typing import Any

from openarmature.graph import Middleware, NextCall


class LogAround:
async def __call__(self, state: Any, next_: NextCall) -> Mapping[str, Any]:
print("before")
partial = await next_(state)
print("after")
return partial


_: Middleware = LogAround() # structural conformance check
```

`next` invokes the next layer of the chain (or the wrapped node, at
the innermost end) and returns the partial update from that layer.
Code before `await next(state)` is the pre-node phase (runs on the way
in); code after is the post-node phase (runs on the way out).

## Four registration sites

You can attach middleware at four places. The same `Middleware` shape
works in all of them.

**Per-node**, on a single function node:

```python
builder.add_node("fetch", fetch_fn, middleware=[RetryMiddleware()])
```

**Per-graph**, applied to every node in the graph:

```python
builder.add_middleware(TimingMiddleware(node_name="...", on_complete=record))
```

**Per-branch**, on a single branch of a parallel-branches node:

```python
from openarmature.graph import BranchSpec

branches = {
"sentiment": BranchSpec(
subgraph=sentiment_subgraph,
middleware=(RetryMiddleware(),),
),
"topic": BranchSpec(subgraph=topic_subgraph),
}
builder.add_parallel_branches_node("classify", branches=branches)
```

The branch middleware wraps the whole branch dispatch as one call. A
retry on a branch retries the entire branch from scratch, not an
individual node inside it.

**Per-fan-out-instance**, on the instance dispatch inside a fan-out
node:

```python
builder.add_fan_out_node(
"summarize",
subgraph=summarize_subgraph,
items_field="articles",
item_field="article",
collect_field="article",
target_field="summaries",
instance_middleware=[RetryMiddleware()],
)
```

A retry here retries one instance, not the whole fan-out.

## Composition order

When a node has middleware from multiple sites, per-graph composes
*outside* per-node. The runtime chain at a single function node is:

```
[per_graph_outer_to_inner...] → [per_node_outer_to_inner...] → node
```

The first middleware in `builder.add_middleware()` calls is the
outermost layer; the last is closest to the node. Same rule for
per-node: list order is outer-to-inner.

## The subgraph boundary

Middleware does not cross into a subgraph. The parent's middleware
wraps the `SubgraphNode` dispatch as a single atomic call, and the
subgraph's own middleware (configured on the child builder) wraps the
child's internal nodes independently.

In practical terms: a `RetryMiddleware` on a subgraph-as-node retries
the whole child graph from its entry. A `RetryMiddleware` inside the
child retries one of its individual nodes.

## Error semantics

An exception raised by `next(state)` propagates up through `await
next(state)`. Middleware may:

- **Re-raise**: the simplest case. Don't catch, let it bubble.
- **Catch and recover**: catch the exception and return a partial
update of your own. The rest of the chain continues as if the node
had returned that partial update normally.
- **Catch and transform**: catch one exception type, raise a different
one. The new exception propagates up.
- **Call `next` more than once**: this is what retry middleware does.

A middleware MUST NOT mutate the input `state` object in place. To
hand a transformed state down the chain, pass a new state instance to
`next(...)`.

## Built-in: RetryMiddleware

```python
from openarmature.graph import RetryMiddleware, exponential_jitter_backoff


async def on_retry(exc: Exception, attempt: int) -> None:
log.warning("retrying after %r (attempt %d)", exc, attempt)


retry = RetryMiddleware(
max_attempts=3,
backoff=exponential_jitter_backoff,
on_retry=on_retry,
)
```

Four plug points, all optional:

- **`max_attempts`** is the total attempt count including the first
call. `1` disables retry. Default `3`.
- **`classifier`** is a predicate `(exception, state) -> bool`.
The default (`default_classifier`, importable from
`openarmature.graph`) treats any exception with a `category`
attribute matching the project's `TRANSIENT_CATEGORIES` set as
transient. To retry on additional types, write a classifier that
delegates to `default_classifier` and falls back to your own check.
- **`backoff`** is a callable `(attempt_index) -> seconds`. The default
is exponential with jitter (base 1s, cap 30s, full jitter).
`deterministic_backoff(seconds)` is provided for tests.
- **`on_retry`** is an optional async callback `(exception, attempt)
-> None`. Fires before each sleep. Useful for emitting a structured
"about to retry" event.

A retry's attempt counter propagates as a context variable to every
node event emitted from within the retry, including nodes inside
subgraphs and branches that the retry wraps transitively. So an
observer logging a retried node sees `attempt=1`, `attempt=2`, etc. on
the inner events.

## Built-in: TimingMiddleware

```python
from openarmature.graph import TimingMiddleware, TimingRecord


async def record(rec: TimingRecord) -> None:
metrics.histogram("node_duration_ms", rec.duration_ms, tags={
"node": rec.node_name,
"outcome": rec.outcome,
})


builder.add_node(
"fetch",
fetch_fn,
middleware=[TimingMiddleware(node_name="fetch", on_complete=record)],
)
```

`TimingMiddleware` records the wrapped chain's duration with a
monotonic clock and delivers a `TimingRecord` to your async callback.
The record includes `node_name`, `duration_ms`, `outcome` (`"success"`
or `"exception"`), and `exception_category` (the failing exception's
`category` attribute when present).

Two implementation details worth knowing:

- The callback fires **inline** before the chain's result returns.
Slow callbacks add to the apparent node duration. Keep them fast
(queue work, defer I/O).
- The clock is injectable per instance via the `clock` kwarg.
Test fixtures use this to supply a deterministic stub without
globally patching `time.monotonic` (which would also distort
asyncio's scheduling).

## Related

- [Parallel branches](parallel-branches.md): per-branch middleware
and its interaction with parent-graph middleware.
- [Fan-out](fan-out.md): `instance_middleware` and how it composes
with parent and node-level layers.
- [LLMs](llms.md): how transient-classification flows from provider
errors into `RetryMiddleware`'s default classifier.
- [Observability](observability.md): observer events emitted around
middleware-wrapped nodes carry the retry attempt index.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ nav:
- State and reducers: concepts/state-and-reducers.md
- Graphs: concepts/graphs.md
- Composition: concepts/composition.md
- Middleware: concepts/middleware.md
- Fan-out: concepts/fan-out.md
- Parallel branches: concepts/parallel-branches.md
- LLMs: concepts/llms.md
Expand Down
2 changes: 1 addition & 1 deletion src/openarmature/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""OpenArmature workflow framework for LLM pipelines and tool-calling agents."""
"""OpenArmature: workflow framework for LLM pipelines and tool-calling agents."""

__version__ = "0.6.0"
__spec_version__ = "0.16.1"
2 changes: 1 addition & 1 deletion src/openarmature/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# internal + fan-out nodes per §10.3.
# - Resume via ``invoke(resume_invocation=...)`` restores per §10.4.

"""openarmature.checkpoint checkpointing capability.
"""openarmature.checkpoint: checkpointing capability.

Public surface: the typed :class:`Checkpointer` Protocol,
:class:`CheckpointRecord` / :class:`NodePosition` /
Expand Down
15 changes: 13 additions & 2 deletions src/openarmature/checkpoint/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""In-memory Checkpointer.

Keeps records in a Python ``dict`` keyed by ``invocation_id``. NOT
durable across process crashes useful for tests, short-lived runs,
durable across process crashes; useful for tests, short-lived runs,
and development. Accepts any state shape (the dict holds the
:class:`CheckpointRecord` directly; nothing is serialized).
"""
Expand All @@ -26,7 +26,7 @@ class InMemoryCheckpointer:

**State shape:** any. The record is held by reference, so the
Pydantic state instance the engine produces is what comes back
from :meth:`load` no serialization round-trip. (This is the
from :meth:`load`; no serialization round-trip. (This is the
feature: tests can assert on the saved state's identity.)

**State-migration eligibility:** none. Per spec §10.12.1, a
Expand All @@ -52,14 +52,23 @@ def __init__(self) -> None:
self._lock = asyncio.Lock()

async def save(self, invocation_id: str, record: CheckpointRecord) -> None:
"""Store ``record`` under ``invocation_id``, replacing any
previous record for the same id. Not durable across process
restarts."""
async with self._lock:
self._records[invocation_id] = record

async def load(self, invocation_id: str) -> CheckpointRecord | None:
"""Return the saved record for ``invocation_id`` or ``None``
if nothing has been saved under that id."""
async with self._lock:
return self._records.get(invocation_id)

async def list(self, filter: CheckpointFilter | None = None) -> Iterable[CheckpointSummary]:
"""Enumerate stored invocations as :class:`CheckpointSummary`
rows. With ``filter.correlation_id`` set, restricts the
results to invocations carrying that correlation id;
otherwise returns all rows."""
async with self._lock:
records = list(self._records.values())
summaries = [
Expand All @@ -76,6 +85,8 @@ async def list(self, filter: CheckpointFilter | None = None) -> Iterable[Checkpo
return [s for s in summaries if s.correlation_id == filter.correlation_id]

async def delete(self, invocation_id: str) -> None:
"""Remove the record for ``invocation_id``. No-op when nothing
is saved under that id (no error)."""
async with self._lock:
self._records.pop(invocation_id, None)

Expand Down
27 changes: 22 additions & 5 deletions src/openarmature/checkpoint/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@

Persists records to a SQLite database with WAL mode enabled. Durable
across process crashes within a single host. One row per
``invocation_id`` (upsert retention overwritten on every save).
``invocation_id`` (upsert retention; overwritten on every save).

**Serialization knobs:**

- ``"pickle"`` (default) accepts any pickleable state shape.
- ``"pickle"`` (default): accepts any pickleable state shape.
Python-only on the read side; a TypeScript reimplementation cannot
decode pickle blobs.
- ``"json"`` accepts only JSON-native state shapes (Pydantic
- ``"json"``: accepts only JSON-native state shapes (Pydantic
``model_dump(mode="json")`` output). Cross-language portable; if
the user wants to read python-written records from a TypeScript
consumer (or vice versa), this is the choice.

Choose deliberately at construction time; the same database file
MUST be read with the same serialization mode it was written with
MUST be read with the same serialization mode it was written with;
mismatches surface as :class:`CheckpointRecordInvalid` on
:meth:`load`.

Expand Down Expand Up @@ -87,7 +87,7 @@ def _to_json_native(obj: Any) -> Any:
class SQLiteCheckpointer:
"""SQLite Checkpointer with WAL-mode durability.

**Retention:** upsert one row per ``invocation_id``, overwritten
**Retention:** upsert; one row per ``invocation_id``, overwritten
on every save. Saved records are NOT historical: only the most
recent save for any given ``invocation_id`` is retained.

Expand Down Expand Up @@ -167,6 +167,11 @@ def _decode(self, blob: bytes, recorded_mode: str, invocation_id: str) -> Any:
# ------------------------------------------------------------------

async def save(self, invocation_id: str, record: CheckpointRecord) -> None:
"""Upsert ``record`` under ``invocation_id``. The state,
completed positions, and parent-state stack are serialized via
the configured :class:`SerializationMode` and written in a
single statement. Writes are durable on return (WAL mode,
per-write fsync at the SQLite layer)."""
await self._ensure_initialized()
state_blob = self._encode(record.state)
positions_blob = self._encode([asdict(p) for p in record.completed_positions])
Expand Down Expand Up @@ -207,6 +212,11 @@ def _do() -> None:
await asyncio.to_thread(_do)

async def load(self, invocation_id: str) -> CheckpointRecord | None:
"""Return the saved record for ``invocation_id`` or ``None``
when no row exists. The serialization mode stored with the
row is used to decode the blobs back, so a database written
with one mode can still be loaded after the backend has been
reconfigured."""
await self._ensure_initialized()

def _do() -> tuple[Any, ...] | None:
Expand Down Expand Up @@ -266,6 +276,11 @@ def _do() -> tuple[Any, ...] | None:
)

async def list(self, filter: CheckpointFilter | None = None) -> Iterable[CheckpointSummary]:
"""Enumerate saved invocations as :class:`CheckpointSummary`
rows, ordered by ``last_saved_at`` ascending. With
``filter.correlation_id`` set the SQL query is constrained at
the database (indexed lookup); without a filter the full
table is returned."""
await self._ensure_initialized()

def _do() -> list[tuple[Any, ...]]:
Expand Down Expand Up @@ -308,6 +323,8 @@ def _do() -> list[tuple[Any, ...]]:
return summaries

async def delete(self, invocation_id: str) -> None:
"""Remove the row for ``invocation_id``. No-op when no row
exists (no error). The delete is durable on return."""
await self._ensure_initialized()

def _do() -> None:
Expand Down
Loading