diff --git a/go.mod b/go.mod index 7218c3dfc..c8446da84 100644 --- a/go.mod +++ b/go.mod @@ -158,6 +158,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/term v0.5.0 // indirect diff --git a/go.sum b/go.sum index 3c7728db5..bb3bdfe4d 100644 --- a/go.sum +++ b/go.sum @@ -358,6 +358,8 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= +github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= diff --git a/models/ast/ast_node.go b/models/ast/ast_node.go index fbcfdeaa4..12749f35d 100644 --- a/models/ast/ast_node.go +++ b/models/ast/ast_node.go @@ -4,10 +4,12 @@ import ( "fmt" "github.com/cockroachdb/errors" + "github.com/mitchellh/hashstructure/v2" ) type Node struct { - Index int + Index int `hash:"ignore"` + UsedCache bool `hash:"ignore"` // A node is a constant xOR a function Function Function @@ -52,6 +54,12 @@ func (node Node) ReadConstantNamedChildString(name string) (string, error) { return value, nil } +func (node Node) Hash() uint64 { + hash, _ := hashstructure.Hash(node, hashstructure.FormatV2, nil) + + return hash +} + // Cost calculates the weights of an AST subtree to reorder, when the parent is commutative, // nodes to prioritize faster ones. func (node Node) Cost() int { diff --git a/models/ast/ast_node_evaluation.go b/models/ast/ast_node_evaluation.go index 93a7eadaf..582acdebf 100644 --- a/models/ast/ast_node_evaluation.go +++ b/models/ast/ast_node_evaluation.go @@ -14,6 +14,9 @@ type NodeEvaluation struct { // Skipped indicates whether this node was evaluated at all or not. A `true` values means the // engine determined the result of this node would not impact the overall decision's outcome. Skipped bool + // UsedCache indicates that a node results where already computed by a previous AST node and pulled + // from the cache. + UsedCache bool Function Function ReturnValue any diff --git a/usecases/ast_eval/evaluate_ast.go b/usecases/ast_eval/evaluate_ast.go index 5f7f2c9d1..be7d285a4 100644 --- a/usecases/ast_eval/evaluate_ast.go +++ b/usecases/ast_eval/evaluate_ast.go @@ -2,12 +2,29 @@ package ast_eval import ( "context" + "fmt" + "sync" "github.com/checkmarble/marble-backend/models/ast" "github.com/checkmarble/marble-backend/pure_utils" + "golang.org/x/sync/singleflight" ) -func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node ast.Node) (ast.NodeEvaluation, bool) { +type EvaluationCache struct { + Cache sync.Map + Executor *singleflight.Group +} + +func NewEvaluationCache() *EvaluationCache { + return &EvaluationCache{ + Cache: sync.Map{}, + Executor: new(singleflight.Group), + } +} + +func EvaluateAst(ctx context.Context, cache *EvaluationCache, + environment AstEvaluationEnvironment, node ast.Node, +) (ast.NodeEvaluation, bool) { // Early exit for constant, because it should have no children. if node.Function == ast.FUNC_CONSTANT { return ast.NodeEvaluation{ @@ -18,13 +35,30 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node }, true } + type nodeEvaluationResponse struct { + eval ast.NodeEvaluation + ok bool + } + + hash := node.Hash() + + if cache != nil { + if cached, ok := cache.Cache.Load(hash); ok { + response := cached.(nodeEvaluationResponse) + response.eval.Index = node.Index + response.eval.UsedCache = true + + return response.eval, response.ok + } + } + childEvaluationFail := false // Only interested in lazy callback which will have default value if an error is returned attrs, _ := node.Function.Attributes() evalChild := func(child ast.Node) (childEval ast.NodeEvaluation, evalNext bool) { - childEval, ok := EvaluateAst(ctx, environment, child) + childEval, ok := EvaluateAst(ctx, cache, environment, child) if !ok { childEvaluationFail = true @@ -38,53 +72,71 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node return } - weightedNodes := NewWeightedNodes(environment, node, node.Children) + cachedExecutor := new(singleflight.Group) - // eval each child - evaluation := ast.NodeEvaluation{ - Index: node.Index, - Function: node.Function, - Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)), - NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild), + if cache != nil { + cachedExecutor = cache.Executor } - if childEvaluationFail { - // an error occurred in at least one of the children. Stop the evaluation. + eval, _, _ := cachedExecutor.Do(fmt.Sprintf("%d", hash), func() (any, error) { + weightedNodes := NewWeightedNodes(environment, node, node.Children) - // the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened. - if node.Function == ast.FUNC_UNDEFINED { - evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction) + // eval each child + evaluation := ast.NodeEvaluation{ + Index: node.Index, + Function: node.Function, + Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)), + NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild), } - return evaluation, false - } + if childEvaluationFail { + // an error occurred in at least one of the children. Stop the evaluation. - getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue } - arguments := ast.Arguments{ - Args: pure_utils.Map(evaluation.Children, getReturnValue), - NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue), - } + // the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened. + if node.Function == ast.FUNC_UNDEFINED { + evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction) + } - evaluator, err := environment.GetEvaluator(node.Function) - if err != nil { - evaluation.Errors = append(evaluation.Errors, err) - return evaluation, false - } + return nodeEvaluationResponse{evaluation, false}, nil + } - evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments) + getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue } + arguments := ast.Arguments{ + Args: pure_utils.Map(evaluation.Children, getReturnValue), + NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue), + } - if evaluation.Errors == nil { - // Assign an empty array to indicate that the evaluation occured. - // The evaluator is not supposed to return a nil array of errors, but let's be nice. - evaluation.Errors = []error{} - } + evaluator, err := environment.GetEvaluator(node.Function) + if err != nil { + evaluation.Errors = append(evaluation.Errors, err) + return nodeEvaluationResponse{evaluation, false}, nil + } - ok := len(evaluation.Errors) == 0 + evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments) - if !ok { - // The evaluator is supposed to return nil ReturnValue when an error is present. - evaluation.ReturnValue = nil - } + if evaluation.Errors == nil { + // Assign an empty array to indicate that the evaluation occured. + // The evaluator is not supposed to return a nil array of errors, but let's be nice. + evaluation.Errors = []error{} + } + + ok := len(evaluation.Errors) == 0 + + if !ok { + // The evaluator is supposed to return nil ReturnValue when an error is present. + evaluation.ReturnValue = nil + } + + evaluationResponse := nodeEvaluationResponse{evaluation, ok} + + if cache != nil { + cache.Cache.Store(hash, evaluationResponse) + } + + return evaluationResponse, nil + }) + + evaluation := eval.(nodeEvaluationResponse) - return evaluation, ok + return evaluation.eval, evaluation.ok } diff --git a/usecases/ast_eval/evaluate_ast_expression.go b/usecases/ast_eval/evaluate_ast_expression.go index a2c176805..9c5d28227 100644 --- a/usecases/ast_eval/evaluate_ast_expression.go +++ b/usecases/ast_eval/evaluate_ast_expression.go @@ -15,6 +15,7 @@ type EvaluateAstExpression struct { func (evaluator *EvaluateAstExpression) EvaluateAstExpression( ctx context.Context, + cache *EvaluationCache, ruleAstExpression ast.Node, organizationId string, payload models.ClientObject, @@ -27,7 +28,7 @@ func (evaluator *EvaluateAstExpression) EvaluateAstExpression( DatabaseAccessReturnFakeValue: false, }) - evaluation, ok := EvaluateAst(ctx, environment, ruleAstExpression) + evaluation, ok := EvaluateAst(ctx, cache, environment, ruleAstExpression) if !ok { return evaluation, errors.Join(evaluation.FlattenErrors()...) } diff --git a/usecases/ast_eval/evaluate_ast_test.go b/usecases/ast_eval/evaluate_ast_test.go index 57fd3d54f..2c791af1c 100644 --- a/usecases/ast_eval/evaluate_ast_test.go +++ b/usecases/ast_eval/evaluate_ast_test.go @@ -2,6 +2,8 @@ package ast_eval import ( "context" + "sync" + "sync/atomic" "testing" "github.com/checkmarble/marble-backend/models/ast" @@ -14,7 +16,7 @@ import ( func TestEval(t *testing.T) { environment := NewAstEvaluationEnvironment() root := ast.NewAstCompareBalance() - evaluation, ok := EvaluateAst(context.TODO(), environment, root) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root) assert.True(t, ok) assert.Len(t, evaluation.Errors, 0) assert.Equal(t, true, evaluation.ReturnValue) @@ -23,7 +25,7 @@ func TestEval(t *testing.T) { func TestEvalUndefinedFunction(t *testing.T) { environment := NewAstEvaluationEnvironment() root := ast.Node{Function: ast.FUNC_UNDEFINED} - evaluation, ok := EvaluateAst(context.TODO(), environment, root) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root) assert.False(t, ok) if assert.Len(t, evaluation.Errors, 1) { assert.ErrorIs(t, evaluation.Errors[0], ast.ErrUndefinedFunction) @@ -33,22 +35,22 @@ func TestEvalUndefinedFunction(t *testing.T) { func TestEvalAndOrFunction(t *testing.T) { environment := NewAstEvaluationEnvironment() - evaluation, ok := EvaluateAst(context.TODO(), environment, NewAstAndTrue()) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, NewAstAndTrue()) assert.True(t, ok) assert.Len(t, evaluation.Errors, 0) assert.Equal(t, true, evaluation.ReturnValue) - evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstAndFalse()) + evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstAndFalse()) assert.True(t, ok) assert.Len(t, evaluation.Errors, 0) assert.Equal(t, false, evaluation.ReturnValue) - evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstOrTrue()) + evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstOrTrue()) assert.True(t, ok) assert.Len(t, evaluation.Errors, 0) assert.Equal(t, true, evaluation.ReturnValue) - evaluation, ok = EvaluateAst(context.TODO(), environment, NewAstOrFalse()) + evaluation, ok = EvaluateAst(context.TODO(), nil, environment, NewAstOrFalse()) assert.True(t, ok) assert.Len(t, evaluation.Errors, 0) assert.Equal(t, false, evaluation.ReturnValue) @@ -104,7 +106,7 @@ func TestLazyAnd(t *testing.T) { AddChild(ast.Node{Constant: true})). AddChild(ast.Node{Function: ast.FUNC_UNKNOWN}) - evaluation, ok := EvaluateAst(context.TODO(), environment, root) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root) switch value { case false: @@ -127,7 +129,7 @@ func TestLazyOr(t *testing.T) { AddChild(ast.Node{Constant: true})). AddChild(ast.Node{Function: ast.FUNC_UNKNOWN}) - evaluation, ok := EvaluateAst(context.TODO(), environment, root) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root) switch value { case true: @@ -169,7 +171,7 @@ func TestLazyBooleanNulls(t *testing.T) { } } - evaluation, _ := EvaluateAst(context.TODO(), environment, root) + evaluation, _ := EvaluateAst(context.TODO(), nil, environment, root) switch { case tt.res == nil: @@ -202,7 +204,7 @@ func TestAggregatesOrderedLast(t *testing.T) { AddChild(ast.Node{Function: TEST_FUNC_COSTLY}). AddChild(ast.Node{Constant: true}) - evaluation, ok := EvaluateAst(context.TODO(), environment, root) + evaluation, ok := EvaluateAst(context.TODO(), nil, environment, root) assert.True(t, ok) assert.Equal(t, ast.NodeEvaluation{Index: 0, Skipped: true, ReturnValue: nil}, evaluation.Children[0]) @@ -210,3 +212,145 @@ func TestAggregatesOrderedLast(t *testing.T) { assert.Equal(t, true, evaluation.Children[1].ReturnValue) assert.Equal(t, true, evaluation.ReturnValue) } + +func TestAstNodeHash(t *testing.T) { + tts := []struct { + lhs ast.Node + rhs ast.Node + equal bool + }{ + {ast.Node{Constant: true}, ast.Node{Constant: true}, true}, + {ast.Node{Constant: true}, ast.Node{Constant: false}, false}, + { + ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}}, + ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}}, + true, + }, + { + ast.Node{Children: []ast.Node{{Constant: true}, {Constant: false}}}, + ast.Node{Children: []ast.Node{{Constant: true}, {Constant: true}}}, + false, + }, + { + ast.Node{ + NamedChildren: map[string]ast.Node{ + "x": {Constant: true}, + }, + }, + ast.Node{ + NamedChildren: map[string]ast.Node{ + "x": {Constant: true}, + }, + }, + true, + }, + { + ast.Node{ + NamedChildren: map[string]ast.Node{ + "x": {Constant: true}, + }, + }, + ast.Node{ + NamedChildren: map[string]ast.Node{ + "x": {Constant: false}, + }, + }, + false, + }, + } + + for _, tt := range tts { + assert.Equal(t, tt.equal, tt.lhs.Hash() == tt.rhs.Hash()) + } +} + +type countingNode struct { + hits atomic.Int64 +} + +func (n *countingNode) Evaluate(ctx context.Context, arguments ast.Arguments) (any, []error) { + n.hits.Add(1) + + return evaluate.MakeEvaluateResult(true) +} + +func TestCachedEvaluation(t *testing.T) { + ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{ + Cost: 1000, + } + + defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY) + + node := &countingNode{} + + environment := NewAstEvaluationEnvironment() + environment.AddEvaluator(TEST_FUNC_COSTLY, node) + + var wg sync.WaitGroup + + cache := NewEvaluationCache() + + for range 10 { + wg.Add(1) + + go func() { + defer wg.Done() + + root := ast.Node{Function: ast.FUNC_AND}. + AddChild(ast.Node{Function: TEST_FUNC_COSTLY}). + AddChild(ast.Node{Function: TEST_FUNC_COSTLY}). + AddChild(ast.Node{Function: TEST_FUNC_COSTLY}). + AddChild(ast.Node{ + Function: ast.FUNC_AND, + Children: []ast.Node{ + {Function: TEST_FUNC_COSTLY}, + {Function: TEST_FUNC_COSTLY}, + }, + }). + AddChild(ast.Node{Constant: true}) + + _, _ = EvaluateAst(context.TODO(), cache, environment, root) + }() + } + + wg.Wait() + + assert.Equal(t, int64(1), node.hits.Load()) +} + +func TestCachedEvaluationWithDifferentParams(t *testing.T) { + ast.FuncAttributesMap[TEST_FUNC_COSTLY] = ast.FuncAttributes{ + Cost: 1000, + } + + defer delete(ast.FuncAttributesMap, TEST_FUNC_COSTLY) + + node := &countingNode{} + + environment := NewAstEvaluationEnvironment() + environment.AddEvaluator(TEST_FUNC_COSTLY, node) + + var wg sync.WaitGroup + + cache := NewEvaluationCache() + + for range 10 { + wg.Add(1) + + go func() { + defer wg.Done() + + root := ast.Node{Function: ast.FUNC_AND}. + AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 1}}}). + AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 2}}}). + AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 1}}}). + AddChild(ast.Node{Function: TEST_FUNC_COSTLY, Children: []ast.Node{{Constant: 2}}}) + + _, _ = EvaluateAst(context.TODO(), cache, environment, root) + }() + } + + wg.Wait() + + assert.Equal(t, int64(2), node.hits.Load()) +} diff --git a/usecases/evaluate_scenario/evaluate_scenario.go b/usecases/evaluate_scenario/evaluate_scenario.go index 4a671040d..c46f66b5c 100644 --- a/usecases/evaluate_scenario/evaluate_scenario.go +++ b/usecases/evaluate_scenario/evaluate_scenario.go @@ -68,10 +68,13 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara ingestedDataReadRepository: repositories.IngestedDataReadRepository, } + cache := ast_eval.NewEvaluationCache() + // Evaluate the trigger errEval := evalScenarioTrigger( ctx, + cache, repositories, *iteration.TriggerConditionAstExpression, dataAccessor.organizationId, @@ -111,6 +114,7 @@ func processScenarioIteration(ctx context.Context, params ScenarioEvaluationPara // Evaluate all rules score, ruleExecutions, errEval := evalAllScenarioRules( ctx, + cache, repositories, iteration.Rules, dataAccessor, @@ -287,6 +291,7 @@ func EvalScenario( func evalScenarioRule( ctx context.Context, + cache *ast_eval.EvaluationCache, repositories ScenarioEvaluationRepositories, rule models.Rule, dataAccessor DataAccessor, @@ -321,6 +326,7 @@ func evalScenarioRule( // Evaluate single rule ruleEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( ctx, + cache, *rule.FormulaAstExpression, dataAccessor.organizationId, dataAccessor.ClientObject, @@ -375,6 +381,7 @@ func evalScenarioRule( func evalScenarioTrigger( ctx context.Context, + cache *ast_eval.EvaluationCache, repositories ScenarioEvaluationRepositories, triggerAstExpression ast.Node, organizationId string, @@ -387,6 +394,7 @@ func evalScenarioTrigger( triggerEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( ctx, + cache, triggerAstExpression, organizationId, payload, @@ -422,6 +430,7 @@ func evalScenarioTrigger( func evalAllScenarioRules( ctx context.Context, + cache *ast_eval.EvaluationCache, repositories ScenarioEvaluationRepositories, rules []models.Rule, dataAccessor DataAccessor, @@ -448,7 +457,8 @@ func evalAllScenarioRules( } // Eval each rule - scoreModifier, ruleExecution, err := evalScenarioRule(ctx, repositories, rule, dataAccessor, dataModel, snoozes) + scoreModifier, ruleExecution, err := evalScenarioRule(ctx, cache, + repositories, rule, dataAccessor, dataModel, snoozes) if err != nil { return err // First err will cancel the ctx } @@ -521,6 +531,7 @@ func EvalCaseName( caseNameEvaluation, err := repositories.EvaluateAstExpression.EvaluateAstExpression( ctx, + nil, *scenario.DecisionToCaseNameTemplate, params.Scenario.OrganizationId, params.ClientObject, diff --git a/usecases/scenarios/scenario_validation.go b/usecases/scenarios/scenario_validation.go index 5bf0fc937..b78e7e43f 100644 --- a/usecases/scenarios/scenario_validation.go +++ b/usecases/scenarios/scenario_validation.go @@ -89,7 +89,7 @@ func (self *ValidateScenarioIterationImpl) Validate(ctx context.Context, Code: models.TriggerConditionRequired, }) } else { - result.Trigger.TriggerEvaluation, _ = ast_eval.EvaluateAst(ctx, dryRunEnvironment, *trigger) + result.Trigger.TriggerEvaluation, _ = ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *trigger) if _, ok := result.Trigger.TriggerEvaluation.ReturnValue.(bool); !ok { result.Trigger.Errors = append(result.Trigger.Errors, models.ScenarioValidationError{ Error: errors.Wrap(models.BadParameterError, @@ -110,7 +110,7 @@ func (self *ValidateScenarioIterationImpl) Validate(ctx context.Context, }) result.Rules.Rules[rule.Id] = ruleValidation } else { - ruleValidation.RuleEvaluation, _ = ast_eval.EvaluateAst(ctx, dryRunEnvironment, *formula) + ruleValidation.RuleEvaluation, _ = ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *formula) if _, ok := ruleValidation.RuleEvaluation.ReturnValue.(bool); !ok { ruleValidation.Errors = append(ruleValidation.Errors, models.ScenarioValidationError{ Error: errors.Wrap(models.BadParameterError, @@ -149,7 +149,7 @@ func (self *ValidateScenarioAstImpl) Validate(ctx context.Context, "unknown specified type '%s'", expectedReturnTypeStr) } - astEvaluation, _ := ast_eval.EvaluateAst(ctx, dryRunEnvironment, *astNode) + astEvaluation, _ := ast_eval.EvaluateAst(ctx, nil, dryRunEnvironment, *astNode) astEvaluationReturnType := reflect.TypeOf(astEvaluation.ReturnValue) if len(astEvaluation.FlattenErrors()) == 0 && astEvaluationReturnType != expectedReturnType {