Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions combinators.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,35 @@ func (g *deferredGen[V]) value(t *T) V {
return g.g.value(t)
}

func filter[V any](g *Generator[V], fn func(V) bool) *Generator[V] {
return newGenerator[V](&filteredGen[V]{
g: g,
fn: fn,
})
type filterOpts struct {
maxAttempts int
}
type filterOption func(*filterOpts) *filterOpts

func filter[V any](g *Generator[V], fn func(V) bool, opts ...filterOption) *Generator[V] {
impl := &filteredGen[V]{
g: g,
fn: fn,
filterOpts: &filterOpts{},
}

impl.applyOptions(opts...)

return newGenerator[V](impl)
}

func WithMaxAttempts(max int) filterOption {
return func(o *filterOpts) *filterOpts {
if max <= 0 {
panic(invalidData("max attempts should be greater than 0"))
}
o.maxAttempts = max
return o
}
}

type filteredGen[V any] struct {
*filterOpts
g *Generator[V]
fn func(V) bool
}
Expand All @@ -93,7 +114,15 @@ func (g *filteredGen[V]) String() string {
}

func (g *filteredGen[V]) value(t *T) V {
return find(g.maybeValue, t, small)
tries := g.maxAttempts
if tries <= 0 {
// If no max attempts specified, use the default value.
tries = flags.filterMaxattempts

}

return find(g.maybeValue, t, tries)

}

func (g *filteredGen[V]) maybeValue(t *T) (V, bool) {
Expand All @@ -106,6 +135,12 @@ func (g *filteredGen[V]) maybeValue(t *T) (V, bool) {
}
}

func (g *filteredGen[V]) applyOptions(opts ...filterOption) {
for _, opt := range opts {
g.filterOpts = opt(g.filterOpts)
}
}

func find[V any](gen func(*T) (V, bool), t *T, tries int) V {
for n := 0; n < tries; n++ {
i := t.s.beginGroup(tryLabel, false)
Expand Down
22 changes: 12 additions & 10 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ var (
)

type cmdline struct {
checks int
steps int
failfile string
nofailfile bool
seed uint64
log bool
verbose bool
debug bool
debugvis bool
shrinkTime time.Duration
checks int
filterMaxattempts int
steps int
failfile string
nofailfile bool
seed uint64
log bool
verbose bool
debug bool
debugvis bool
shrinkTime time.Duration
}

func init() {
flag.IntVar(&flags.checks, "rapid.checks", 100, "rapid: number of checks to perform")
flag.IntVar(&flags.filterMaxattempts, "rapid.filtermaxattempts", 1000, "rapid: maximum number of attempts to draw a valid value from a Filter generator")
flag.IntVar(&flags.steps, "rapid.steps", 30, "rapid: average number of Repeat actions to execute")
flag.StringVar(&flags.failfile, "rapid.failfile", "", "rapid: fail file to use to reproduce test failure")
flag.BoolVar(&flags.nofailfile, "rapid.nofailfile", false, "rapid: do not write fail files on test failures")
Expand Down
10 changes: 10 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ func TestPanicTraceback(t *testing.T) {
return err
},
},
{
"impossible filter with options",
"pgregory.net/rapid.find[...]",
false,
func(t *T) *testError {
g := Bool().Filter(func(bool) bool { return false }, WithMaxAttempts(10))
_, err := recoverValue(g, t)
return err
},
},
{
"broken custom generator",
"pgregory.net/rapid.brokenGen",
Expand Down
4 changes: 2 additions & 2 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func (g *Generator[V]) Example(seed ...int) V {
}

// Filter creates a generator producing only values from g for which fn returns true.
func (g *Generator[V]) Filter(fn func(V) bool) *Generator[V] {
return filter(g, fn)
func (g *Generator[V]) Filter(fn func(V) bool, opts ...filterOption) *Generator[V] {
return filter(g, fn, opts...)
}

// AsAny creates a generator producing values from g converted to any.
Expand Down
Loading