Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache node evaluation to reuse results within a run #811

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/term v0.5.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6B
github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
Expand Down
9 changes: 8 additions & 1 deletion models/ast/ast_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"fmt"

"github.com/cockroachdb/errors"
"github.com/mitchellh/hashstructure/v2"
)

type Node struct {
Index int
Index int `hash:"ignore"`

// A node is a constant xOR a function
Function Function
Expand Down Expand Up @@ -52,6 +53,12 @@ func (node Node) ReadConstantNamedChildString(name string) (string, error) {
return value, nil
}

func (node Node) Hash() uint64 {
hash, _ := hashstructure.Hash(node, hashstructure.FormatV2, nil)

return hash
}

// Cost calculates the weights of an AST subtree to reorder, when the parent is commutative,
// nodes to prioritize faster ones.
func (node Node) Cost() int {
Expand Down
120 changes: 116 additions & 4 deletions models/ast/ast_node_evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ast

import (
"fmt"
"time"

"github.com/cockroachdb/errors"
)
Expand All @@ -10,10 +11,8 @@ type NodeEvaluation struct {
// Index of the initial node winhin its level of the AST tree, used to
// reorder the results as they were. This should become obsolete when each
// node has a unique ID.
Index int
// Skipped indicates whether this node was evaluated at all or not. A `true` values means the
// engine determined the result of this node would not impact the overall decision's outcome.
Skipped bool
Index int
EvaluationPlan NodeEvaluationPlan

Function Function
ReturnValue any
Expand All @@ -23,6 +22,16 @@ type NodeEvaluation struct {
NamedChildren map[string]NodeEvaluation
}

type NodeEvaluationPlan struct {
// Skipped indicates whether this node was evaluated at all or not. A `true` values means the
// engine determined the result of this node would not impact the overall decision's outcome.
Skipped bool
// Cached indicates whether this particular evaluation was pulled from the cached
// value of a previously=executed node.
Cached bool
Took time.Duration
}

func (root NodeEvaluation) FlattenErrors() []error {
errs := make([]error, 0)

Expand Down Expand Up @@ -63,3 +72,106 @@ func (root NodeEvaluation) GetStringReturnValue() (string, error) {

return "", errors.New(fmt.Sprintf("ast expression expected to return a string, got '%T' instead", root.ReturnValue))
}

func (root *NodeEvaluation) SetCached() {
root.EvaluationPlan.Cached = true

for idx := range root.Children {
root.Children[idx].SetCached()
}
for key := range root.NamedChildren {
child := root.NamedChildren[key]
child.SetCached()

root.NamedChildren[key] = child
}
}

type EvaluationStats struct {
Function Function
Took time.Duration
Nodes int
SkippedCount int
CachedCount int
Skipped bool
Cached bool
Children []EvaluationStats
}

func BuildEvaluationStats(root NodeEvaluation, parentCached bool) EvaluationStats {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing how the main source of execution delay by is linked to computing aggregates today, I think this may be at the same time a bit overkill for generic nodes and lack of detailed breakdown for aggregates nodes (and perhaps tomorrow, for sanction check/api call nodes)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we can keep it though, it's fine)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could filter some low-cost nodes, if necessaery. The issue being that costly functions could be nested under those "simple" nodes, which is why I walk the whole tree to gather statistics.

stats := EvaluationStats{
Function: root.Function,
Took: root.EvaluationPlan.Took,
Nodes: len(root.Children) + len(root.NamedChildren),
Children: make([]EvaluationStats, len(root.Children),
len(root.Children)+len(root.NamedChildren)),
}

if root.EvaluationPlan.Skipped {
stats.Skipped = true
stats.SkippedCount = 1
}
if parentCached || root.EvaluationPlan.Cached {
stats.Cached = true
stats.CachedCount = 1
}

for idx, child := range root.Children {
stats.Children[idx] = BuildEvaluationStats(child, stats.Cached)

stats.Nodes += stats.Children[idx].Nodes
stats.SkippedCount += stats.Children[idx].SkippedCount
stats.CachedCount += stats.Children[idx].CachedCount
}
for _, child := range root.NamedChildren {
namedChildrenStats := BuildEvaluationStats(child, stats.Cached)

stats.Nodes += namedChildrenStats.Nodes
stats.SkippedCount += namedChildrenStats.SkippedCount
stats.CachedCount += namedChildrenStats.CachedCount

stats.Children = append(stats.Children, namedChildrenStats)
}

return stats
}

type FunctionStats struct {
Count int `json:"count"`
Cached int `json:"cached"`
Skipped int `json:"skipped"`
Took time.Duration `json:"took"`
}

func (stats EvaluationStats) FunctionStats() map[string]FunctionStats {
acc := make(map[string]FunctionStats)

buildFunctionStats(acc, stats)

return acc
}

func buildFunctionStats(acc map[string]FunctionStats, stats EvaluationStats) {
f := stats.Function.DebugString()

if _, ok := acc[f]; !ok {
acc[f] = FunctionStats{}
}

stat := acc[f]
stat.Count += 1
stat.Took += stats.Took

if stats.Skipped {
stat.Skipped += 1
}
if stats.Cached {
stat.Cached += 1
}

acc[f] = stat

for _, child := range stats.Children {
buildFunctionStats(acc, child)
}
}
2 changes: 1 addition & 1 deletion models/ast/node_evaluation_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func AdaptNodeEvaluationDto(evaluation NodeEvaluation) NodeEvaluationDto {
Errors: pure_utils.Map(evaluation.Errors, AdaptEvaluationErrorDto),
Children: pure_utils.Map(evaluation.Children, AdaptNodeEvaluationDto),
NamedChildren: pure_utils.MapValues(evaluation.NamedChildren, AdaptNodeEvaluationDto),
Skipped: evaluation.Skipped,
Skipped: evaluation.EvaluationPlan.Skipped,
}
}

Expand Down
144 changes: 107 additions & 37 deletions usecases/ast_eval/evaluate_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,32 @@ package ast_eval

import (
"context"
"fmt"
"sync"
"time"

"github.com/checkmarble/marble-backend/models/ast"
"github.com/checkmarble/marble-backend/pure_utils"
"golang.org/x/sync/singleflight"
)

func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node ast.Node) (ast.NodeEvaluation, bool) {
type EvaluationCache struct {
Cache sync.Map
Executor *singleflight.Group
}

func NewEvaluationCache() *EvaluationCache {
return &EvaluationCache{
Cache: sync.Map{},
Executor: new(singleflight.Group),
}
}

func EvaluateAst(ctx context.Context, cache *EvaluationCache,
environment AstEvaluationEnvironment, node ast.Node,
) (ast.NodeEvaluation, bool) {
start := time.Now()

// Early exit for constant, because it should have no children.
if node.Function == ast.FUNC_CONSTANT {
return ast.NodeEvaluation{
Expand All @@ -18,13 +38,34 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
}, true
}

type nodeEvaluationResponse struct {
eval ast.NodeEvaluation
ok bool
}

hash := node.Hash()

if cache != nil {
if cached, ok := cache.Cache.Load(hash); ok {
response := cached.(nodeEvaluationResponse)
response.eval.Index = node.Index

response.eval.EvaluationPlan = ast.NodeEvaluationPlan{
Took: 0,
Cached: true,
}

return response.eval, response.ok
}
}

childEvaluationFail := false

// Only interested in lazy callback which will have default value if an error is returned
attrs, _ := node.Function.Attributes()

evalChild := func(child ast.Node) (childEval ast.NodeEvaluation, evalNext bool) {
childEval, ok := EvaluateAst(ctx, environment, child)
childEval, ok := EvaluateAst(ctx, cache, environment, child)

if !ok {
childEvaluationFail = true
Expand All @@ -38,53 +79,82 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node
return
}

weightedNodes := NewWeightedNodes(environment, node, node.Children)
cachedExecutor := new(singleflight.Group)
notCached := false

// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
if cache != nil {
cachedExecutor = cache.Executor
}

if childEvaluationFail {
// an error occurred in at least one of the children. Stop the evaluation.
eval, _, _ := cachedExecutor.Do(fmt.Sprintf("%d", hash), func() (any, error) {
notCached = true
weightedNodes := NewWeightedNodes(environment, node, node.Children)

// the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened.
if node.Function == ast.FUNC_UNDEFINED {
evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction)
// eval each child
evaluation := ast.NodeEvaluation{
Index: node.Index,
Function: node.Function,
Children: weightedNodes.Reorder(pure_utils.MapWhile(weightedNodes.Sorted(), evalChild)),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
}

return evaluation, false
}
if childEvaluationFail {
// an error occurred in at least one of the children. Stop the evaluation.

getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue }
arguments := ast.Arguments{
Args: pure_utils.Map(evaluation.Children, getReturnValue),
NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue),
}
// the frontend expects an ErrUndefinedFunction error to be present even when no evaluation happened.
if node.Function == ast.FUNC_UNDEFINED {
evaluation.Errors = append(evaluation.Errors, ast.ErrUndefinedFunction)
}

evaluator, err := environment.GetEvaluator(node.Function)
if err != nil {
evaluation.Errors = append(evaluation.Errors, err)
return evaluation, false
}
return nodeEvaluationResponse{evaluation, false}, nil
}

evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments)
getReturnValue := func(e ast.NodeEvaluation) any { return e.ReturnValue }
arguments := ast.Arguments{
Args: pure_utils.Map(evaluation.Children, getReturnValue),
NamedArgs: pure_utils.MapValues(evaluation.NamedChildren, getReturnValue),
}

if evaluation.Errors == nil {
// Assign an empty array to indicate that the evaluation occured.
// The evaluator is not supposed to return a nil array of errors, but let's be nice.
evaluation.Errors = []error{}
}
evaluator, err := environment.GetEvaluator(node.Function)
if err != nil {
evaluation.Errors = append(evaluation.Errors, err)
return nodeEvaluationResponse{evaluation, false}, nil
}

evaluation.ReturnValue, evaluation.Errors = evaluator.Evaluate(ctx, arguments)

ok := len(evaluation.Errors) == 0
if evaluation.Errors == nil {
// Assign an empty array to indicate that the evaluation occured.
// The evaluator is not supposed to return a nil array of errors, but let's be nice.
evaluation.Errors = []error{}
}

ok := len(evaluation.Errors) == 0

if !ok {
// The evaluator is supposed to return nil ReturnValue when an error is present.
evaluation.ReturnValue = nil
}

evaluationResponse := nodeEvaluationResponse{evaluation, ok}

if cache != nil {
cache.Cache.Store(hash, evaluationResponse)
}

return evaluationResponse, nil
})

evaluation := eval.(nodeEvaluationResponse)
evaluation.eval.Index = node.Index

evaluation.eval.EvaluationPlan = ast.NodeEvaluationPlan{
Took: time.Since(start),
}

if !ok {
// The evaluator is supposed to return nil ReturnValue when an error is present.
evaluation.ReturnValue = nil
if !notCached {
evaluation.eval.SetCached()
}

return evaluation, ok
return evaluation.eval, evaluation.ok
}
Loading