Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Evaluation results are now cached and reused between rules.
Browse files Browse the repository at this point in the history
This will now cache rule evaluation between AST node of the same run. It prevents two equal AST nodes from executing at the same time, using the result for all concurrent instances. Also, it will use the cached result in subsequent instance of the same AST node.
apognu committed Jan 30, 2025
1 parent 8597f7a commit fc96036
Showing 9 changed files with 276 additions and 54 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -158,6 +158,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -358,6 +358,8 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
10 changes: 9 additions & 1 deletion models/ast/ast_node.go
Original file line number Diff line number Diff line change
@@ -4,10 +4,12 @@ import (
"fmt"

"github.com/cockroachdb/errors"
"github.com/mitchellh/hashstructure/v2"
)

type Node struct {
Index int
Index int `hash:"ignore"`
UsedCache bool `hash:"ignore"`

// A node is a constant xOR a function
Function Function
@@ -52,6 +54,12 @@ func (node Node) ReadConstantNamedChildString(name string) (string, error) {
return value, nil
}

func (node Node) Hash() uint64 {
hash, _ := hashstructure.Hash(node, hashstructure.FormatV2, nil)

return hash
}

// Cost calculates the weights of an AST subtree to reorder, when the parent is commutative,
// nodes to prioritize faster ones.
func (node Node) Cost() int {
3 changes: 3 additions & 0 deletions models/ast/ast_node_evaluation.go
Original file line number Diff line number Diff line change
@@ -14,6 +14,9 @@ type NodeEvaluation struct {
// Skipped indicates whether this node was evaluated at all or not. A `true` values means the
// engine determined the result of this node would not impact the overall decision's outcome.
Skipped bool
// UsedCache indicates that a node results where already computed by a previous AST node and pulled
// from the cache.
UsedCache bool

Function Function
ReturnValue any
128 changes: 90 additions & 38 deletions usecases/ast_eval/evaluate_ast.go
Original file line number Diff line number Diff line change
@@ -2,12 +2,29 @@ package ast_eval

import (
"context"
"fmt"
"sync"

"github.com/checkmarble/marble-backend/models/ast"
"github.com/checkmarble/marble-backend/pure_utils"
"golang.org/x/sync/singleflight"
)

func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node ast.Node) (ast.NodeEvaluation, bool) {
type EvaluationCache struct {
Cache sync.Map
Executor *singleflight.Group
}

func NewEvaluationCache() *EvaluationCache {
return &EvaluationCache{
Cache: sync.Map{},
Executor: new(singleflight.Group),
}
}

func EvaluateAst(ctx context.Context, cache *EvaluationCache,
environment AstEvaluationEnvironment, node ast.Node,
) (ast.NodeEvaluation, bool) {
// Early exit for constant, because it should have no children.
if node.Function == ast.FUNC_CONSTANT {
return ast.NodeEvaluation{
@@ -18,13 +35,30 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
}, true
}

type nodeEvaluationResponse struct {
eval ast.NodeEvaluation
ok bool
}

hash := node.Hash()

if cache != nil {
if cached, ok := cache.Cache.Load(hash); ok {
response := cached.(nodeEvaluationResponse)
response.eval.Index = node.Index
response.eval.UsedCache = true

return response.eval, response.ok
}
}

childEvaluationFail := false

// Only interested in lazy callback which will have default value if an error is returned
attrs, _ := node.Function.Attributes()

evalChild := func(child ast.Node) (childEval ast.NodeEvaluation, evalNext bool) {
childEval, ok := EvaluateAst(ctx, environment, child)
childEval, ok := EvaluateAst(ctx, cache, environment, child)

if !ok {
childEvaluationFail = true
@@ -38,53 +72,71 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
return
}

weightedNodes := NewWeightedNodes(environment, node, node.Children)
cachedExecutor := new(singleflight.Group)

// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
if cache != nil {
cachedExecutor = cache.Executor
}

if childEvaluationFail {
// an error occurred in at least one of the children. Stop the evaluation.
eval, _, _ := cachedExecutor.Do(fmt.Sprintf("%d", hash), func() (any, error) {
weightedNodes := NewWeightedNodes(environment, node, node.Children)

// the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened.
if node.Function == ast.FUNC_UNDEFINED {
evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction)
// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
}

return evaluation, false
}
if childEvaluationFail {
// an error occurred in at least one of the children. Stop the evaluation.

getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue }
arguments := ast.Arguments{
Args: pure_utils.Map(evaluation.Children, getReturnValue),
NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue),
}
// the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened.
if node.Function == ast.FUNC_UNDEFINED {
evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction)
}

evaluator, err := environment.GetEvaluator(node.Function)
if err != nil {
evaluation.Errors = append(evaluation.Errors, err)
return evaluation, false
}
return nodeEvaluationResponse{evaluation, false}, nil
}

evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments)
getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue }
arguments := ast.Arguments{
Args: pure_utils.Map(evaluation.Children, getReturnValue),
NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue),
}

if evaluation.Errors == nil {
// Assign an empty array to indicate that the evaluation occured.
// The evaluator is not supposed to return a nil array of errors, but let's be nice.
evaluation.Errors = []error{}
}
evaluator, err := environment.GetEvaluator(node.Function)
if err != nil {
evaluation.Errors = append(evaluation.Errors, err)
return nodeEvaluationResponse{evaluation, false}, nil
}

ok := len(evaluation.Errors) == 0
evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments)

if !ok {
// The evaluator is supposed to return nil ReturnValue when an error is present.
evaluation.ReturnValue = nil
}
if evaluation.Errors == nil {
// Assign an empty array to indicate that the evaluation occured.
// The evaluator is not supposed to return a nil array of errors, but let's be nice.
evaluation.Errors = []error{}
}

ok := len(evaluation.Errors) == 0

if !ok {
// The evaluator is supposed to return nil ReturnValue when an error is present.
evaluation.ReturnValue = nil
}

evaluationResponse := nodeEvaluationResponse{evaluation, ok}

if cache != nil {
cache.Cache.Store(hash, evaluationResponse)
}

return evaluationResponse, nil
})

evaluation := eval.(nodeEvaluationResponse)

return evaluation, ok
return evaluation.eval, evaluation.ok
}
3 changes: 2 additions & 1 deletion usecases/ast_eval/evaluate_ast_expression.go
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ type EvaluateAstExpression struct {

func (evaluator *EvaluateAstExpression) EvaluateAstExpression(
ctx context.Context,
cache *EvaluationCache,
ruleAstExpression ast.Node,
organizationId string,
payload models.ClientObject,
@@ -27,7 +28,7 @@ func (evaluator *EvaluateAstExpression) EvaluateAstExpression(
DatabaseAccessReturnFakeValue: false,
})

evaluation, ok := EvaluateAst(ctx, environment, ruleAstExpression)
evaluation, ok := EvaluateAst(ctx, cache, environment, ruleAstExpression)
if !ok {
return evaluation, errors.Join(evaluation.FlattenErrors()...)
}
164 changes: 154 additions & 10 deletions usecases/ast_eval/evaluate_ast_test.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,8 @@ package ast_eval

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/checkmarble/marble-backend/models/ast"
@@ -14,7 +16,7 @@ import (
func TestEval(t *testing.T) {
environment := NewAstEvaluationEnvironment()
root := ast.NewAstCompareBalance()
evaluation, ok := EvaluateAst(context.TODO(), environment, root)
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root)
assert.True(t, ok)
assert.Len(t, evaluation.Errors, 0)
assert.Equal(t, true, evaluation.ReturnValue)
@@ -23,7 +25,7 @@ func TestEval(t *testing.T) {
func TestEvalUndefinedFunction(t *testing.T) {
environment := NewAstEvaluationEnvironment()
root := ast.Node{Function: ast.FUNC_UNDEFINED}
evaluation, ok := EvaluateAst(context.TODO(), environment, root)
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root)
assert.False(t, ok)
if assert.Len(t, evaluation.Errors, 1) {
assert.ErrorIs(t, evaluation.Errors[0], ast.ErrUndefinedFunction)
@@ -33,22 +35,22 @@ func TestEvalUndefinedFunction(t *testing.T) {
func TestEvalAndOrFunction(t *testing.T) {
environment := NewAstEvaluationEnvironment()

evaluation, ok := EvaluateAst(context.TODO(), environment, NewAstAndTrue())
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, NewAstAndTrue())
assert.True(t, ok)
assert.Len(t, evaluation.Errors, 0)
assert.Equal(t, true, evaluation.ReturnValue)

evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstAndFalse())
evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstAndFalse())
assert.True(t, ok)
assert.Len(t, evaluation.Errors, 0)
assert.Equal(t, false, evaluation.ReturnValue)

evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstOrTrue())
evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstOrTrue())
assert.True(t, ok)
assert.Len(t, evaluation.Errors, 0)
assert.Equal(t, true, evaluation.ReturnValue)

evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstOrFalse())
evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstOrFalse())
assert.True(t, ok)
assert.Len(t, evaluation.Errors, 0)
assert.Equal(t, false, evaluation.ReturnValue)
@@ -104,7 +106,7 @@ func TestLazyAnd(t *testing.T) {
AddChild(ast.Node{Constant: true})).
AddChild(ast.Node{Function: ast.FUNC_UNKNOWN})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root)

switch value {
case false:
@@ -127,7 +129,7 @@ func TestLazyOr(t *testing.T) {
AddChild(ast.Node{Constant: true})).
AddChild(ast.Node{Function: ast.FUNC_UNKNOWN})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root)

switch value {
case true:
@@ -169,7 +171,7 @@ func TestLazyBooleanNulls(t *testing.T) {
}
}

evaluation, _ := EvaluateAst(context.TODO(), environment, root)
evaluation, _ := EvaluateAst(context.TODO(), nil, environment, root)

switch {
case tt.res == nil:
@@ -202,11 +204,153 @@ func TestAggregatesOrderedLast(t *testing.T) {
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{Constant: true})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)
evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root)

assert.True(t, ok)
assert.Equal(t, ast.NodeEvaluation{Index: 0, Skipped: true, ReturnValue: nil}, evaluation.Children[0])
assert.Equal(t, false, evaluation.Children[1].Skipped)
assert.Equal(t, true, evaluation.Children[1].ReturnValue)
assert.Equal(t, true, evaluation.ReturnValue)
}

func TestAstNodeHash(t *testing.T) {
tts := []struct {
lhs ast.Node
rhs ast.Node
equal bool
}{
{ast.Node{Constant: true}, ast.Node{Constant: true}, true},
{ast.Node{Constant: true}, ast.Node{Constant: false}, false},
{
ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}},
ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}},
true,
},
{
ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}},
ast.Node{Children: []ast.Node{{Constant: true}, {Constant: true}}},
false,
},
{
ast.Node{
NamedChildren: map[string]ast.Node{
"x": {Constant: true},
},
},
ast.Node{
NamedChildren: map[string]ast.Node{
"x": {Constant: true},
},
},
true,
},
{
ast.Node{
NamedChildren: map[string]ast.Node{
"x": {Constant: true},
},
},
ast.Node{
NamedChildren: map[string]ast.Node{
"x": {Constant: false},
},
},
false,
},
}

for _, tt := range tts {
assert.Equal(t, tt.equal, tt.lhs.Hash() == tt.rhs.Hash())
}
}

type countingNode struct {
hits atomic.Int64
}

func (n *countingNode) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) {
n.hits.Add(1)

return evaluate.MakeEvaluateResult(true)
}

func TestCachedEvaluation(t *testing.T) {
ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{
Cost: 1000,
}

defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY)

node := &countingNode{}

environment := NewAstEvaluationEnvironment()
environment.AddEvaluator(TEST_FUNC_COSTLY, node)

var wg sync.WaitGroup

cache := NewEvaluationCache()

for range 10 {
wg.Add(1)

go func() {
defer wg.Done()

root := ast.Node{Function: ast.FUNC_AND}.
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{Function: TEST_FUNC_COSTLY}).
AddChild(ast.Node{
Function: ast.FUNC_AND,
Children: []ast.Node{
{Function: TEST_FUNC_COSTLY},
{Function: TEST_FUNC_COSTLY},
},
}).
AddChild(ast.Node{Constant: true})

_, _ = EvaluateAst(context.TODO(), cache, environment, root)
}()
}

wg.Wait()

assert.Equal(t, int64(1), node.hits.Load())
}

func TestCachedEvaluationWithDifferentParams(t *testing.T) {
ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{
Cost: 1000,
}

defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY)

node := &countingNode{}

environment := NewAstEvaluationEnvironment()
environment.AddEvaluator(TEST_FUNC_COSTLY, node)

var wg sync.WaitGroup

cache := NewEvaluationCache()

for range 10 {
wg.Add(1)

go func() {
defer wg.Done()

root := ast.Node{Function: ast.FUNC_AND}.
AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 1}}}).
AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 2}}}).
AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 1}}}).
AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 2}}})

_, _ = EvaluateAst(context.TODO(), cache, environment, root)
}()
}

wg.Wait()

assert.Equal(t, int64(2), node.hits.Load())
}
13 changes: 12 additions & 1 deletion usecases/evaluate_scenario/evaluate_scenario.go
Original file line number Diff line number Diff line change
@@ -68,10 +68,13 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara
ingestedDataReadRepository: repositories.IngestedDataReadRepository,
}

cache := ast_eval.NewEvaluationCache()

// Evaluate the trigger

errEval := evalScenarioTrigger(
ctx,
cache,
repositories,
*iteration.TriggerConditionAstExpression,
dataAccessor.organizationId,
@@ -111,6 +114,7 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara
// Evaluate all rules
score, ruleExecutions, errEval := evalAllScenarioRules(
ctx,
cache,
repositories,
iteration.Rules,
dataAccessor,
@@ -287,6 +291,7 @@ func EvalScenario(

func evalScenarioRule(
ctx context.Context,
cache *ast_eval.EvaluationCache,
repositories ScenarioEvaluationRepositories,
rule models.Rule,
dataAccessor DataAccessor,
@@ -321,6 +326,7 @@ func evalScenarioRule(
// Evaluate single rule
ruleEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression(
ctx,
cache,
*rule.FormulaAstExpression,
dataAccessor.organizationId,
dataAccessor.ClientObject,
@@ -375,6 +381,7 @@ func evalScenarioRule(

func evalScenarioTrigger(
ctx context.Context,
cache *ast_eval.EvaluationCache,
repositories ScenarioEvaluationRepositories,
triggerAstExpression ast.Node,
organizationId string,
@@ -387,6 +394,7 @@ func evalScenarioTrigger(

triggerEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression(
ctx,
cache,
triggerAstExpression,
organizationId,
payload,
@@ -422,6 +430,7 @@ func evalScenarioTrigger(

func evalAllScenarioRules(
ctx context.Context,
cache *ast_eval.EvaluationCache,
repositories ScenarioEvaluationRepositories,
rules []models.Rule,
dataAccessor DataAccessor,
@@ -448,7 +457,8 @@ func evalAllScenarioRules(
}

// Eval each rule
scoreModifier, ruleExecution, err := evalScenarioRule(ctx, repositories, rule, dataAccessor, dataModel, snoozes)
scoreModifier, ruleExecution, err := evalScenarioRule(ctx, cache,
repositories, rule, dataAccessor, dataModel, snoozes)
if err != nil {
return err // First err will cancel the ctx
}
@@ -521,6 +531,7 @@ func EvalCaseName(

caseNameEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression(
ctx,
nil,
*scenario.DecisionToCaseNameTemplate,
params.Scenario.OrganizationId,
params.ClientObject,
6 changes: 3 additions & 3 deletions usecases/scenarios/scenario_validation.go
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ func (self *ValidateScenarioIterationImpl) Validate(ctx context.Context,
Code: models.TriggerConditionRequired,
})
} else {
result.Trigger.TriggerEvaluation, _ = ast_eval.EvaluateAst(ctx, dryRunEnvironment, *trigger)
result.Trigger.TriggerEvaluation, _ = ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *trigger)
if _, ok := result.Trigger.TriggerEvaluation.ReturnValue.(bool); !ok {
result.Trigger.Errors = append(result.Trigger.Errors, models.ScenarioValidationError{
Error: errors.Wrap(models.BadParameterError,
@@ -110,7 +110,7 @@ func (self *ValidateScenarioIterationImpl) Validate(ctx context.Context,
})
result.Rules.Rules[rule.Id] = ruleValidation
} else {
ruleValidation.RuleEvaluation, _ = ast_eval.EvaluateAst(ctx, dryRunEnvironment, *formula)
ruleValidation.RuleEvaluation, _ = ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *formula)
if _, ok := ruleValidation.RuleEvaluation.ReturnValue.(bool); !ok {
ruleValidation.Errors = append(ruleValidation.Errors, models.ScenarioValidationError{
Error: errors.Wrap(models.BadParameterError,
@@ -149,7 +149,7 @@ func (self *ValidateScenarioAstImpl) Validate(ctx context.Context,
"unknown specified type '%s'", expectedReturnTypeStr)
}

astEvaluation, _ := ast_eval.EvaluateAst(ctx, dryRunEnvironment, *astNode)
astEvaluation, _ := ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *astNode)
astEvaluationReturnType := reflect.TypeOf(astEvaluation.ReturnValue)

if len(astEvaluation.FlattenErrors()) == 0 && astEvaluationReturnType != expectedReturnType {

0 comments on commit fc96036

Please sign in to comment.