diff --git a/internal/services/integrationtesting/query_plan_consistency_test.go b/internal/services/integrationtesting/query_plan_consistency_test.go index dbc97c6a9..a21e6d370 100644 --- a/internal/services/integrationtesting/query_plan_consistency_test.go +++ b/internal/services/integrationtesting/query_plan_consistency_test.go @@ -3,9 +3,7 @@ package integrationtesting_test import ( - "fmt" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/require" @@ -102,55 +100,70 @@ func runQueryPlanAssertions(t *testing.T, handle *queryPlanConsistencyHandle) { v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, } { - entry := entry t.Run(entry.name, func(t *testing.T) { for _, assertion := range entry.assertions { - assertion := assertion t.Run(assertion.RelationshipWithContextString, func(t *testing.T) { - require := require.New(t) - - rel := assertion.Relationship - it, err := query.BuildIteratorFromSchema(handle.schema, rel.Resource.ObjectType, rel.Resource.Relation) - require.NoError(err) - - qctx := handle.buildContext(t) + // Run both unoptimized and optimized versions + for _, optimizationMode := range []struct { + name string + optimize bool + }{ + {"unoptimized", false}, + {"optimized", true}, + } { + t.Run(optimizationMode.name, func(t *testing.T) { + require := require.New(t) + + rel := assertion.Relationship + it, err := query.BuildIteratorFromSchema(handle.schema, rel.Resource.ObjectType, rel.Resource.Relation) + require.NoError(err) + + // Apply static optimizations if requested + if optimizationMode.optimize { + it, _, err = query.ApplyOptimizations(it, query.StaticOptimizations) + require.NoError(err) + } - // Add caveat context from assertion if available - if len(assertion.CaveatContext) > 0 { - qctx.CaveatContext = assertion.CaveatContext - } + qctx := handle.buildContext(t) - seq, err := qctx.Check(it, []query.Object{query.GetObject(rel.Resource)}, rel.Subject) - require.NoError(err) + // Add caveat context from assertion if available + if len(assertion.CaveatContext) > 0 { + qctx.CaveatContext = assertion.CaveatContext + } - rels, err := query.CollectAll(seq) - require.NoError(err) + seq, err := qctx.Check(it, []query.Object{query.GetObject(rel.Resource)}, rel.Subject) + require.NoError(err) + + rels, err := query.CollectAll(seq) + require.NoError(err) + + // Print trace if test fails + if qctx.TraceLogger != nil { + defer func() { + if t.Failed() { + t.Logf("Trace for %s:\n%s", entry.name, qctx.TraceLogger.DumpTrace()) + // Also print the tree structure for debugging + if it != nil { + t.Logf("Tree structure:\n%s", it.Explain().String()) + } + } + }() + } - // Print trace if test fails - if qctx.TraceLogger != nil { - defer func() { - if t.Failed() { - t.Logf("Trace for %s:\n%s", entry.name, qctx.TraceLogger.DumpTrace()) - // Also print the tree structure for debugging - if it != nil { - t.Logf("Tree structure:\n%s", explainTree(it, 0)) + switch entry.expectedPermissionship { + case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION: + require.Len(rels, 1) + require.NotNil(rels[0].Caveat) + case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION: + require.Len(rels, 1) + require.Nil(rels[0].Caveat) + case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION: + if len(rels) != 0 && qctx.TraceLogger != nil { + t.Logf("Expected 0 relations but got %d. Trace:\n%s", len(rels), qctx.TraceLogger.DumpTrace()) } + require.Len(rels, 0) } - }() - } - - switch entry.expectedPermissionship { - case v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION: - require.Len(rels, 1) - require.NotNil(rels[0].Caveat) - case v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION: - require.Len(rels, 1) - require.Nil(rels[0].Caveat) - case v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION: - if len(rels) != 0 && qctx.TraceLogger != nil { - t.Logf("Expected 0 relations but got %d. Trace:\n%s", len(rels), qctx.TraceLogger.DumpTrace()) - } - require.Len(rels, 0) + }) } }) } @@ -159,21 +172,3 @@ func runQueryPlanAssertions(t *testing.T, handle *queryPlanConsistencyHandle) { } }) } - -// explainTree recursively explains the tree structure for debugging -func explainTree(iter query.Iterator, depth int) string { - indent := strings.Repeat(" ", depth) - explain := iter.Explain() - result := fmt.Sprintf("%s%s: %s\n", indent, explain.Name, explain.Info) - - for _, subExplain := range explain.SubExplain { - // For SubExplain, we need to create a dummy iterator to get the tree structure - // This is a simplified approach - in practice we'd need access to the actual sub-iterators - subResult := fmt.Sprintf("%s %s: %s\n", indent, subExplain.Name, subExplain.Info) - result += subResult - // Note: We can't recursively call explainTree on SubExplain because it's not an Iterator - // This gives us one level of detail which should be sufficient for debugging - } - - return result -} diff --git a/pkg/query/alias.go b/pkg/query/alias.go index f6d50a194..0dff12e40 100644 --- a/pkg/query/alias.go +++ b/pkg/query/alias.go @@ -39,7 +39,8 @@ func (a *Alias) CheckImpl(ctx *Context, resources []Object, subject ObjectAndRel return nil, err } - return func(yield func(Path, error) bool) { + // Create combined sequence with self-edge and rewritten paths + combined := func(yield func(Path, error) bool) { // Yield the self-edge first if !yield(selfPath, nil) { return @@ -57,7 +58,10 @@ func (a *Alias) CheckImpl(ctx *Context, resources []Object, subject ObjectAndRel return } } - }, nil + } + + // Wrap with deduplication to handle duplicate paths after rewriting + return DeduplicatePathSeq(combined), nil } } @@ -67,7 +71,7 @@ func (a *Alias) CheckImpl(ctx *Context, resources []Object, subject ObjectAndRel return nil, err } - return func(yield func(Path, error) bool) { + rewritten := func(yield func(Path, error) bool) { for path, err := range subSeq { if err != nil { yield(Path{}, err) @@ -79,7 +83,10 @@ func (a *Alias) CheckImpl(ctx *Context, resources []Object, subject ObjectAndRel return } } - }, nil + } + + // Wrap with deduplication to handle duplicate paths after rewriting + return DeduplicatePathSeq(rewritten), nil } func (a *Alias) IterSubjectsImpl(ctx *Context, resource Object) (PathSeq, error) { diff --git a/pkg/query/alias_test.go b/pkg/query/alias_test.go index 46298136b..a85bc0d0c 100644 --- a/pkg/query/alias_test.go +++ b/pkg/query/alias_test.go @@ -42,9 +42,9 @@ func TestAliasIterator(t *testing.T) { require.Equal("...", rel.Subject.Relation) } - // Should have same number of relations as original sub-iterator - // (alice has viewer, editor, owner on doc1) - require.Len(rels, 3, "should have 3 rewritten relations") + // Should have 1 deduplicated relation after rewriting + // (alice has viewer, editor, owner on doc1, but all get rewritten to "read" and deduplicated) + require.Len(rels, 1, "should have 1 deduplicated relation after rewriting") }) t.Run("Check_SelfEdgeDetection", func(t *testing.T) { diff --git a/pkg/query/build_tree.go b/pkg/query/build_tree.go index f41d9f6a2..f7f474004 100644 --- a/pkg/query/build_tree.go +++ b/pkg/query/build_tree.go @@ -276,11 +276,8 @@ func (b *iteratorBuilder) buildBaseRelationIterator(br *schema.BaseRelation, wit } // We must check the effective arrow of a subrelation if we have one - union := NewUnion() - union.addSubIterator(base) - arrow := NewArrow(base.Clone(), rightside) - union.addSubIterator(arrow) + union := NewUnion(base, arrow) return union, nil } diff --git a/pkg/query/intersection.go b/pkg/query/intersection.go index 2e5cc356b..4faa93581 100644 --- a/pkg/query/intersection.go +++ b/pkg/query/intersection.go @@ -12,8 +12,13 @@ type Intersection struct { var _ Iterator = &Intersection{} -func NewIntersection() *Intersection { - return &Intersection{} +func NewIntersection(subiterators ...Iterator) *Intersection { + if len(subiterators) == 0 { + return &Intersection{} + } + return &Intersection{ + subIts: subiterators, + } } func (i *Intersection) addSubIterator(subIt Iterator) { diff --git a/pkg/query/optimize.go b/pkg/query/optimize.go new file mode 100644 index 000000000..110dec11e --- /dev/null +++ b/pkg/query/optimize.go @@ -0,0 +1,97 @@ +package query + +// TypedOptimizerFunc is a function that transforms an iterator of a specific type T +// into a potentially optimized iterator. It returns the optimized iterator, a boolean +// indicating whether any optimization was performed, and an error if the optimization failed. +// +// The type parameter T constrains the function to operate only on specific iterator types, +// providing compile-time type safety when creating typed optimizers. +type TypedOptimizerFunc[T Iterator] func(it T) (Iterator, bool, error) + +// OptimizerFunc is a type-erased wrapper around TypedOptimizerFunc[T] that can be +// stored in a homogeneous list while maintaining type safety at runtime. +type OptimizerFunc func(it Iterator) (Iterator, bool, error) + +// WrapOptimizer wraps a typed TypedOptimizerFunc[T] into a type-erased OptimizerFunc. +// This allows optimizer functions for different concrete iterator types to be stored +// together in a heterogeneous list. +func WrapOptimizer[T Iterator](fn TypedOptimizerFunc[T]) OptimizerFunc { + return func(it Iterator) (Iterator, bool, error) { + if v, ok := it.(T); ok { + return fn(v) + } + return it, false, nil + } +} + +// StaticOptimizations is a list of optimization functions that can be safely applied +// to any iterator tree without needing runtime information or context. +var StaticOptimizations = []OptimizerFunc{ + RemoveNullIterators, + CollapseSingletonUnionAndIntersection, + WrapOptimizer(PushdownCaveatEvaluation), +} + +// ApplyOptimizations recursively applies a list of optimizer functions to an iterator +// tree, transforming it into an optimized form. +// +// The function operates bottom-up, optimizing leafs and subiterators first, and replacing the +// subtrees up to the top, which it then returns. +// +// Parameters: +// - it: The iterator tree to optimize +// - fns: A list of optimizer functions to apply +// +// Returns: +// - The optimized iterator (which may be the same as the input if no optimizations applied) +// - A boolean indicating whether any changes were made +// - An error if any optimization failed +func ApplyOptimizations(it Iterator, fns []OptimizerFunc) (Iterator, bool, error) { + var err error + origSubs := it.Subiterators() + changed := false + if len(origSubs) != 0 { + // Make a copy of the subiterators slice to avoid mutating the original iterator + subs := make([]Iterator, len(origSubs)) + copy(subs, origSubs) + + subChanged := false + for i, subit := range subs { + newit, ok, err := ApplyOptimizations(subit, fns) + if err != nil { + return nil, false, err + } + if ok { + subs[i] = newit + subChanged = true + } + } + if subChanged { + changed = true + it, err = it.ReplaceSubiterators(subs) + if err != nil { + return nil, false, err + } + } + } + + // Apply each optimizer to the current iterator + // If any optimizer transforms the iterator, recursively optimize the new tree + for _, fn := range fns { + newit, fnChanged, err := fn(it) + if err != nil { + return nil, false, err + } + if fnChanged { + // The iterator was transformed - recursively optimize the new tree + // to ensure all optimizations are fully applied + optimizedIt, _, err := ApplyOptimizations(newit, fns) + if err != nil { + return nil, false, err + } + // Return true for changed since we did transform the iterator + return optimizedIt, true, nil + } + } + return it, changed, nil +} diff --git a/pkg/query/optimize_caveat.go b/pkg/query/optimize_caveat.go new file mode 100644 index 000000000..0ab927cd5 --- /dev/null +++ b/pkg/query/optimize_caveat.go @@ -0,0 +1,93 @@ +package query + +import ( + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// PushdownCaveatEvaluation pushes caveat evaluation down through certain composite iterators +// to allow earlier filtering and better performance. +// +// This optimization transforms: +// +// Caveat(Union[A, B]) -> Union[Caveat(A), B] (if only A contains the caveat) +// Caveat(Union[A, B]) -> Union[Caveat(A), Caveat(B)] (if both contain the caveat) +// +// The pushdown does NOT occur through IntersectionArrow iterators, as they have special +// semantics that require caveat evaluation to happen after the intersection. +func PushdownCaveatEvaluation(c *CaveatIterator) (Iterator, bool, error) { + // Don't push through IntersectionArrow + if _, ok := c.subiterator.(*IntersectionArrow); ok { + return c, false, nil + } + + // Don't push down if the subiterator is already a CaveatIterator + // This prevents infinite recursion + if _, ok := c.subiterator.(*CaveatIterator); ok { + return c, false, nil + } + + // Get the subiterators of the child + subs := c.subiterator.Subiterators() + if len(subs) == 0 { + // No subiterators to push down into (e.g., leaf iterator) + return c, false, nil + } + + // Find which subiterators contain relations with this caveat + newSubs := make([]Iterator, len(subs)) + changed := false + for i, sub := range subs { + if containsCaveat(sub, c.caveat) { + // Wrap this subiterator with the caveat + newSubs[i] = NewCaveatIterator(sub, c.caveat) + changed = true + } else { + // Leave unchanged + newSubs[i] = sub + } + } + + if !changed { + return c, false, nil + } + + // Replace the subiterators in the child iterator + newChild, err := c.subiterator.ReplaceSubiterators(newSubs) + if err != nil { + return nil, false, err + } + + // Return the child without the caveat wrapper + return newChild, true, nil +} + +// containsCaveat checks if an iterator tree contains a RelationIterator +// that references the given caveat. +func containsCaveat(it Iterator, caveat *core.ContextualizedCaveat) bool { + found := false + _, err := Walk(it, func(node Iterator) (Iterator, error) { + if rel, ok := node.(*RelationIterator); ok { + if relationContainsCaveat(rel, caveat) { + found = true + } + } + return node, nil + }) + if err != nil { + spiceerrors.MustPanicf("should never error -- callback contains no errors, but linters must always check") + } + + return found +} + +// relationContainsCaveat checks if a RelationIterator's base relation +// has a caveat that matches the given caveat name. +func relationContainsCaveat(rel *RelationIterator, caveat *core.ContextualizedCaveat) bool { + if rel.base == nil || caveat == nil { + return false + } + + // Check if the relation has this caveat + return rel.base.Caveat() == caveat.CaveatName +} diff --git a/pkg/query/optimize_caveat_test.go b/pkg/query/optimize_caveat_test.go new file mode 100644 index 000000000..702a88009 --- /dev/null +++ b/pkg/query/optimize_caveat_test.go @@ -0,0 +1,348 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/require" + + core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/schema/v2" +) + +// createTestCaveatForPushdown creates a test ContextualizedCaveat +func createTestCaveatForPushdown(name string) *core.ContextualizedCaveat { + return &core.ContextualizedCaveat{ + CaveatName: name, + Context: nil, + } +} + +// createTestRelationIterator creates a RelationIterator with a caveat +func createTestRelationIterator(caveatName string) *RelationIterator { + // Create a BaseRelation with the caveat + baseRelation := schema.NewTestBaseRelationWithFeatures("document", "viewer", "user", "", caveatName, false) + return NewRelationIterator(baseRelation) +} + +// createTestRelationIteratorNoCaveat creates a RelationIterator without a caveat +func createTestRelationIteratorNoCaveat() *RelationIterator { + baseRelation := schema.NewTestBaseRelationWithFeatures("document", "viewer", "user", "", "", false) + return NewRelationIterator(baseRelation) +} + +func TestPushdownCaveatEvaluation(t *testing.T) { + t.Parallel() + + t.Run("pushes caveat through union when both sides have caveat", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Union[Relation(with caveat), Relation(with caveat)] + rel1 := createTestRelationIterator("test_caveat") + rel2 := createTestRelationIterator("test_caveat") + union := NewUnion(rel1, rel2) + + // Wrap in caveat: Caveat(Union[Rel1, Rel2]) + caveatIterator := NewCaveatIterator(union, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.True(t, changed) + + // Should become Union[Caveat(Rel1), Caveat(Rel2)] + resultUnion, ok := result.(*Union) + require.True(t, ok, "Expected result to be a Union") + require.Len(t, resultUnion.subIts, 2) + + // Both should be wrapped in caveats + _, ok1 := resultUnion.subIts[0].(*CaveatIterator) + _, ok2 := resultUnion.subIts[1].(*CaveatIterator) + require.True(t, ok1, "First subiterator should be a CaveatIterator") + require.True(t, ok2, "Second subiterator should be a CaveatIterator") + }) + + t.Run("pushes caveat through union only on side with caveat", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Union[Relation(with caveat), Relation(no caveat)] + rel1 := createTestRelationIterator("test_caveat") + rel2 := createTestRelationIteratorNoCaveat() + union := NewUnion(rel1, rel2) + + // Wrap in caveat: Caveat(Union[Rel1, Rel2]) + caveatIterator := NewCaveatIterator(union, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.True(t, changed) + + // Should become Union[Caveat(Rel1), Rel2] + resultUnion, ok := result.(*Union) + require.True(t, ok, "Expected result to be a Union") + require.Len(t, resultUnion.subIts, 2) + + // First should be wrapped, second should not + caveat1, ok1 := resultUnion.subIts[0].(*CaveatIterator) + rel2Result, ok2 := resultUnion.subIts[1].(*RelationIterator) + require.True(t, ok1, "First subiterator should be a CaveatIterator") + require.True(t, ok2, "Second subiterator should be a RelationIterator (not wrapped)") + + // Verify the caveat wraps the correct relation + caveat1Sub, ok := caveat1.subiterator.(*RelationIterator) + require.True(t, ok) + require.Equal(t, rel1, caveat1Sub) + require.Equal(t, rel2, rel2Result) + }) + + t.Run("does not push caveat through intersection arrow", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create an IntersectionArrow with a relation that has the caveat + rel := createTestRelationIterator("test_caveat") + relNoCaveat := createTestRelationIteratorNoCaveat() + intersectionArrow := NewIntersectionArrow(rel, relNoCaveat) + + // Wrap in caveat + caveatIterator := NewCaveatIterator(intersectionArrow, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.False(t, changed, "Should not optimize through IntersectionArrow") + + // Should remain as Caveat(IntersectionArrow) + resultCaveat, ok := result.(*CaveatIterator) + require.True(t, ok, "Expected result to still be a CaveatIterator") + _, ok = resultCaveat.subiterator.(*IntersectionArrow) + require.True(t, ok, "Subiterator should still be IntersectionArrow") + }) + + t.Run("does not push when no subiterators have caveat", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Union[Relation(no caveat), Relation(no caveat)] + rel1 := createTestRelationIteratorNoCaveat() + rel2 := createTestRelationIteratorNoCaveat() + union := NewUnion(rel1, rel2) + + // Wrap in caveat: Caveat(Union[Rel1, Rel2]) + caveatIterator := NewCaveatIterator(union, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.False(t, changed) + + // Should remain unchanged + resultCaveat, ok := result.(*CaveatIterator) + require.True(t, ok) + require.Equal(t, caveatIterator, resultCaveat) + }) + + t.Run("does not push through leaf iterator", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Caveat(Relation) - leaf has no subiterators + rel := createTestRelationIterator("test_caveat") + caveatIterator := NewCaveatIterator(rel, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.False(t, changed) + + // Should remain unchanged + resultCaveat, ok := result.(*CaveatIterator) + require.True(t, ok) + require.Equal(t, caveatIterator, resultCaveat) + }) + + t.Run("pushes through nested union", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Caveat(Union[Union[Rel1, Rel2], Rel3]) + rel1 := createTestRelationIterator("test_caveat") + rel2 := createTestRelationIteratorNoCaveat() + innerUnion := NewUnion(rel1, rel2) + + rel3 := createTestRelationIterator("test_caveat") + outerUnion := NewUnion(innerUnion, rel3) + + caveatIterator := NewCaveatIterator(outerUnion, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.True(t, changed) + + // Due to recursive optimization, this will become: + // Union[Union[Caveat(Rel1), Rel2], Caveat(Rel3)] + // The outer caveat pushes down to wrap innerUnion and rel3 + // Then the caveat on innerUnion recursively pushes down to only wrap rel1 + resultUnion, ok := result.(*Union) + require.True(t, ok) + require.Len(t, resultUnion.subIts, 2) + + // First should be Union[Caveat(Rel1), Rel2] (caveat pushed down further) + innerResultUnion, ok1 := resultUnion.subIts[0].(*Union) + require.True(t, ok1, "First subiterator should be a Union (caveat pushed down)") + require.Len(t, innerResultUnion.subIts, 2) + _, ok = innerResultUnion.subIts[0].(*CaveatIterator) + require.True(t, ok, "First element of inner union should be Caveat(Rel1)") + _, ok = innerResultUnion.subIts[1].(*RelationIterator) + require.True(t, ok, "Second element of inner union should be Rel2 (no caveat)") + + // Second should be Caveat(Rel3) + caveat2, ok2 := resultUnion.subIts[1].(*CaveatIterator) + require.True(t, ok2) + _, ok = caveat2.subiterator.(*RelationIterator) + require.True(t, ok, "Second subiterator should be Caveat(Relation)") + }) + + t.Run("works with intersection of relations", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Caveat(Intersection[Rel1(with caveat), Rel2(no caveat)]) + rel1 := createTestRelationIterator("test_caveat") + rel2 := createTestRelationIteratorNoCaveat() + intersection := NewIntersection(rel1, rel2) + + caveatIterator := NewCaveatIterator(intersection, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(caveatIterator, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.True(t, changed) + + // Should become Intersection[Caveat(Rel1), Rel2] + resultIntersection, ok := result.(*Intersection) + require.True(t, ok) + require.Len(t, resultIntersection.subIts, 2) + + // First should be wrapped, second should not + _, ok1 := resultIntersection.subIts[0].(*CaveatIterator) + _, ok2 := resultIntersection.subIts[1].(*RelationIterator) + require.True(t, ok1, "First subiterator should be a CaveatIterator") + require.True(t, ok2, "Second subiterator should be a RelationIterator") + }) +} + +func TestContainsCaveat(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + t.Run("detects caveat in relation iterator", func(t *testing.T) { + t.Parallel() + + rel := createTestRelationIterator("test_caveat") + require.True(t, containsCaveat(rel, caveat)) + }) + + t.Run("does not detect when caveat name differs", func(t *testing.T) { + t.Parallel() + + rel := createTestRelationIterator("other_caveat") + require.False(t, containsCaveat(rel, caveat)) + }) + + t.Run("does not detect when no caveat", func(t *testing.T) { + t.Parallel() + + rel := createTestRelationIteratorNoCaveat() + require.False(t, containsCaveat(rel, caveat)) + }) + + t.Run("detects caveat in nested structure", func(t *testing.T) { + t.Parallel() + + rel1 := createTestRelationIteratorNoCaveat() + rel2 := createTestRelationIterator("test_caveat") + union := NewUnion(rel1, rel2) + + require.True(t, containsCaveat(union, caveat)) + }) + + t.Run("does not detect caveat in structure without it", func(t *testing.T) { + t.Parallel() + + rel1 := createTestRelationIteratorNoCaveat() + rel2 := createTestRelationIteratorNoCaveat() + union := NewUnion(rel1, rel2) + + require.False(t, containsCaveat(union, caveat)) + }) + + t.Run("handles nil caveat in relationContainsCaveat", func(t *testing.T) { + t.Parallel() + + rel := createTestRelationIterator("test_caveat") + require.False(t, relationContainsCaveat(rel, nil)) + }) + + t.Run("handles relation with nil base in relationContainsCaveat", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + // Create a RelationIterator with nil base + rel := &RelationIterator{base: nil} + require.False(t, relationContainsCaveat(rel, caveat)) + }) +} + +func TestPushdownCaveatEvaluationEdgeCases(t *testing.T) { + t.Parallel() + + t.Run("does not push through nested CaveatIterator", func(t *testing.T) { + t.Parallel() + + caveat := createTestCaveatForPushdown("test_caveat") + + // Create Caveat(Caveat(Relation)) + rel := createTestRelationIterator("test_caveat") + innerCaveat := NewCaveatIterator(rel, caveat) + outerCaveat := NewCaveatIterator(innerCaveat, caveat) + + // Apply optimization + result, changed, err := ApplyOptimizations(outerCaveat, []OptimizerFunc{ + WrapOptimizer[*CaveatIterator](PushdownCaveatEvaluation), + }) + require.NoError(t, err) + require.False(t, changed, "Should not push through nested CaveatIterator to prevent infinite recursion") + + // Should remain unchanged + resultCaveat, ok := result.(*CaveatIterator) + require.True(t, ok) + _, ok = resultCaveat.subiterator.(*CaveatIterator) + require.True(t, ok, "Subiterator should still be a CaveatIterator") + }) +} diff --git a/pkg/query/optimize_simple.go b/pkg/query/optimize_simple.go new file mode 100644 index 000000000..04cfe434c --- /dev/null +++ b/pkg/query/optimize_simple.go @@ -0,0 +1,60 @@ +package query + +import "slices" + +// CollapseSingletonUnionAndIntersection removes unnecessary union and intersection wrappers +// that contain only a single subiterator. +func CollapseSingletonUnionAndIntersection(it Iterator) (Iterator, bool, error) { + switch v := it.(type) { + case *Union: + if len(v.subIts) == 1 { + return v.subIts[0], true, nil + } + case *Intersection: + if len(v.subIts) == 1 { + return v.subIts[0], true, nil + } + } + return it, false, nil +} + +// RemoveNullIterators removes null iterators from union and intersection operations. +// Unions, removes the empty set (A | 0 = A), Intersection, returns a null itself (A & 0 = 0) +func RemoveNullIterators(it Iterator) (Iterator, bool, error) { + switch v := it.(type) { + case *Union: + subs := v.Subiterators() + hasEmpty := false + newSubs := make([]Iterator, 0) + for _, s := range subs { + if isEmptyFixed(s) { + hasEmpty = true + } else { + newSubs = append(newSubs, s) + } + } + if hasEmpty { + // If all subiterators were empty, return empty + if len(newSubs) == 0 { + return NewEmptyFixedIterator(), true, nil + } + newit, err := it.ReplaceSubiterators(newSubs) + return newit, true, err + } + case *Intersection: + if slices.ContainsFunc(v.Subiterators(), isEmptyFixed) { + return NewEmptyFixedIterator(), true, nil + } + } + return it, false, nil +} + +// isEmptyFixed detects an empty, fixed iterator, used as a null. +func isEmptyFixed(it Iterator) bool { + if v, ok := it.(*FixedIterator); ok { + if len(v.paths) == 0 { + return true + } + } + return false +} diff --git a/pkg/query/optimize_test.go b/pkg/query/optimize_test.go new file mode 100644 index 000000000..537275e7c --- /dev/null +++ b/pkg/query/optimize_test.go @@ -0,0 +1,419 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// newNonEmptyFixedIterator creates a FixedIterator with at least one path for testing +func newNonEmptyFixedIterator() *FixedIterator { + return NewFixedIterator(Path{ + Resource: Object{ObjectType: "doc", ObjectID: "test"}, + Relation: "viewer", + Subject: ObjectAndRelation{ + ObjectType: "user", + ObjectID: "alice", + Relation: "...", + }, + }) +} + +func TestWrapOptimizer(t *testing.T) { + t.Parallel() + + t.Run("matches correct type", func(t *testing.T) { + t.Parallel() + + // Create a typed optimizer that only works on Union + typedOptimizer := func(u *Union) (Iterator, bool, error) { + if len(u.subIts) == 1 { + return u.subIts[0], true, nil + } + return u, false, nil + } + + // Wrap it and use in ApplyOptimizations + wrapped := WrapOptimizer[*Union](typedOptimizer) + + // Test with a Union - should match and optimize + fixed := newNonEmptyFixedIterator() + union := NewUnion(fixed) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{wrapped}) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, fixed, result) + }) + + t.Run("does not match wrong type", func(t *testing.T) { + t.Parallel() + + // Create a typed optimizer that only works on Union + typedOptimizer := func(u *Union) (Iterator, bool, error) { + return u, true, nil // Would return true if called + } + + // Wrap it and use in ApplyOptimizations + wrapped := WrapOptimizer[*Union](typedOptimizer) + + // Test with an Intersection - should not match + intersection := NewIntersection() + result, changed, err := ApplyOptimizations(intersection, []OptimizerFunc{wrapped}) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, intersection, result) + }) +} + +func TestCollapseSingletonUnionAndIntersection(t *testing.T) { + t.Parallel() + + t.Run("collapses singleton union", func(t *testing.T) { + t.Parallel() + + fixed := newNonEmptyFixedIterator() + union := NewUnion(fixed) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{CollapseSingletonUnionAndIntersection}) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, fixed, result) + }) + + t.Run("collapses singleton intersection", func(t *testing.T) { + t.Parallel() + + fixed := newNonEmptyFixedIterator() + intersection := NewIntersection(fixed) + + result, changed, err := ApplyOptimizations(intersection, []OptimizerFunc{CollapseSingletonUnionAndIntersection}) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, fixed, result) + }) + + t.Run("does not collapse multi-element union", func(t *testing.T) { + t.Parallel() + + union := NewUnion(newNonEmptyFixedIterator(), newNonEmptyFixedIterator()) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{CollapseSingletonUnionAndIntersection}) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, union, result) + }) + + t.Run("does not collapse multi-element intersection", func(t *testing.T) { + t.Parallel() + + intersection := NewIntersection(newNonEmptyFixedIterator(), newNonEmptyFixedIterator()) + + result, changed, err := ApplyOptimizations(intersection, []OptimizerFunc{CollapseSingletonUnionAndIntersection}) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, intersection, result) + }) + + t.Run("does not collapse other iterator types", func(t *testing.T) { + t.Parallel() + + fixed := newNonEmptyFixedIterator() + result, changed, err := ApplyOptimizations(fixed, []OptimizerFunc{CollapseSingletonUnionAndIntersection}) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, fixed, result) + }) +} + +func TestRemoveNullIterators(t *testing.T) { + t.Parallel() + + t.Run("removes empty fixed from union", func(t *testing.T) { + t.Parallel() + + fixed := newNonEmptyFixedIterator() + empty := NewEmptyFixedIterator() + union := NewUnion(fixed, empty) + + result, changed, err := ApplyOptimizations(union, StaticOptimizations) + require.NoError(t, err) + require.True(t, changed) + + // Should remove the empty iterator, leaving a singleton union, which then gets collapsed + require.Equal(t, fixed, result) + }) + + t.Run("removes multiple empty fixed from union", func(t *testing.T) { + t.Parallel() + + fixed1 := newNonEmptyFixedIterator() + fixed2 := newNonEmptyFixedIterator() + empty1 := NewEmptyFixedIterator() + empty2 := NewEmptyFixedIterator() + union := NewUnion(fixed1, empty1, fixed2, empty2) + + result, changed, err := ApplyOptimizations(union, StaticOptimizations) + require.NoError(t, err) + require.True(t, changed) + + // Should remove both empty iterators, leaving a union with 2 elements + resultUnion, ok := result.(*Union) + require.True(t, ok) + require.Len(t, resultUnion.subIts, 2) + require.Equal(t, fixed1, resultUnion.subIts[0]) + require.Equal(t, fixed2, resultUnion.subIts[1]) + }) + + t.Run("replaces intersection with empty if it contains empty and fixed", func(t *testing.T) { + t.Parallel() + + fixed := newNonEmptyFixedIterator() + empty := NewEmptyFixedIterator() + intersection := NewIntersection(fixed, empty) + + result, changed, err := ApplyOptimizations(intersection, StaticOptimizations) + require.NoError(t, err) + require.True(t, changed) + + // Should return an empty fixed iterator + resultFixed, ok := result.(*FixedIterator) + require.True(t, ok) + require.Len(t, resultFixed.paths, 0) + }) + + t.Run("does not change union without empty iterators", func(t *testing.T) { + t.Parallel() + + fixed1 := newNonEmptyFixedIterator() + fixed2 := newNonEmptyFixedIterator() + union := NewUnion(fixed1, fixed2) + + result, changed, err := ApplyOptimizations(union, StaticOptimizations) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, union, result) + }) + + t.Run("does not change intersection without empty iterators", func(t *testing.T) { + t.Parallel() + + fixed1 := newNonEmptyFixedIterator() + fixed2 := newNonEmptyFixedIterator() + intersection := NewIntersection(fixed1, fixed2) + + result, changed, err := ApplyOptimizations(intersection, StaticOptimizations) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, intersection, result) + }) + + t.Run("returns empty when all union subiterators are empty", func(t *testing.T) { + t.Parallel() + + empty1 := NewEmptyFixedIterator() + empty2 := NewEmptyFixedIterator() + empty3 := NewEmptyFixedIterator() + union := NewUnion(empty1, empty2, empty3) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{RemoveNullIterators}) + require.NoError(t, err) + require.True(t, changed) + + // Should return an empty fixed iterator + resultFixed, ok := result.(*FixedIterator) + require.True(t, ok) + require.Len(t, resultFixed.paths, 0) + }) +} + +func TestApplyOptimizations(t *testing.T) { + t.Parallel() + + t.Run("applies optimization to nested iterators", func(t *testing.T) { + t.Parallel() + + // Create a union with a nested singleton union + fixed := newNonEmptyFixedIterator() + innerUnion := NewUnion(fixed) + outerUnion := NewUnion(innerUnion, newNonEmptyFixedIterator()) + + result, changed, err := ApplyOptimizations(outerUnion, StaticOptimizations) + require.NoError(t, err) + require.True(t, changed) + + // The outer union should still be a union (has 2 elements) + // but the inner singleton union should be collapsed + resultUnion, ok := result.(*Union) + require.True(t, ok) + require.Len(t, resultUnion.subIts, 2) + require.Equal(t, fixed, resultUnion.subIts[0]) + }) + + t.Run("chains multiple optimizations", func(t *testing.T) { + t.Parallel() + + // Create a union with a singleton union inside + fixed := newNonEmptyFixedIterator() + innerUnion := NewUnion(fixed) + outerUnion := NewUnion(innerUnion) + + result, changed, err := ApplyOptimizations(outerUnion, StaticOptimizations) + require.NoError(t, err) + require.True(t, changed) + // After optimizations: + // 1. Inner union collapsed: outerUnion has [fixed] + // 2. Outer union collapsed: returns fixed + require.Equal(t, fixed, result) + }) + + t.Run("returns unchanged when no optimizations apply", func(t *testing.T) { + t.Parallel() + + union := NewUnion(newNonEmptyFixedIterator(), newNonEmptyFixedIterator()) + + result, changed, err := ApplyOptimizations(union, StaticOptimizations) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, union, result) + }) + + t.Run("applies multiple optimizations in sequence", func(t *testing.T) { + t.Parallel() + + // Create two different optimizers + unionOptimizer := func(it Iterator) (Iterator, bool, error) { + if u, ok := it.(*Union); ok && len(u.subIts) == 1 { + return u.subIts[0], true, nil + } + return it, false, nil + } + + intersectionOptimizer := func(it Iterator) (Iterator, bool, error) { + if i, ok := it.(*Intersection); ok && len(i.subIts) == 1 { + return i.subIts[0], true, nil + } + return it, false, nil + } + + // Test that both are applied + fixed := newNonEmptyFixedIterator() + intersection := NewIntersection(fixed) + union := NewUnion(intersection) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{unionOptimizer, intersectionOptimizer}) + require.NoError(t, err) + require.True(t, changed) + // Both the union and intersection should be collapsed + require.Equal(t, fixed, result) + }) + + t.Run("handles empty optimizer list", func(t *testing.T) { + t.Parallel() + + union := NewUnion(newNonEmptyFixedIterator()) + + result, changed, err := ApplyOptimizations(union, []OptimizerFunc{}) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, union, result) + }) + + t.Run("optimizer order independence - removes empty then collapses singleton", func(t *testing.T) { + t.Parallel() + + // Create a union with an empty iterator and one non-empty iterator + // After removing the empty, we have a singleton union that should be collapsed + fixed := newNonEmptyFixedIterator() + empty := NewEmptyFixedIterator() + union := NewUnion(fixed, empty) + + // Test with RemoveNullIterators first, then CollapseSingletonUnionAndIntersection + result1, changed1, err := ApplyOptimizations(union, []OptimizerFunc{ + RemoveNullIterators, + CollapseSingletonUnionAndIntersection, + }) + require.NoError(t, err) + require.True(t, changed1) + + // Test with CollapseSingletonUnionAndIntersection first, then RemoveNullIterators + result2, changed2, err := ApplyOptimizations(union, []OptimizerFunc{ + CollapseSingletonUnionAndIntersection, + RemoveNullIterators, + }) + require.NoError(t, err) + require.True(t, changed2) + + // Both orders should produce the same final result: the fixed iterator + require.Equal(t, fixed, result1) + require.Equal(t, fixed, result2) + require.Equal(t, result1, result2) + }) + + t.Run("optimizer order independence - nested structure", func(t *testing.T) { + t.Parallel() + + // Create a more complex nested structure: + // Union[Intersection[Fixed, Empty], Fixed] + // Expected result: Fixed (the second one) + // After RemoveNullIterators: Union[Empty, Fixed] -> Union[Fixed] -> Fixed + // OR after processing intersection first: Union[Empty, Fixed] -> Union[Fixed] -> Fixed + + fixed1 := newNonEmptyFixedIterator() + fixed2 := newNonEmptyFixedIterator() + empty := NewEmptyFixedIterator() + innerIntersection := NewIntersection(fixed1, empty) + outerUnion := NewUnion(innerIntersection, fixed2) + + result1, changed1, err := ApplyOptimizations(outerUnion, []OptimizerFunc{ + RemoveNullIterators, + CollapseSingletonUnionAndIntersection, + }) + require.NoError(t, err) + require.True(t, changed1) + + result2, changed2, err := ApplyOptimizations(outerUnion, []OptimizerFunc{ + CollapseSingletonUnionAndIntersection, + RemoveNullIterators, + }) + require.NoError(t, err) + require.True(t, changed2) + + // Both should result in the same structure + // The intersection with empty should become empty, leaving Union[Empty, Fixed2] + // Then removing empty gives Union[Fixed2], which collapses to Fixed2 + require.Equal(t, result1, result2) + require.Equal(t, fixed2, result1) + require.Equal(t, fixed2, result2) + }) + + t.Run("optimizer order independence - union with all empty", func(t *testing.T) { + t.Parallel() + + // Union[Empty, Empty] should become Empty regardless of order + empty1 := NewEmptyFixedIterator() + empty2 := NewEmptyFixedIterator() + union := NewUnion(empty1, empty2) + + result1, changed1, err := ApplyOptimizations(union, []OptimizerFunc{ + RemoveNullIterators, + CollapseSingletonUnionAndIntersection, + }) + require.NoError(t, err) + require.True(t, changed1) + + result2, changed2, err := ApplyOptimizations(union, []OptimizerFunc{ + CollapseSingletonUnionAndIntersection, + RemoveNullIterators, + }) + require.NoError(t, err) + require.True(t, changed2) + + // Both should result in an empty fixed iterator + // After removing all empties, we get Union[], which then gets optimized + // to an empty fixed iterator by CollapseSingletonUnionAndIntersection + require.Equal(t, result1, result2) + _, ok := result1.(*FixedIterator) + require.True(t, ok, "result1 should be a FixedIterator (empty), got %T", result1) + }) +} diff --git a/pkg/query/path.go b/pkg/query/path.go index 4fad9570c..945b71a77 100644 --- a/pkg/query/path.go +++ b/pkg/query/path.go @@ -46,6 +46,12 @@ func (p Path) Key() string { return fmt.Sprintf("%s#%s@%s", p.Resource.Key(), p.Relation, ObjectAndRelationKey(p.Subject)) } +// EndpointsKey returns a unique string key for this Path based on its resource and subject only, +// excluding the relation. This matches the semantics of EqualsEndpoints. +func (p Path) EndpointsKey() string { + return fmt.Sprintf("%s@%s", p.Resource.Key(), ObjectAndRelationKey(p.Subject)) +} + // MergeOr combines the paths, ORing the caveats and expiration and metadata together. // Returns a new Path with the merged values. func (p Path) MergeOr(other Path) (Path, error) { @@ -300,3 +306,39 @@ func CollectAll(seq PathSeq) ([]Path, error) { } return out, nil } + +// DeduplicatePathSeq returns a new PathSeq that deduplicates paths based on their +// endpoints (resource and subject, excluding relation). Paths with the same endpoints +// are merged using OR semantics (caveats are OR'd, no caveat wins over caveat). +// This collects all paths first, deduplicates with merging, then yields results. +func DeduplicatePathSeq(seq PathSeq) PathSeq { + return func(yield func(Path, error) bool) { + seen := make(map[string]Path) + for path, err := range seq { + if err != nil { + yield(Path{}, err) + return + } + + key := path.EndpointsKey() + if existing, exists := seen[key]; !exists { + seen[key] = path + } else { + // Merge with existing path using OR semantics + merged, err := existing.MergeOr(path) + if err != nil { + yield(Path{}, err) + return + } + seen[key] = merged + } + } + + // Yield all deduplicated paths + for _, path := range seen { + if !yield(path, nil) { + return + } + } + } +} diff --git a/pkg/query/path_test.go b/pkg/query/path_test.go index 44814e62c..1638ab593 100644 --- a/pkg/query/path_test.go +++ b/pkg/query/path_test.go @@ -52,9 +52,9 @@ func TestPath_IsExpired(t *testing.T) { t.Run("exact_now_expiration", func(t *testing.T) { t.Parallel() - now := time.Now() + now := time.Now().Add(-time.Millisecond) path := &Path{Expiration: &now} - // Should be considered expired if exactly at current time + // Should be considered expired if expiration is in the past require.True(path.IsExpired()) }) } diff --git a/pkg/query/tracing_test.go b/pkg/query/tracing_test.go index 725915117..138b22899 100644 --- a/pkg/query/tracing_test.go +++ b/pkg/query/tracing_test.go @@ -98,6 +98,5 @@ func TestIteratorTracing(t *testing.T) { require.True(t, strings.Contains(trace, "-> Union: check(document:doc1, user:alice)")) require.True(t, strings.Contains(trace, "<- Union: returned 1 paths")) require.True(t, strings.Contains(trace, "Union: processing 2 sub-iterators")) - require.True(t, strings.Contains(trace, "Union: deduplicated to")) }) } diff --git a/pkg/query/union.go b/pkg/query/union.go index b8265811d..f1558b816 100644 --- a/pkg/query/union.go +++ b/pkg/query/union.go @@ -12,8 +12,13 @@ type Union struct { var _ Iterator = &Union{} -func NewUnion() *Union { - return &Union{} +func NewUnion(subiterators ...Iterator) *Union { + if len(subiterators) == 0 { + return &Union{} + } + return &Union{ + subIts: subiterators, + } } func (u *Union) addSubIterator(subIt Iterator) { @@ -21,62 +26,37 @@ func (u *Union) addSubIterator(subIt Iterator) { } func (u *Union) CheckImpl(ctx *Context, resources []Object, subject ObjectAndRelation) (PathSeq, error) { - var out []Path - // Collect paths from all sub-iterators ctx.TraceStep(u, "processing %d sub-iterators with %d resources", len(u.subIts), len(resources)) - for iterIdx, it := range u.subIts { - ctx.TraceStep(u, "processing sub-iterator %d", iterIdx) - - pathSeq, err := ctx.Check(it, resources, subject) - if err != nil { - return nil, err - } - paths, err := CollectAll(pathSeq) - if err != nil { - return nil, err - } + // Create a concatenated sequence from all sub-iterators + combinedSeq := func(yield func(Path, error) bool) { + for iterIdx, it := range u.subIts { + ctx.TraceStep(u, "processing sub-iterator %d", iterIdx) - ctx.TraceStep(u, "sub-iterator %d returned %d paths", iterIdx, len(paths)) - out = append(out, paths...) - } - - ctx.TraceStep(u, "collected %d total paths before deduplication", len(out)) - - // Deduplicate paths based on resource for CheckImpl - // Since the subject is fixed in CheckImpl, we only need to deduplicate by resource - seen := make(map[string]Path) - for _, path := range out { - // Use resource object (type + id) as key for deduplication, not the full resource with relation - key := path.Resource.Key() - if existing, exists := seen[key]; !exists { - seen[key] = path - } else { - // If we already have a path for this resource, - // merge it with the new one using OR semantics - merged, err := existing.MergeOr(path) + pathSeq, err := ctx.Check(it, resources, subject) if err != nil { - return nil, err + yield(Path{}, err) + return } - seen[key] = merged - } - } - - // Convert map to slice - deduplicatedSlice := make([]Path, 0, len(seen)) - for _, path := range seen { - deduplicatedSlice = append(deduplicatedSlice, path) - } - - ctx.TraceStep(u, "deduplicated to %d paths", len(deduplicatedSlice)) - return func(yield func(Path, error) bool) { - for _, path := range deduplicatedSlice { - if !yield(path, nil) { - return + pathCount := 0 + for path, err := range pathSeq { + if err != nil { + yield(Path{}, err) + return + } + pathCount++ + if !yield(path, nil) { + return + } } + + ctx.TraceStep(u, "sub-iterator %d returned %d paths", iterIdx, pathCount) } - }, nil + } + + // Wrap with deduplication + return DeduplicatePathSeq(combinedSeq), nil } func (u *Union) IterSubjectsImpl(ctx *Context, resource Object) (PathSeq, error) {