Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/00-hello-world/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
- Pydantic class (``Classification``, ``Summary``): typed
instance on ``Response.parsed``.
- JSON Schema dict (``research``): raw dict on ``Response.parsed``.
- ``RuntimeConfig`` for per-call sampling knobs — every ``complete()``
here passes ``config=RuntimeConfig(temperature=0.0)`` so the run
reproduces deterministically.
Comment thread
chris-colinsky marked this conversation as resolved.
Outdated
- Conditional routing on a parsed field (``route`` reads
``state.classification.intent``).
- ``attach_observer`` for boundary visibility.
Expand Down Expand Up @@ -49,7 +52,7 @@
append,
merge,
)
from openarmature.llm import OpenAIProvider, UserMessage
from openarmature.llm import OpenAIProvider, RuntimeConfig, UserMessage


# Pydantic schemas the model is constrained to produce. Passing a
Expand Down Expand Up @@ -84,6 +87,15 @@ class PipelineState(State):
# builders, IDE inspection) import this module without running main().
_provider_instance: OpenAIProvider | None = None

# Per-call sampling knobs. The demo locks the model at temperature 0
# so the routing classification (and the rest of the run) reproduces
# across invocations — useful for tutorial output, less appropriate
# for production where some sampling variety is desirable.
# RuntimeConfig also surfaces max_tokens, top_p, and seed; only
# temperature is set here so the others fall through to provider
# defaults.
_DETERMINISTIC = RuntimeConfig(temperature=0.0)


def _get_provider() -> OpenAIProvider:
global _provider_instance
Expand Down Expand Up @@ -113,6 +125,7 @@ async def classify(state: PipelineState) -> Mapping[str, Any]:
)
],
response_schema=Classification,
config=_DETERMINISTIC,
)
return {"classification": response.parsed, "metadata": {"classified_by": "llm"}}

Expand Down Expand Up @@ -140,6 +153,7 @@ async def research(state: PipelineState) -> Mapping[str, Any]:
"required": ["topics", "follow_up_questions"],
"additionalProperties": False,
},
config=_DETERMINISTIC,
)
return {
"research_plan": response.parsed,
Expand All @@ -161,6 +175,7 @@ async def summarize(state: PipelineState) -> Mapping[str, Any]:
)
],
response_schema=Summary,
config=_DETERMINISTIC,
)
return {
"summary": response.parsed,
Expand Down
69 changes: 66 additions & 3 deletions examples/05-fan-out-with-retry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,24 @@
per-instance: a failure on headline 3 doesn't restart headlines 0-2.
- ``concurrency=3`` caps how many instances run in flight at once. Use
this to be polite to the upstream API.
- ``error_policy`` defaults to ``"fail_fast"`` — the first instance
failure (after retries exhaust) raises and cancels siblings. Set
the ``COLLECT_MODE`` env var to switch to ``"collect"``: each
instance runs independently and per-instance failures land in
``state.instance_errors`` instead of aborting the batch. The
``errors_field="instance_errors"`` knob names where the records go.
- A ``TimingRecord`` is captured per instance via an ``on_complete``
callback. ``TimingRecord`` carries the per-call duration but not the
``fan_out_index`` — that index lives on observer NodeEvents instead.
The demo prints captured durations in completion order plus a
wall-clock vs sum-of-durations comparison that shows concurrency
actually parallelized the work.
- A ``fan_out_config_observer`` reads ``NodeEvent.fan_out_config`` on
the fan-out node's dispatch event. Inner-instance events carry
``fan_out_index`` but not ``fan_out_config``; the config lives on
the fan-out node's own started / completed pair and gives observers
a record of the resolved item_count, concurrency, and error_policy
at dispatch time.

**Configuration** (env vars; OpenAI defaults shown):

Expand Down Expand Up @@ -61,6 +73,7 @@
END,
CompiledGraph,
GraphBuilder,
NodeEvent,
State,
append,
)
Expand Down Expand Up @@ -114,11 +127,14 @@ async def _chat(system: str, user: str) -> str:

class BatchState(State):
"""Outer graph: list of headlines goes in, parallel lists of summaries
and topic tags come out."""
and topic tags come out. ``branch_errors`` only populates under
Comment thread
chris-colinsky marked this conversation as resolved.
Outdated
``error_policy="collect"`` — each failed instance contributes one
record naming its ``fan_out_index`` and the exception category."""

headlines: list[str] = Field(default_factory=list)
summaries: Annotated[list[str], append] = Field(default_factory=list)
topics: Annotated[list[str], append] = Field(default_factory=list)
instance_errors: Annotated[list[dict[str, Any]], append] = Field(default_factory=list[dict[str, Any]])
trace: Annotated[list[str], append] = Field(default_factory=list)


Expand Down Expand Up @@ -216,7 +232,16 @@ async def present(s: BatchState) -> Mapping[str, Any]:
return {"trace": ["present"]}


def build_graph() -> CompiledGraph[BatchState]:
def build_graph(error_policy: str = "fail_fast") -> CompiledGraph[BatchState]:
"""Build the fan-out demo graph.

``error_policy`` switches between ``"fail_fast"`` (default; first
exhausted-retry failure raises and cancels the rest) and
``"collect"`` (each instance runs independently; failures land in
``state.instance_errors`` and the batch produces partial results).
The smoke test calls this with no argument, exercising the default
path; main() lets the COLLECT_MODE env var flip to collect.
"""
headline_subgraph = build_headline_subgraph()

retry = RetryMiddleware(
Expand Down Expand Up @@ -244,6 +269,8 @@ def build_graph() -> CompiledGraph[BatchState]:
extra_outputs={"topics": "topic"},
concurrency=3,
instance_middleware=(retry, timing),
error_policy=error_policy,
errors_field="instance_errors",
)
.add_node("present", present)
.add_edge("announce", "headline_runs")
Expand All @@ -254,6 +281,30 @@ def build_graph() -> CompiledGraph[BatchState]:
)


async def fan_out_config_observer(event: NodeEvent) -> None:
"""Print the fan-out node's resolved config when its dispatch event
fires.

NodeEvent carries ``fan_out_config`` ONLY on the fan-out node's own
started / completed pair (the dispatch wrapper); inner-instance
events carry ``fan_out_index`` but not ``fan_out_config``. Reading
the config gives observability layers a record of how the dispatch
actually resolved at runtime — useful when ``count`` or
``concurrency`` are callable resolvers whose value isn't visible
in code.
"""
if event.fan_out_config is None:
return
if event.phase != "started":
return
cfg = event.fan_out_config
print(
f" [observer] fan-out node {event.node_name!r} dispatching: "
f"item_count={cfg.item_count} concurrency={cfg.concurrency} "
f"error_policy={cfg.error_policy!r}"
)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
Expand All @@ -264,12 +315,19 @@ async def main() -> None:
# doesn't accumulate timings across invocations.
_timings.clear()

graph = build_graph()
# Set COLLECT_MODE=1 to switch the fan-out error policy from the
# default fail_fast to collect. Under collect, each instance runs
# independently and per-instance failures (after retries exhaust)
# land in state.instance_errors instead of aborting the batch.
error_policy = "collect" if os.environ.get("COLLECT_MODE") else "fail_fast"
graph = build_graph(error_policy=error_policy)
graph.attach_observer(fan_out_config_observer)

initial = BatchState(headlines=HEADLINES)

print("=" * 72)
print(f"Summarizing {len(HEADLINES)} headlines in parallel (concurrency=3)")
print(f"error_policy={error_policy!r}")
print("=" * 72)
print()

Expand All @@ -284,6 +342,11 @@ async def main() -> None:
print(f" summary: {s}")
print(f" topic: {t}")
print()
if final.instance_errors:
print(f"Captured {len(final.instance_errors)} per-instance error(s):")
for err in final.instance_errors:
print(f" {err}")
print()
print("Per-instance timings (in completion order):")
for nth, record in enumerate(_timings):
print(f" #{nth} {record.duration_ms:7.1f} ms outcome={record.outcome}")
Expand Down
23 changes: 23 additions & 0 deletions examples/06-parallel-branches/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
mapping (not in completion order). The three branches here write
disjoint parent fields, so the order doesn't affect the result —
but the property holds and would matter if they overlapped.
- A ``branch_attribution_observer`` reads ``NodeEvent.branch_name``
on inner-node events. ``branch_name`` is populated only for
events INSIDE a branch's subgraph; outermost nodes (receive,
enrich, present) have ``branch_name=None``. This is the
per-event attribution that lets observability backends route
metrics / spans by branch.

**Configuration** (env vars; OpenAI defaults shown):

Expand Down Expand Up @@ -64,6 +70,7 @@
BranchSpec,
CompiledGraph,
GraphBuilder,
NodeEvent,
State,
append,
)
Expand Down Expand Up @@ -233,6 +240,21 @@ async def present(s: ArticleState) -> Mapping[str, Any]:
return {"trace": ["present"]}


async def branch_attribution_observer(event: NodeEvent) -> None:
"""Print which branch each inner-node event came from.

NodeEvent carries ``branch_name`` on events from nodes that
execute INSIDE a parallel-branches branch — it's the per-event
attribution that says "this came from branch X." Outermost-graph
nodes (receive, enrich, present) carry no branch_name. The
observer skips events with no branch attribution and prints
``(branch=…) node_name`` for the rest.
"""
if event.branch_name is None or event.phase != "started":
return
print(f" [observer] (branch={event.branch_name}) inner node {event.node_name!r} started")


def build_graph() -> CompiledGraph[ArticleState]:
summary = build_summary_subgraph()
sentiment = build_sentiment_subgraph()
Expand Down Expand Up @@ -287,6 +309,7 @@ def build_graph() -> CompiledGraph[ArticleState]:

async def main() -> None:
graph = build_graph()
graph.attach_observer(branch_attribution_observer)

print("=" * 72)
print("Lunar-mission article enrichment — three independent analyses in parallel")
Expand Down
Loading