From bd6cd2555d1bb0e57a34ce74b0add36cb7fb6c76 Mon Sep 17 00:00:00 2001 From: Ruin09 Date: Wed, 16 Aug 2023 00:56:13 +0900 Subject: [PATCH] fix: Fixed memoization is unchecked after mutex synchronization. Fixes #11219 (#11578) Signed-off-by: shmruin --- workflow/controller/operator.go | 127 +++++++++++------- .../controller/operator_concurrency_test.go | 91 +++++++++++++ 2 files changed, 172 insertions(+), 46 deletions(-) diff --git a/workflow/controller/operator.go b/workflow/controller/operator.go index 8869532f8895..1118333b4fe9 100644 --- a/workflow/controller/operator.go +++ b/workflow/controller/operator.go @@ -1795,52 +1795,6 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string, return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err } - // If memoization is on, check if node output exists in cache - if node == nil && processedTmpl.Memoize != nil { - memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name) - if memoizationCache == nil { - err := fmt.Errorf("cache could not be found or created") - woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err) - return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err - } - - entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key) - if err != nil { - return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err - } - - hit := entry.Hit() - var outputs *wfv1.Outputs - if processedTmpl.Memoize.MaxAge != "" { - maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge) - if err != nil { - err := fmt.Errorf("invalid maxAge: %s", err) - return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err - } - maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge) - if !ok { - // The outputs are expired, so this cache entry is not hit - hit = false - } - outputs = maxAgeOutputs - } else { - outputs = entry.GetOutputs() - } - - memoizationStatus := &wfv1.MemoizationStatus{ - Hit: hit, - Key: processedTmpl.Memoize.Key, - CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name, - } - if hit { - node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus) - } else { - node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus) - } - woc.wf.Status.Nodes[node.ID] = *node - woc.updated = true - } - if node != nil { if node.Fulfilled() { woc.controller.syncManager.Release(woc.wf, node.ID, processedTmpl.Synchronization) @@ -1888,6 +1842,8 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string, return node, err } + unlockedNode := false + if processedTmpl.Synchronization != nil { lockAcquired, wfUpdated, msg, err := woc.controller.syncManager.TryAcquire(woc.wf, woc.wf.NodeID(nodeName), processedTmpl.Synchronization) if err != nil { @@ -1909,10 +1865,71 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string, if node != nil { node = woc.markNodeWaitingForLock(node.Name, "") } + // Set this value to check that this node is using synchronization, and has acquired the lock + unlockedNode = true } woc.updated = woc.updated || wfUpdated } + + // Check memoization cache if the node is about to be created, or was created in the past but is only now allowed to run due to acquiring a lock + if processedTmpl.Memoize != nil { + if node == nil || unlockedNode { + memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name) + if memoizationCache == nil { + err := fmt.Errorf("cache could not be found or created") + woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err) + return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err + } + + entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key) + if err != nil { + return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err + } + + hit := entry.Hit() + var outputs *wfv1.Outputs + if processedTmpl.Memoize.MaxAge != "" { + maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge) + if err != nil { + err := fmt.Errorf("invalid maxAge: %s", err) + return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err + } + maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge) + if !ok { + // The outputs are expired, so this cache entry is not hit + hit = false + } + outputs = maxAgeOutputs + } else { + outputs = entry.GetOutputs() + } + + memoizationStatus := &wfv1.MemoizationStatus{ + Hit: hit, + Key: processedTmpl.Memoize.Key, + CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name, + } + if hit { + if node == nil { + node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus) + } else { + woc.log.Infof("Node %s is using mutex with memoize. Cache is hit.", nodeName) + woc.updateAsCacheHitNode(node, outputs, memoizationStatus) + } + } else { + if node == nil { + node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus) + } else { + woc.log.Infof("Node %s is using mutex with memoize. Cache is NOT hit", nodeName) + woc.updateAsCacheNode(node, memoizationStatus) + } + } + woc.wf.Status.Nodes[node.ID] = *node + woc.updated = true + } + } + // If the user has specified retries, node becomes a special retry node. // This node acts as a parent of all retries that will be done for // the container. The status of this node should be "Success" if any @@ -2355,6 +2372,24 @@ func (woc *wfOperationCtx) initializeNode(nodeName string, nodeType wfv1.NodeTyp return &node } +// Update a node status with cache status +func (woc *wfOperationCtx) updateAsCacheNode(node *wfv1.NodeStatus, memStat *wfv1.MemoizationStatus) { + node.MemoizationStatus = memStat + + woc.wf.Status.Nodes[node.ID] = *node + woc.updated = true +} + +// Update a node status that has been cached and marked as finished +func (woc *wfOperationCtx) updateAsCacheHitNode(node *wfv1.NodeStatus, outputs *wfv1.Outputs, memStat *wfv1.MemoizationStatus, message ...string) { + node.Phase = wfv1.NodeSucceeded + node.Outputs = outputs + node.FinishedAt = metav1.Time{Time: time.Now().UTC()} + + woc.updateAsCacheNode(node, memStat) + woc.log.Infof("%s node %v updated %s%s", node.Type, node.ID, node.Phase, message) +} + // markNodePhase marks a node with the given phase, creating the node if necessary and handles timestamps func (woc *wfOperationCtx) markNodePhase(nodeName string, phase wfv1.NodePhase, message ...string) *wfv1.NodeStatus { node := woc.wf.GetNodeByName(nodeName) diff --git a/workflow/controller/operator_concurrency_test.go b/workflow/controller/operator_concurrency_test.go index 418d88d4cf6b..a371427dc384 100644 --- a/workflow/controller/operator_concurrency_test.go +++ b/workflow/controller/operator_concurrency_test.go @@ -3,12 +3,14 @@ package controller import ( "context" "encoding/json" + "fmt" "os" "strconv" "strings" "testing" "github.com/stretchr/testify/assert" + apiv1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -957,3 +959,92 @@ func TestSynchronizationForPendingShuttingdownWfs(t *testing.T) { assert.Equal(t, wfv1.WorkflowRunning, wocTwo.execWf.Status.Phase) }) } + +func TestWorkflowMemoizationWithMutex(t *testing.T) { + wf := wfv1.MustUnmarshalWorkflow(`apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: example-steps-simple + namespace: default +spec: + entrypoint: main + templates: + - name: main + steps: + - - name: job-1 + template: sleep + arguments: + parameters: + - name: sleep_duration + value: 10 + - name: job-2 + template: sleep + arguments: + parameters: + - name: sleep_duration + value: 5 + + - name: sleep + synchronization: + mutex: + name: mutex-example-steps-simple + inputs: + parameters: + - name: sleep_duration + script: + image: alpine:latest + command: [/bin/sh] + source: | + echo "Sleeping for {{ inputs.parameters.sleep_duration }}" + sleep {{ inputs.parameters.sleep_duration }} + memoize: + key: "memo-key-1" + cache: + configMap: + name: cache-example-steps-simple + `) + cancel, controller := newController(wf) + defer cancel() + + ctx := context.Background() + + woc := newWorkflowOperationCtx(wf, controller) + woc.operate(ctx) + + holdingJobs := make(map[string]string) + for _, node := range woc.wf.Status.Nodes { + holdingJobs[node.ID] = node.DisplayName + } + + // Check initial status: job-1 acquired the lock + job1AcquiredLock := false + if woc.wf.Status.Synchronization != nil && woc.wf.Status.Synchronization.Mutex != nil { + for _, holding := range woc.wf.Status.Synchronization.Mutex.Holding { + if holdingJobs[holding.Holder] == "job-1" { + fmt.Println("acquired: ", holding.Holder) + job1AcquiredLock = true + } + } + } + assert.True(t, job1AcquiredLock) + + // Make job-1's pod succeed + makePodsPhase(ctx, woc, apiv1.PodSucceeded, func(pod *apiv1.Pod) { + if pod.ObjectMeta.Name == "job-1" { + pod.Status.Phase = apiv1.PodSucceeded + } + }) + woc.operate(ctx) + + // Check final status: both job-1 and job-2 succeeded, job-2 simply hit the cache + for _, node := range woc.wf.Status.Nodes { + switch node.DisplayName { + case "job-1": + assert.Equal(t, wfv1.NodeSucceeded, node.Phase) + assert.False(t, node.MemoizationStatus.Hit) + case "job-2": + assert.Equal(t, wfv1.NodeSucceeded, node.Phase) + assert.True(t, node.MemoizationStatus.Hit) + } + } +}