diff --git a/combinators.go b/combinators.go index 8195099..194a758 100644 --- a/combinators.go +++ b/combinators.go @@ -76,14 +76,35 @@ func (g *deferredGen[V]) value(t *T) V { return g.g.value(t) } -func filter[V any](g *Generator[V], fn func(V) bool) *Generator[V] { - return newGenerator[V](&filteredGen[V]{ - g: g, - fn: fn, - }) +type filterOpts struct { + maxAttempts int +} +type filterOption func(*filterOpts) *filterOpts + +func filter[V any](g *Generator[V], fn func(V) bool, opts ...filterOption) *Generator[V] { + impl := &filteredGen[V]{ + g: g, + fn: fn, + filterOpts: &filterOpts{}, + } + + impl.applyOptions(opts...) + + return newGenerator[V](impl) +} + +func WithMaxAttempts(max int) filterOption { + return func(o *filterOpts) *filterOpts { + if max <= 0 { + panic(invalidData("max attempts should be greater than 0")) + } + o.maxAttempts = max + return o + } } type filteredGen[V any] struct { + *filterOpts g *Generator[V] fn func(V) bool } @@ -93,7 +114,15 @@ func (g *filteredGen[V]) String() string { } func (g *filteredGen[V]) value(t *T) V { - return find(g.maybeValue, t, small) + tries := g.maxAttempts + if tries <= 0 { + // If no max attempts specified, use the default value. + tries = flags.filterMaxattempts + + } + + return find(g.maybeValue, t, tries) + } func (g *filteredGen[V]) maybeValue(t *T) (V, bool) { @@ -106,6 +135,12 @@ func (g *filteredGen[V]) maybeValue(t *T) (V, bool) { } } +func (g *filteredGen[V]) applyOptions(opts ...filterOption) { + for _, opt := range opts { + g.filterOpts = opt(g.filterOpts) + } +} + func find[V any](gen func(*T) (V, bool), t *T, tries int) V { for n := 0; n < tries; n++ { i := t.s.beginGroup(tryLabel, false) diff --git a/engine.go b/engine.go index 57c5674..9a8a5c2 100644 --- a/engine.go +++ b/engine.go @@ -47,20 +47,22 @@ var ( ) type cmdline struct { - checks int - steps int - failfile string - nofailfile bool - seed uint64 - log bool - verbose bool - debug bool - debugvis bool - shrinkTime time.Duration + checks int + filterMaxattempts int + steps int + failfile string + nofailfile bool + seed uint64 + log bool + verbose bool + debug bool + debugvis bool + shrinkTime time.Duration } func init() { flag.IntVar(&flags.checks, "rapid.checks", 100, "rapid: number of checks to perform") + flag.IntVar(&flags.filterMaxattempts, "rapid.filtermaxattempts", 1000, "rapid: maximum number of attempts to draw a valid value from a Filter generator") flag.IntVar(&flags.steps, "rapid.steps", 30, "rapid: average number of Repeat actions to execute") flag.StringVar(&flags.failfile, "rapid.failfile", "", "rapid: fail file to use to reproduce test failure") flag.BoolVar(&flags.nofailfile, "rapid.nofailfile", false, "rapid: do not write fail files on test failures") diff --git a/engine_test.go b/engine_test.go index fcee09b..785698b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -40,6 +40,16 @@ func TestPanicTraceback(t *testing.T) { return err }, }, + { + "impossible filter with options", + "pgregory.net/rapid.find[...]", + false, + func(t *T) *testError { + g := Bool().Filter(func(bool) bool { return false }, WithMaxAttempts(10)) + _, err := recoverValue(g, t) + return err + }, + }, { "broken custom generator", "pgregory.net/rapid.brokenGen", diff --git a/generator.go b/generator.go index 128c7b5..062ad33 100644 --- a/generator.go +++ b/generator.go @@ -91,8 +91,8 @@ func (g *Generator[V]) Example(seed ...int) V { } // Filter creates a generator producing only values from g for which fn returns true. -func (g *Generator[V]) Filter(fn func(V) bool) *Generator[V] { - return filter(g, fn) +func (g *Generator[V]) Filter(fn func(V) bool, opts ...filterOption) *Generator[V] { + return filter(g, fn, opts...) } // AsAny creates a generator producing values from g converted to any.