Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions workflow/dynamic_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"iter"
"sync"

"github.com/google/jsonschema-go/jsonschema"

Expand Down Expand Up @@ -104,7 +105,11 @@ func (n *dynamicNode[IN, OUT]) Run(ctx agent.Context, input any) iter.Seq2[*sess
return
}

emit := makeEmit(yield, ctx)
// One mutex serializes every yield for this activation: the emit
// passed to the DynamicFn and the same emit driven by RunNode via
// the sub-scheduler. Concurrent children must not yield at once.
var emitMu sync.Mutex
emit := makeEmit(yield, ctx, &emitMu)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it make sense to create mutex within makeEmit function? this will make a simpler func prototype.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an excelant idea!

sub := newDynamicSubScheduler(ctx, n.composePath(ctx), emit)
orchestratorCtx := newDynamicNodeContext(ctx, sub.ParentPath(), "", sub, sub.OutputForAncestors())

Expand Down Expand Up @@ -186,8 +191,15 @@ func (n *dynamicNode[IN, OUT]) composePath(parent NodeContext) string {
// When yield returns false without ctx cancellation (no current
// consumer triggers this, but the contract must not depend on it),
// return context.Canceled as a stand-in.
func makeEmit(yield func(*session.Event, error) bool, parentCtx NodeContext) func(*session.Event) error {
//
// mu serializes yield: a DynamicFn may run concurrent children (see
// WithUseSubBranch) that all emit through this one callback, and calling
// the same yield from multiple goroutines panics the iterator and races
// the parent runNode's completion accumulator.
func makeEmit(yield func(*session.Event, error) bool, parentCtx NodeContext, mu *sync.Mutex) func(*session.Event) error {
return func(ev *session.Event) error {
mu.Lock()
defer mu.Unlock()
if err := parentCtx.Err(); err != nil {
return err
}
Expand Down
74 changes: 74 additions & 0 deletions workflow/run_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"errors"
"iter"
"reflect"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -633,3 +634,76 @@ func (n *countingStubNode) runCount() int {
defer n.mu.Unlock()
return n.calls
}

// TestRunNode_ConcurrentChildren_NoRace runs several children from
// separate goroutines (the WithUseSubBranch pattern). They all emit
// through one shared yield, so without serialization this races the
// range-over-func loop state and panics the iterator. A start gate
// releases the goroutines together to maximize the overlap. Each
// goroutine recovers its panic so the failure is reported even without
// -race; -race is the reliable signal.
func TestRunNode_ConcurrentChildren_NoRace(t *testing.T) {
const n = 8

var (
mu sync.Mutex
panics []any
)
errs := make([]error, n)

orch := NewDynamicNode[string, string](
"orch",
func(ctx NodeContext, _ string, _ func(*session.Event) error) (string, error) {
start := make(chan struct{})
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
// Distinct child + run-id per goroutine so the only shared
// mutable state is the parent yield path; a shared run-id
// would instead exercise the idempotency-cache race.
child := newStubNode("child", "out")
go func(i int) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
mu.Lock()
panics = append(panics, r)
mu.Unlock()
}
}()
<-start
_, errs[i] = RunNode[string](ctx, child, nil,
WithUseSubBranch(), WithRunID("c"+strconv.Itoa(i)))
}(i)
}
close(start)
wg.Wait()
return "done", errors.Join(errs...)
},
NodeConfig{},
)

events, err := drainDynamicWithErr(t, orch, "")

mu.Lock()
gotPanics := append([]any(nil), panics...)
mu.Unlock()
if len(gotPanics) > 0 {
t.Fatalf("%d recovered panic(s) from concurrent children sharing one yield; first: %v",
len(gotPanics), gotPanics[0])
}
if err != nil {
t.Fatalf("orchestrator error: %v", err)
}

// n child outputs forwarded up + the parent's own terminal output.
outputs := 0
for _, ev := range events {
if ev.Output != nil {
outputs++
}
}
if want := n + 1; outputs != want {
t.Errorf("output-bearing events = %d, want %d", outputs, want)
}
}
Loading