diff --git a/common/metrics/defs.go b/common/metrics/defs.go index bef10f86deb..be6bafb6746 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -1475,6 +1475,7 @@ const ( ShardDistributorStoreAssignShardScope ShardDistributorStoreAssignShardsScope ShardDistributorStoreDeleteExecutorsScope + ShardDistributorStoreDeleteShardStatsScope ShardDistributorStoreGetHeartbeatScope ShardDistributorStoreGetStateScope ShardDistributorStoreRecordHeartbeatScope @@ -2151,18 +2152,19 @@ var ScopeDefs = map[ServiceIdx]map[ScopeIdx]scopeDefinition{ DiagnosticsWorkflowScope: {operation: "DiagnosticsWorkflow"}, }, ShardDistributor: { - ShardDistributorGetShardOwnerScope: {operation: "GetShardOwner"}, - ShardDistributorHeartbeatScope: {operation: "ExecutorHeartbeat"}, - ShardDistributorAssignLoopScope: {operation: "ShardAssignLoop"}, - ShardDistributorExecutorScope: {operation: "Executor"}, - ShardDistributorStoreGetShardOwnerScope: {operation: "StoreGetShardOwner"}, - ShardDistributorStoreAssignShardScope: {operation: "StoreAssignShard"}, - ShardDistributorStoreAssignShardsScope: {operation: "StoreAssignShards"}, - ShardDistributorStoreDeleteExecutorsScope: {operation: "StoreDeleteExecutors"}, - ShardDistributorStoreGetHeartbeatScope: {operation: "StoreGetHeartbeat"}, - ShardDistributorStoreGetStateScope: {operation: "StoreGetState"}, - ShardDistributorStoreRecordHeartbeatScope: {operation: "StoreRecordHeartbeat"}, - ShardDistributorStoreSubscribeScope: {operation: "StoreSubscribe"}, + ShardDistributorGetShardOwnerScope: {operation: "GetShardOwner"}, + ShardDistributorHeartbeatScope: {operation: "ExecutorHeartbeat"}, + ShardDistributorAssignLoopScope: {operation: "ShardAssignLoop"}, + ShardDistributorExecutorScope: {operation: "Executor"}, + ShardDistributorStoreGetShardOwnerScope: {operation: "StoreGetShardOwner"}, + ShardDistributorStoreAssignShardScope: {operation: "StoreAssignShard"}, + ShardDistributorStoreAssignShardsScope: {operation: "StoreAssignShards"}, + ShardDistributorStoreDeleteExecutorsScope: {operation: "StoreDeleteExecutors"}, + ShardDistributorStoreDeleteShardStatsScope: {operation: "StoreDeleteShardStats"}, + ShardDistributorStoreGetHeartbeatScope: {operation: "StoreGetHeartbeat"}, + ShardDistributorStoreGetStateScope: {operation: "StoreGetState"}, + ShardDistributorStoreRecordHeartbeatScope: {operation: "StoreRecordHeartbeat"}, + ShardDistributorStoreSubscribeScope: {operation: "StoreSubscribe"}, }, } diff --git a/service/sharddistributor/leader/process/processor.go b/service/sharddistributor/leader/process/processor.go index 7f41c21b451..4dee09fa532 100644 --- a/service/sharddistributor/leader/process/processor.go +++ b/service/sharddistributor/leader/process/processor.go @@ -230,16 +230,21 @@ func (p *namespaceProcessor) runCleanupLoop(ctx context.Context) { return case <-ticker.Chan(): p.logger.Info("Periodic heartbeat cleanup triggered.") - p.cleanupStaleExecutors(ctx) + namespaceState, err := p.shardStore.GetState(ctx, p.namespaceCfg.Name) + if err != nil { + p.logger.Error("Failed to get state for cleanup", tag.Error(err)) + continue + } + p.cleanupStaleExecutors(ctx, namespaceState) + p.cleanupStaleShardStats(ctx, namespaceState) } } } // cleanupStaleExecutors removes executors who have not reported a heartbeat recently. -func (p *namespaceProcessor) cleanupStaleExecutors(ctx context.Context) { - namespaceState, err := p.shardStore.GetState(ctx, p.namespaceCfg.Name) - if err != nil { - p.logger.Error("Failed to get state for heartbeat cleanup", tag.Error(err)) +func (p *namespaceProcessor) cleanupStaleExecutors(ctx context.Context, namespaceState *store.NamespaceState) { + if namespaceState == nil { + p.logger.Error("Namespace state missing for heartbeat cleanup") return } @@ -264,6 +269,73 @@ func (p *namespaceProcessor) cleanupStaleExecutors(ctx context.Context) { } } +func (p *namespaceProcessor) cleanupStaleShardStats(ctx context.Context, namespaceState *store.NamespaceState) { + if namespaceState == nil { + p.logger.Error("Namespace state missing for shard stats cleanup") + return + } + + activeShards := make(map[string]struct{}) + now := p.timeSource.Now().Unix() + shardStatsTTL := int64(p.cfg.HeartbeatTTL.Seconds()) + + // 1. build set of active executors + + // add all assigned shards from executors that are ACTIVE and not stale + for executorID, assignedState := range namespaceState.ShardAssignments { + executor, exists := namespaceState.Executors[executorID] + if !exists { + continue + } + + isActive := executor.Status == types.ExecutorStatusACTIVE + isNotStale := (now - executor.LastHeartbeat) <= shardStatsTTL + if isActive && isNotStale { + for shardID := range assignedState.AssignedShards { + activeShards[shardID] = struct{}{} + } + } + } + + // add all shards in ReportedShards where the status is not DONE + for _, heartbeatState := range namespaceState.Executors { + for shardID, shardStatusReport := range heartbeatState.ReportedShards { + if shardStatusReport.Status != types.ShardStatusDONE { + activeShards[shardID] = struct{}{} + } + } + } + + // 2. build set of stale shard stats + + // append all shard stats that are not in the active shards set + var staleShardStats []string + for shardID, stats := range namespaceState.ShardStats { + if _, ok := activeShards[shardID]; ok { + continue + } + recentUpdate := stats.LastUpdateTime > 0 && (now-stats.LastUpdateTime) <= shardStatsTTL + recentMove := stats.LastMoveTime > 0 && (now-stats.LastMoveTime) <= shardStatsTTL + if recentUpdate || recentMove { + // Preserve stats that have been updated recently to allow cooldown/load history to + // survive executor churn. These shards are likely awaiting reassignment, + // so we don't want to delete them. + continue + } + staleShardStats = append(staleShardStats, shardID) + } + + if len(staleShardStats) == 0 { + return + } + + p.logger.Info("Removing stale shard stats") + // Use the leader guard for the delete operation. + if err := p.shardStore.DeleteShardStats(ctx, p.namespaceCfg.Name, staleShardStats, p.election.Guard()); err != nil { + p.logger.Error("Failed to delete stale shard stats", tag.Error(err)) + } +} + // rebalanceShards is the core logic for distributing shards among active executors. func (p *namespaceProcessor) rebalanceShards(ctx context.Context) (err error) { metricsLoopScope := p.metricsClient.Scope(metrics.ShardDistributorAssignLoopScope) diff --git a/service/sharddistributor/leader/process/processor_test.go b/service/sharddistributor/leader/process/processor_test.go index 2b78ab83b8a..1c6e0422caf 100644 --- a/service/sharddistributor/leader/process/processor_test.go +++ b/service/sharddistributor/leader/process/processor_test.go @@ -178,11 +178,80 @@ func TestCleanupStaleExecutors(t *testing.T) { "exec-stale": {LastHeartbeat: now.Add(-2 * time.Second).Unix()}, } - mocks.store.EXPECT().GetState(gomock.Any(), mocks.cfg.Name).Return(&store.NamespaceState{Executors: heartbeats}, nil) + namespaceState := &store.NamespaceState{Executors: heartbeats} mocks.election.EXPECT().Guard().Return(store.NopGuard()) mocks.store.EXPECT().DeleteExecutors(gomock.Any(), mocks.cfg.Name, []string{"exec-stale"}, gomock.Any()).Return(nil) - processor.cleanupStaleExecutors(context.Background()) + processor.cleanupStaleExecutors(context.Background(), namespaceState) +} + +func TestCleanupStaleShardStats(t *testing.T) { + t.Run("stale shard stats are deleted", func(t *testing.T) { + mocks := setupProcessorTest(t, config.NamespaceTypeFixed) + defer mocks.ctrl.Finish() + processor := mocks.factory.CreateProcessor(mocks.cfg, mocks.store, mocks.election).(*namespaceProcessor) + + now := mocks.timeSource.Now() + + heartbeats := map[string]store.HeartbeatState{ + "exec-active": {LastHeartbeat: now.Unix(), Status: types.ExecutorStatusACTIVE}, + "exec-stale": {LastHeartbeat: now.Add(-2 * time.Second).Unix()}, + } + + assignments := map[string]store.AssignedState{ + "exec-active": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard-1": {Status: types.AssignmentStatusREADY}, + "shard-2": {Status: types.AssignmentStatusREADY}, + }, + }, + "exec-stale": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard-3": {Status: types.AssignmentStatusREADY}, + }, + }, + } + + shardStats := map[string]store.ShardStatistics{ + "shard-1": {SmoothedLoad: 1.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, + "shard-2": {SmoothedLoad: 2.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, + "shard-3": {SmoothedLoad: 3.0, LastUpdateTime: now.Add(-2 * time.Second).Unix(), LastMoveTime: now.Add(-2 * time.Second).Unix()}, + } + + namespaceState := &store.NamespaceState{ + Executors: heartbeats, + ShardAssignments: assignments, + ShardStats: shardStats, + } + + mocks.election.EXPECT().Guard().Return(store.NopGuard()) + mocks.store.EXPECT().DeleteShardStats(gomock.Any(), mocks.cfg.Name, []string{"shard-3"}, gomock.Any()).Return(nil) + processor.cleanupStaleShardStats(context.Background(), namespaceState) + }) + + t.Run("recent shard stats are preserved", func(t *testing.T) { + mocks := setupProcessorTest(t, config.NamespaceTypeFixed) + defer mocks.ctrl.Finish() + processor := mocks.factory.CreateProcessor(mocks.cfg, mocks.store, mocks.election).(*namespaceProcessor) + + now := mocks.timeSource.Now() + + expiredExecutor := now.Add(-2 * time.Second).Unix() + namespaceState := &store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "exec-stale": {LastHeartbeat: expiredExecutor}, + }, + ShardAssignments: map[string]store.AssignedState{}, + ShardStats: map[string]store.ShardStatistics{ + "shard-1": {SmoothedLoad: 5.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, + }, + } + + processor.cleanupStaleShardStats(context.Background(), namespaceState) + + // No delete expected since stats are recent. + }) + } func TestRebalance_StoreErrors(t *testing.T) { @@ -213,16 +282,13 @@ func TestCleanup_StoreErrors(t *testing.T) { processor := mocks.factory.CreateProcessor(mocks.cfg, mocks.store, mocks.election).(*namespaceProcessor) expectedErr := errors.New("store is down") - mocks.store.EXPECT().GetState(gomock.Any(), mocks.cfg.Name).Return(nil, expectedErr) - processor.cleanupStaleExecutors(context.Background()) - - mocks.store.EXPECT().GetState(gomock.Any(), mocks.cfg.Name).Return(&store.NamespaceState{ + namespaceState := &store.NamespaceState{ Executors: map[string]store.HeartbeatState{"stale": {LastHeartbeat: 0}}, GlobalRevision: 1, - }, nil) + } mocks.election.EXPECT().Guard().Return(store.NopGuard()) mocks.store.EXPECT().DeleteExecutors(gomock.Any(), mocks.cfg.Name, gomock.Any(), gomock.Any()).Return(expectedErr) - processor.cleanupStaleExecutors(context.Background()) + processor.cleanupStaleExecutors(context.Background(), namespaceState) } func TestRunLoop_SubscriptionError(t *testing.T) { diff --git a/service/sharddistributor/store/etcd/etcdkeys/etcdkeys.go b/service/sharddistributor/store/etcd/etcdkeys/etcdkeys.go index c79c832eca5..39e3f917de8 100644 --- a/service/sharddistributor/store/etcd/etcdkeys/etcdkeys.go +++ b/service/sharddistributor/store/etcd/etcdkeys/etcdkeys.go @@ -11,6 +11,7 @@ const ( ExecutorReportedShardsKey = "reported_shards" ExecutorAssignedStateKey = "assigned_state" ShardAssignedKey = "assigned" + ShardStatisticsKey = "statistics" ExecutorMetadataKey = "metadata" ) @@ -70,6 +71,30 @@ func ParseExecutorKey(prefix string, namespace, key string) (executorID, keyType return parts[0], parts[1], nil } +func BuildShardPrefix(prefix string, namespace string) string { + return fmt.Sprintf("%s/shards/", BuildNamespacePrefix(prefix, namespace)) +} + +func BuildShardKey(prefix string, namespace, shardID, keyType string) (string, error) { + if keyType != ShardStatisticsKey { + return "", fmt.Errorf("invalid shard key type: %s", keyType) + } + return fmt.Sprintf("%s%s/%s", BuildShardPrefix(prefix, namespace), shardID, keyType), nil +} + +func ParseShardKey(prefix string, namespace, key string) (shardID, keyType string, err error) { + prefix = BuildShardPrefix(prefix, namespace) + if !strings.HasPrefix(key, prefix) { + return "", "", fmt.Errorf("key '%s' does not have expected prefix '%s'", key, prefix) + } + remainder := strings.TrimPrefix(key, prefix) + parts := strings.Split(remainder, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("unexpected shard key format: %s", key) + } + return parts[0], parts[1], nil +} + func BuildMetadataKey(prefix string, namespace, executorID, metadataKey string) string { metadataKeyPrefix, err := BuildExecutorKey(prefix, namespace, executorID, ExecutorMetadataKey) if err != nil { diff --git a/service/sharddistributor/store/etcd/etcdkeys/etcdkeys_test.go b/service/sharddistributor/store/etcd/etcdkeys/etcdkeys_test.go index fd4d50fe20d..de041e0ce72 100644 --- a/service/sharddistributor/store/etcd/etcdkeys/etcdkeys_test.go +++ b/service/sharddistributor/store/etcd/etcdkeys/etcdkeys_test.go @@ -16,6 +16,11 @@ func TestBuildExecutorPrefix(t *testing.T) { assert.Equal(t, "/cadence/test-ns/executors/", got) } +func TestBuildShardPrefix(t *testing.T) { + got := BuildShardPrefix("/cadence", "test-ns") + assert.Equal(t, "/cadence/test-ns/shards/", got) +} + func TestBuildExecutorKey(t *testing.T) { got, err := BuildExecutorKey("/cadence", "test-ns", "exec-1", "heartbeat") assert.NoError(t, err) @@ -27,6 +32,17 @@ func TestBuildExecutorKeyFail(t *testing.T) { assert.ErrorContains(t, err, "invalid key type: invalid") } +func TestBuildShardKey(t *testing.T) { + got, err := BuildShardKey("/cadence", "test-ns", "shard-1", "statistics") + assert.NoError(t, err) + assert.Equal(t, "/cadence/test-ns/shards/shard-1/statistics", got) +} + +func TestBuildShardKeyFail(t *testing.T) { + _, err := BuildShardKey("/cadence", "test-ns", "shard-1", "invalid") + assert.ErrorContains(t, err, "invalid shard key type: invalid") +} + func TestParseExecutorKey(t *testing.T) { // Valid key executorID, keyType, err := ParseExecutorKey("/cadence", "test-ns", "/cadence/test-ns/executors/exec-1/heartbeat") @@ -43,6 +59,22 @@ func TestParseExecutorKey(t *testing.T) { assert.ErrorContains(t, err, "unexpected key format: /cadence/test-ns/executors/exec-1/heartbeat/extra") } +func TestParseShardKey(t *testing.T) { + // Valid key + shardID, keyType, err := ParseShardKey("/cadence", "test-ns", "/cadence/test-ns/shards/shard-1/statistics") + assert.NoError(t, err) + assert.Equal(t, "shard-1", shardID) + assert.Equal(t, "statistics", keyType) + + // Prefix missing + _, _, err = ParseShardKey("/cadence", "test-ns", "/cadence/other/shards/shard-1/statistics") + assert.ErrorContains(t, err, "key '/cadence/other/shards/shard-1/statistics' does not have expected prefix '/cadence/test-ns/shards/'") + + // Unexpected format + _, _, err = ParseShardKey("/cadence", "test-ns", "/cadence/test-ns/shards/shard-1/statistics/extra") + assert.ErrorContains(t, err, "unexpected shard key format: /cadence/test-ns/shards/shard-1/statistics/extra") +} + func TestBuildMetadataKey(t *testing.T) { got := BuildMetadataKey("/cadence", "test-ns", "exec-1", "my-metadata-key") assert.Equal(t, "/cadence/test-ns/executors/exec-1/metadata/my-metadata-key", got) diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore.go b/service/sharddistributor/store/etcd/executorstore/etcdstore.go index bf32a7d54a6..8dd5d641a76 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore.go @@ -14,6 +14,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/fx" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/types" @@ -32,16 +33,27 @@ type executorStoreImpl struct { prefix string logger log.Logger shardCache *shardcache.ShardToExecutorCache + timeSource clock.TimeSource +} + +// shardStatisticsUpdate holds the staged statistics for a shard so we can write them +// to etcd after the main AssignShards transaction commits. +type shardStatisticsUpdate struct { + key string + shardID string + stats store.ShardStatistics + desiredLastMove int64 // intended LastMoveTime for this update } // ExecutorStoreParams defines the dependencies for the etcd store, for use with fx. type ExecutorStoreParams struct { fx.In - Client *clientv3.Client `optional:"true"` - Cfg config.ShardDistribution - Lifecycle fx.Lifecycle - Logger log.Logger + Client *clientv3.Client `optional:"true"` + Cfg config.ShardDistribution + Lifecycle fx.Lifecycle + Logger log.Logger + TimeSource clock.TimeSource } // NewStore creates a new etcd-backed store and provides it to the fx application. @@ -70,11 +82,17 @@ func NewStore(p ExecutorStoreParams) (store.Store, error) { shardCache := shardcache.NewShardToExecutorCache(etcdCfg.Prefix, etcdClient, p.Logger) + timeSource := p.TimeSource + if timeSource == nil { + timeSource = clock.NewRealTimeSource() + } + store := &executorStoreImpl{ client: etcdClient, prefix: etcdCfg.Prefix, logger: p.Logger, shardCache: shardCache, + timeSource: timeSource, } p.Lifecycle.Append(fx.StartStopHook(store.Start, store.Stop)) @@ -204,6 +222,7 @@ func (s *executorStoreImpl) GetHeartbeat(ctx context.Context, namespace string, func (s *executorStoreImpl) GetState(ctx context.Context, namespace string) (*store.NamespaceState, error) { heartbeatStates := make(map[string]store.HeartbeatState) assignedStates := make(map[string]store.AssignedState) + shardStats := make(map[string]store.ShardStatistics) executorPrefix := etcdkeys.BuildExecutorPrefix(s.prefix, namespace) resp, err := s.client.Get(ctx, executorPrefix, clientv3.WithPrefix()) @@ -245,8 +264,30 @@ func (s *executorStoreImpl) GetState(ctx context.Context, namespace string) (*st assignedStates[executorID] = assigned } + // Fetch shard-level statistics stored under shard namespace keys. + shardPrefix := etcdkeys.BuildShardPrefix(s.prefix, namespace) + shardResp, err := s.client.Get(ctx, shardPrefix, clientv3.WithPrefix()) + if err != nil { + return nil, fmt.Errorf("get shard data: %w", err) + } + for _, kv := range shardResp.Kvs { + shardID, shardKeyType, err := etcdkeys.ParseShardKey(s.prefix, namespace, string(kv.Key)) + if err != nil { + continue + } + if shardKeyType != etcdkeys.ShardStatisticsKey { + continue + } + var shardStatistic store.ShardStatistics + if err := json.Unmarshal(kv.Value, &shardStatistic); err != nil { + continue + } + shardStats[shardID] = shardStatistic + } + return &store.NamespaceState{ Executors: heartbeatStates, + ShardStats: shardStats, ShardAssignments: assignedStates, GlobalRevision: resp.Header.Revision, }, nil @@ -297,6 +338,11 @@ func (s *executorStoreImpl) AssignShards(ctx context.Context, namespace string, var ops []clientv3.Op var comparisons []clientv3.Cmp + statsUpdates, err := s.prepareShardStatisticsUpdates(ctx, namespace, request.NewState.ShardAssignments) + if err != nil { + return fmt.Errorf("prepare shard statistics: %w", err) + } + // 1. Prepare operations to update executor states and shard ownership, // and comparisons to check for concurrent modifications. for executorID, state := range request.NewState.ShardAssignments { @@ -360,6 +406,9 @@ func (s *executorStoreImpl) AssignShards(ctx context.Context, namespace string, return fmt.Errorf("%w: transaction failed, a shard may have been concurrently assigned", store.ErrVersionConflict) } + // Apply shard statistics updates outside the main transaction to stay within etcd's max operations per txn. + s.applyShardStatisticsUpdates(ctx, namespace, statsUpdates) + return nil } @@ -372,16 +421,21 @@ func (s *executorStoreImpl) AssignShard(ctx context.Context, namespace, shardID, if err != nil { return fmt.Errorf("build executor status key: %w", err) } + shardStatsKey, err := etcdkeys.BuildShardKey(s.prefix, namespace, shardID, etcdkeys.ShardStatisticsKey) + if err != nil { + return fmt.Errorf("build shard statistics key: %w", err) + } // Use a read-modify-write loop to handle concurrent updates safely. for { - // 1. Get the current assigned state of the executor. + // 1. Get the current assigned state of the executor and prepare the shard statistics. resp, err := s.client.Get(ctx, assignedState) if err != nil { return fmt.Errorf("get executor state: %w", err) } var state store.AssignedState + var shardStats store.ShardStatistics modRevision := int64(0) // A revision of 0 means the key doesn't exist yet. if len(resp.Kvs) > 0 { @@ -396,6 +450,28 @@ func (s *executorStoreImpl) AssignShard(ctx context.Context, namespace, shardID, state.AssignedShards = make(map[string]*types.ShardAssignment) } + statsResp, err := s.client.Get(ctx, shardStatsKey) + if err != nil { + return fmt.Errorf("get shard statistics: %w", err) + } + now := s.timeSource.Now().Unix() + statsModRevision := int64(0) + if len(statsResp.Kvs) > 0 { + statsModRevision = statsResp.Kvs[0].ModRevision + if err := json.Unmarshal(statsResp.Kvs[0].Value, &shardStats); err != nil { + return fmt.Errorf("unmarshal shard statistics: %w", err) + } + // Statistics already exist, update the last move time. + // This can happen if the shard was previously assigned to an executor, and a lookup happens after the executor is deleted, + // AssignShard is then called to assign the shard to a new executor. + shardStats.LastMoveTime = now + } else { + // Statistics don't exist, initialize them. + shardStats.SmoothedLoad = 0 + shardStats.LastUpdateTime = now + shardStats.LastMoveTime = now + } + // 2. Modify the state in memory, adding the new shard if it's not already there. if _, alreadyAssigned := state.AssignedShards[shardID]; !alreadyAssigned { state.AssignedShards[shardID] = &types.ShardAssignment{Status: types.AssignmentStatusREADY} @@ -406,13 +482,19 @@ func (s *executorStoreImpl) AssignShard(ctx context.Context, namespace, shardID, return fmt.Errorf("marshal new assigned state: %w", err) } + newStatsValue, err := json.Marshal(shardStats) + if err != nil { + return fmt.Errorf("marshal new shard statistics: %w", err) + } + var comparisons []clientv3.Cmp - // 3. Prepare and commit the transaction with three atomic checks. + // 3. Prepare and commit the transaction with four atomic checks. // a) Check that the executor's status is ACTIVE. comparisons = append(comparisons, clientv3.Compare(clientv3.Value(statusKey), "=", _executorStatusRunningJSON)) - // b) Check that the assigned_state key hasn't been changed by another process. + // b) Check that neither the assigned_state nor shard statistics were modified concurrently. comparisons = append(comparisons, clientv3.Compare(clientv3.ModRevision(assignedState), "=", modRevision)) + comparisons = append(comparisons, clientv3.Compare(clientv3.ModRevision(shardStatsKey), "=", statsModRevision)) // c) Check that the cache is up to date. cmp, err := s.shardCache.GetExecutorModRevisionCmp(namespace) if err != nil { @@ -431,7 +513,10 @@ func (s *executorStoreImpl) AssignShard(ctx context.Context, namespace, shardID, txnResp, err := s.client.Txn(ctx). If(comparisons...). - Then(clientv3.OpPut(assignedState, string(newStateValue))). + Then( + clientv3.OpPut(assignedState, string(newStateValue)), + clientv3.OpPut(shardStatsKey, string(newStatsValue)), + ). Commit() if err != nil { @@ -498,6 +583,119 @@ func (s *executorStoreImpl) DeleteExecutors(ctx context.Context, namespace strin return nil } +func (s *executorStoreImpl) DeleteShardStats(ctx context.Context, namespace string, shardIDs []string, guard store.GuardFunc) error { + if len(shardIDs) == 0 { + return nil + } + var ops []clientv3.Op + for _, shardID := range shardIDs { + shardStatsKey, err := etcdkeys.BuildShardKey(s.prefix, namespace, shardID, etcdkeys.ShardStatisticsKey) + if err != nil { + return fmt.Errorf("build shard statistics key: %w", err) + } + ops = append(ops, clientv3.OpDelete(shardStatsKey)) + } + + nativeTxn := s.client.Txn(ctx) + guardedTxn, err := guard(nativeTxn) + + if err != nil { + return fmt.Errorf("apply transaction guard: %w", err) + } + etcdGuardedTxn, ok := guardedTxn.(clientv3.Txn) + if !ok { + return fmt.Errorf("guard function returned invalid transaction type") + } + + etcdGuardedTxn = etcdGuardedTxn.Then(ops...) + resp, err := etcdGuardedTxn.Commit() + if err != nil { + return fmt.Errorf("commit shard statistics deletion: %w", err) + } + if !resp.Succeeded { + return fmt.Errorf("transaction failed, leadership may have changed") + } + return nil +} + func (s *executorStoreImpl) GetShardOwner(ctx context.Context, namespace, shardID string) (*store.ShardOwner, error) { return s.shardCache.GetShardOwner(ctx, namespace, shardID) } + +func (s *executorStoreImpl) prepareShardStatisticsUpdates(ctx context.Context, namespace string, newAssignments map[string]store.AssignedState) ([]shardStatisticsUpdate, error) { + var updates []shardStatisticsUpdate + + for executorID, state := range newAssignments { + for shardID := range state.AssignedShards { + now := s.timeSource.Now().Unix() + + oldOwner, err := s.shardCache.GetShardOwner(ctx, namespace, shardID) + if err != nil && !errors.Is(err, store.ErrShardNotFound) { + return nil, fmt.Errorf("lookup cached shard owner: %w", err) + } + + // we should just skip if the owner hasn't changed + if err == nil && oldOwner.ExecutorID == executorID { + continue + } + + shardStatisticsKey, err := etcdkeys.BuildShardKey(s.prefix, namespace, shardID, etcdkeys.ShardStatisticsKey) + if err != nil { + return nil, fmt.Errorf("build shard statistics key: %w", err) + } + + statsResp, err := s.client.Get(ctx, shardStatisticsKey) + if err != nil { + return nil, fmt.Errorf("get shard statistics: %w", err) + } + + stats := store.ShardStatistics{} + + if len(statsResp.Kvs) > 0 { + if err := json.Unmarshal(statsResp.Kvs[0].Value, &stats); err != nil { + return nil, fmt.Errorf("unmarshal shard statistics: %w", err) + } + } else { + stats.SmoothedLoad = 0 + stats.LastUpdateTime = now + } + + updates = append(updates, shardStatisticsUpdate{ + key: shardStatisticsKey, + shardID: shardID, + stats: stats, + desiredLastMove: now, + }) + } + } + + return updates, nil +} + +// applyShardStatisticsUpdates updates shard statistics. +// Is intentionally made tolerant of failures since the data is telemetry only. +func (s *executorStoreImpl) applyShardStatisticsUpdates(ctx context.Context, namespace string, updates []shardStatisticsUpdate) { + for _, update := range updates { + update.stats.LastMoveTime = update.desiredLastMove + + payload, err := json.Marshal(update.stats) + if err != nil { + s.logger.Warn( + "failed to marshal shard statistics after assignment", + tag.ShardNamespace(namespace), + tag.ShardKey(update.shardID), + tag.Error(err), + ) + continue + } + + if _, err := s.client.Put(ctx, update.key, string(payload)); err != nil { + s.logger.Warn( + "failed to update shard statistics", + tag.ShardNamespace(namespace), + tag.ShardKey(update.shardID), + tag.Error(err), + ) + } + } +} diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go index 903693c00d5..095a058728b 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go @@ -521,6 +521,59 @@ func TestAssignShardErrors(t *testing.T) { assert.ErrorIs(t, err, store.ErrVersionConflict, "Error should be ErrVersionConflict for non-active executor") } +// TestShardStatisticsPersistence verifies that shard statistics are preserved on assignment +// when they already exist, and that GetState exposes them. +func TestShardStatisticsPersistence(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "exec-stats" + shardID := "shard-stats" + + // 1. Setup: ensure executor is ACTIVE + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{Status: types.ExecutorStatusACTIVE})) + + // 2. Pre-create shard statistics as if coming from prior history + stats := store.ShardStatistics{SmoothedLoad: 12.5, LastUpdateTime: 1234, LastMoveTime: 5678} + shardStatsKey, err := etcdkeys.BuildShardKey(tc.EtcdPrefix, tc.Namespace, shardID, etcdkeys.ShardStatisticsKey) + require.NoError(t, err) + payload, err := json.Marshal(stats) + require.NoError(t, err) + _, err = tc.Client.Put(ctx, shardStatsKey, string(payload)) + require.NoError(t, err) + + // 3. Assign the shard via AssignShard (should not clobber existing metrics) + require.NoError(t, executorStore.AssignShard(ctx, tc.Namespace, shardID, executorID)) + + // 4. Verify via GetState that metrics are preserved and exposed + nsState, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + require.Contains(t, nsState.ShardStats, shardID) + updatedStats := nsState.ShardStats[shardID] + assert.Equal(t, stats.SmoothedLoad, updatedStats.SmoothedLoad) + assert.Equal(t, stats.LastUpdateTime, updatedStats.LastUpdateTime) + // This should be greater than the last move time + assert.Greater(t, updatedStats.LastMoveTime, stats.LastMoveTime) + + // 5. Also ensure assignment recorded correctly + require.Contains(t, nsState.ShardAssignments[executorID].AssignedShards, shardID) +} + +// TestGetShardStatisticsForMissingShard verifies GetState does not report statistics for unknown shards. +func TestGetShardStatisticsForMissingShard(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // No metrics are written; GetState should not contain unknown shard + st, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + assert.NotContains(t, st.ShardStats, "unknown") +} + // --- Test Setup --- func stringStatus(s types.ExecutorStatus) string { diff --git a/service/sharddistributor/store/state.go b/service/sharddistributor/store/state.go index 4b1f58f9e52..70c75ea0df6 100644 --- a/service/sharddistributor/store/state.go +++ b/service/sharddistributor/store/state.go @@ -19,6 +19,7 @@ type AssignedState struct { type NamespaceState struct { Executors map[string]HeartbeatState + ShardStats map[string]ShardStatistics ShardAssignments map[string]AssignedState GlobalRevision int64 } @@ -27,6 +28,12 @@ type ShardState struct { ExecutorID string } +type ShardStatistics struct { + SmoothedLoad float64 `json:"smoothed_load"` // EWMA of shard load that persists across executor changes + LastUpdateTime int64 `json:"last_update_time"` // heartbeat timestamp that last updated the EWMA + LastMoveTime int64 `json:"last_move_time"` // timestamp for the latest reassignment, used for cooldowns +} + type ShardOwner struct { ExecutorID string Metadata map[string]string diff --git a/service/sharddistributor/store/store.go b/service/sharddistributor/store/store.go index 6b51d4b1fa3..a9500408933 100644 --- a/service/sharddistributor/store/store.go +++ b/service/sharddistributor/store/store.go @@ -60,6 +60,7 @@ type Store interface { AssignShards(ctx context.Context, namespace string, request AssignShardsRequest, guard GuardFunc) error Subscribe(ctx context.Context, namespace string) (<-chan int64, error) DeleteExecutors(ctx context.Context, namespace string, executorIDs []string, guard GuardFunc) error + DeleteShardStats(ctx context.Context, namespace string, shardIDs []string, guard GuardFunc) error GetShardOwner(ctx context.Context, namespace, shardID string) (*ShardOwner, error) AssignShard(ctx context.Context, namespace, shardID, executorID string) error diff --git a/service/sharddistributor/store/store_mock.go b/service/sharddistributor/store/store_mock.go index c686c65e487..a246f8e3f85 100644 --- a/service/sharddistributor/store/store_mock.go +++ b/service/sharddistributor/store/store_mock.go @@ -106,6 +106,20 @@ func (mr *MockStoreMockRecorder) DeleteExecutors(ctx, namespace, executorIDs, gu return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteExecutors", reflect.TypeOf((*MockStore)(nil).DeleteExecutors), ctx, namespace, executorIDs, guard) } +// DeleteShardStats mocks base method. +func (m *MockStore) DeleteShardStats(ctx context.Context, namespace string, shardIDs []string, guard GuardFunc) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteShardStats", ctx, namespace, shardIDs, guard) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteShardStats indicates an expected call of DeleteShardStats. +func (mr *MockStoreMockRecorder) DeleteShardStats(ctx, namespace, shardIDs, guard any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteShardStats", reflect.TypeOf((*MockStore)(nil).DeleteShardStats), ctx, namespace, shardIDs, guard) +} + // GetHeartbeat mocks base method. func (m *MockStore) GetHeartbeat(ctx context.Context, namespace, executorID string) (*HeartbeatState, *AssignedState, error) { m.ctrl.T.Helper() diff --git a/service/sharddistributor/store/wrappers/metered/store_generated.go b/service/sharddistributor/store/wrappers/metered/store_generated.go index db2f3bb7a84..91f1523ecfc 100644 --- a/service/sharddistributor/store/wrappers/metered/store_generated.go +++ b/service/sharddistributor/store/wrappers/metered/store_generated.go @@ -69,6 +69,16 @@ func (c *meteredStore) DeleteExecutors(ctx context.Context, namespace string, ex return } +func (c *meteredStore) DeleteShardStats(ctx context.Context, namespace string, shardIDs []string, guard store.GuardFunc) (err error) { + op := func() error { + err = c.wrapped.DeleteShardStats(ctx, namespace, shardIDs, guard) + return err + } + + err = c.call(metrics.ShardDistributorStoreDeleteShardStatsScope, op, metrics.NamespaceTag(namespace)) + return +} + func (c *meteredStore) GetHeartbeat(ctx context.Context, namespace string, executorID string) (hp1 *store.HeartbeatState, ap1 *store.AssignedState, err error) { op := func() error { hp1, ap1, err = c.wrapped.GetHeartbeat(ctx, namespace, executorID)