diff --git a/service/worker/migration/activities.go b/service/worker/migration/activities.go index c2c6db10c19..5a922cfb993 100644 --- a/service/worker/migration/activities.go +++ b/service/worker/migration/activities.go @@ -28,6 +28,7 @@ import ( "context" "fmt" "math" + "slices" "sort" "time" @@ -38,12 +39,10 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/activity" "go.temporal.io/sdk/temporal" - "go.temporal.io/server/api/adminservice/v1" enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" replicationspb "go.temporal.io/server/api/replication/v1" serverClient "go.temporal.io/server/client" - "go.temporal.io/server/client/admin" "go.temporal.io/server/common/definition" "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" @@ -52,26 +51,11 @@ import ( "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/quotas" - "go.temporal.io/server/common/searchattribute" - "google.golang.org/protobuf/types/known/timestamppb" + "go.temporal.io/server/common/rpc/interceptor" + "google.golang.org/grpc/metadata" ) type ( - activities struct { - historyShardCount int32 - executionManager persistence.ExecutionManager - taskManager persistence.TaskManager - namespaceRegistry namespace.Registry - historyClient historyservice.HistoryServiceClient - frontendClient workflowservice.WorkflowServiceClient - clientFactory serverClient.Factory - clientBean serverClient.Bean - logger log.Logger - metricsHandler metrics.Handler - forceReplicationMetricsHandler metrics.Handler - namespaceReplicationQueue persistence.NamespaceReplicationQueue - } - SkippedWorkflowExecution struct { WorkflowExecution *commonpb.WorkflowExecution Reason string @@ -88,6 +72,70 @@ type ( status verifyStatus reason string } + + listWorkflowsResponse struct { + Executions []*commonpb.WorkflowExecution + NextPageToken []byte + Error error + + // These can be used to help report progress of the force-replication scan + LastCloseTime time.Time + LastStartTime time.Time + } + + countWorkflowResponse struct { + WorkflowCount int64 + } + + generateReplicationTasksRequest struct { + NamespaceID string + Executions []*commonpb.WorkflowExecution + RPS float64 + GetParentInfoRPS float64 + } + + verifyReplicationTasksRequest struct { + Namespace string + NamespaceID string + TargetClusterEndpoint string + TargetClusterName string + VerifyInterval time.Duration `validate:"gte=0"` + Executions []*commonpb.WorkflowExecution + } + + verifyReplicationTasksResponse struct { + VerifiedWorkflowCount int64 + } + + metadataRequest struct { + Namespace string + } + + metadataResponse struct { + ShardCount int32 + NamespaceID string + } + + waitCatchupRequest struct { + ActiveCluster string + RemoteCluster string + Namespace string + } + + activities struct { + historyShardCount int32 + executionManager persistence.ExecutionManager + taskManager persistence.TaskManager + namespaceRegistry namespace.Registry + historyClient historyservice.HistoryServiceClient + frontendClient workflowservice.WorkflowServiceClient + clientFactory serverClient.Factory + clientBean serverClient.Bean + logger log.Logger + metricsHandler metrics.Handler + forceReplicationMetricsHandler metrics.Handler + namespaceReplicationQueue persistence.NamespaceReplicationQueue + } ) const ( @@ -106,7 +154,7 @@ func (r verifyResult) isVerified() bool { // TODO: CallerTypePreemptablee should be set in activity background context for all migration activities. // However, activity background context is per-worker, which means once set, all activities processed by the -// worker will use CallerTypePreemptable, including those not related to migration. This is not ideal. +// worker will use CallerType Preemptable, including those not related to migration. This is not ideal. // Using a different task queue and a dedicated worker for migration can solve the issue but requires // changing all existing tooling around namespace migration to start workflows & activities on the new task queue. // Another approach is to use separate workers for workflow tasks and activities and keep existing tooling unchanged. @@ -179,7 +227,6 @@ func (a *activities) checkReplicationOnce(ctx context.Context, waitRequest waitR for _, shard := range resp.Shards { clusterInfo, hasClusterInfo := shard.RemoteClusters[waitRequest.RemoteCluster] - actualLag := shard.MaxReplicationTaskVisibilityTime.AsTime().Sub(clusterInfo.AckedTaskVisibilityTime.AsTime()) if hasClusterInfo { // WE are all caught up if shard.MaxReplicationTaskId == clusterInfo.AckedTaskId { @@ -190,7 +237,7 @@ func (a *activities) checkReplicationOnce(ctx context.Context, waitRequest waitR // Caught up to the last checked IDs, and within allowed lagging range if clusterInfo.AckedTaskId >= waitRequest.WaitForTaskIds[shard.ShardId] && (shard.MaxReplicationTaskId-clusterInfo.AckedTaskId <= waitRequest.AllowedLaggingTasks || - actualLag <= waitRequest.AllowedLagging) { + shard.MaxReplicationTaskVisibilityTime.AsTime().Sub(clusterInfo.AckedTaskVisibilityTime.AsTime()) <= waitRequest.AllowedLagging) { readyShardCount++ continue } @@ -210,10 +257,10 @@ func (a *activities) checkReplicationOnce(ctx context.Context, waitRequest waitR tag.NewInt64("AckedTaskId", clusterInfo.AckedTaskId), tag.NewInt64("WaitForTaskId", waitRequest.WaitForTaskIds[shard.ShardId]), tag.NewDurationTag("AllowedLagging", waitRequest.AllowedLagging), - tag.NewDurationTag("ActualLagging", actualLag), + tag.NewDurationTag("ActualLagging", shard.MaxReplicationTaskVisibilityTime.AsTime().Sub(clusterInfo.AckedTaskVisibilityTime.AsTime())), tag.NewInt64("MaxReplicationTaskId", shard.MaxReplicationTaskId), - tag.NewTimePtrTag("MaxReplicationTaskVisibilityTime", shard.MaxReplicationTaskVisibilityTime), - tag.NewTimePtrTag("AckedTaskVisibilityTime", clusterInfo.AckedTaskVisibilityTime), + tag.NewTimeTag("MaxReplicationTaskVisibilityTime", shard.MaxReplicationTaskVisibilityTime.AsTime()), + tag.NewTimeTag("AckedTaskVisibilityTime", clusterInfo.AckedTaskVisibilityTime.AsTime()), tag.NewInt64("AllowedLaggingTasks", waitRequest.AllowedLaggingTasks), tag.NewInt64("ActualLaggingTasks", shard.MaxReplicationTaskId-clusterInfo.AckedTaskId), ) @@ -221,7 +268,7 @@ func (a *activities) checkReplicationOnce(ctx context.Context, waitRequest waitR } // emit metrics about how many shards are ready - metrics.CatchUpReadyShardCountGauge.With(a.metricsHandler).Record( + a.metricsHandler.Gauge(metrics.CatchUpReadyShardCountGauge.Name()).Record( float64(readyShardCount), metrics.OperationTag(metrics.MigrationWorkflowScope), metrics.TargetClusterTag(waitRequest.RemoteCluster)) @@ -299,7 +346,7 @@ func (a *activities) checkHandoverOnce(ctx context.Context, waitRequest waitHand } // emit metrics about how many shards are ready - metrics.HandoverReadyShardCountGauge.With(a.metricsHandler).Record( + a.metricsHandler.Gauge(metrics.HandoverReadyShardCountGauge.Name()).Record( float64(readyShardCount), metrics.OperationTag(metrics.MigrationWorkflowScope), metrics.TargetClusterTag(waitRequest.RemoteCluster), @@ -400,63 +447,144 @@ func (a *activities) UpdateActiveCluster(ctx context.Context, req updateActiveCl func (a *activities) ListWorkflows(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(request.Namespace, headers.CallerTypePreemptable, "")) - // modify query to include all namespace divisions - request.Query = searchattribute.QueryWithAnyNamespaceDivision(request.Query) - resp, err := a.frontendClient.ListWorkflowExecutions(ctx, request) if err != nil { return nil, err } - var lastCloseTime, lastStartTime *timestamppb.Timestamp + var lastCloseTime, lastStartTime time.Time executions := make([]*commonpb.WorkflowExecution, len(resp.Executions)) for i, e := range resp.Executions { executions[i] = e.Execution if e.CloseTime != nil { - lastCloseTime = e.CloseTime + lastCloseTime = e.CloseTime.AsTime() } if e.StartTime != nil { - lastStartTime = e.StartTime + lastStartTime = e.StartTime.AsTime() } } - return &listWorkflowsResponse{Executions: executions, NextPageToken: resp.NextPageToken, LastCloseTime: lastCloseTime.AsTime(), LastStartTime: lastStartTime.AsTime()}, nil + return &listWorkflowsResponse{Executions: executions, NextPageToken: resp.NextPageToken, LastCloseTime: lastCloseTime, LastStartTime: lastStartTime}, nil +} + +func (a *activities) CountWorkflow(ctx context.Context, request *workflowservice.CountWorkflowExecutionsRequest) (*countWorkflowResponse, error) { + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(request.Namespace, headers.CallerTypePreemptable, "")) + + resp, err := a.frontendClient.CountWorkflowExecutions(ctx, request) + if err != nil { + return nil, err + } + return &countWorkflowResponse{ + WorkflowCount: resp.Count, + }, nil } func (a *activities) GenerateReplicationTasks(ctx context.Context, request *generateReplicationTasksRequest) error { - ctx = a.setCallerInfoForGenReplicationTask(ctx, namespace.ID(request.NamespaceID)) + ctx = a.setCallerInfoForServerAPI(ctx, namespace.ID(request.NamespaceID)) rateLimiter := quotas.NewRateLimiter(request.RPS, int(math.Ceil(request.RPS))) start := time.Now() defer func() { - metrics.GenerateReplicationTasksLatency.With(a.forceReplicationMetricsHandler).Record(time.Since(start)) + a.forceReplicationMetricsHandler.Timer(metrics.GenerateReplicationTasksLatency.Name()).Record(time.Since(start)) }() startIndex := 0 if activity.HasHeartbeatDetails(ctx) { - var finishedIndex int - if err := activity.GetHeartbeatDetails(ctx, &finishedIndex); err == nil { - startIndex = finishedIndex + 1 // start from next one + if err := activity.GetHeartbeatDetails(ctx, &startIndex); err == nil { + startIndex = startIndex + 1 // start from next one } } for i := startIndex; i < len(request.Executions); i++ { - we := request.Executions[i] - if err := a.generateWorkflowReplicationTask(ctx, rateLimiter, definition.NewWorkflowKey(request.NamespaceID, we.WorkflowId, we.RunId)); err != nil { - if !isNotFoundServiceError(err) { - a.logger.Error("force-replication failed to generate replication task", tag.WorkflowNamespaceID(request.NamespaceID), tag.WorkflowID(we.WorkflowId), tag.WorkflowRunID(we.RunId), tag.Error(err)) - return err + var executionCandidates []definition.WorkflowKey + executionCandidates = []definition.WorkflowKey{definition.NewWorkflowKey(request.NamespaceID, request.Executions[i].GetWorkflowId(), request.Executions[i].GetRunId())} + + for _, we := range executionCandidates { + if err := a.generateWorkflowReplicationTask(ctx, rateLimiter, we); err != nil { + if !isNotFoundServiceError(err) { + a.logger.Error("force-replication failed to generate replication task", + tag.WorkflowNamespaceID(we.GetNamespaceID()), + tag.WorkflowID(we.GetWorkflowID()), + tag.WorkflowRunID(we.GetRunID()), + tag.Error(err)) + return err + } + + a.logger.Warn("force-replication ignore replication task due to NotFoundServiceError", + tag.WorkflowNamespaceID(we.GetNamespaceID()), + tag.WorkflowID(we.GetWorkflowID()), + tag.WorkflowRunID(we.GetRunID()), + tag.Error(err)) } } - activity.RecordHeartbeat(ctx, i) } return nil } -func (a *activities) setCallerInfoForGenReplicationTask( +func (a *activities) generateExecutionsToReplicate( + ctx context.Context, + rateLimiter quotas.RateLimiter, + executionDedupMap map[definition.WorkflowKey]struct{}, + namespaceID string, + baseWf *commonpb.WorkflowExecution, +) ([]definition.WorkflowKey, error) { + + start := time.Now() + defer func() { + a.forceReplicationMetricsHandler.Timer("GenerateParentWorkflowExecutionsLatency").Record(time.Since(start)) + }() + + var resultStack []definition.WorkflowKey + baseWfKey := definition.NewWorkflowKey(namespaceID, baseWf.GetWorkflowId(), baseWf.GetRunId()) + queue := []definition.WorkflowKey{baseWfKey} + for len(queue) > 0 { + var currWorkflow definition.WorkflowKey + currWorkflow, queue = queue[0], queue[1:] + + if _, ok := executionDedupMap[currWorkflow]; ok { + // already in the result set + continue + } + executionDedupMap[currWorkflow] = struct{}{} + + if err := rateLimiter.WaitN(ctx, 1); err != nil { + return nil, err + } + // Reason to use history client + // 1. Reduce networking routing + // 2. Bypass frontend per namespace rate limiter + resp, err := a.historyClient.DescribeWorkflowExecution(ctx, &historyservice.DescribeWorkflowExecutionRequest{ + NamespaceId: currWorkflow.GetNamespaceID(), + Request: &workflowservice.DescribeWorkflowExecutionRequest{ + Execution: &commonpb.WorkflowExecution{ + WorkflowId: currWorkflow.GetWorkflowID(), + RunId: currWorkflow.GetRunID(), + }, + }, + }) + if err != nil { + if isNotFoundServiceError(err) { + continue + } + return nil, err + } + resultStack = append(resultStack, currWorkflow) + + parentExecInfo := resp.GetWorkflowExecutionInfo().GetParentExecution() + if parentExecInfo != nil { + parentExecution := definition.NewWorkflowKey(resp.GetWorkflowExecutionInfo().GetParentNamespaceId(), parentExecInfo.GetWorkflowId(), parentExecInfo.GetRunId()) + queue = append(queue, parentExecution) + } + } + + slices.Reverse(resultStack) + return resultStack, nil +} + +func (a *activities) setCallerInfoForServerAPI( ctx context.Context, namespaceID namespace.ID, ) context.Context { @@ -584,7 +712,7 @@ func (a *activities) checkSkipWorkflowExecution( if isNotFoundServiceError(err) { // The outstanding workflow execution may be deleted (due to retention) on source cluster after replication tasks were generated. // Since retention runs on both source/target clusters, such execution may also be deleted (hence not found) from target cluster. - metrics.EncounterNotFoundWorkflowCount.With(a.forceReplicationMetricsHandler).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.EncounterNotFoundWorkflowCount.Name()).Record(1) return verifyResult{ status: skipped, reason: reasonWorkflowNotFound, @@ -599,7 +727,7 @@ func (a *activities) checkSkipWorkflowExecution( // Zombie workflow should be a transient state. However, if there is Zombie workflow on the source cluster, // it is skipped to avoid such workflow being processed on the target cluster. if resp.GetDatabaseMutableState().GetExecutionState().GetState() == enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE { - metrics.EncounterZombieWorkflowCount.With(a.forceReplicationMetricsHandler).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.EncounterZombieWorkflowCount.Name()).Record(1) a.logger.Info("createReplicationTasks skip Zombie workflow", tags...) return verifyResult{ status: skipped, @@ -611,7 +739,7 @@ func (a *activities) checkSkipWorkflowExecution( if closeTime := resp.GetDatabaseMutableState().GetExecutionInfo().GetCloseTime(); closeTime != nil && ns != nil && ns.Retention() > 0 { deleteTime := closeTime.AsTime().Add(ns.Retention()) if deleteTime.Before(time.Now()) { - metrics.EncounterPassRetentionWorkflowCount.With(a.forceReplicationMetricsHandler).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.EncounterPassRetentionWorkflowCount.Name()).Record(1) return verifyResult{ status: skipped, reason: reasonWorkflowCloseToRetention, @@ -627,43 +755,43 @@ func (a *activities) checkSkipWorkflowExecution( func (a *activities) verifySingleReplicationTask( ctx context.Context, request *verifyReplicationTasksRequest, - remoteClient adminservice.AdminServiceClient, + remoteClient workflowservice.WorkflowServiceClient, ns *namespace.Namespace, we *commonpb.WorkflowExecution, -) (result verifyResult, rerr error) { +) (verifyResult, error) { s := time.Now() // Check if execution exists on remote cluster - _, err := remoteClient.DescribeMutableState(ctx, &adminservice.DescribeMutableStateRequest{ + _, err := remoteClient.DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{ Namespace: request.Namespace, Execution: we, }) - metrics.VerifyDescribeMutableStateLatency.With(a.forceReplicationMetricsHandler).Record(time.Since(s)) + a.forceReplicationMetricsHandler.Timer(metrics.VerifyDescribeMutableStateLatency.Name()).Record(time.Since(s)) switch err.(type) { case nil: - metrics.VerifyReplicationTaskSuccess.With(a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace))).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.VerifyReplicationTaskSuccess.Name()).Record(1) return verifyResult{ status: verified, }, nil case *serviceerror.NotFound: - metrics.VerifyReplicationTaskNotFound.With(a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace))).Record(1) - // Calling checkSkipWorkflowExecution for every NotFound is sub-optimal as most common case to skip is workfow being deleted due to retention. + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace)).Counter(metrics.VerifyReplicationTaskNotFound.Name()).Record(1) + // Calling checkSkipWorkflowExecution for every NotFound is sub-optimal as most common case to skip is workflow being deleted due to retention. // A better solution is to only check the existence for workflow which is close to retention period. return a.checkSkipWorkflowExecution(ctx, request, we, ns) case *serviceerror.NamespaceNotFound: return verifyResult{ status: notVerified, - }, temporal.NewNonRetryableApplicationError("remoteClient.DescribeMutableState call failed", "NamespaceNotFound", err) + }, temporal.NewNonRetryableApplicationError("failed to describe workflow from the remote cluster", "NamespaceNotFound", err) default: - metrics.VerifyReplicationTaskFailed.With(a.forceReplicationMetricsHandler. - WithTags(metrics.NamespaceTag(request.Namespace), metrics.ServiceErrorTypeTag(err))).Record(1) + a.forceReplicationMetricsHandler.WithTags(metrics.NamespaceTag(request.Namespace), metrics.ServiceErrorTypeTag(err)). + Counter(metrics.VerifyReplicationTaskFailed.Name()).Record(1) return verifyResult{ status: notVerified, - }, errors.WithMessage(err, "remoteClient.DescribeMutableState call failed") + }, errors.WithMessage(err, "failed to describe workflow from the remote cluster") } } @@ -671,7 +799,7 @@ func (a *activities) verifyReplicationTasks( ctx context.Context, request *verifyReplicationTasksRequest, details *replicationTasksHeartbeatDetails, - remoteClient adminservice.AdminServiceClient, + remoteClient workflowservice.WorkflowServiceClient, ns *namespace.Namespace, heartbeat func(details replicationTasksHeartbeatDetails), ) (bool, error) { @@ -684,9 +812,11 @@ func (a *activities) verifyReplicationTasks( } heartbeat(*details) - metrics.VerifyReplicationTasksLatency.With(a.forceReplicationMetricsHandler).Record(time.Since(start)) + a.forceReplicationMetricsHandler.Timer(metrics.VerifyReplicationTasksLatency.Name()).Record(time.Since(start)) }() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(interceptor.DCRedirectionContextHeaderName, "false")) + for ; details.NextIndex < len(request.Executions); details.NextIndex++ { we := request.Executions[details.NextIndex] r, err := a.verifySingleReplicationTask(ctx, request, remoteClient, ns, we) @@ -711,25 +841,7 @@ const ( ) func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verifyReplicationTasksRequest) (verifyReplicationTasksResponse, error) { - ctx = headers.SetCallerInfo(ctx, headers.NewPreemptableCallerInfo(request.Namespace)) var response verifyReplicationTasksResponse - var remoteClient adminservice.AdminServiceClient - var err error - - if len(request.TargetClusterName) > 0 { - remoteClient, err = a.clientBean.GetRemoteAdminClient(request.TargetClusterName) - if err != nil { - return response, err - } - } else { - // TODO: remove once TargetClusterEndpoint is no longer used. - remoteClient = a.clientFactory.NewRemoteAdminClientWithTimeout( - request.TargetClusterEndpoint, - admin.DefaultTimeout, - admin.DefaultLargeTimeout, - ) - } - var details replicationTasksHeartbeatDetails if activity.HasHeartbeatDetails(ctx) { if err := activity.GetHeartbeatDetails(ctx, &details); err != nil { @@ -741,6 +853,11 @@ func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verify activity.RecordHeartbeat(ctx, details) } + _, remoteClient, err := a.clientBean.GetRemoteFrontendClient(request.TargetClusterName) + if err != nil { + return response, err + } + nsEntry, err := a.namespaceRegistry.GetNamespace(namespace.Name(request.Namespace)) if err != nil { return response, err @@ -750,7 +867,7 @@ func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verify // 1. replication lag // 2. Zombie workflow execution // 3. workflow execution was deleted (due to retention) after replication task was created - // 4. workflow execution was not applied succesfully on target cluster (i.e, bug) + // 4. workflow execution was not applied successfully on target cluster (i.e, bug) // // The verification step is retried for every VerifyInterval to handle #1. Verification progress // is recorded in activity heartbeat. The verification is considered of making progress if there was at least one new execution @@ -771,6 +888,7 @@ func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verify } if verified == true { + response.VerifiedWorkflowCount = int64(len(request.Executions)) return response, nil } @@ -786,3 +904,116 @@ func (a *activities) VerifyReplicationTasks(ctx context.Context, request *verify } } } + +func (a *activities) WaitCatchup(ctx context.Context, params CatchUpParams) error { + ctx = headers.SetCallerInfo(ctx, headers.NewCallerInfo(params.Namespace, headers.CallerTypeAPI, "")) + + descResp, err := a.frontendClient.DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{ + Namespace: params.Namespace, + }) + if err != nil { + return err + } + + waitCatchupRequest := waitCatchupRequest{ + Namespace: params.Namespace, + RemoteCluster: params.RemoteCluster, + ActiveCluster: descResp.ReplicationConfig.GetActiveClusterName(), + } + + activeAckIDOnShard, err := a.getActiveClusterReplicationStatus(ctx, waitCatchupRequest) + if err != nil { + return err + } + + for { + done, err := a.checkReplicationOnRemoteCluster(ctx, waitCatchupRequest, activeAckIDOnShard) + if err != nil { + return err + } + if done { + return nil + } + + // keep waiting and check again + time.Sleep(time.Second) + activity.RecordHeartbeat(ctx, nil) + } +} + +// Check if remote cluster has caught up on all shards on replication tasks from passive replica. +func (a *activities) getActiveClusterReplicationStatus(ctx context.Context, waitRequest waitCatchupRequest) (map[int32]int64, error) { + activeAckIDOnShard := make(map[int32]int64) + + resp, err := a.historyClient.GetReplicationStatus(ctx, &historyservice.GetReplicationStatusRequest{ + RemoteClusters: []string{waitRequest.ActiveCluster}, + }) + if err != nil { + return activeAckIDOnShard, err + } + + // record the acked task id from active for each shard + for _, shard := range resp.Shards { + activeInfo, hasActiveInfo := shard.RemoteClusters[waitRequest.ActiveCluster] + if hasActiveInfo { + activeAckIDOnShard[shard.ShardId] = activeInfo.AckedTaskId + } + } + + return activeAckIDOnShard, nil +} + +// Check if remote cluster has caught up on all shards on replication tasks from passive replica. +func (a *activities) checkReplicationOnRemoteCluster(ctx context.Context, waitRequest waitCatchupRequest, activeAckIDOnShard map[int32]int64) (bool, error) { + + resp, err := a.historyClient.GetReplicationStatus(ctx, &historyservice.GetReplicationStatusRequest{ + RemoteClusters: []string{waitRequest.RemoteCluster}, + }) + if err != nil { + return false, err + } + + expectedShardCount := len(activeAckIDOnShard) + + readyShardCount := 0 + logged := false + // check that on every shard, all source clusters have caught up with target cluster + for _, shard := range resp.Shards { + clusterInfo, hasClusterInfo := shard.RemoteClusters[waitRequest.RemoteCluster] + if hasClusterInfo { + value, exists := activeAckIDOnShard[shard.ShardId] + // If the active acked task ID is not found, the shard is considered ready, as the remote ack level + // is assumed to be more up-to-date than the active ack level. + if !exists { + readyShardCount++ + continue + } + if clusterInfo.AckedTaskId >= value { + readyShardCount++ + continue + } + } + + // shard is not ready, log first non-ready shard + if !logged { + logged = true + if !hasClusterInfo { + a.logger.Info("Wait handover missing remote cluster info", tag.ShardID(shard.ShardId), tag.ClusterName(waitRequest.RemoteCluster)) + // this is not expected, so fail activity to surface the error, but retryPolicy will keep retrying. + return false, temporal.NewNonRetryableApplicationError(fmt.Sprintf("GetReplicationStatus response for shard %d does not contains remote cluster %s", shard.ShardId, waitRequest.RemoteCluster), "", nil) + } + + a.logger.Info("Wait handover not ready", + tag.NewInt32("ShardId", shard.ShardId), + tag.NewInt64("AckedTaskId", clusterInfo.AckedTaskId), + tag.NewStringTag("Namespace", waitRequest.Namespace), + tag.NewStringTag("RemoteCluster", waitRequest.RemoteCluster), + tag.NewInt64("activeAckIDOnShard", activeAckIDOnShard[shard.ShardId]), + ) + + } + + } + + return readyShardCount == expectedShardCount, nil +} diff --git a/service/worker/migration/activities_test.go b/service/worker/migration/activities_test.go index 88ad7546c09..6c70da8b845 100644 --- a/service/worker/migration/activities_test.go +++ b/service/worker/migration/activities_test.go @@ -31,12 +31,12 @@ import ( "github.com/stretchr/testify/suite" commonpb "go.temporal.io/api/common/v1" + replicationpb "go.temporal.io/api/replication/v1" "go.temporal.io/api/serviceerror" + "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/interceptor" "go.temporal.io/sdk/testsuite" "go.temporal.io/sdk/worker" - "go.temporal.io/server/api/adminservice/v1" - "go.temporal.io/server/api/adminservicemock/v1" enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" "go.temporal.io/server/api/historyservicemock/v1" @@ -66,9 +66,9 @@ type activitiesSuite struct { mockClientFactory *client.MockFactory mockClientBean *client.MockBean - mockFrontendClient *workflowservicemock.MockWorkflowServiceClient - mockHistoryClient *historyservicemock.MockHistoryServiceClient - mockRemoteAdminClient *adminservicemock.MockAdminServiceClient + mockFrontendClient *workflowservicemock.MockWorkflowServiceClient + mockHistoryClient *historyservicemock.MockHistoryServiceClient + mockRemoteClient *workflowservicemock.MockWorkflowServiceClient logger log.Logger mockMetricsHandler *metrics.MockHandler @@ -83,24 +83,17 @@ const ( ) var ( - emptyExecutions = commonpb.WorkflowExecution{} - - execution1 = commonpb.WorkflowExecution{ + execution1 = &commonpb.WorkflowExecution{ WorkflowId: "workflow1", RunId: "run1", } - execution2 = commonpb.WorkflowExecution{ + execution2 = &commonpb.WorkflowExecution{ WorkflowId: "workflow2", RunId: "run2", } - execution3 = commonpb.WorkflowExecution{ - WorkflowId: "workflow3", - RunId: "run3", - } - - completeState = historyservice.DescribeMutableStateResponse{ + completeState = &historyservice.DescribeMutableStateResponse{ DatabaseMutableState: &persistencespb.WorkflowMutableState{ ExecutionState: &persistencespb.WorkflowExecutionState{ State: enumsspb.WORKFLOW_EXECUTION_STATE_COMPLETED, @@ -108,7 +101,7 @@ var ( }, } - zombieState = historyservice.DescribeMutableStateResponse{ + zombieState = &historyservice.DescribeMutableStateResponse{ DatabaseMutableState: &persistencespb.WorkflowMutableState{ ExecutionState: &persistencespb.WorkflowExecutionState{ State: enumsspb.WORKFLOW_EXECUTION_STATE_ZOMBIE, @@ -133,14 +126,14 @@ func (s *activitiesSuite) SetupTest() { s.mockFrontendClient = workflowservicemock.NewMockWorkflowServiceClient(s.controller) s.mockHistoryClient = historyservicemock.NewMockHistoryServiceClient(s.controller) - s.mockRemoteAdminClient = adminservicemock.NewMockAdminServiceClient(s.controller) + s.mockRemoteClient = workflowservicemock.NewMockWorkflowServiceClient(s.controller) s.logger = log.NewNoopLogger() s.mockMetricsHandler = metrics.NewMockHandler(s.controller) s.mockMetricsHandler.EXPECT().WithTags(gomock.Any()).Return(s.mockMetricsHandler).AnyTimes() s.mockMetricsHandler.EXPECT().Timer(gomock.Any()).Return(metrics.NoopTimerMetricFunc).AnyTimes() s.mockMetricsHandler.EXPECT().Counter(gomock.Any()).Return(metrics.NoopCounterMetricFunc).AnyTimes() - s.mockClientBean.EXPECT().GetRemoteAdminClient(remoteCluster).Return(s.mockRemoteAdminClient, nil).AnyTimes() + s.mockClientBean.EXPECT().GetRemoteFrontendClient(remoteCluster).Return(nil, s.mockRemoteClient, nil).AnyTimes() s.mockNamespaceRegistry.EXPECT().GetNamespaceName(gomock.Any()). Return(namespace.Name(mockedNamespace), nil).AnyTimes() s.mockNamespaceRegistry.EXPECT().GetNamespace(gomock.Any()). @@ -180,44 +173,48 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_Success() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1, &execution2}, + Executions: []*commonpb.WorkflowExecution{execution1, execution2}, } // Immediately replicated - s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, - })).Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) + Execution: execution1, + })).Return(&workflowservice.DescribeWorkflowExecutionResponse{}, nil).Times(1) // Slowly replicated replicationSlowReponses := []struct { - resp *adminservice.DescribeMutableStateResponse + resp *workflowservice.DescribeWorkflowExecutionResponse err error }{ {nil, serviceerror.NewNotFound("")}, {nil, serviceerror.NewNotFound("")}, - {&adminservice.DescribeMutableStateResponse{}, nil}, + {&workflowservice.DescribeWorkflowExecutionResponse{}, nil}, } for _, r := range replicationSlowReponses { - s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution2, + Execution: execution2, })).Return(r.resp, r.err).Times(1) } s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution2, - })).Return(&completeState, nil).Times(2) + Execution: execution2, + })).Return(completeState, nil).Times(2) - _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) + f, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) s.NoError(err) + var output verifyReplicationTasksResponse + err = f.Get(&output) + s.NoError(err) + s.Equal(len(request.Executions), int(output.VerifiedWorkflowCount)) s.Greater(len(iceptor.replicationRecordedHeartbeats), 0) lastHeartBeat := iceptor.replicationRecordedHeartbeats[len(iceptor.replicationRecordedHeartbeats)-1] s.Equal(len(request.Executions), lastHeartBeat.NextIndex) - s.ProtoEqual(&execution2, lastHeartBeat.LastNotVerifiedWorkflowExecution) + s.Equal(execution2, lastHeartBeat.LastNotVerifiedWorkflowExecution) } func (s *activitiesSuite) TestVerifyReplicationTasks_SkipWorkflowExecution() { @@ -229,7 +226,7 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_SkipWorkflowExecution() { expectedErr error }{ { - &zombieState, + zombieState, nil, reasonZombieWorkflow, nil, @@ -250,21 +247,21 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_SkipWorkflowExecution() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1}, + Executions: []*commonpb.WorkflowExecution{execution1}, } start := time.Now() for _, t := range testcases { env, iceptor := s.initEnv() - s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, + Execution: execution1, })).Return(nil, serviceerror.NewNotFound("")).Times(1) s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, + Execution: execution1, })).Return(t.resp, t.err).Times(1) _, err := env.ExecuteActivity(s.a.VerifyReplicationTasks, &request) @@ -289,18 +286,18 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_FailedNotFound() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1}, + Executions: []*commonpb.WorkflowExecution{execution1}, } s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, - })).Return(&completeState, nil) + Execution: execution1, + })).Return(completeState, nil) // Workflow not found at target cluster. - s.mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, + Execution: execution1, })).Return(nil, serviceerror.NewNotFound("")).AnyTimes() // Set CheckPoint to an early to trigger failure. @@ -316,7 +313,7 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_FailedNotFound() { s.Greater(len(iceptor.replicationRecordedHeartbeats), 0) lastHeartBeat := iceptor.replicationRecordedHeartbeats[len(iceptor.replicationRecordedHeartbeats)-1] s.Equal(0, lastHeartBeat.NextIndex) - s.ProtoEqual(&execution1, lastHeartBeat.LastNotVerifiedWorkflowExecution) + s.Equal(execution1, lastHeartBeat.LastNotVerifiedWorkflowExecution) } func (s *activitiesSuite) TestVerifyReplicationTasks_AlreadyVerified() { @@ -325,7 +322,7 @@ func (s *activitiesSuite) TestVerifyReplicationTasks_AlreadyVerified() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1, &execution2}, + Executions: []*commonpb.WorkflowExecution{execution1, execution2}, } // Set NextIndex to indicate all executions have been verified. No additional mock is needed. @@ -345,31 +342,30 @@ func (s *activitiesSuite) Test_verifySingleReplicationTask() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1, &execution2}, + Executions: []*commonpb.WorkflowExecution{execution1, execution2}, } ctx := context.TODO() - mockRemoteAdminClient := adminservicemock.NewMockAdminServiceClient(s.controller) - mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, - })).Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) - result, err := s.a.verifySingleReplicationTask(ctx, &request, mockRemoteAdminClient, &testNamespace, request.Executions[0]) + Execution: execution1, + })).Return(&workflowservice.DescribeWorkflowExecutionResponse{}, nil).Times(1) + result, err := s.a.verifySingleReplicationTask(ctx, &request, s.mockRemoteClient, &testNamespace, request.Executions[0]) s.NoError(err) s.True(result.isVerified()) // Test not verified workflow - mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution2, - })).Return(&adminservice.DescribeMutableStateResponse{}, serviceerror.NewNotFound("")).Times(1) + Execution: execution2, + })).Return(&workflowservice.DescribeWorkflowExecutionResponse{}, serviceerror.NewNotFound("")).Times(1) s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution2, - })).Return(&completeState, nil).AnyTimes() + Execution: execution2, + })).Return(completeState, nil).AnyTimes() - result, err = s.a.verifySingleReplicationTask(ctx, &request, mockRemoteAdminClient, &testNamespace, request.Executions[1]) + result, err = s.a.verifySingleReplicationTask(ctx, &request, s.mockRemoteClient, &testNamespace, request.Executions[1]) s.NoError(err) s.False(result.isVerified()) } @@ -382,31 +378,31 @@ const ( executionErr executionState = 2 ) -func createExecutions(mockClient *adminservicemock.MockAdminServiceClient, states []executionState, nextIndex int) []*commonpb.WorkflowExecution { +func createExecutions(mockClient *workflowservicemock.MockWorkflowServiceClient, states []executionState, nextIndex int) []*commonpb.WorkflowExecution { var executions []*commonpb.WorkflowExecution for i := 0; i < len(states); i++ { - executions = append(executions, &execution1) + executions = append(executions, execution1) } Loop: for i := nextIndex; i < len(states); i++ { switch states[i] { case executionFound: - mockClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + mockClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, - })).Return(&adminservice.DescribeMutableStateResponse{}, nil).Times(1) + Execution: execution1, + })).Return(&workflowservice.DescribeWorkflowExecutionResponse{}, nil).Times(1) case executionNotfound: - mockClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + mockClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, + Execution: execution1, })).Return(nil, serviceerror.NewNotFound("")).Times(1) break Loop case executionErr: - mockClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + mockClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, + Execution: execution1, })).Return(nil, serviceerror.NewInternal("")).Times(1) } } @@ -474,20 +470,20 @@ func (s *activitiesSuite) Test_verifyReplicationTasks() { s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, - })).Return(&completeState, nil).AnyTimes() + Execution: execution1, + })).Return(completeState, nil).AnyTimes() checkPointTime := time.Now() for _, tc := range tests { var recorder mockHeartBeatRecorder - mockRemoteAdminClient := adminservicemock.NewMockAdminServiceClient(s.controller) - request.Executions = createExecutions(mockRemoteAdminClient, tc.remoteExecutionStates, tc.nextIndex) + mockRemoteClient := workflowservicemock.NewMockWorkflowServiceClient(s.controller) + request.Executions = createExecutions(mockRemoteClient, tc.remoteExecutionStates, tc.nextIndex) details := replicationTasksHeartbeatDetails{ NextIndex: tc.nextIndex, CheckPoint: checkPointTime, } - verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteAdminClient, &testNamespace, recorder.hearbeat) + verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteClient, &testNamespace, recorder.hearbeat) if tc.expectedErr == nil { s.NoError(err) } @@ -496,7 +492,7 @@ func (s *activitiesSuite) Test_verifyReplicationTasks() { s.GreaterOrEqual(len(tc.remoteExecutionStates), details.NextIndex) s.Equal(recorder.lastHeartBeat, details) if details.NextIndex < len(tc.remoteExecutionStates) && tc.remoteExecutionStates[details.NextIndex] == executionNotfound { - s.ProtoEqual(&execution1, details.LastNotVerifiedWorkflowExecution) + s.Equal(execution1, details.LastNotVerifiedWorkflowExecution) } if len(request.Executions) > 0 { @@ -508,19 +504,18 @@ func (s *activitiesSuite) Test_verifyReplicationTasks() { func (s *activitiesSuite) Test_verifyReplicationTasksNoProgress() { var recorder mockHeartBeatRecorder - mockRemoteAdminClient := adminservicemock.NewMockAdminServiceClient(s.controller) request := verifyReplicationTasksRequest{ Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: createExecutions(mockRemoteAdminClient, []executionState{executionFound, executionFound, executionNotfound, executionFound}, 0), + Executions: createExecutions(s.mockRemoteClient, []executionState{executionFound, executionFound, executionNotfound, executionFound}, 0), } s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, - })).Return(&completeState, nil).AnyTimes() + Execution: execution1, + })).Return(completeState, nil).AnyTimes() checkPointTime := time.Now() details := replicationTasksHeartbeatDetails{ @@ -529,7 +524,7 @@ func (s *activitiesSuite) Test_verifyReplicationTasksNoProgress() { } ctx := context.TODO() - verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteAdminClient, &testNamespace, recorder.hearbeat) + verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, s.mockRemoteClient, &testNamespace, recorder.hearbeat) s.NoError(err) s.False(verified) // Verify has made progress. @@ -539,13 +534,13 @@ func (s *activitiesSuite) Test_verifyReplicationTasksNoProgress() { prevDetails := details // Mock for one more NotFound call - mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), &adminservice.DescribeMutableStateRequest{ + s.mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, - }).Return(nil, serviceerror.NewNotFound("")).Times(1) + Execution: execution1, + })).Return(nil, serviceerror.NewNotFound("")).Times(1) // All results should be either NotFound or cached and no progress should be made. - verified, err = s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteAdminClient, &testNamespace, recorder.hearbeat) + verified, err = s.a.verifyReplicationTasks(ctx, &request, &details, s.mockRemoteClient, &testNamespace, recorder.hearbeat) s.NoError(err) s.False(verified) s.Equal(prevDetails, details) @@ -556,7 +551,7 @@ func (s *activitiesSuite) Test_verifyReplicationTasksSkipRetention() { Namespace: mockedNamespace, NamespaceID: mockedNamespaceID, TargetClusterName: remoteCluster, - Executions: []*commonpb.WorkflowExecution{&execution1}, + Executions: []*commonpb.WorkflowExecution{execution1}, } var tests = []struct { @@ -579,15 +574,15 @@ func (s *activitiesSuite) Test_verifyReplicationTasksSkipRetention() { retention := time.Hour closeTime := deleteTime.Add(-retention) - mockRemoteAdminClient := adminservicemock.NewMockAdminServiceClient(s.controller) - mockRemoteAdminClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&adminservice.DescribeMutableStateRequest{ + mockRemoteClient := workflowservicemock.NewMockWorkflowServiceClient(s.controller) + mockRemoteClient.EXPECT().DescribeWorkflowExecution(gomock.Any(), protomock.Eq(&workflowservice.DescribeWorkflowExecutionRequest{ Namespace: mockedNamespace, - Execution: &execution1, + Execution: execution1, })).Return(nil, serviceerror.NewNotFound("")).Times(1) s.mockHistoryClient.EXPECT().DescribeMutableState(gomock.Any(), protomock.Eq(&historyservice.DescribeMutableStateRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, + Execution: execution1, })).Return(&historyservice.DescribeMutableStateResponse{ DatabaseMutableState: &persistencespb.WorkflowMutableState{ ExecutionState: &persistencespb.WorkflowExecutionState{ @@ -611,7 +606,7 @@ func (s *activitiesSuite) Test_verifyReplicationTasksSkipRetention() { details := replicationTasksHeartbeatDetails{} ctx := context.TODO() - verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteAdminClient, ns, recorder.hearbeat) + verified, err := s.a.verifyReplicationTasks(ctx, &request, &details, mockRemoteClient, ns, recorder.hearbeat) s.NoError(err) s.Equal(tc.verified, verified) s.Equal(recorder.lastHeartBeat, details) @@ -629,9 +624,10 @@ func (s *activitiesSuite) TestGenerateReplicationTasks_Success() { env, iceptor := s.initEnv() request := generateReplicationTasksRequest{ - NamespaceID: mockedNamespaceID, - RPS: 10, - Executions: []*commonpb.WorkflowExecution{&execution1, &execution2}, + NamespaceID: mockedNamespaceID, + RPS: 10, + GetParentInfoRPS: 10, + Executions: []*commonpb.WorkflowExecution{execution1, execution2}, } for i := 0; i < len(request.Executions); i++ { @@ -651,54 +647,108 @@ func (s *activitiesSuite) TestGenerateReplicationTasks_Success() { s.Equal(lastIdx, lastHeartBeat) } -func (s *activitiesSuite) TestGenerateReplicationTasks_NotFound() { +func (s *activitiesSuite) TestGenerateReplicationTasks_Failed() { env, iceptor := s.initEnv() request := generateReplicationTasksRequest{ - NamespaceID: mockedNamespaceID, - RPS: 10, - Executions: []*commonpb.WorkflowExecution{&execution1}, + NamespaceID: mockedNamespaceID, + RPS: 10, + GetParentInfoRPS: 10, + Executions: []*commonpb.WorkflowExecution{execution1, execution2}, } s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ NamespaceId: mockedNamespaceID, - Execution: &execution1, - })).Return(nil, serviceerror.NewNotFound("")).Times(1) + Execution: execution1, + })).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + + s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ + NamespaceId: mockedNamespaceID, + Execution: execution2, + })).Return(nil, serviceerror.NewInternal("")).Times(1) _, err := env.ExecuteActivity(s.a.GenerateReplicationTasks, &request) - s.NoError(err) + s.Error(err) s.Greater(len(iceptor.generateReplicationRecordedHeartbeats), 0) lastIdx := len(iceptor.generateReplicationRecordedHeartbeats) - 1 lastHeartBeat := iceptor.generateReplicationRecordedHeartbeats[lastIdx] + // Only the generation of 1st execution suceeded. s.Equal(0, lastHeartBeat) } -func (s *activitiesSuite) TestGenerateReplicationTasks_Failed() { - env, iceptor := s.initEnv() +func (s *activitiesSuite) TestCountWorkflows() { + env, _ := s.initEnv() - request := generateReplicationTasksRequest{ - NamespaceID: mockedNamespaceID, - RPS: 10, - Executions: []*commonpb.WorkflowExecution{&execution1, &execution2}, + request := &workflowservice.CountWorkflowExecutionsRequest{ + Namespace: mockedNamespace, + Query: "abc", } - s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ - NamespaceId: mockedNamespaceID, - Execution: &execution1, - })).Return(&historyservice.GenerateLastHistoryReplicationTasksResponse{}, nil).Times(1) + s.mockFrontendClient.EXPECT().CountWorkflowExecutions(gomock.Any(), protomock.Eq(request)).Return(&workflowservice.CountWorkflowExecutionsResponse{ + Count: 100, + }, nil).Times(1) - s.mockHistoryClient.EXPECT().GenerateLastHistoryReplicationTasks(gomock.Any(), protomock.Eq(&historyservice.GenerateLastHistoryReplicationTasksRequest{ - NamespaceId: mockedNamespaceID, - Execution: &execution2, - })).Return(nil, serviceerror.NewInternal("")) + f, err := env.ExecuteActivity(s.a.CountWorkflow, request) + s.NoError(err) + var output *countWorkflowResponse + err = f.Get(&output) + s.NoError(err) + s.Equal(int64(100), output.WorkflowCount) +} - _, err := env.ExecuteActivity(s.a.GenerateReplicationTasks, &request) - s.Error(err) +func (s *activitiesSuite) TestWaitCatchUp() { + env, _ := s.initEnv() - s.Greater(len(iceptor.generateReplicationRecordedHeartbeats), 0) - lastIdx := len(iceptor.generateReplicationRecordedHeartbeats) - 1 - lastHeartBeat := iceptor.generateReplicationRecordedHeartbeats[lastIdx] - // Only the generation of 1st execution suceeded. - s.Equal(0, lastHeartBeat) + describeNamespaceRequest := &workflowservice.DescribeNamespaceRequest{ + Namespace: mockedNamespace, + } + + getReplicationStatusRequestFromRemote := &historyservice.GetReplicationStatusRequest{ + RemoteClusters: []string{remoteCluster}, + } + + getReplicationStatusRequestFromActive := &historyservice.GetReplicationStatusRequest{ + RemoteClusters: []string{"test_cluster"}, + } + + request := CatchUpParams{ + Namespace: mockedNamespace, + RemoteCluster: remoteCluster, + } + + s.mockFrontendClient.EXPECT().DescribeNamespace(gomock.Any(), protomock.Eq(describeNamespaceRequest)).Return(&workflowservice.DescribeNamespaceResponse{ + ReplicationConfig: &replicationpb.NamespaceReplicationConfig{ + ActiveClusterName: "test_cluster", + }, + }, nil).Times(1) + + s.mockHistoryClient.EXPECT().GetReplicationStatus(gomock.Any(), protomock.Eq(getReplicationStatusRequestFromRemote)).Return(&historyservice.GetReplicationStatusResponse{ + Shards: []*historyservice.ShardReplicationStatus{ + { + ShardId: 1, + RemoteClusters: map[string]*historyservice.ShardReplicationStatusPerCluster{ + remoteCluster: { + AckedTaskId: 123, + }, + }, + }, + }, + }, nil).AnyTimes() + + s.mockHistoryClient.EXPECT().GetReplicationStatus(gomock.Any(), protomock.Eq(getReplicationStatusRequestFromActive)).Return(&historyservice.GetReplicationStatusResponse{ + Shards: []*historyservice.ShardReplicationStatus{ + { + ShardId: 1, + RemoteClusters: map[string]*historyservice.ShardReplicationStatusPerCluster{ + "test_cluster": { + AckedTaskId: 111, + }, + }, + }, + }, + }, nil).AnyTimes() + + _, err := env.ExecuteActivity(s.a.WaitCatchup, request) + s.NoError(err) } diff --git a/service/worker/migration/catchup_workflow.go b/service/worker/migration/catchup_workflow.go new file mode 100644 index 00000000000..47cf5478f09 --- /dev/null +++ b/service/worker/migration/catchup_workflow.go @@ -0,0 +1,82 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package migration + +import ( + "time" + + "go.temporal.io/sdk/temporal" + "go.temporal.io/sdk/workflow" +) + +const ( + catchupWorkflowName = "catchup" +) + +type ( + CatchUpParams struct { + Namespace string + RemoteCluster string + } + + CatchUpOutput struct{} +) + +func CatchupWorkflow(ctx workflow.Context, params CatchUpParams) (CatchUpOutput, error) { + if err := validateCatchupParams(¶ms); err != nil { + return CatchUpOutput{}, err + } + + retryPolicy := &temporal.RetryPolicy{ + InitialInterval: time.Second, + MaximumInterval: time.Second, + BackoffCoefficient: 1, + } + ao := workflow.ActivityOptions{ + StartToCloseTimeout: time.Hour, + HeartbeatTimeout: time.Second * 10, + RetryPolicy: retryPolicy, + } + ctx1 := workflow.WithActivityOptions(ctx, ao) + + var a *activities + err := workflow.ExecuteActivity(ctx1, a.WaitCatchup, params).Get(ctx, nil) + if err != nil { + return CatchUpOutput{}, err + } + + return CatchUpOutput{}, err +} + +func validateCatchupParams(params *CatchUpParams) error { + if len(params.Namespace) == 0 { + return temporal.NewNonRetryableApplicationError("InvalidArgument: Namespace is required", "InvalidArgument", nil) + } + if len(params.RemoteCluster) == 0 { + return temporal.NewNonRetryableApplicationError("InvalidArgument: RemoteCluster is required", "InvalidArgument", nil) + } + + return nil +} diff --git a/service/worker/migration/catchup_workflow_test.go b/service/worker/migration/catchup_workflow_test.go new file mode 100644 index 00000000000..ad4b34475ae --- /dev/null +++ b/service/worker/migration/catchup_workflow_test.go @@ -0,0 +1,50 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package migration + +import ( + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.temporal.io/sdk/testsuite" +) + +func TestCatchupWorkflow(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + var a *activities + + env.OnActivity(a.WaitCatchup, mock.Anything, mock.Anything).Return(nil) + + env.ExecuteWorkflow(CatchupWorkflow, CatchUpParams{ + Namespace: "test-ns", + RemoteCluster: "test-remote", + }) + + require.True(t, env.IsWorkflowCompleted()) + require.NoError(t, env.GetWorkflowError()) + env.AssertExpectations(t) +} diff --git a/service/worker/migration/force_replication_workflow.go b/service/worker/migration/force_replication_workflow.go index 395378ffec8..1befce55212 100644 --- a/service/worker/migration/force_replication_workflow.go +++ b/service/worker/migration/force_replication_workflow.go @@ -33,7 +33,7 @@ import ( "go.temporal.io/api/workflowservice/v1" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/metrics" ) type ( @@ -56,6 +56,7 @@ type ( Query string `validate:"required"` // query to list workflows for replication ConcurrentActivityCount int OverallRps float64 // RPS for enqueuing of replication tasks + GetParentInfoRPS float64 // RPS for getting parent child info ListWorkflowsPageSize int // PageSize of ListWorkflow, will paginate through results. PageCountPerExecution int // number of pages to be processed before continue as new, max is 1000. NextPageToken []byte // used by continue as new @@ -71,11 +72,30 @@ type ( LastStartTime time.Time ContinuedAsNewCount int TaskQueueUserDataReplicationParams TaskQueueUserDataReplicationParams + ReplicatedWorkflowCount int64 + TotalForceReplicateWorkflowCount int64 + ReplicatedWorkflowCountPerSecond float64 + + // Used to calculate QPS + QPSQueue QPSQueue // Carry over the replication status after continue-as-new. TaskQueueUserDataReplicationStatus TaskQueueUserDataReplicationStatus } + QPSQueue struct { + MaxSize int + Data []QPSData + } + + QPSData struct { + Count int64 + Timestamp time.Time + } + + ForceReplicationOutput struct { + } + TaskQueueUserDataReplicationStatus struct { Done bool FailureMessage string @@ -86,42 +106,10 @@ type ( LastStartTime time.Time TaskQueueUserDataReplicationStatus TaskQueueUserDataReplicationStatus ContinuedAsNewCount int - } - - listWorkflowsResponse struct { - Executions []*commonpb.WorkflowExecution - NextPageToken []byte - Error error - - // These can be used to help report progress of the force-replication scan - LastCloseTime time.Time - LastStartTime time.Time - } - - generateReplicationTasksRequest struct { - NamespaceID string - Executions []*commonpb.WorkflowExecution - RPS float64 - } - - verifyReplicationTasksRequest struct { - Namespace string - NamespaceID string - TargetClusterEndpoint string - TargetClusterName string - VerifyInterval time.Duration `validate:"gte=0"` - Executions []*commonpb.WorkflowExecution - } - - verifyReplicationTasksResponse struct{} - - metadataRequest struct { - Namespace string - } - - metadataResponse struct { - ShardCount int32 - NamespaceID string + TotalWorkflowCount int64 + ReplicatedWorkflowCount int64 + ReplicatedWorkflowCountPerSecond float64 + PageTokenForRestart []byte } ) @@ -130,10 +118,14 @@ var ( InitialInterval: time.Second, MaximumInterval: time.Second * 10, } + + NamespaceTagName = "namespace" + ForceReplicationRpsTagName = "force_replication_rps" ) const ( forceReplicationWorkflowName = "force-replication" + forceTaskQueueUserDataReplicationWorkflow = "force-task-queue-user-data-replication" forceReplicationStatusQueryType = "force-replication-status" taskQueueUserDataReplicationDoneSignalType = "task-queue-user-data-replication-done" taskQueueUserDataReplicationVersionMarker = "replicate-task-queue-user-data" @@ -147,14 +139,20 @@ const ( ) func ForceReplicationWorkflow(ctx workflow.Context, params ForceReplicationParams) error { - ctx = workflow.WithTaskQueue(ctx, primitives.MigrationActivityTQ) + // For now, we'll return the initial page token for simplicity. + // If we want this to be more precise, we could track processed pages. + startPageToken := params.NextPageToken - workflow.SetQueryHandler(ctx, forceReplicationStatusQueryType, func() (ForceReplicationStatus, error) { + _ = workflow.SetQueryHandler(ctx, forceReplicationStatusQueryType, func() (ForceReplicationStatus, error) { return ForceReplicationStatus{ LastCloseTime: params.LastCloseTime, LastStartTime: params.LastStartTime, ContinuedAsNewCount: params.ContinuedAsNewCount, TaskQueueUserDataReplicationStatus: params.TaskQueueUserDataReplicationStatus, + TotalWorkflowCount: params.TotalForceReplicateWorkflowCount, + ReplicatedWorkflowCount: params.ReplicatedWorkflowCount, + ReplicatedWorkflowCountPerSecond: params.ReplicatedWorkflowCountPerSecond, + PageTokenForRestart: startPageToken, }, nil }) @@ -162,6 +160,14 @@ func ForceReplicationWorkflow(ctx workflow.Context, params ForceReplicationParam return err } + if params.TotalForceReplicateWorkflowCount == 0 { + wfCount, err := countWorkflowForReplication(ctx, params) + if err != nil { + return err + } + params.TotalForceReplicateWorkflowCount = wfCount + } + metadataResp, err := getClusterMetadata(ctx, params) if err != nil { return err @@ -187,7 +193,7 @@ func ForceReplicationWorkflow(ctx workflow.Context, params ForceReplicationParam workflowExecutionsCh.Close() }) - if err := enqueueReplicationTasks(ctx, workflowExecutionsCh, metadataResp.NamespaceID, params); err != nil { + if err := enqueueReplicationTasks(ctx, workflowExecutionsCh, metadataResp.NamespaceID, ¶ms); err != nil { return err } @@ -253,9 +259,6 @@ func maybeKickoffTaskQueueUserDataReplication(ctx workflow.Context, params Force } func ForceTaskQueueUserDataReplicationWorkflow(ctx workflow.Context, params TaskQueueUserDataReplicationParamsWithNamespace) error { - ctx = workflow.WithTaskQueue(ctx, primitives.MigrationActivityTQ) // children do not inherit ActivityOptions - - var a *activities ao := workflow.ActivityOptions{ // This shouldn't take "too long", just set an arbitrary long timeout here and rely on heartbeats for liveness detection. StartToCloseTimeout: time.Hour * 24 * 7, @@ -268,7 +271,7 @@ func ForceTaskQueueUserDataReplicationWorkflow(ctx workflow.Context, params Task } actx := workflow.WithActivityOptions(ctx, ao) - + var a *activities err := workflow.ExecuteActivity(actx, a.SeedReplicationQueueWithUserDataEntries, params).Get(ctx, nil) errStr := "" if err != nil { @@ -294,6 +297,9 @@ func validateAndSetForceReplicationParams(params *ForceReplicationParams) error if params.OverallRps <= 0 { params.OverallRps = float64(params.ConcurrentActivityCount) } + if params.GetParentInfoRPS <= 0 { + params.GetParentInfoRPS = float64(params.ConcurrentActivityCount) + } if params.ListWorkflowsPageSize <= 0 { params.ListWorkflowsPageSize = defaultListWorkflowsPageSize @@ -311,12 +317,19 @@ func validateAndSetForceReplicationParams(params *ForceReplicationParams) error params.VerifyIntervalInSeconds = defaultVerifyIntervalInSeconds } + if params.ReplicatedWorkflowCountPerSecond <= 0 { + params.ReplicatedWorkflowCountPerSecond = params.OverallRps + } + + if params.QPSQueue.Data == nil { + params.QPSQueue = NewQPSQueue(params.ConcurrentActivityCount) + params.QPSQueue.Enqueue(params.ReplicatedWorkflowCount) + } + return nil } func getClusterMetadata(ctx workflow.Context, params ForceReplicationParams) (metadataResponse, error) { - var a *activities - // Get cluster metadata, we need namespace ID for history API call. // TODO: remove this step. lao := workflow.LocalActivityOptions{ @@ -327,13 +340,12 @@ func getClusterMetadata(ctx workflow.Context, params ForceReplicationParams) (me actx := workflow.WithLocalActivityOptions(ctx, lao) var metadataResp metadataResponse metadataRequest := metadataRequest{Namespace: params.Namespace} + var a *activities err := workflow.ExecuteLocalActivity(actx, a.GetMetadata, metadataRequest).Get(ctx, &metadataResp) return metadataResp, err } func listWorkflowsForReplication(ctx workflow.Context, workflowExecutionsCh workflow.Channel, params *ForceReplicationParams) error { - var a *activities - ao := workflow.ActivityOptions{ StartToCloseTimeout: time.Hour, HeartbeatTimeout: time.Second * 30, @@ -341,14 +353,17 @@ func listWorkflowsForReplication(ctx workflow.Context, workflowExecutionsCh work } actx := workflow.WithActivityOptions(ctx, ao) - + var a *activities for i := 0; i < params.PageCountPerExecution; i++ { - listFuture := workflow.ExecuteActivity(actx, a.ListWorkflows, &workflowservice.ListWorkflowExecutionsRequest{ - Namespace: params.Namespace, - PageSize: int32(params.ListWorkflowsPageSize), - NextPageToken: params.NextPageToken, - Query: params.Query, - }) + listFuture := workflow.ExecuteActivity( + actx, + a.ListWorkflows, + &workflowservice.ListWorkflowExecutionsRequest{ + Namespace: params.Namespace, + PageSize: int32(params.ListWorkflowsPageSize), + NextPageToken: params.NextPageToken, + Query: params.Query, + }) var listResp listWorkflowsResponse if err := listFuture.Get(ctx, &listResp); err != nil { @@ -369,29 +384,54 @@ func listWorkflowsForReplication(ctx workflow.Context, workflowExecutionsCh work return nil } -func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow.Channel, namespaceID string, params ForceReplicationParams) error { +func countWorkflowForReplication(ctx workflow.Context, params ForceReplicationParams) (int64, error) { + ao := workflow.ActivityOptions{ + StartToCloseTimeout: 2 * time.Minute, + RetryPolicy: forceReplicationActivityRetryPolicy, + } + + var a *activities + var output countWorkflowResponse + if err := workflow.ExecuteActivity( + workflow.WithActivityOptions(ctx, ao), + a.CountWorkflow, + &workflowservice.CountWorkflowExecutionsRequest{ + Namespace: params.Namespace, + Query: params.Query, + }).Get(ctx, &output); err != nil { + return 0, err + } + + return output.WorkflowCount, nil +} + +func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow.Channel, namespaceID string, params *ForceReplicationParams) error { selector := workflow.NewSelector(ctx) pendingGenerateTasks := 0 pendingVerifyTasks := 0 ao := workflow.ActivityOptions{ StartToCloseTimeout: time.Hour, - HeartbeatTimeout: time.Second * 30, + HeartbeatTimeout: time.Second * 60, RetryPolicy: forceReplicationActivityRetryPolicy, } actx := workflow.WithActivityOptions(ctx, ao) - var a *activities var futures []workflow.Future var workflowExecutions []*commonpb.WorkflowExecution var lastActivityErr error + var a *activities for workflowExecutionsCh.Receive(ctx, &workflowExecutions) { - generateTaskFuture := workflow.ExecuteActivity(actx, a.GenerateReplicationTasks, &generateReplicationTasksRequest{ - NamespaceID: namespaceID, - Executions: workflowExecutions, - RPS: params.OverallRps / float64(params.ConcurrentActivityCount), - }) + generateTaskFuture := workflow.ExecuteActivity( + actx, + a.GenerateReplicationTasks, + &generateReplicationTasksRequest{ + NamespaceID: namespaceID, + Executions: workflowExecutions, + RPS: params.OverallRps / float64(params.ConcurrentActivityCount), + GetParentInfoRPS: params.GetParentInfoRPS / float64(params.ConcurrentActivityCount), + }) pendingGenerateTasks++ selector.AddFuture(generateTaskFuture, func(f workflow.Future) { @@ -404,14 +444,17 @@ func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow futures = append(futures, generateTaskFuture) if params.EnableVerification { - verifyTaskFuture := workflow.ExecuteActivity(actx, a.VerifyReplicationTasks, &verifyReplicationTasksRequest{ - TargetClusterEndpoint: params.TargetClusterEndpoint, - TargetClusterName: params.TargetClusterName, - Namespace: params.Namespace, - NamespaceID: namespaceID, - Executions: workflowExecutions, - VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, - }) + verifyTaskFuture := workflow.ExecuteActivity( + actx, + a.VerifyReplicationTasks, + &verifyReplicationTasksRequest{ + TargetClusterEndpoint: params.TargetClusterEndpoint, + TargetClusterName: params.TargetClusterName, + Namespace: params.Namespace, + NamespaceID: namespaceID, + Executions: workflowExecutions, + VerifyInterval: time.Duration(params.VerifyIntervalInSeconds) * time.Second, + }) pendingVerifyTasks++ selector.AddFuture(verifyTaskFuture, func(f workflow.Future) { @@ -419,6 +462,18 @@ func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow if err := f.Get(ctx, nil); err != nil { lastActivityErr = err + } else { + // Update replication status + params.ReplicatedWorkflowCount += int64(len(workflowExecutions)) + params.QPSQueue.Enqueue(params.ReplicatedWorkflowCount) + params.ReplicatedWorkflowCountPerSecond = params.QPSQueue.CalculateQPS() + + // Report new QPS to metrics + tags := map[string]string{ + metrics.OperationTagName: metrics.MigrationWorkflowScope, + NamespaceTagName: params.Namespace, + } + workflow.GetMetricsHandler(ctx).WithTags(tags).Gauge(ForceReplicationRpsTagName).Update(params.ReplicatedWorkflowCountPerSecond) } }) @@ -441,3 +496,47 @@ func enqueueReplicationTasks(ctx workflow.Context, workflowExecutionsCh workflow return nil } + +// NewQPSQueue initializes a QPSQueue to collect data points for each workflow execution. +// The queue size is set to concurrency + 1 to account for up to 'concurrency' activities +// running simultaneously and the initial starting point. +func NewQPSQueue(concurrentActivityCount int) QPSQueue { + return QPSQueue{ + Data: make([]QPSData, 0, concurrentActivityCount+1), + MaxSize: concurrentActivityCount + 1, + } +} + +func (q *QPSQueue) Enqueue(count int64) { + data := QPSData{Count: count, Timestamp: time.Now()} + + // If queue length reaches max capacity, remove the oldest item + if len(q.Data) >= q.MaxSize { + q.Data = q.Data[1:] + } + + q.Data = append(q.Data, data) +} + +func (q *QPSQueue) CalculateQPS() float64 { + // Check if the queue has at least two items + if len(q.Data) < 2 { + return 0.0 + } + + first := q.Data[0] + last := q.Data[len(q.Data)-1] + + // Calculate the count difference and time difference + countDiff := last.Count - first.Count + timeDiff := last.Timestamp.Sub(first.Timestamp).Seconds() + + // If count difference is <= 0 or time difference is <= 0, return a rate of 0 + if countDiff <= 0 || timeDiff <= 0 { + return 0.0 + } + + // Calculate the QPS + qps := float64(countDiff) / timeDiff + return qps +} diff --git a/service/worker/migration/force_replication_workflow_test.go b/service/worker/migration/force_replication_workflow_test.go index c07ebacaa6d..d6f88eb766b 100644 --- a/service/worker/migration/force_replication_workflow_test.go +++ b/service/worker/migration/force_replication_workflow_test.go @@ -26,10 +26,14 @@ package migration import ( "context" + "encoding/json" "errors" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -41,6 +45,7 @@ import ( "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/testsuite" "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/log" "go.temporal.io/server/common/persistence" @@ -51,11 +56,11 @@ import ( func TestForceReplicationWorkflow(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities - + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) totalPageCount := 4 @@ -83,13 +88,10 @@ func TestForceReplicationWorkflow(t *testing.T) { LastCloseTime: closeTime, }, nil }).Times(totalPageCount) - env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(totalPageCount) env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(verifyReplicationTasksResponse{}, nil).Times(totalPageCount) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Times(1) - env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ Namespace: "test-ns", Query: "", @@ -116,31 +118,26 @@ func TestForceReplicationWorkflow(t *testing.T) { assert.Equal(t, closeTime, status.LastCloseTime) assert.True(t, status.TaskQueueUserDataReplicationStatus.Done) assert.Equal(t, "", status.TaskQueueUserDataReplicationStatus.FailureMessage) + assert.Equal(t, int64(10), status.TotalWorkflowCount) + assert.Equal(t, int64(0), status.ReplicatedWorkflowCount) + assert.Equal(t, []byte(nil), status.PageTokenForRestart) } func TestForceReplicationWorkflow_ContinueAsNew(t *testing.T) { - testSuite := &testsuite.WorkflowTestSuite{} - env := testSuite.NewTestWorkflowEnvironment() - - namespaceID := uuid.New() - - var a *activities - env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) - totalPageCount := 4 currentPageCount := 0 - maxPageCountPerExecution := 2 + testMaxPageCountPerExecution := 2 layout := "2006-01-01 00:00Z" startTime, _ := time.Parse(layout, "2020-01-01 00:00Z") closeTime, _ := time.Parse(layout, "2020-02-01 00:00Z") - env.OnActivity(a.ListWorkflows, mock.Anything, mock.Anything).Return(func(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { + mockListWorkflows := func(ctx context.Context, request *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error) { assert.Equal(t, "test-ns", request.Namespace) currentPageCount++ if currentPageCount < totalPageCount { return &listWorkflowsResponse{ Executions: []*commonpb.WorkflowExecution{}, - NextPageToken: []byte("fake-page-token"), + NextPageToken: []byte(fmt.Sprintf("fake-page-token-%d", currentPageCount)), LastStartTime: startTime, LastCloseTime: closeTime, }, nil @@ -150,39 +147,135 @@ func TestForceReplicationWorkflow_ContinueAsNew(t *testing.T) { Executions: []*commonpb.WorkflowExecution{}, NextPageToken: nil, // last page }, nil - }).Times(maxPageCountPerExecution) - - env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(maxPageCountPerExecution) - env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(verifyReplicationTasksResponse{}, nil).Times(maxPageCountPerExecution) - - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) - env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) + } - env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ + expectedContinueAsNewParams := ForceReplicationParams{ Namespace: "test-ns", Query: "", ConcurrentActivityCount: 2, OverallRps: 10, + GetParentInfoRPS: 2.0, ListWorkflowsPageSize: 1, - PageCountPerExecution: maxPageCountPerExecution, + PageCountPerExecution: testMaxPageCountPerExecution, + NextPageToken: []byte("fake-page-token-2"), EnableVerification: true, TargetClusterEndpoint: "test-target", - }) + TargetClusterName: "", + VerifyIntervalInSeconds: defaultVerifyIntervalInSeconds, + LastCloseTime: closeTime, + LastStartTime: startTime, + ContinuedAsNewCount: 1, + TaskQueueUserDataReplicationParams: TaskQueueUserDataReplicationParams{ + PageSize: 0, + RPS: 0, + }, + ReplicatedWorkflowCount: 0, + TotalForceReplicateWorkflowCount: 10, + } + + expectContinueAsNew := true + + // Run the workflow once. We should get a continue as new error. + continueAsNewInput, queryStatus := testRunForceReplicationForContinueAsNew(t, + mockListWorkflows, + ForceReplicationParams{ + Namespace: "test-ns", + Query: "", + ConcurrentActivityCount: 2, + OverallRps: 10, + ListWorkflowsPageSize: 1, + PageCountPerExecution: testMaxPageCountPerExecution, + EnableVerification: true, + TargetClusterEndpoint: "test-target", + NextPageToken: []byte("fake-initial-page-token"), + }, + expectContinueAsNew, + testMaxPageCountPerExecution, + ) + require.NotNil(t, continueAsNewInput) + + // tl;dr We do not check TaskQueueUserDataReplicationStatus because we do not know if + // ForceTaskQueueUserDataReplicationWorkflow will complete in a given execution of + // force replication. + // + // ForceTaskQueueUserDataReplicationWorkflow is a child workflow that runs in parallel + // and may span many ContinueAsNew'd executions of force replication: + // + // - It is started on the first execution as a child workflow (when ContinueAsNewCount == 0). + // - It is not started on subsequent executions (when ContinueAsNewCount > 0) + // - Only the final execution waits for it to complete (when NextPageToken == nil) + // + // To keep this test simple and to resolve past test flakes, we do not test all cases. + // We test with ContinueAsNewCount > 0 and NextPageToken != nil, so that: + // + // - ForceTaskQueueUserDataReplicationWorkflow is not started + // - Force replication does not get stuck waiting for some previous execution of + // ForceTaskQueueUserDataReplicationWorkflow to finish + // + // Another test checks that ForceTaskQueueUserDataReplicationWorkflow is invoked correctly. + require.Empty(t, cmp.Diff( + expectedContinueAsNewParams, *continueAsNewInput, + cmpopts.IgnoreFields(ForceReplicationParams{}, "TaskQueueUserDataReplicationStatus"), + cmpopts.IgnoreFields(ForceReplicationParams{}, "QPSQueue"), + )) + + assert.Equal(t, closeTime, queryStatus.LastCloseTime) + assert.Equal(t, startTime, queryStatus.LastStartTime) + assert.Equal(t, 1, queryStatus.ContinuedAsNewCount) + assert.Equal(t, []byte("fake-initial-page-token"), queryStatus.PageTokenForRestart) +} + +func testRunForceReplicationForContinueAsNew(t *testing.T, + mockListWorkflows func(context.Context, *workflowservice.ListWorkflowExecutionsRequest) (*listWorkflowsResponse, error), + input ForceReplicationParams, + expectContinueAsNew bool, + expMaxPageCountPerExecution int, +) (*ForceReplicationParams, ForceReplicationStatus) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) + namespaceID := uuid.New() + + var a *activities + if input.TotalForceReplicateWorkflowCount == 0 { + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) + } + env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) + env.OnActivity(a.ListWorkflows, mock.Anything, mock.Anything).Return(mockListWorkflows).Times(expMaxPageCountPerExecution) + env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil).Times(expMaxPageCountPerExecution) + env.OnActivity(a.VerifyReplicationTasks, mock.Anything, mock.Anything).Return(verifyReplicationTasksResponse{}, nil).Times(expMaxPageCountPerExecution) + // ForceTaskQueueUserDataReplicationWorkflow runs in parallel as a child and may span many ContinueAsNew'd + // executions of ForceReplication. The SeedReplicationQueueWithUserDataEntries activity will eventually run + // once, but we aren't guaranteed that it will run during any given execution of ForceReplication. + env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil).Maybe() + env.ExecuteWorkflow(ForceReplicationWorkflow, input) require.True(t, env.IsWorkflowCompleted()) err := env.GetWorkflowError() - require.Error(t, err) - require.Contains(t, err.Error(), "continue as new") env.AssertExpectations(t) + var continueAsNewErr *workflow.ContinueAsNewError + var continueAsNewParams *ForceReplicationParams + if !expectContinueAsNew { + require.NoError(t, err) + } else { + require.Error(t, err) + require.True(t, errors.As(err, &continueAsNewErr)) + + var params ForceReplicationParams + payloads := continueAsNewErr.Input.GetPayloads() + require.Len(t, payloads, 1) + require.NoError(t, json.Unmarshal(payloads[0].GetData(), ¶ms)) + continueAsNewParams = ¶ms + } + envValue, err := env.QueryWorkflow(forceReplicationStatusQueryType) require.NoError(t, err) var status ForceReplicationStatus - envValue.Get(&status) - assert.Equal(t, 1, status.ContinuedAsNewCount) - assert.Equal(t, startTime, status.LastStartTime) - assert.Equal(t, closeTime, status.LastCloseTime) + require.NoError(t, envValue.Get(&status)) + + return continueAsNewParams, status } func TestForceReplicationWorkflow_InvalidInput(t *testing.T) { @@ -213,16 +306,16 @@ func TestForceReplicationWorkflow_InvalidInput(t *testing.T) { func TestForceReplicationWorkflow_ListWorkflowsError(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) maxPageCountPerExecution := 2 env.OnActivity(a.ListWorkflows, mock.Anything, mock.Anything).Return(nil, errors.New("mock listWorkflows error")) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ @@ -244,10 +337,11 @@ func TestForceReplicationWorkflow_ListWorkflowsError(t *testing.T) { func TestForceReplicationWorkflow_GenerateReplicationTaskRetryableError(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) totalPageCount := 4 @@ -270,7 +364,6 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskRetryableError(t *testi env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(errors.New("mock generate replication tasks error")) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ @@ -292,10 +385,11 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskRetryableError(t *testi func TestForceReplicationWorkflow_GenerateReplicationTaskNonRetryableError(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) totalPageCount := 4 @@ -323,7 +417,6 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskNonRetryableError(t *te temporal.NewNonRetryableApplicationError(errMsg, "", nil), ).Times(1) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ @@ -347,10 +440,11 @@ func TestForceReplicationWorkflow_GenerateReplicationTaskNonRetryableError(t *te func TestForceReplicationWorkflow_VerifyReplicationTaskNonRetryableError(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) totalPageCount := 4 @@ -379,7 +473,6 @@ func TestForceReplicationWorkflow_VerifyReplicationTaskNonRetryableError(t *test temporal.NewNonRetryableApplicationError(errMsg, "", nil), ).Times(1) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return(nil) env.ExecuteWorkflow(ForceReplicationWorkflow, ForceReplicationParams{ @@ -403,10 +496,11 @@ func TestForceReplicationWorkflow_VerifyReplicationTaskNonRetryableError(t *test func TestForceReplicationWorkflow_TaskQueueReplicationFailure(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() - + env.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) namespaceID := uuid.New() var a *activities + env.OnActivity(a.CountWorkflow, mock.Anything, mock.Anything).Return(&countWorkflowResponse{WorkflowCount: 10}, nil) env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) env.OnActivity(a.ListWorkflows, mock.Anything, mock.Anything).Return(&listWorkflowsResponse{ @@ -414,7 +508,6 @@ func TestForceReplicationWorkflow_TaskQueueReplicationFailure(t *testing.T) { NextPageToken: nil, // last page }, nil) env.OnActivity(a.GenerateReplicationTasks, mock.Anything, mock.Anything).Return(nil) - env.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) env.OnActivity(a.SeedReplicationQueueWithUserDataEntries, mock.Anything, mock.Anything).Return( temporal.NewNonRetryableApplicationError("namespace is required", "InvalidArgument", nil), ) @@ -440,6 +533,7 @@ func TestForceReplicationWorkflow_TaskQueueReplicationFailure(t *testing.T) { require.NoError(t, err) assert.True(t, status.TaskQueueUserDataReplicationStatus.Done) assert.Contains(t, status.TaskQueueUserDataReplicationStatus.FailureMessage, "namespace is required") + assert.Equal(t, []byte(nil), status.PageTokenForRestart) } func TestSeedReplicationQueueWithUserDataEntries_Heartbeats(t *testing.T) { diff --git a/service/worker/migration/fx.go b/service/worker/migration/fx.go index ab7dec7ddaf..3e75c568035 100644 --- a/service/worker/migration/fx.go +++ b/service/worker/migration/fx.go @@ -59,25 +59,38 @@ type ( MetricsHandler metrics.Handler } + fxResult struct { + fx.Out + Component workercommon.WorkerComponent `group:"workerComponent"` + } + replicationWorkerComponent struct { initParams } ) -var Module = workercommon.AnnotateWorkerComponentProvider(newComponent) +var Module = fx.Options( + fx.Provide(NewResult), +) -func newComponent(params initParams) workercommon.WorkerComponent { - return &replicationWorkerComponent{initParams: params} +func NewResult(params initParams) fxResult { + component := &replicationWorkerComponent{ + initParams: params, + } + return fxResult{ + Component: component, + } } func (wc *replicationWorkerComponent) RegisterWorkflow(registry sdkworker.Registry) { + registry.RegisterWorkflowWithOptions(CatchupWorkflow, workflow.RegisterOptions{Name: catchupWorkflowName}) registry.RegisterWorkflowWithOptions(ForceReplicationWorkflow, workflow.RegisterOptions{Name: forceReplicationWorkflowName}) registry.RegisterWorkflowWithOptions(NamespaceHandoverWorkflow, workflow.RegisterOptions{Name: namespaceHandoverWorkflowName}) - registry.RegisterWorkflow(ForceTaskQueueUserDataReplicationWorkflow) + registry.RegisterWorkflowWithOptions(ForceTaskQueueUserDataReplicationWorkflow, workflow.RegisterOptions{Name: forceTaskQueueUserDataReplicationWorkflow}) } func (wc *replicationWorkerComponent) DedicatedWorkflowWorkerOptions() *workercommon.DedicatedWorkerOptions { - // use default worker + // Use default worker return nil } diff --git a/service/worker/migration/handover_workflow.go b/service/worker/migration/handover_workflow.go index 1622433004a..df2811ea87d 100644 --- a/service/worker/migration/handover_workflow.go +++ b/service/worker/migration/handover_workflow.go @@ -30,14 +30,14 @@ import ( enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/sdk/temporal" "go.temporal.io/sdk/workflow" - "go.temporal.io/server/common/primitives" ) const ( namespaceHandoverWorkflowName = "namespace-handover" minimumAllowedLaggingSeconds = 5 - minimumHandoverTimeoutSeconds = 30 + maximumAllowedLaggingSeconds = 120 + maximumHandoverTimeoutSeconds = 10 ) type ( @@ -87,8 +87,6 @@ func NamespaceHandoverWorkflow(ctx workflow.Context, params NamespaceHandoverPar return err } - ctx = workflow.WithTaskQueue(ctx, primitives.MigrationActivityTQ) - retryPolicy := &temporal.RetryPolicy{ InitialInterval: time.Second, MaximumInterval: time.Second, @@ -100,11 +98,10 @@ func NamespaceHandoverWorkflow(ctx workflow.Context, params NamespaceHandoverPar } ctx = workflow.WithActivityOptions(ctx, ao) - var a *activities - // ** Step 1: Get Cluster Metadata ** var metadataResp metadataResponse metadataRequest := metadataRequest{Namespace: params.Namespace} + var a *activities err := workflow.ExecuteActivity(ctx, a.GetMetadata, metadataRequest).Get(ctx, &metadataResp) if err != nil { return err @@ -136,7 +133,7 @@ func NamespaceHandoverWorkflow(ctx workflow.Context, params NamespaceHandoverPar return err } - // ** Step 4: Initiate Handover (WARNING: Namespace cannot serve traffic while in this state) + // ** Step 4: RecoverOrInitialize Handover (WARNING: Namespace cannot serve traffic while in this state) handoverRequest := updateStateRequest{ Namespace: params.Namespace, NewState: enumspb.REPLICATION_STATE_HANDOVER, @@ -163,10 +160,11 @@ func NamespaceHandoverWorkflow(ctx workflow.Context, params NamespaceHandoverPar // ** Step 5: Wait for Remote Cluster to completely drain its Replication Tasks ao3 := workflow.ActivityOptions{ - StartToCloseTimeout: time.Second * 30, - HeartbeatTimeout: time.Second * 10, - ScheduleToCloseTimeout: time.Second * time.Duration(params.HandoverTimeoutSeconds), - RetryPolicy: retryPolicy, + StartToCloseTimeout: time.Second * time.Duration(params.HandoverTimeoutSeconds), + HeartbeatTimeout: time.Second * 10, + RetryPolicy: &temporal.RetryPolicy{ + MaximumAttempts: 1, + }, } ctx3 := workflow.WithActivityOptions(ctx, ao3) @@ -203,8 +201,11 @@ func validateAndSetNamespaceHandoverParams(params *NamespaceHandoverParams) erro if params.AllowedLaggingSeconds <= minimumAllowedLaggingSeconds { params.AllowedLaggingSeconds = minimumAllowedLaggingSeconds } - if params.HandoverTimeoutSeconds <= minimumHandoverTimeoutSeconds { - params.HandoverTimeoutSeconds = minimumHandoverTimeoutSeconds + if params.AllowedLaggingSeconds >= maximumAllowedLaggingSeconds { + params.AllowedLaggingSeconds = maximumAllowedLaggingSeconds + } + if params.HandoverTimeoutSeconds >= maximumHandoverTimeoutSeconds { + params.HandoverTimeoutSeconds = maximumHandoverTimeoutSeconds } return nil diff --git a/service/worker/migration/handover_workflow_test.go b/service/worker/migration/handover_workflow_test.go index 2ddacbde551..aa4bbb4eb9b 100644 --- a/service/worker/migration/handover_workflow_test.go +++ b/service/worker/migration/handover_workflow_test.go @@ -36,9 +36,10 @@ import ( func TestHandoverWorkflow(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestWorkflowEnvironment() + var a *activities + namespaceID := uuid.New() - var a *activities env.OnActivity(a.GetMetadata, mock.Anything, metadataRequest{Namespace: "test-ns"}).Return(&metadataResponse{ShardCount: 4, NamespaceID: namespaceID}, nil) env.OnActivity(a.GetMaxReplicationTaskIDs, mock.Anything).Return(