diff --git a/components/crm/internal/adapters/mongodb/encryption/index_tracker.go b/components/crm/internal/adapters/mongodb/encryption/index_tracker.go new file mode 100644 index 000000000..c414062de --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/index_tracker.go @@ -0,0 +1,54 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import "sync" + +// indexState tracks whether indexes have been successfully created for a specific database/collection pair. +type indexState struct { + mu sync.Mutex + done bool +} + +// indexTracker manages per-database index creation state. +// In multi-tenant mode, each tenant database needs its own indexes. +// This tracker ensures indexes are created exactly once per database, with retry on failure. +type indexTracker struct { + states sync.Map // key: "dbName:collection" -> *indexState +} + +// ensureOnce executes fn exactly once per key, but only marks as done on success. +// If fn returns an error, subsequent calls will retry. +func (t *indexTracker) ensureOnce(key string, fn func() error) error { + v, _ := t.states.LoadOrStore(key, &indexState{}) + state := v.(*indexState) + + state.mu.Lock() + defer state.mu.Unlock() + + if state.done { + return nil + } + + if err := fn(); err != nil { + return err + } + + state.done = true + + return nil +} + +// reset clears the state for a specific key. Used in integration tests to ensure fresh state +// when each test runs with a new MongoDB container. +// +//nolint:unused // used in *_integration_test.go files (build tag: integration) +func (t *indexTracker) reset(key string) { + t.states.Delete(key) +} + +// globalIndexTracker is shared across all encryption repository instances. +// This ensures indexes are created once per database even if multiple repository instances exist. +var globalIndexTracker = &indexTracker{} diff --git a/components/crm/internal/adapters/mongodb/encryption/index_tracker_test.go b/components/crm/internal/adapters/mongodb/encryption/index_tracker_test.go new file mode 100644 index 000000000..eed0ff195 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/index_tracker_test.go @@ -0,0 +1,221 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "errors" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIndexTracker_EnsureOnce_ExecutesOnceOnSuccess(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + + for i := 0; i < 3; i++ { + err := tracker.ensureOnce("test-db:test-collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "function should be called exactly once on success") +} + +func TestIndexTracker_EnsureOnce_RetriesOnFailure(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + expectedErr := errors.New("index creation failed") + + // First call fails + err := tracker.ensureOnce("test-db:test-collection", func() error { + atomic.AddInt32(&callCount, 1) + return expectedErr + }) + require.Error(t, err) + assert.Equal(t, expectedErr, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + + // Second call should retry (not skipped) + err = tracker.ensureOnce("test-db:test-collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil // Now succeeds + }) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be retried after failure") + + // Third call should be skipped (already succeeded) + err = tracker.ensureOnce("test-db:test-collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should not be called again after success") +} + +func TestIndexTracker_EnsureOnce_DifferentKeys(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCountA, callCountB int32 + + // Key A + err := tracker.ensureOnce("tenant-a:collection", func() error { + atomic.AddInt32(&callCountA, 1) + return nil + }) + require.NoError(t, err) + + // Key B (different tenant) + err = tracker.ensureOnce("tenant-b:collection", func() error { + atomic.AddInt32(&callCountB, 1) + return nil + }) + require.NoError(t, err) + + // Key A again (should be skipped) + err = tracker.ensureOnce("tenant-a:collection", func() error { + atomic.AddInt32(&callCountA, 1) + return nil + }) + require.NoError(t, err) + + assert.Equal(t, int32(1), atomic.LoadInt32(&callCountA), "tenant A should be called once") + assert.Equal(t, int32(1), atomic.LoadInt32(&callCountB), "tenant B should be called once") +} + +func TestIndexTracker_EnsureOnce_ConcurrentAccess(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + const goroutines = 100 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + _ = tracker.ensureOnce("concurrent-db:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + }() + } + + wg.Wait() + + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "function should be called exactly once even with concurrent access") +} + +func TestIndexTracker_EnsureOnce_ConcurrentAccessWithFailure(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + var failureCount int32 + const goroutines = 50 + + // First, make the first call fail + err := tracker.ensureOnce("concurrent-fail-db:collection", func() error { + atomic.AddInt32(&callCount, 1) + return errors.New("initial failure") + }) + require.Error(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + + // Now run concurrent retries - only one should succeed in executing + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + + err := tracker.ensureOnce("concurrent-fail-db:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + if err != nil { + atomic.AddInt32(&failureCount, 1) + } + }() + } + + wg.Wait() + + // Due to mutex, only one goroutine should have actually executed + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be retried exactly once after initial failure") + assert.Equal(t, int32(0), atomic.LoadInt32(&failureCount), "all concurrent calls should succeed once retry succeeds") +} + +func TestIndexTracker_Reset(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + + // First call + err := tracker.ensureOnce("reset-test:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&callCount)) + + // Reset the key + tracker.reset("reset-test:collection") + + // Call again - should execute + err = tracker.ensureOnce("reset-test:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be called again after reset") +} + +func TestIndexTracker_MultipleFailuresBeforeSuccess(t *testing.T) { + t.Parallel() + + tracker := &indexTracker{} + var callCount int32 + + // Simulate multiple failures before success + for i := 0; i < 5; i++ { + _ = tracker.ensureOnce("multi-fail:collection", func() error { + atomic.AddInt32(&callCount, 1) + return errors.New("temporary failure") + }) + } + assert.Equal(t, int32(5), atomic.LoadInt32(&callCount), "function should be called on each retry") + + // Now succeed + err := tracker.ensureOnce("multi-fail:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(6), atomic.LoadInt32(&callCount)) + + // Subsequent calls should be skipped + err = tracker.ensureOnce("multi-fail:collection", func() error { + atomic.AddInt32(&callCount, 1) + return nil + }) + require.NoError(t, err) + assert.Equal(t, int32(6), atomic.LoadInt32(&callCount), "function should not be called after success") +} diff --git a/components/crm/internal/adapters/mongodb/encryption/keyset.go b/components/crm/internal/adapters/mongodb/encryption/keyset.go new file mode 100644 index 000000000..8aa2d2a76 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/keyset.go @@ -0,0 +1,119 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "time" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" +) + +// KeysetMongoDBModel is the MongoDB representation of OrganizationKeyset. +type KeysetMongoDBModel struct { + TenantID string `bson:"tenant_id,omitempty"` + OrganizationID string `bson:"organization_id"` + KEKPath string `bson:"kek_path"` + WrappedKeyset string `bson:"wrapped_keyset"` + KeysetInfo KeysetInfoModel `bson:"keyset_info"` + LegacyKeyImported bool `bson:"legacy_key_imported"` + WrappedHMACKeyset string `bson:"wrapped_hmac_keyset,omitempty"` + HMACKeysetInfo KeysetInfoModel `bson:"hmac_keyset_info,omitempty"` + LegacyHMACKeyImported bool `bson:"legacy_hmac_key_imported"` + Revision int64 `bson:"revision"` + CreatedAt time.Time `bson:"created_at"` + RotatedAt *time.Time `bson:"rotated_at,omitempty"` +} + +// KeysetInfoModel is the MongoDB representation of KeysetInfo. +type KeysetInfoModel struct { + PrimaryKeyID uint32 `bson:"primary_key_id"` + Keys []KeyInfoModel `bson:"keys"` +} + +// KeyInfoModel is the MongoDB representation of KeyInfo. +type KeyInfoModel struct { + KeyID uint32 `bson:"key_id"` + Status string `bson:"status"` + Type string `bson:"type"` + IsPrimary bool `bson:"is_primary"` +} + +// KeysetFromEntity converts a domain OrganizationKeyset to MongoDB model. +func KeysetFromEntity(k *mmodel.OrganizationKeyset) *KeysetMongoDBModel { + if k == nil { + return nil + } + + return &KeysetMongoDBModel{ + TenantID: k.TenantID, + OrganizationID: k.OrganizationID, + KEKPath: k.KEKPath, + WrappedKeyset: k.WrappedKeyset, + KeysetInfo: keysetInfoFromEntity(k.KeysetInfo), + LegacyKeyImported: k.LegacyKeyImported, + WrappedHMACKeyset: k.WrappedHMACKeyset, + HMACKeysetInfo: keysetInfoFromEntity(k.HMACKeysetInfo), + LegacyHMACKeyImported: k.LegacyHMACKeyImported, + Revision: k.Revision, + CreatedAt: k.CreatedAt, + RotatedAt: k.RotatedAt, + } +} + +// ToEntity converts the MongoDB model to a domain OrganizationKeyset. +func (m *KeysetMongoDBModel) ToEntity() *mmodel.OrganizationKeyset { + if m == nil { + return nil + } + + return &mmodel.OrganizationKeyset{ + TenantID: m.TenantID, + OrganizationID: m.OrganizationID, + KEKPath: m.KEKPath, + WrappedKeyset: m.WrappedKeyset, + KeysetInfo: m.KeysetInfo.toEntity(), + LegacyKeyImported: m.LegacyKeyImported, + WrappedHMACKeyset: m.WrappedHMACKeyset, + HMACKeysetInfo: m.HMACKeysetInfo.toEntity(), + LegacyHMACKeyImported: m.LegacyHMACKeyImported, + Revision: m.Revision, + CreatedAt: m.CreatedAt, + RotatedAt: m.RotatedAt, + } +} + +func keysetInfoFromEntity(info mmodel.KeysetInfo) KeysetInfoModel { + keys := make([]KeyInfoModel, len(info.Keys)) + for i, k := range info.Keys { + keys[i] = KeyInfoModel{ + KeyID: k.KeyID, + Status: k.Status, + Type: k.Type, + IsPrimary: k.IsPrimary, + } + } + + return KeysetInfoModel{ + PrimaryKeyID: info.PrimaryKeyID, + Keys: keys, + } +} + +func (m *KeysetInfoModel) toEntity() mmodel.KeysetInfo { + keys := make([]mmodel.KeyInfo, len(m.Keys)) + for i, k := range m.Keys { + keys[i] = mmodel.KeyInfo{ + KeyID: k.KeyID, + Status: k.Status, + Type: k.Type, + IsPrimary: k.IsPrimary, + } + } + + return mmodel.KeysetInfo{ + PrimaryKeyID: m.PrimaryKeyID, + Keys: keys, + } +} diff --git a/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb.go b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb.go new file mode 100644 index 000000000..68aee58e8 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb.go @@ -0,0 +1,233 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "errors" + "fmt" + + libCommons "github.com/LerianStudio/lib-commons/v5/commons" + libMongo "github.com/LerianStudio/lib-commons/v5/commons/mongo" + libOpenTelemetry "github.com/LerianStudio/lib-commons/v5/commons/opentelemetry" + tmcore "github.com/LerianStudio/lib-commons/v5/commons/tenant-manager/core" + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.opentelemetry.io/otel/attribute" +) + +const keysetCollection = "organization_keyset" + +// KeysetRepository provides an interface for operations related to keyset entities. +// +//go:generate go run go.uber.org/mock/mockgen@v0.6.0 --destination=keyset.mongodb_mock.go --package=encryption . KeysetRepository +type KeysetRepository interface { + Save(ctx context.Context, keyset *mmodel.OrganizationKeyset) error + Get(ctx context.Context, organizationID string) (*mmodel.OrganizationKeyset, error) + Update(ctx context.Context, keyset *mmodel.OrganizationKeyset, expectedRevision int64) error +} + +// KeysetMongoDBRepository is a MongoDB-specific implementation of KeysetRepository. +type KeysetMongoDBRepository struct { + connection *libMongo.Client +} + +// NewKeysetMongoDBRepository returns a new instance of KeysetMongoDBRepository using the given MongoDB connection. +// In multi-tenant mode, connection may be nil — the per-request tenant context provides the database. +func NewKeysetMongoDBRepository(connection *libMongo.Client) (*KeysetMongoDBRepository, error) { + r := &KeysetMongoDBRepository{ + connection: connection, + } + + if connection != nil { + if _, err := r.connection.Database(context.Background()); err != nil { + return nil, fmt.Errorf("failed to connect to MongoDB for keyset repository: %w", err) + } + } + + return r, nil +} + +func (r *KeysetMongoDBRepository) Save(ctx context.Context, keyset *mmodel.OrganizationKeyset) error { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.keyset.save") + defer span.End() + + if keyset == nil { + return fmt.Errorf("keyset is required") + } + + span.SetAttributes(attribute.String("app.request.organization_id", keyset.OrganizationID)) + + if err := keyset.Validate(); err != nil { + libOpenTelemetry.HandleSpanError(span, "Keyset validation failed", err) + return err + } + + if keyset.Revision == 0 { + keyset.Revision = 1 + } + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return err + } + + if err := r.ensureIndexes(ctx, collection); err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to create keyset indexes", err) + return fmt.Errorf("create keyset indexes: %w", err) + } + + model := KeysetFromEntity(keyset) + + filter := bson.M{"organization_id": keyset.OrganizationID} + update := bson.M{"$setOnInsert": model} + + result, err := collection.UpdateOne(ctx, filter, update, options.Update().SetUpsert(true)) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to save organization keyset", err) + return fmt.Errorf("save organization keyset: %w", err) + } + + if result.MatchedCount > 0 { + return mmodel.ErrKeysetAlreadyExists + } + + return nil +} + +func (r *KeysetMongoDBRepository) Get(ctx context.Context, organizationID string) (*mmodel.OrganizationKeyset, error) { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.keyset.get") + defer span.End() + + span.SetAttributes(attribute.String("app.request.organization_id", organizationID)) + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return nil, err + } + + var model KeysetMongoDBModel + + if err := collection.FindOne(ctx, bson.M{"organization_id": organizationID}).Decode(&model); err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, mmodel.ErrKeysetNotFound + } + + libOpenTelemetry.HandleSpanError(span, "Failed to get organization keyset", err) + + return nil, fmt.Errorf("get organization keyset: %w", err) + } + + return model.ToEntity(), nil +} + +func (r *KeysetMongoDBRepository) Update(ctx context.Context, keyset *mmodel.OrganizationKeyset, expectedRevision int64) error { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.keyset.update") + defer span.End() + + if keyset == nil { + return fmt.Errorf("keyset is required") + } + + span.SetAttributes( + attribute.String("app.request.organization_id", keyset.OrganizationID), + attribute.Int64("app.request.expected_revision", expectedRevision), + ) + + if err := keyset.Validate(); err != nil { + libOpenTelemetry.HandleSpanError(span, "Keyset validation failed", err) + return err + } + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return err + } + + // Create model from entity and set the new revision on the model, not on the input entity. + // This prevents mutation of the caller's object if the database operation fails. + model := KeysetFromEntity(keyset) + model.Revision = expectedRevision + 1 + + result, err := collection.ReplaceOne(ctx, bson.M{"organization_id": keyset.OrganizationID, "revision": expectedRevision}, model) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to update organization keyset", err) + return fmt.Errorf("update organization keyset: %w", err) + } + + if result.MatchedCount == 0 { + return mmodel.ErrKeysetRevisionConflict + } + + span.SetAttributes(attribute.Int64("db.rows_affected", result.ModifiedCount)) + + return nil +} + +// getDatabase resolves the MongoDB database for the current request. +// In multi-tenant mode, the middleware injects a tenant-specific *mongo.Database into context. +// In single-tenant mode (or when no tenant context exists), falls back to the static connection. +func (r *KeysetMongoDBRepository) getDatabase(ctx context.Context) (*mongo.Database, error) { + if r.connection == nil { + if db := tmcore.GetMBContext(ctx); db != nil { + return db, nil + } + + return nil, fmt.Errorf("no database connection available: multi-tenant context required but not present, and no static connection configured") + } + + if db := tmcore.GetMBContext(ctx); db != nil { + return db, nil + } + + return r.connection.Database(ctx) +} + +func (r *KeysetMongoDBRepository) collection(ctx context.Context) (*mongo.Collection, error) { + db, err := r.getDatabase(ctx) + if err != nil { + return nil, err + } + + return db.Collection(keysetCollection), nil +} + +// ensureIndexes ensures indexes exist for the keyset collection. +// Uses per-database tracking to handle multi-tenant mode correctly. +// Retries on failure — indexes are only marked as done after successful creation. +func (r *KeysetMongoDBRepository) ensureIndexes(ctx context.Context, collection *mongo.Collection) error { + key := collection.Database().Name() + ":" + keysetCollection + + return globalIndexTracker.ensureOnce(key, func() error { + return r.createIndexes(ctx, collection) + }) +} + +// createIndexes ensures indexes exist for the keyset collection. +func (r *KeysetMongoDBRepository) createIndexes(ctx context.Context, collection *mongo.Collection) error { + indexModels := []mongo.IndexModel{ + { + Keys: bson.D{{Key: "organization_id", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + } + + _, err := collection.Indexes().CreateMany(ctx, indexModels) + + return err +} + +var _ KeysetRepository = (*KeysetMongoDBRepository)(nil) diff --git a/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_integration_test.go b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_integration_test.go new file mode 100644 index 000000000..23c750881 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_integration_test.go @@ -0,0 +1,579 @@ +//go:build integration + +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "testing" + "time" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + mongotestutil "github.com/LerianStudio/midaz/v3/tests/utils/mongodb" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" +) + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// createKeysetRepository creates a KeysetMongoDBRepository for integration testing. +// Resets the global index tracker for this database to ensure a fresh state, +// since each test runs with a new MongoDB container. +func createKeysetRepository(t *testing.T, container *mongotestutil.ContainerResult) *KeysetMongoDBRepository { + t.Helper() + + // Reset index tracker state for this database — each test has a fresh container + globalIndexTracker.reset(container.DBName + ":" + keysetCollection) + + conn := mongotestutil.CreateConnection(t, container.URI, container.DBName) + + repo, err := NewKeysetMongoDBRepository(conn) + require.NoError(t, err) + + return repo +} + +// createValidKeyset creates a valid OrganizationKeyset for testing. +func createValidKeyset(organizationID string) *mmodel.OrganizationKeyset { + now := time.Now().UTC().Truncate(time.Second) + + return &mmodel.OrganizationKeyset{ + OrganizationID: organizationID, + KEKPath: "transit/keys/crm-" + organizationID, + WrappedKeyset: "vault:v1:encrypted-dek-" + uuid.New().String()[:8], + KeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 1, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + }, + }, + WrappedHMACKeyset: "vault:v1:encrypted-hmac-" + uuid.New().String()[:8], + HMACKeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 1, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "HMAC_SHA256", IsPrimary: true}, + }, + }, + LegacyKeyImported: false, + LegacyHMACKeyImported: false, + Revision: 0, + CreatedAt: now, + } +} + +// ============================================================================ +// Save Tests +// ============================================================================ + +func TestIntegration_KeysetRepo_Save(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + // Act + err := repo.Save(ctx, keyset) + + // Assert + require.NoError(t, err, "Save should not return error") + + // Verify via direct query + count := mongotestutil.CountDocuments(t, container.Database, keysetCollection, bson.M{"organization_id": organizationID}) + assert.Equal(t, int64(1), count, "should have exactly 1 document") +} + +func TestIntegration_KeysetRepo_Save_SetsRevisionToOne(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-rev-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + keyset.Revision = 0 + + // Act + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Assert - Get and verify revision was set to 1 + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, int64(1), result.Revision, "revision should be set to 1 on save") +} + +func TestIntegration_KeysetRepo_Save_AlreadyExists(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-dup-" + uuid.New().String()[:8] + keyset1 := createValidKeyset(organizationID) + + // First save should succeed + err := repo.Save(ctx, keyset1) + require.NoError(t, err, "first save should succeed") + + // Act - Try to save again with same organization_id + keyset2 := createValidKeyset(organizationID) + err = repo.Save(ctx, keyset2) + + // Assert + require.Error(t, err, "second save should fail") + assert.ErrorIs(t, err, mmodel.ErrKeysetAlreadyExists, "should return ErrKeysetAlreadyExists") +} + +func TestIntegration_KeysetRepo_Save_DifferentOrganizations(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + org1 := "org-1-" + uuid.New().String()[:8] + org2 := "org-2-" + uuid.New().String()[:8] + + keyset1 := createValidKeyset(org1) + keyset2 := createValidKeyset(org2) + + // Act + err1 := repo.Save(ctx, keyset1) + err2 := repo.Save(ctx, keyset2) + + // Assert - Both should succeed + require.NoError(t, err1, "first org save should succeed") + require.NoError(t, err2, "second org save should succeed") + + count := mongotestutil.CountDocuments(t, container.Database, keysetCollection, bson.M{}) + assert.Equal(t, int64(2), count, "should have 2 documents") +} + +func TestIntegration_KeysetRepo_Save_WithHMACKeyset(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-hmac-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + // Act + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Assert - Get and verify HMAC fields + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.NotEmpty(t, result.WrappedHMACKeyset, "HMAC keyset should be persisted") + assert.Equal(t, uint32(1), result.HMACKeysetInfo.PrimaryKeyID, "HMAC keyset info should be persisted") +} + +func TestIntegration_KeysetRepo_Save_WithoutHMACKeyset(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-no-hmac-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + keyset.WrappedHMACKeyset = "" + keyset.HMACKeysetInfo = mmodel.KeysetInfo{} + + // Act + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Assert + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Empty(t, result.WrappedHMACKeyset, "HMAC keyset should be empty") +} + +// ============================================================================ +// Get Tests +// ============================================================================ + +func TestIntegration_KeysetRepo_Get(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-get-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Act + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, organizationID, result.OrganizationID) + assert.Equal(t, keyset.KEKPath, result.KEKPath) + assert.Equal(t, keyset.WrappedKeyset, result.WrappedKeyset) + assert.Equal(t, keyset.KeysetInfo.PrimaryKeyID, result.KeysetInfo.PrimaryKeyID) +} + +func TestIntegration_KeysetRepo_Get_NotFound(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + nonExistentOrg := "org-notfound-" + uuid.New().String()[:8] + + // Act + result, err := repo.Get(ctx, nonExistentOrg) + + // Assert + require.Error(t, err, "should return error for non-existent organization") + assert.Nil(t, result) + assert.ErrorIs(t, err, mmodel.ErrKeysetNotFound, "should return ErrKeysetNotFound") +} + +func TestIntegration_KeysetRepo_Get_ReturnsAllFields(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-fields-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + keyset.LegacyKeyImported = true + keyset.LegacyHMACKeyImported = true + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Act + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, keyset.OrganizationID, result.OrganizationID) + assert.Equal(t, keyset.KEKPath, result.KEKPath) + assert.Equal(t, keyset.WrappedKeyset, result.WrappedKeyset) + assert.Equal(t, keyset.WrappedHMACKeyset, result.WrappedHMACKeyset) + assert.True(t, result.LegacyKeyImported, "legacy_key_imported should be true") + assert.True(t, result.LegacyHMACKeyImported, "legacy_hmac_key_imported should be true") + assert.Equal(t, int64(1), result.Revision, "revision should be 1") + assert.False(t, result.CreatedAt.IsZero(), "created_at should be set") +} + +// ============================================================================ +// Update Tests +// ============================================================================ + +func TestIntegration_KeysetRepo_Update(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-update-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Get the saved keyset to confirm revision + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + require.Equal(t, int64(1), saved.Revision) + + // Modify keyset + saved.WrappedKeyset = "vault:v2:new-encrypted-dek" + saved.KEKPath = "transit/keys/crm-" + organizationID + "-rotated" + + // Act + err = repo.Update(ctx, saved, 1) + + // Assert + require.NoError(t, err) + + // Verify update + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, "vault:v2:new-encrypted-dek", result.WrappedKeyset) + assert.Equal(t, "transit/keys/crm-"+organizationID+"-rotated", result.KEKPath) +} + +func TestIntegration_KeysetRepo_Update_IncrementRevision(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-increv-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + require.Equal(t, int64(1), saved.Revision) + + // Act - Update with correct revision + saved.WrappedKeyset = "vault:v2:updated-dek" + err = repo.Update(ctx, saved, 1) + require.NoError(t, err) + + // Assert - Revision should increment + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, int64(2), result.Revision, "revision should increment to 2") +} + +func TestIntegration_KeysetRepo_Update_RevisionConflict(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-conflict-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // Act - Try to update with wrong revision + saved.WrappedKeyset = "vault:v2:should-fail" + wrongRevision := int64(999) + err = repo.Update(ctx, saved, wrongRevision) + + // Assert + require.Error(t, err, "should return error for revision conflict") + assert.ErrorIs(t, err, mmodel.ErrKeysetRevisionConflict, "should return ErrKeysetRevisionConflict") + + // Verify original data unchanged + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.NotEqual(t, "vault:v2:should-fail", result.WrappedKeyset, "data should not be updated") + assert.Equal(t, int64(1), result.Revision, "revision should still be 1") +} + +func TestIntegration_KeysetRepo_Update_DoesNotMutateInputOnFailure(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-nomutate-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + initialRevision := saved.Revision + wrongRevision := int64(999) + + // Act - Update should fail + err = repo.Update(ctx, saved, wrongRevision) + + // Assert - Input object should not be mutated + require.Error(t, err) + assert.Equal(t, initialRevision, saved.Revision, "input revision should not be mutated on failure") +} + +func TestIntegration_KeysetRepo_Update_NotFound(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + nonExistentOrg := "org-updatenotfound-" + uuid.New().String()[:8] + keyset := createValidKeyset(nonExistentOrg) + keyset.Revision = 1 + + // Act - Try to update non-existent keyset + err := repo.Update(ctx, keyset, 1) + + // Assert + require.Error(t, err, "should return error for non-existent organization") + assert.ErrorIs(t, err, mmodel.ErrKeysetRevisionConflict, "should return revision conflict (no document matched)") +} + +func TestIntegration_KeysetRepo_Update_MultipleUpdates(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-multi-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + // Act - Perform multiple sequential updates + for i := 1; i <= 5; i++ { + current, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + current.WrappedKeyset = "vault:v" + string(rune('0'+i)) + ":dek-update-" + string(rune('0'+i)) + err = repo.Update(ctx, current, current.Revision) + require.NoError(t, err, "update %d should succeed", i) + } + + // Assert + final, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, int64(6), final.Revision, "revision should be 6 after 5 updates") +} + +func TestIntegration_KeysetRepo_Update_RotatedAt(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-rotated-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // Set rotated_at timestamp + rotatedAt := time.Now().UTC().Truncate(time.Second) + saved.RotatedAt = &rotatedAt + saved.WrappedKeyset = "vault:v2:rotated-dek" + + // Act + err = repo.Update(ctx, saved, saved.Revision) + require.NoError(t, err) + + // Assert + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + require.NotNil(t, result.RotatedAt, "rotated_at should be set") + assert.Equal(t, rotatedAt.Unix(), result.RotatedAt.Unix(), "rotated_at should match") +} + +// ============================================================================ +// Index Constraint Tests +// ============================================================================ + +func TestIntegration_KeysetRepo_UniqueIndex_OrganizationID(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-unique-" + uuid.New().String()[:8] + + // Save first keyset + keyset1 := createValidKeyset(organizationID) + err := repo.Save(ctx, keyset1) + require.NoError(t, err) + + // Act - Try to insert directly via MongoDB (bypassing Save logic) + keyset2Model := KeysetFromEntity(createValidKeyset(organizationID)) + _, err = container.Database.Collection(keysetCollection).InsertOne(ctx, keyset2Model) + + // Assert - Should fail due to unique index + require.Error(t, err, "direct insert with duplicate organization_id should fail") + assert.Contains(t, err.Error(), "duplicate key", "should be a duplicate key error") +} + +// ============================================================================ +// Round-Trip Tests +// ============================================================================ + +func TestIntegration_KeysetRepo_RoundTrip(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-roundtrip-" + uuid.New().String()[:8] + original := createValidKeyset(organizationID) + original.LegacyKeyImported = true + original.LegacyHMACKeyImported = true + + // Act - Save and retrieve + err := repo.Save(ctx, original) + require.NoError(t, err) + + result, err := repo.Get(ctx, organizationID) + + // Assert - All fields should match + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, original.OrganizationID, result.OrganizationID) + assert.Equal(t, original.KEKPath, result.KEKPath) + assert.Equal(t, original.WrappedKeyset, result.WrappedKeyset) + assert.Equal(t, original.WrappedHMACKeyset, result.WrappedHMACKeyset) + assert.Equal(t, original.KeysetInfo.PrimaryKeyID, result.KeysetInfo.PrimaryKeyID) + assert.Equal(t, original.HMACKeysetInfo.PrimaryKeyID, result.HMACKeysetInfo.PrimaryKeyID) + assert.Equal(t, original.LegacyKeyImported, result.LegacyKeyImported) + assert.Equal(t, original.LegacyHMACKeyImported, result.LegacyHMACKeyImported) + assert.Equal(t, int64(1), result.Revision) + + // Keys array + require.Len(t, result.KeysetInfo.Keys, len(original.KeysetInfo.Keys)) + assert.Equal(t, original.KeysetInfo.Keys[0].KeyID, result.KeysetInfo.Keys[0].KeyID) + assert.Equal(t, original.KeysetInfo.Keys[0].Status, result.KeysetInfo.Keys[0].Status) + assert.Equal(t, original.KeysetInfo.Keys[0].Type, result.KeysetInfo.Keys[0].Type) + assert.Equal(t, original.KeysetInfo.Keys[0].IsPrimary, result.KeysetInfo.Keys[0].IsPrimary) +} + +func TestIntegration_KeysetRepo_RoundTrip_WithMultipleKeys(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createKeysetRepository(t, container) + ctx := context.Background() + + organizationID := "org-multikey-" + uuid.New().String()[:8] + keyset := createValidKeyset(organizationID) + + // Add multiple keys (simulating key rotation history) + keyset.KeysetInfo.Keys = []mmodel.KeyInfo{ + {KeyID: 1, Status: "DISABLED", Type: "AES256_GCM", IsPrimary: false}, + {KeyID: 2, Status: "DISABLED", Type: "AES256_GCM", IsPrimary: false}, + {KeyID: 3, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + } + keyset.KeysetInfo.PrimaryKeyID = 3 + + // Act + err := repo.Save(ctx, keyset) + require.NoError(t, err) + + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + require.Len(t, result.KeysetInfo.Keys, 3, "should have 3 keys") + assert.Equal(t, uint32(3), result.KeysetInfo.PrimaryKeyID) + + // Verify key order and values + assert.Equal(t, uint32(1), result.KeysetInfo.Keys[0].KeyID) + assert.Equal(t, "DISABLED", result.KeysetInfo.Keys[0].Status) + assert.Equal(t, uint32(3), result.KeysetInfo.Keys[2].KeyID) + assert.True(t, result.KeysetInfo.Keys[2].IsPrimary) +} diff --git a/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_test.go b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_test.go new file mode 100644 index 000000000..40c6d119b --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/keyset.mongodb_test.go @@ -0,0 +1,340 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "testing" + "time" + + libMongo "github.com/LerianStudio/lib-commons/v5/commons/mongo" + tmcore "github.com/LerianStudio/lib-commons/v5/commons/tenant-manager/core" + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func TestKeysetMongoDBRepositoryImplementsRepository(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + + require.NoError(t, err) + require.Implements(t, (*KeysetRepository)(nil), repo) +} + +func TestNewKeysetMongoDBRepository_NilConnection(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + + require.NoError(t, err) + require.NotNil(t, repo) + assert.Nil(t, repo.connection) +} + +func TestKeysetMongoDBRepository_Save_NilKeyset(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + err = repo.Save(context.Background(), nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), "keyset is required") +} + +func TestKeysetMongoDBRepository_Save_ValidationError(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + tests := []struct { + name string + keyset *mmodel.OrganizationKeyset + wantErrMsg string + }{ + { + name: "empty organization_id", + keyset: &mmodel.OrganizationKeyset{ + OrganizationID: "", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:dek", + KeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "organization_id is required", + }, + { + name: "empty kek_path", + keyset: &mmodel.OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "", + WrappedKeyset: "vault:v1:dek", + KeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "kek_path is required", + }, + { + name: "empty wrapped_keyset", + keyset: &mmodel.OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "", + KeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "wrapped_keyset is required", + }, + { + name: "zero primary_key_id", + keyset: &mmodel.OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:dek", + KeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 0}, + }, + wantErrMsg: "keyset_info.primary_key_id is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := repo.Save(context.Background(), tt.keyset) + + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrMsg) + }) + } +} + +func TestKeysetMongoDBRepository_Update_NilKeyset(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + err = repo.Update(context.Background(), nil, 1) + + require.Error(t, err) + assert.Contains(t, err.Error(), "keyset is required") +} + +func TestKeysetMongoDBRepository_Update_ValidationError(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + invalidKeyset := &mmodel.OrganizationKeyset{ + OrganizationID: "", + KEKPath: "transit/keys/test", + } + + err = repo.Update(context.Background(), invalidKeyset, 1) + + require.Error(t, err) + assert.Contains(t, err.Error(), "organization_id is required") +} + +func TestKeysetMongoDBRepository_getDatabase_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + // Without tenant context and without connection, should error + db, err := repo.getDatabase(context.Background()) + + require.Error(t, err) + assert.Nil(t, db) + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestKeysetMongoDBRepository_collection_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + // Without tenant context and without connection, should error + coll, err := repo.collection(context.Background()) + + require.Error(t, err) + assert.Nil(t, coll) + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestKeysetMongoDBRepository_Save_SetsRevisionToOne(t *testing.T) { + t.Parallel() + + keyset := validTestKeyset() + keyset.Revision = 0 + + // Verify validation passes + require.NoError(t, keyset.Validate()) + + // After Save would set Revision to 1 (tested in integration tests) + // This unit test just ensures the struct is properly configured + assert.Equal(t, int64(0), keyset.Revision) +} + +func TestValidTestKeyset_IsValid(t *testing.T) { + t.Parallel() + + keyset := validTestKeyset() + + require.NoError(t, keyset.Validate()) + assert.NotEmpty(t, keyset.OrganizationID) + assert.NotEmpty(t, keyset.KEKPath) + assert.NotEmpty(t, keyset.WrappedKeyset) + assert.NotZero(t, keyset.KeysetInfo.PrimaryKeyID) +} + +func TestKeysetMongoDBRepository_Get_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewKeysetMongoDBRepository(nil) + require.NoError(t, err) + + keyset, err := repo.Get(context.Background(), "org-a") + + require.Error(t, err) + assert.Nil(t, keyset) + assert.Contains(t, err.Error(), "no database connection available") +} + +// newDisconnectedDatabase creates a MongoDB database handle for testing tenant isolation. +func newDisconnectedDatabase(t *testing.T, dbName string) *mongo.Database { + t.Helper() + + client, err := mongo.Connect(context.Background(), options.Client().ApplyURI("mongodb://localhost:27017")) + require.NoError(t, err, "mongo.Connect should succeed for a disconnected handle") + + t.Cleanup(func() { + require.NoError(t, client.Disconnect(context.Background()), "client disconnect should not error") + }) + + return client.Database(dbName) +} + +// newPlaceholderConnection creates a placeholder libMongo.Client for testing. +func newPlaceholderConnection(_ string) *libMongo.Client { + return &libMongo.Client{} +} + +func TestKeysetMongoDBRepository_getDatabase_TenantContextTakesPrecedence(t *testing.T) { + t.Parallel() + + tenantDB := newDisconnectedDatabase(t, "tenant-keyset-priority") + repo := &KeysetMongoDBRepository{connection: newPlaceholderConnection("static-db")} + ctx := tmcore.ContextWithMB(context.Background(), tenantDB) + + db, err := repo.getDatabase(ctx) + + require.NoError(t, err, "getDatabase should not return error when tenant DB is in context") + require.NotNil(t, db, "returned database must not be nil") + assert.Same(t, tenantDB, db, "tenant DB must take precedence over static connection") + assert.Equal(t, "tenant-keyset-priority", db.Name(), "returned DB name should be the tenant DB name") +} + +func TestKeysetMongoDBRepository_getDatabase_FallbackToStaticConnection(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctx context.Context + wantErr bool + }{ + {name: "plain_background_context_no_tenant", ctx: context.Background(), wantErr: true}, + {name: "context_with_unrelated_values", ctx: context.WithValue(context.Background(), struct{}{}, "unrelated"), wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &KeysetMongoDBRepository{connection: newPlaceholderConnection("fallback-db")} + + db, err := repo.getDatabase(tt.ctx) + + if tt.wantErr { + require.Error(t, err, "getDatabase should return error when static connection has no live MongoDB") + assert.Nil(t, db, "database should be nil when static connection fails") + } + }) + } +} + +func TestKeysetMongoDBRepository_getDatabase_NilConnectionWithTenantContext(t *testing.T) { + t.Parallel() + + tenantDB := newDisconnectedDatabase(t, "tenant-keyset-nil-conn") + repo := &KeysetMongoDBRepository{connection: nil} + ctx := tmcore.ContextWithMB(context.Background(), tenantDB) + + db, err := repo.getDatabase(ctx) + + require.NoError(t, err, "getDatabase should not error when tenant DB is in context") + require.NotNil(t, db, "returned database must not be nil") + assert.Same(t, tenantDB, db, "must return the exact tenant DB from context") + assert.Equal(t, "tenant-keyset-nil-conn", db.Name(), "database name should match the tenant DB") +} + +func TestKeysetMongoDBRepository_getDatabase_TwoTenants_ResolveToDifferentDatabases(t *testing.T) { + t.Parallel() + + tenantAcmeDB := newDisconnectedDatabase(t, "tenant-keyset-acme") + tenantGlobexDB := newDisconnectedDatabase(t, "tenant-keyset-globex") + + repo := &KeysetMongoDBRepository{connection: newPlaceholderConnection("static-db")} + + ctxTenantA := tmcore.ContextWithMB(context.Background(), tenantAcmeDB) + ctxTenantB := tmcore.ContextWithMB(context.Background(), tenantGlobexDB) + + dbA, errA := repo.getDatabase(ctxTenantA) + dbB, errB := repo.getDatabase(ctxTenantB) + + require.NoError(t, errA, "getDatabase should not error for tenant A") + require.NoError(t, errB, "getDatabase should not error for tenant B") + require.NotNil(t, dbA, "tenant A database must not be nil") + require.NotNil(t, dbB, "tenant B database must not be nil") + + assert.NotSame(t, dbA, dbB, "two different tenants must resolve to different *mongo.Database instances") + assert.NotEqual(t, dbA.Name(), dbB.Name(), "two different tenants must have different database names") + assert.Equal(t, "tenant-keyset-acme", dbA.Name(), "tenant A database name must match") + assert.Equal(t, "tenant-keyset-globex", dbB.Name(), "tenant B database name must match") +} + +func validTestKeyset() *mmodel.OrganizationKeyset { + now := time.Now().UTC() + + return &mmodel.OrganizationKeyset{ + OrganizationID: "org-test", + KEKPath: "transit/keys/crm-org-test", + WrappedKeyset: "vault:v1:encrypted-dek", + WrappedHMACKeyset: "vault:v1:encrypted-hmac", + KeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 1, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + }, + }, + HMACKeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 1, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "HMAC_SHA256", IsPrimary: true}, + }, + }, + LegacyKeyImported: false, + LegacyHMACKeyImported: false, + Revision: 1, + CreatedAt: now, + } +} diff --git a/components/crm/internal/adapters/mongodb/encryption/keyset_test.go b/components/crm/internal/adapters/mongodb/encryption/keyset_test.go new file mode 100644 index 000000000..76a76cffc --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/keyset_test.go @@ -0,0 +1,161 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "testing" + "time" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKeysetFromEntity(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + rotatedAt := now.Add(time.Hour) + + entity := &mmodel.OrganizationKeyset{ + TenantID: "tenant-a", + OrganizationID: "org-a", + KEKPath: "transit/keys/org-a", + WrappedKeyset: "vault:v1:encrypted-dek", + KeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 2, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "LEGACY_AES_GCM", IsPrimary: false}, + {KeyID: 2, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + }, + }, + LegacyKeyImported: true, + WrappedHMACKeyset: "vault:v1:encrypted-hmac", + HMACKeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 1}, + LegacyHMACKeyImported: true, + Revision: 5, + CreatedAt: now, + RotatedAt: &rotatedAt, + } + + model := KeysetFromEntity(entity) + + require.NotNil(t, model) + assert.Equal(t, entity.TenantID, model.TenantID) + assert.Equal(t, entity.OrganizationID, model.OrganizationID) + assert.Equal(t, entity.KEKPath, model.KEKPath) + assert.Equal(t, entity.WrappedKeyset, model.WrappedKeyset) + assert.Equal(t, entity.KeysetInfo.PrimaryKeyID, model.KeysetInfo.PrimaryKeyID) + assert.Len(t, model.KeysetInfo.Keys, 2) + assert.Equal(t, entity.LegacyKeyImported, model.LegacyKeyImported) + assert.Equal(t, entity.WrappedHMACKeyset, model.WrappedHMACKeyset) + assert.Equal(t, entity.LegacyHMACKeyImported, model.LegacyHMACKeyImported) + assert.Equal(t, entity.Revision, model.Revision) + assert.Equal(t, entity.CreatedAt, model.CreatedAt) + assert.Equal(t, entity.RotatedAt, model.RotatedAt) +} + +func TestKeysetFromEntity_NilEntity(t *testing.T) { + t.Parallel() + + model := KeysetFromEntity(nil) + + assert.Nil(t, model) +} + +func TestKeysetMongoDBModel_ToEntity(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + rotatedAt := now.Add(time.Hour) + + model := &KeysetMongoDBModel{ + TenantID: "tenant-a", + OrganizationID: "org-a", + KEKPath: "transit/keys/org-a", + WrappedKeyset: "vault:v1:encrypted-dek", + KeysetInfo: KeysetInfoModel{ + PrimaryKeyID: 2, + Keys: []KeyInfoModel{ + {KeyID: 1, Status: "ENABLED", Type: "LEGACY_AES_GCM", IsPrimary: false}, + {KeyID: 2, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + }, + }, + LegacyKeyImported: true, + WrappedHMACKeyset: "vault:v1:encrypted-hmac", + HMACKeysetInfo: KeysetInfoModel{PrimaryKeyID: 1}, + LegacyHMACKeyImported: true, + Revision: 5, + CreatedAt: now, + RotatedAt: &rotatedAt, + } + + entity := model.ToEntity() + + require.NotNil(t, entity) + assert.Equal(t, model.TenantID, entity.TenantID) + assert.Equal(t, model.OrganizationID, entity.OrganizationID) + assert.Equal(t, model.KEKPath, entity.KEKPath) + assert.Equal(t, model.WrappedKeyset, entity.WrappedKeyset) + assert.Equal(t, model.KeysetInfo.PrimaryKeyID, entity.KeysetInfo.PrimaryKeyID) + assert.Len(t, entity.KeysetInfo.Keys, 2) + assert.Equal(t, model.LegacyKeyImported, entity.LegacyKeyImported) + assert.Equal(t, model.WrappedHMACKeyset, entity.WrappedHMACKeyset) + assert.Equal(t, model.LegacyHMACKeyImported, entity.LegacyHMACKeyImported) + assert.Equal(t, model.Revision, entity.Revision) + assert.Equal(t, model.CreatedAt, entity.CreatedAt) + assert.Equal(t, model.RotatedAt, entity.RotatedAt) +} + +func TestKeysetMongoDBModel_ToEntity_NilModel(t *testing.T) { + t.Parallel() + + var model *KeysetMongoDBModel + + entity := model.ToEntity() + + assert.Nil(t, entity) +} + +func TestKeysetConversion_RoundTrip(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + + original := &mmodel.OrganizationKeyset{ + TenantID: "tenant-a", + OrganizationID: "org-a", + KEKPath: "transit/keys/org-a", + WrappedKeyset: "vault:v1:encrypted-dek", + KeysetInfo: mmodel.KeysetInfo{ + PrimaryKeyID: 2, + Keys: []mmodel.KeyInfo{ + {KeyID: 1, Status: "ENABLED", Type: "LEGACY_AES_GCM", IsPrimary: false}, + {KeyID: 2, Status: "ENABLED", Type: "AES256_GCM", IsPrimary: true}, + }, + }, + LegacyKeyImported: true, + WrappedHMACKeyset: "vault:v1:encrypted-hmac", + HMACKeysetInfo: mmodel.KeysetInfo{PrimaryKeyID: 1}, + LegacyHMACKeyImported: true, + Revision: 5, + CreatedAt: now, + } + + // Convert to model and back + model := KeysetFromEntity(original) + recovered := model.ToEntity() + + assert.Equal(t, original.TenantID, recovered.TenantID) + assert.Equal(t, original.OrganizationID, recovered.OrganizationID) + assert.Equal(t, original.KEKPath, recovered.KEKPath) + assert.Equal(t, original.WrappedKeyset, recovered.WrappedKeyset) + assert.Equal(t, original.KeysetInfo.PrimaryKeyID, recovered.KeysetInfo.PrimaryKeyID) + assert.Equal(t, len(original.KeysetInfo.Keys), len(recovered.KeysetInfo.Keys)) + assert.Equal(t, original.LegacyKeyImported, recovered.LegacyKeyImported) + assert.Equal(t, original.WrappedHMACKeyset, recovered.WrappedHMACKeyset) + assert.Equal(t, original.LegacyHMACKeyImported, recovered.LegacyHMACKeyImported) + assert.Equal(t, original.Revision, recovered.Revision) +} diff --git a/components/crm/internal/adapters/mongodb/encryption/registry.go b/components/crm/internal/adapters/mongodb/encryption/registry.go new file mode 100644 index 000000000..e1c95594f --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/registry.go @@ -0,0 +1,80 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "time" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" +) + +// RegistryMongoDBModel is the MongoDB representation of OrganizationRegistryRecord. +type RegistryMongoDBModel struct { + TenantID string `bson:"tenant_id,omitempty"` + OrganizationID string `bson:"organization_id"` + Status mmodel.RegistryStatus `bson:"status"` + ProtectionModel mmodel.ProtectionModel `bson:"protection_model"` + CurrentVersion int `bson:"current_version"` + ReadableVersions []int `bson:"readable_versions"` + Revision int64 `bson:"revision"` + LegacyReadable bool `bson:"legacy_readable"` + CreatedAt time.Time `bson:"created_at"` + UpdatedAt time.Time `bson:"updated_at"` + CreatedBy string `bson:"created_by"` + UpdatedBy string `bson:"updated_by"` + LastTransitionReason string `bson:"last_transition_reason"` +} + +// RegistryFromEntity converts a domain OrganizationRegistryRecord to MongoDB model. +func RegistryFromEntity(r *mmodel.OrganizationRegistryRecord) *RegistryMongoDBModel { + if r == nil { + return nil + } + + readableVersions := make([]int, len(r.ReadableVersions)) + copy(readableVersions, r.ReadableVersions) + + return &RegistryMongoDBModel{ + TenantID: r.TenantID, + OrganizationID: r.OrganizationID, + Status: r.Status, + ProtectionModel: r.ProtectionModel, + CurrentVersion: r.CurrentVersion, + ReadableVersions: readableVersions, + Revision: r.Revision, + LegacyReadable: r.LegacyReadable, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + CreatedBy: r.CreatedBy, + UpdatedBy: r.UpdatedBy, + LastTransitionReason: r.LastTransitionReason, + } +} + +// ToEntity converts the MongoDB model to a domain OrganizationRegistryRecord. +func (m *RegistryMongoDBModel) ToEntity() *mmodel.OrganizationRegistryRecord { + if m == nil { + return nil + } + + readableVersions := make([]int, len(m.ReadableVersions)) + copy(readableVersions, m.ReadableVersions) + + return &mmodel.OrganizationRegistryRecord{ + TenantID: m.TenantID, + OrganizationID: m.OrganizationID, + Status: m.Status, + ProtectionModel: m.ProtectionModel, + CurrentVersion: m.CurrentVersion, + ReadableVersions: readableVersions, + Revision: m.Revision, + LegacyReadable: m.LegacyReadable, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + CreatedBy: m.CreatedBy, + UpdatedBy: m.UpdatedBy, + LastTransitionReason: m.LastTransitionReason, + } +} diff --git a/components/crm/internal/adapters/mongodb/encryption/registry.mongodb.go b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb.go new file mode 100644 index 000000000..8398e66aa --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb.go @@ -0,0 +1,219 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "errors" + "fmt" + + libCommons "github.com/LerianStudio/lib-commons/v5/commons" + libMongo "github.com/LerianStudio/lib-commons/v5/commons/mongo" + libOpenTelemetry "github.com/LerianStudio/lib-commons/v5/commons/opentelemetry" + tmcore "github.com/LerianStudio/lib-commons/v5/commons/tenant-manager/core" + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.opentelemetry.io/otel/attribute" +) + +const registryCollection = "organization_registry" + +// RegistryRepository provides an interface for operations related to registry entities. +// +//go:generate go run go.uber.org/mock/mockgen@v0.6.0 --destination=registry.mongodb_mock.go --package=encryption . RegistryRepository +type RegistryRepository interface { + Save(ctx context.Context, record *mmodel.OrganizationRegistryRecord) error + Get(ctx context.Context, organizationID string) (*mmodel.OrganizationRegistryRecord, error) + Update(ctx context.Context, record *mmodel.OrganizationRegistryRecord, expectedRevision int64) error +} + +// RegistryMongoDBRepository is a MongoDB-specific implementation of RegistryRepository. +type RegistryMongoDBRepository struct { + connection *libMongo.Client +} + +// NewRegistryMongoDBRepository returns a new instance of RegistryMongoDBRepository using the given MongoDB connection. +// In multi-tenant mode, connection may be nil — the per-request tenant context provides the database. +func NewRegistryMongoDBRepository(connection *libMongo.Client) (*RegistryMongoDBRepository, error) { + r := &RegistryMongoDBRepository{ + connection: connection, + } + + if connection != nil { + if _, err := r.connection.Database(context.Background()); err != nil { + return nil, fmt.Errorf("failed to connect to MongoDB for registry repository: %w", err) + } + } + + return r, nil +} + +func (r *RegistryMongoDBRepository) Save(ctx context.Context, record *mmodel.OrganizationRegistryRecord) error { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.registry.save") + defer span.End() + + if record == nil { + return fmt.Errorf("registry record is required") + } + + span.SetAttributes(attribute.String("app.request.organization_id", record.OrganizationID)) + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return err + } + + if err := r.ensureIndexes(ctx, collection); err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to create registry indexes", err) + return fmt.Errorf("create registry indexes: %w", err) + } + + model := RegistryFromEntity(record) + + filter := bson.M{"organization_id": record.OrganizationID} + update := bson.M{"$setOnInsert": model} + + result, err := collection.UpdateOne(ctx, filter, update, options.Update().SetUpsert(true)) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to save organization registry", err) + return fmt.Errorf("save organization registry: %w", err) + } + + if result.MatchedCount > 0 { + return mmodel.ErrRegistryAlreadyExists + } + + return nil +} + +func (r *RegistryMongoDBRepository) Get(ctx context.Context, organizationID string) (*mmodel.OrganizationRegistryRecord, error) { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.registry.get") + defer span.End() + + span.SetAttributes(attribute.String("app.request.organization_id", organizationID)) + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return nil, err + } + + var model RegistryMongoDBModel + + if err := collection.FindOne(ctx, bson.M{"organization_id": organizationID}).Decode(&model); err != nil { + if errors.Is(err, mongo.ErrNoDocuments) { + return nil, mmodel.ErrRegistryNotFound + } + + libOpenTelemetry.HandleSpanError(span, "Failed to get organization registry", err) + + return nil, fmt.Errorf("get organization registry: %w", err) + } + + return model.ToEntity(), nil +} + +func (r *RegistryMongoDBRepository) Update(ctx context.Context, record *mmodel.OrganizationRegistryRecord, expectedRevision int64) error { + _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled // consistent with codebase pattern + + ctx, span := tracer.Start(ctx, "mongodb.registry.update") + defer span.End() + + if record == nil { + return fmt.Errorf("registry record is required") + } + + span.SetAttributes( + attribute.String("app.request.organization_id", record.OrganizationID), + attribute.Int64("app.request.expected_revision", expectedRevision), + ) + + collection, err := r.collection(ctx) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to get collection", err) + return err + } + + // Create model from entity and set the new revision on the model, not on the input entity. + // This prevents mutation of the caller's object if the database operation fails. + model := RegistryFromEntity(record) + model.Revision = expectedRevision + 1 + + result, err := collection.ReplaceOne(ctx, bson.M{"organization_id": record.OrganizationID, "revision": expectedRevision}, model) + if err != nil { + libOpenTelemetry.HandleSpanError(span, "Failed to update organization registry", err) + return fmt.Errorf("update organization registry: %w", err) + } + + if result.MatchedCount == 0 { + return mmodel.ErrRegistryRevisionConflict + } + + span.SetAttributes(attribute.Int64("db.rows_affected", result.ModifiedCount)) + + return nil +} + +// getDatabase resolves the MongoDB database for the current request. +// In multi-tenant mode, the middleware injects a tenant-specific *mongo.Database into context. +// In single-tenant mode (or when no tenant context exists), falls back to the static connection. +func (r *RegistryMongoDBRepository) getDatabase(ctx context.Context) (*mongo.Database, error) { + if r.connection == nil { + if db := tmcore.GetMBContext(ctx); db != nil { + return db, nil + } + + return nil, fmt.Errorf("no database connection available: multi-tenant context required but not present, and no static connection configured") + } + + if db := tmcore.GetMBContext(ctx); db != nil { + return db, nil + } + + return r.connection.Database(ctx) +} + +func (r *RegistryMongoDBRepository) collection(ctx context.Context) (*mongo.Collection, error) { + db, err := r.getDatabase(ctx) + if err != nil { + return nil, err + } + + return db.Collection(registryCollection), nil +} + +// ensureIndexes ensures indexes exist for the registry collection. +// Uses per-database tracking to handle multi-tenant mode correctly. +// Retries on failure — indexes are only marked as done after successful creation. +func (r *RegistryMongoDBRepository) ensureIndexes(ctx context.Context, collection *mongo.Collection) error { + key := collection.Database().Name() + ":" + registryCollection + + return globalIndexTracker.ensureOnce(key, func() error { + return r.createIndexes(ctx, collection) + }) +} + +// createIndexes ensures indexes exist for the registry collection. +func (r *RegistryMongoDBRepository) createIndexes(ctx context.Context, collection *mongo.Collection) error { + indexModels := []mongo.IndexModel{ + { + Keys: bson.D{{Key: "organization_id", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + } + + _, err := collection.Indexes().CreateMany(ctx, indexModels) + + return err +} + +var _ RegistryRepository = (*RegistryMongoDBRepository)(nil) diff --git a/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_integration_test.go b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_integration_test.go new file mode 100644 index 000000000..f0646daf1 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_integration_test.go @@ -0,0 +1,654 @@ +//go:build integration + +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "testing" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + mongotestutil "github.com/LerianStudio/midaz/v3/tests/utils/mongodb" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" +) + +// ============================================================================ +// Test Helpers +// ============================================================================ + +// createRegistryRepository creates a RegistryMongoDBRepository for integration testing. +// Resets the global index tracker for this database to ensure a fresh state, +// since each test runs with a new MongoDB container. +func createRegistryRepository(t *testing.T, container *mongotestutil.ContainerResult) *RegistryMongoDBRepository { + t.Helper() + + // Reset index tracker state for this database — each test has a fresh container + globalIndexTracker.reset(container.DBName + ":" + registryCollection) + + conn := mongotestutil.CreateConnection(t, container.URI, container.DBName) + + repo, err := NewRegistryMongoDBRepository(conn) + require.NoError(t, err) + + return repo +} + +// createValidRegistry creates a valid OrganizationRegistryRecord for testing. +func createValidRegistry(t *testing.T, tenantID, organizationID string) *mmodel.OrganizationRegistryRecord { + t.Helper() + + record, err := mmodel.NewOrganizationRegistryRecord(tenantID, organizationID, "system", "integration test setup") + require.NoError(t, err) + + return record +} + +// ============================================================================ +// Save Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_Save(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-" + uuid.New().String()[:8] + organizationID := "org-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + // Act + err := repo.Save(ctx, registry) + + // Assert + require.NoError(t, err, "Save should not return error") + + // Verify via direct query + count := mongotestutil.CountDocuments(t, container.Database, registryCollection, bson.M{"organization_id": organizationID}) + assert.Equal(t, int64(1), count, "should have exactly 1 document") +} + +func TestIntegration_RegistryRepo_Save_AlreadyExists(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-dup-" + uuid.New().String()[:8] + organizationID := "org-dup-" + uuid.New().String()[:8] + + registry1 := createValidRegistry(t, tenantID, organizationID) + + // First save should succeed + err := repo.Save(ctx, registry1) + require.NoError(t, err, "first save should succeed") + + // Act - Try to save again with same organization_id + registry2 := createValidRegistry(t, tenantID, organizationID) + err = repo.Save(ctx, registry2) + + // Assert + require.Error(t, err, "second save should fail") + assert.ErrorIs(t, err, mmodel.ErrRegistryAlreadyExists, "should return ErrRegistryAlreadyExists") +} + +func TestIntegration_RegistryRepo_Save_DifferentOrganizations(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-" + uuid.New().String()[:8] + org1 := "org-1-" + uuid.New().String()[:8] + org2 := "org-2-" + uuid.New().String()[:8] + + registry1 := createValidRegistry(t, tenantID, org1) + registry2 := createValidRegistry(t, tenantID, org2) + + // Act + err1 := repo.Save(ctx, registry1) + err2 := repo.Save(ctx, registry2) + + // Assert - Both should succeed + require.NoError(t, err1, "first org save should succeed") + require.NoError(t, err2, "second org save should succeed") + + count := mongotestutil.CountDocuments(t, container.Database, registryCollection, bson.M{}) + assert.Equal(t, int64(2), count, "should have 2 documents") +} + +func TestIntegration_RegistryRepo_Save_InitialStatus(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-status-" + uuid.New().String()[:8] + organizationID := "org-status-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + // Act + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Assert - Get and verify initial status (NewOrganizationRegistryRecord sets pending_migration) + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.RegistryStatusPendingMigration, result.Status, "initial status should be pending_migration") + assert.Equal(t, mmodel.ProtectionModelLegacy, result.ProtectionModel, "initial protection model should be legacy") +} + +func TestIntegration_RegistryRepo_Save_WithReadableVersions(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-versions-" + uuid.New().String()[:8] + organizationID := "org-versions-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + registry.ReadableVersions = []int{1, 2, 3} + registry.CurrentVersion = 3 + + // Act + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Assert + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3}, result.ReadableVersions) + assert.Equal(t, 3, result.CurrentVersion) +} + +// ============================================================================ +// Get Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_Get(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-get-" + uuid.New().String()[:8] + organizationID := "org-get-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Act + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, organizationID, result.OrganizationID) + assert.Equal(t, tenantID, result.TenantID) + assert.Equal(t, mmodel.RegistryStatusPendingMigration, result.Status) +} + +func TestIntegration_RegistryRepo_Get_NotFound(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + nonExistentOrg := "org-notfound-" + uuid.New().String()[:8] + + // Act + result, err := repo.Get(ctx, nonExistentOrg) + + // Assert + require.Error(t, err, "should return error for non-existent organization") + assert.Nil(t, result) + assert.ErrorIs(t, err, mmodel.ErrRegistryNotFound, "should return ErrRegistryNotFound") +} + +func TestIntegration_RegistryRepo_Get_ReturnsAllFields(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-fields-" + uuid.New().String()[:8] + organizationID := "org-fields-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + registry.LegacyReadable = true + registry.CurrentVersion = 2 + registry.ReadableVersions = []int{1, 2} + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Act + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, tenantID, result.TenantID) + assert.Equal(t, organizationID, result.OrganizationID) + assert.Equal(t, mmodel.RegistryStatusPendingMigration, result.Status) + assert.Equal(t, mmodel.ProtectionModelLegacy, result.ProtectionModel) + assert.Equal(t, 2, result.CurrentVersion) + assert.Equal(t, []int{1, 2}, result.ReadableVersions) + assert.True(t, result.LegacyReadable) + assert.Equal(t, "system", result.CreatedBy) + assert.Equal(t, "system", result.UpdatedBy) + assert.Equal(t, "integration test setup", result.LastTransitionReason) + assert.False(t, result.CreatedAt.IsZero()) + assert.False(t, result.UpdatedAt.IsZero()) +} + +// ============================================================================ +// Update Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_Update(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-update-" + uuid.New().String()[:8] + organizationID := "org-update-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Get the saved registry + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // Modify registry + saved.Status = mmodel.RegistryStatusActive + saved.ProtectionModel = mmodel.ProtectionModelEnvelope + saved.CurrentVersion = 1 + saved.ReadableVersions = []int{1} + saved.UpdatedBy = "migration-service" + saved.LastTransitionReason = "migration completed" + + // Act + err = repo.Update(ctx, saved, saved.Revision) + + // Assert + require.NoError(t, err) + + // Verify update + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.RegistryStatusActive, result.Status) + assert.Equal(t, mmodel.ProtectionModelEnvelope, result.ProtectionModel) + assert.Equal(t, 1, result.CurrentVersion) + assert.Equal(t, []int{1}, result.ReadableVersions) + assert.Equal(t, "migration-service", result.UpdatedBy) + assert.Equal(t, "migration completed", result.LastTransitionReason) +} + +func TestIntegration_RegistryRepo_Update_IncrementRevision(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-increv-" + uuid.New().String()[:8] + organizationID := "org-increv-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + initialRevision := saved.Revision + + // Act + saved.Status = mmodel.RegistryStatusPendingMigration + err = repo.Update(ctx, saved, initialRevision) + require.NoError(t, err) + + // Assert + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, initialRevision+1, result.Revision, "revision should increment") +} + +func TestIntegration_RegistryRepo_Update_RevisionConflict(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-conflict-" + uuid.New().String()[:8] + organizationID := "org-conflict-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // Act - Try to update with wrong revision + saved.Status = mmodel.RegistryStatusFailed + wrongRevision := int64(999) + err = repo.Update(ctx, saved, wrongRevision) + + // Assert + require.Error(t, err, "should return error for revision conflict") + assert.ErrorIs(t, err, mmodel.ErrRegistryRevisionConflict, "should return ErrRegistryRevisionConflict") + + // Verify original data unchanged + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.NotEqual(t, mmodel.RegistryStatusFailed, result.Status, "status should not be updated") +} + +func TestIntegration_RegistryRepo_Update_DoesNotMutateInputOnFailure(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-nomutate-" + uuid.New().String()[:8] + organizationID := "org-nomutate-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + initialRevision := saved.Revision + wrongRevision := int64(999) + + // Act - Update should fail + err = repo.Update(ctx, saved, wrongRevision) + + // Assert - Input object should not be mutated + require.Error(t, err) + assert.Equal(t, initialRevision, saved.Revision, "input revision should not be mutated on failure") +} + +func TestIntegration_RegistryRepo_Update_NotFound(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-notfound-" + uuid.New().String()[:8] + nonExistentOrg := "org-updatenotfound-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, nonExistentOrg) + + // Act - Try to update non-existent registry + err := repo.Update(ctx, registry, 1) + + // Assert + require.Error(t, err, "should return error for non-existent organization") + assert.ErrorIs(t, err, mmodel.ErrRegistryRevisionConflict, "should return revision conflict (no document matched)") +} + +func TestIntegration_RegistryRepo_Update_MultipleUpdates(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-multi-" + uuid.New().String()[:8] + organizationID := "org-multi-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Act - Perform multiple sequential updates (simulating migration workflow) + statusTransitions := []mmodel.RegistryStatus{ + mmodel.RegistryStatusPendingMigration, + mmodel.RegistryStatusPartiallyMigrated, + mmodel.RegistryStatusMigrationComplete, + mmodel.RegistryStatusActive, + } + + for _, status := range statusTransitions { + current, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + current.Status = status + current.LastTransitionReason = "transition to " + string(status) + err = repo.Update(ctx, current, current.Revision) + require.NoError(t, err, "update to %s should succeed", status) + } + + // Assert + final, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.RegistryStatusActive, final.Status) + assert.Equal(t, int64(5), final.Revision, "revision should be 5 after 4 updates") // 1 initial + 4 updates +} + +// ============================================================================ +// Status Transition Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_Update_StatusTransitions(t *testing.T) { + tests := []struct { + name string + fromStatus mmodel.RegistryStatus + toStatus mmodel.RegistryStatus + }{ + {"pending_to_partially", mmodel.RegistryStatusPendingMigration, mmodel.RegistryStatusPartiallyMigrated}, + {"partially_to_complete", mmodel.RegistryStatusPartiallyMigrated, mmodel.RegistryStatusMigrationComplete}, + {"complete_to_active", mmodel.RegistryStatusMigrationComplete, mmodel.RegistryStatusActive}, + {"pending_to_failed", mmodel.RegistryStatusPendingMigration, mmodel.RegistryStatusFailed}, + {"failed_to_pending", mmodel.RegistryStatusFailed, mmodel.RegistryStatusPendingMigration}, + {"active_to_blocked", mmodel.RegistryStatusActive, mmodel.RegistryStatusBlocked}, + {"active_to_legacy", mmodel.RegistryStatusActive, mmodel.RegistryStatusLegacy}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-" + uuid.New().String()[:8] + organizationID := "org-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + registry.Status = tt.fromStatus + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // Act + saved.Status = tt.toStatus + saved.LastTransitionReason = tt.name + err = repo.Update(ctx, saved, saved.Revision) + + // Assert + require.NoError(t, err) + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, tt.toStatus, result.Status) + }) + } +} + +// ============================================================================ +// Protection Model Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_Update_ProtectionModel(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-protection-" + uuid.New().String()[:8] + organizationID := "org-protection-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + saved, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.ProtectionModelLegacy, saved.ProtectionModel) + + // Act - Switch to envelope encryption + saved.ProtectionModel = mmodel.ProtectionModelEnvelope + saved.Status = mmodel.RegistryStatusActive + err = repo.Update(ctx, saved, saved.Revision) + + // Assert + require.NoError(t, err) + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.ProtectionModelEnvelope, result.ProtectionModel) +} + +// ============================================================================ +// Index Constraint Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_UniqueIndex_OrganizationID(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-unique-" + uuid.New().String()[:8] + organizationID := "org-unique-" + uuid.New().String()[:8] + + // Save first registry + registry1 := createValidRegistry(t, tenantID, organizationID) + err := repo.Save(ctx, registry1) + require.NoError(t, err) + + // Act - Try to insert directly via MongoDB (bypassing Save logic) + registry2Model := RegistryFromEntity(createValidRegistry(t, tenantID, organizationID)) + _, err = container.Database.Collection(registryCollection).InsertOne(ctx, registry2Model) + + // Assert - Should fail due to unique index + require.Error(t, err, "direct insert with duplicate organization_id should fail") + assert.Contains(t, err.Error(), "duplicate key", "should be a duplicate key error") +} + +// ============================================================================ +// Round-Trip Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_RoundTrip(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-roundtrip-" + uuid.New().String()[:8] + organizationID := "org-roundtrip-" + uuid.New().String()[:8] + original := createValidRegistry(t, tenantID, organizationID) + original.LegacyReadable = true + original.CurrentVersion = 3 + original.ReadableVersions = []int{1, 2, 3} + + // Act - Save and retrieve + err := repo.Save(ctx, original) + require.NoError(t, err) + + result, err := repo.Get(ctx, organizationID) + + // Assert - All fields should match + require.NoError(t, err) + require.NotNil(t, result) + + assert.Equal(t, original.TenantID, result.TenantID) + assert.Equal(t, original.OrganizationID, result.OrganizationID) + assert.Equal(t, mmodel.RegistryStatusPendingMigration, result.Status, "status should be pending_migration from constructor") + assert.Equal(t, original.ProtectionModel, result.ProtectionModel) + assert.Equal(t, original.CurrentVersion, result.CurrentVersion) + assert.Equal(t, original.ReadableVersions, result.ReadableVersions) + assert.Equal(t, original.LegacyReadable, result.LegacyReadable) + assert.Equal(t, original.CreatedBy, result.CreatedBy) + assert.Equal(t, original.UpdatedBy, result.UpdatedBy) + assert.Equal(t, original.LastTransitionReason, result.LastTransitionReason) +} + +func TestIntegration_RegistryRepo_RoundTrip_EmptyReadableVersions(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-empty-" + uuid.New().String()[:8] + organizationID := "org-empty-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + registry.ReadableVersions = []int{} + + // Act + err := repo.Save(ctx, registry) + require.NoError(t, err) + + result, err := repo.Get(ctx, organizationID) + + // Assert + require.NoError(t, err) + assert.Empty(t, result.ReadableVersions, "empty readable_versions should be preserved") +} + +// ============================================================================ +// Concurrent Update Tests +// ============================================================================ + +func TestIntegration_RegistryRepo_ConcurrentUpdate_OptimisticLocking(t *testing.T) { + // Arrange + container := mongotestutil.SetupContainer(t) + repo := createRegistryRepository(t, container) + ctx := context.Background() + + tenantID := "tenant-concurrent-" + uuid.New().String()[:8] + organizationID := "org-concurrent-" + uuid.New().String()[:8] + registry := createValidRegistry(t, tenantID, organizationID) + + err := repo.Save(ctx, registry) + require.NoError(t, err) + + // Simulate two concurrent reads + snapshot1, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + snapshot2, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + + // First update succeeds + snapshot1.Status = mmodel.RegistryStatusPendingMigration + err = repo.Update(ctx, snapshot1, snapshot1.Revision) + require.NoError(t, err, "first update should succeed") + + // Act - Second update should fail (stale revision) + snapshot2.Status = mmodel.RegistryStatusActive + err = repo.Update(ctx, snapshot2, snapshot2.Revision) + + // Assert + require.Error(t, err, "second update should fail due to stale revision") + assert.ErrorIs(t, err, mmodel.ErrRegistryRevisionConflict) + + // Verify first update was applied + result, err := repo.Get(ctx, organizationID) + require.NoError(t, err) + assert.Equal(t, mmodel.RegistryStatusPendingMigration, result.Status) + assert.Equal(t, int64(2), result.Revision) +} diff --git a/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_test.go b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_test.go new file mode 100644 index 000000000..10d064994 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/registry.mongodb_test.go @@ -0,0 +1,249 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "context" + "testing" + + tmcore "github.com/LerianStudio/lib-commons/v5/commons/tenant-manager/core" + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegistryMongoDBRepositoryImplementsRepository(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + + require.NoError(t, err) + require.Implements(t, (*RegistryRepository)(nil), repo) +} + +func TestNewRegistryMongoDBRepository_NilConnection(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + + require.NoError(t, err) + require.NotNil(t, repo) + assert.Nil(t, repo.connection) +} + +func TestRegistryMongoDBRepository_Save_NilRecord(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + err = repo.Save(context.Background(), nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), "registry record is required") +} + +func TestRegistryMongoDBRepository_Update_NilRecord(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + err = repo.Update(context.Background(), nil, 1) + + require.Error(t, err) + assert.Contains(t, err.Error(), "registry record is required") +} + +func TestRegistryMongoDBRepository_getDatabase_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + // Without tenant context and without connection, should error + db, err := repo.getDatabase(context.Background()) + + require.Error(t, err) + assert.Nil(t, db) + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestRegistryMongoDBRepository_collection_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + // Without tenant context and without connection, should error + coll, err := repo.collection(context.Background()) + + require.Error(t, err) + assert.Nil(t, coll) + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestRegistryMongoDBRepository_Save_ValidRecord(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + record, err := mmodel.NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + // Without connection, Save will fail at collection() but validation passes + err = repo.Save(context.Background(), record) + + require.Error(t, err) + // Error should be about connection, not validation + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestRegistryMongoDBRepository_Get_NoConnection(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + record, err := repo.Get(context.Background(), "org-a") + + require.Error(t, err) + assert.Nil(t, record) + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestRegistryMongoDBRepository_Update_ValidRecord(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + record, err := mmodel.NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + // Without connection, Update will fail at collection() but validation passes + err = repo.Update(context.Background(), record, 1) + + require.Error(t, err) + // Error should be about connection, not validation + assert.Contains(t, err.Error(), "no database connection available") +} + +func TestRegistryMongoDBRepository_Update_DoesNotMutateInputOnFailure(t *testing.T) { + t.Parallel() + + repo, err := NewRegistryMongoDBRepository(nil) + require.NoError(t, err) + + record, err := mmodel.NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + initialRevision := record.Revision + expectedRevision := int64(5) + + // Update should NOT mutate the input record when the operation fails. + // The revision increment is applied only to the internal model, not the caller's object. + err = repo.Update(context.Background(), record, expectedRevision) + + require.Error(t, err) + // The input record should remain unchanged because the operation failed. + assert.Equal(t, initialRevision, record.Revision, "input record should not be mutated on failure") +} + +func TestRegistryMongoDBRepository_getDatabase_TenantContextTakesPrecedence(t *testing.T) { + t.Parallel() + + tenantDB := newDisconnectedDatabase(t, "tenant-registry-priority") + repo := &RegistryMongoDBRepository{connection: newPlaceholderConnection("static-db")} + ctx := tmcore.ContextWithMB(context.Background(), tenantDB) + + db, err := repo.getDatabase(ctx) + + require.NoError(t, err, "getDatabase should not return error when tenant DB is in context") + require.NotNil(t, db, "returned database must not be nil") + assert.Same(t, tenantDB, db, "tenant DB must take precedence over static connection") + assert.Equal(t, "tenant-registry-priority", db.Name(), "returned DB name should be the tenant DB name") +} + +func TestRegistryMongoDBRepository_getDatabase_FallbackToStaticConnection(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctx context.Context + wantErr bool + }{ + {name: "plain_background_context_no_tenant", ctx: context.Background(), wantErr: true}, + {name: "context_with_unrelated_values", ctx: context.WithValue(context.Background(), struct{}{}, "unrelated"), wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + repo := &RegistryMongoDBRepository{connection: newPlaceholderConnection("fallback-db")} + + db, err := repo.getDatabase(tt.ctx) + + if tt.wantErr { + require.Error(t, err, "getDatabase should return error when static connection has no live MongoDB") + assert.Nil(t, db, "database should be nil when static connection fails") + } + }) + } +} + +func TestRegistryMongoDBRepository_getDatabase_NilConnectionWithTenantContext(t *testing.T) { + t.Parallel() + + tenantDB := newDisconnectedDatabase(t, "tenant-registry-nil-conn") + repo := &RegistryMongoDBRepository{connection: nil} + ctx := tmcore.ContextWithMB(context.Background(), tenantDB) + + db, err := repo.getDatabase(ctx) + + require.NoError(t, err, "getDatabase should not error when tenant DB is in context") + require.NotNil(t, db, "returned database must not be nil") + assert.Same(t, tenantDB, db, "must return the exact tenant DB from context") + assert.Equal(t, "tenant-registry-nil-conn", db.Name(), "database name should match the tenant DB") +} + +func TestRegistryMongoDBRepository_getDatabase_NilConnectionWithoutTenantContext(t *testing.T) { + t.Parallel() + + repo := &RegistryMongoDBRepository{connection: nil} + + db, err := repo.getDatabase(context.Background()) + + require.Error(t, err, "getDatabase must return error when connection is nil and no tenant context exists") + assert.Nil(t, db, "database must be nil when no connection and no tenant context") + assert.Contains(t, err.Error(), "no database connection available", "error message should indicate no connection is available") +} + +func TestRegistryMongoDBRepository_getDatabase_TwoTenants_ResolveToDifferentDatabases(t *testing.T) { + t.Parallel() + + tenantAcmeDB := newDisconnectedDatabase(t, "tenant-registry-acme") + tenantGlobexDB := newDisconnectedDatabase(t, "tenant-registry-globex") + + repo := &RegistryMongoDBRepository{connection: newPlaceholderConnection("static-db")} + + ctxTenantA := tmcore.ContextWithMB(context.Background(), tenantAcmeDB) + ctxTenantB := tmcore.ContextWithMB(context.Background(), tenantGlobexDB) + + dbA, errA := repo.getDatabase(ctxTenantA) + dbB, errB := repo.getDatabase(ctxTenantB) + + require.NoError(t, errA, "getDatabase should not error for tenant A") + require.NoError(t, errB, "getDatabase should not error for tenant B") + require.NotNil(t, dbA, "tenant A database must not be nil") + require.NotNil(t, dbB, "tenant B database must not be nil") + + assert.NotSame(t, dbA, dbB, "two different tenants must resolve to different *mongo.Database instances") + assert.NotEqual(t, dbA.Name(), dbB.Name(), "two different tenants must have different database names") + assert.Equal(t, "tenant-registry-acme", dbA.Name(), "tenant A database name must match") + assert.Equal(t, "tenant-registry-globex", dbB.Name(), "tenant B database name must match") +} diff --git a/components/crm/internal/adapters/mongodb/encryption/registry_test.go b/components/crm/internal/adapters/mongodb/encryption/registry_test.go new file mode 100644 index 000000000..2be1c5d11 --- /dev/null +++ b/components/crm/internal/adapters/mongodb/encryption/registry_test.go @@ -0,0 +1,180 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package encryption + +import ( + "testing" + "time" + + "github.com/LerianStudio/midaz/v3/pkg/mmodel" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegistryFromEntity(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + + entity := &mmodel.OrganizationRegistryRecord{ + TenantID: "tenant-a", + OrganizationID: "org-a", + Status: mmodel.RegistryStatusActive, + ProtectionModel: mmodel.ProtectionModelEnvelope, + CurrentVersion: 1, + ReadableVersions: []int{1}, + Revision: 2, + LegacyReadable: true, + CreatedAt: now, + UpdatedAt: now, + CreatedBy: "system", + UpdatedBy: "admin", + LastTransitionReason: "activated", + } + + model := RegistryFromEntity(entity) + + require.NotNil(t, model) + assert.Equal(t, entity.TenantID, model.TenantID) + assert.Equal(t, entity.OrganizationID, model.OrganizationID) + assert.Equal(t, entity.Status, model.Status) + assert.Equal(t, entity.ProtectionModel, model.ProtectionModel) + assert.Equal(t, entity.CurrentVersion, model.CurrentVersion) + assert.Equal(t, entity.ReadableVersions, model.ReadableVersions) + assert.Equal(t, entity.Revision, model.Revision) + assert.Equal(t, entity.LegacyReadable, model.LegacyReadable) + assert.Equal(t, entity.CreatedAt, model.CreatedAt) + assert.Equal(t, entity.UpdatedAt, model.UpdatedAt) + assert.Equal(t, entity.CreatedBy, model.CreatedBy) + assert.Equal(t, entity.UpdatedBy, model.UpdatedBy) + assert.Equal(t, entity.LastTransitionReason, model.LastTransitionReason) +} + +func TestRegistryFromEntity_NilEntity(t *testing.T) { + t.Parallel() + + model := RegistryFromEntity(nil) + + assert.Nil(t, model) +} + +func TestRegistryFromEntity_CopiesSlice(t *testing.T) { + t.Parallel() + + entity := &mmodel.OrganizationRegistryRecord{ + ReadableVersions: []int{1, 2, 3}, + } + + model := RegistryFromEntity(entity) + + // Modify original slice + entity.ReadableVersions[0] = 999 + + // Model should not be affected + assert.Equal(t, 1, model.ReadableVersions[0]) +} + +func TestRegistryMongoDBModel_ToEntity(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + + model := &RegistryMongoDBModel{ + TenantID: "tenant-a", + OrganizationID: "org-a", + Status: mmodel.RegistryStatusActive, + ProtectionModel: mmodel.ProtectionModelEnvelope, + CurrentVersion: 1, + ReadableVersions: []int{1}, + Revision: 2, + LegacyReadable: true, + CreatedAt: now, + UpdatedAt: now, + CreatedBy: "system", + UpdatedBy: "admin", + LastTransitionReason: "activated", + } + + entity := model.ToEntity() + + require.NotNil(t, entity) + assert.Equal(t, model.TenantID, entity.TenantID) + assert.Equal(t, model.OrganizationID, entity.OrganizationID) + assert.Equal(t, model.Status, entity.Status) + assert.Equal(t, model.ProtectionModel, entity.ProtectionModel) + assert.Equal(t, model.CurrentVersion, entity.CurrentVersion) + assert.Equal(t, model.ReadableVersions, entity.ReadableVersions) + assert.Equal(t, model.Revision, entity.Revision) + assert.Equal(t, model.LegacyReadable, entity.LegacyReadable) + assert.Equal(t, model.CreatedAt, entity.CreatedAt) + assert.Equal(t, model.UpdatedAt, entity.UpdatedAt) + assert.Equal(t, model.CreatedBy, entity.CreatedBy) + assert.Equal(t, model.UpdatedBy, entity.UpdatedBy) + assert.Equal(t, model.LastTransitionReason, entity.LastTransitionReason) +} + +func TestRegistryMongoDBModel_ToEntity_NilModel(t *testing.T) { + t.Parallel() + + var model *RegistryMongoDBModel + + entity := model.ToEntity() + + assert.Nil(t, entity) +} + +func TestRegistryMongoDBModel_ToEntity_CopiesSlice(t *testing.T) { + t.Parallel() + + model := &RegistryMongoDBModel{ + ReadableVersions: []int{1, 2, 3}, + } + + entity := model.ToEntity() + + // Modify model slice + model.ReadableVersions[0] = 999 + + // Entity should not be affected + assert.Equal(t, 1, entity.ReadableVersions[0]) +} + +func TestRegistryConversion_RoundTrip(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + + original := &mmodel.OrganizationRegistryRecord{ + TenantID: "tenant-a", + OrganizationID: "org-a", + Status: mmodel.RegistryStatusActive, + ProtectionModel: mmodel.ProtectionModelEnvelope, + CurrentVersion: 1, + ReadableVersions: []int{1, 2}, + Revision: 5, + LegacyReadable: true, + CreatedAt: now, + UpdatedAt: now, + CreatedBy: "system", + UpdatedBy: "admin", + LastTransitionReason: "key rotated", + } + + // Convert to model and back + model := RegistryFromEntity(original) + recovered := model.ToEntity() + + assert.Equal(t, original.TenantID, recovered.TenantID) + assert.Equal(t, original.OrganizationID, recovered.OrganizationID) + assert.Equal(t, original.Status, recovered.Status) + assert.Equal(t, original.ProtectionModel, recovered.ProtectionModel) + assert.Equal(t, original.CurrentVersion, recovered.CurrentVersion) + assert.Equal(t, original.ReadableVersions, recovered.ReadableVersions) + assert.Equal(t, original.Revision, recovered.Revision) + assert.Equal(t, original.LegacyReadable, recovered.LegacyReadable) + assert.Equal(t, original.CreatedBy, recovered.CreatedBy) + assert.Equal(t, original.UpdatedBy, recovered.UpdatedBy) + assert.Equal(t, original.LastTransitionReason, recovered.LastTransitionReason) +} diff --git a/pkg/constant/errors.go b/pkg/constant/errors.go index d89a064b8..6a2d29eff 100644 --- a/pkg/constant/errors.go +++ b/pkg/constant/errors.go @@ -250,3 +250,14 @@ var ( ErrRelatedPartyStartDateRequired = errors.New("CRM-0028") ErrRelatedPartyEndDateInvalid = errors.New("CRM-0029") ) + +// Encryption and keyset management errors. +var ( + ErrKeysetNotFound = errors.New("ENC-0001") + ErrKeysetAlreadyExists = errors.New("ENC-0002") + ErrKeysetRevisionConflict = errors.New("ENC-0003") + ErrRegistryNotFound = errors.New("ENC-0004") + ErrRegistryAlreadyExists = errors.New("ENC-0005") + ErrRegistryRevisionConflict = errors.New("ENC-0006") + ErrRegistryInvalidTransition = errors.New("ENC-0007") +) diff --git a/pkg/mmodel/organization_keyset.go b/pkg/mmodel/organization_keyset.go new file mode 100644 index 000000000..4e27cf09f --- /dev/null +++ b/pkg/mmodel/organization_keyset.go @@ -0,0 +1,95 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package mmodel + +import ( + "fmt" + "time" + + "github.com/LerianStudio/midaz/v3/pkg/constant" +) + +// Re-export errors from constant package for backward compatibility. +// Callers should migrate to using constant.ErrKeyset* directly. +var ( + ErrKeysetNotFound = constant.ErrKeysetNotFound + ErrKeysetAlreadyExists = constant.ErrKeysetAlreadyExists + ErrKeysetRevisionConflict = constant.ErrKeysetRevisionConflict +) + +// OrganizationKeyset stores wrapped keyset metadata for an organization. +// Wrapped keysets are encrypted by a KEK in the KMS provider. +type OrganizationKeyset struct { + TenantID string + OrganizationID string + KEKPath string + WrappedKeyset string + KeysetInfo KeysetInfo + LegacyKeyImported bool + WrappedHMACKeyset string + HMACKeysetInfo KeysetInfo + LegacyHMACKeyImported bool + Revision int64 + CreatedAt time.Time + RotatedAt *time.Time +} + +// KeysetInfo contains metadata about a Tink keyset without exposing key material. +type KeysetInfo struct { + PrimaryKeyID uint32 + Keys []KeyInfo +} + +// KeyInfo describes a single key within a keyset. +type KeyInfo struct { + KeyID uint32 + Status string + Type string + IsPrimary bool +} + +// Validate checks that required fields are present. +func (k *OrganizationKeyset) Validate() error { + if k.OrganizationID == "" { + return fmt.Errorf("organization_id is required") + } + + if k.KEKPath == "" { + return fmt.Errorf("kek_path is required") + } + + if k.WrappedKeyset == "" { + return fmt.Errorf("wrapped_keyset is required") + } + + if k.KeysetInfo.PrimaryKeyID == 0 { + return fmt.Errorf("keyset_info.primary_key_id is required") + } + + // If HMAC keyset is provided, validate its info + if k.WrappedHMACKeyset != "" && k.HMACKeysetInfo.PrimaryKeyID == 0 { + return fmt.Errorf("hmac_keyset_info.primary_key_id is required when wrapped_hmac_keyset is provided") + } + + return nil +} + +// SafeView returns a copy with wrapped keysets redacted for logging/API responses. +func (k *OrganizationKeyset) SafeView() OrganizationKeyset { + return OrganizationKeyset{ + TenantID: k.TenantID, + OrganizationID: k.OrganizationID, + KEKPath: k.KEKPath, + WrappedKeyset: "[REDACTED]", + KeysetInfo: k.KeysetInfo, + LegacyKeyImported: k.LegacyKeyImported, + WrappedHMACKeyset: "[REDACTED]", + HMACKeysetInfo: k.HMACKeysetInfo, + LegacyHMACKeyImported: k.LegacyHMACKeyImported, + Revision: k.Revision, + CreatedAt: k.CreatedAt, + RotatedAt: k.RotatedAt, + } +} diff --git a/pkg/mmodel/organization_keyset_test.go b/pkg/mmodel/organization_keyset_test.go new file mode 100644 index 000000000..d00f2166a --- /dev/null +++ b/pkg/mmodel/organization_keyset_test.go @@ -0,0 +1,149 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package mmodel + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOrganizationKeyset_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + keyset OrganizationKeyset + wantErrMsg string + }{ + { + name: "valid keyset", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "", + }, + { + name: "empty organization_id", + keyset: OrganizationKeyset{ + OrganizationID: "", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "organization_id is required", + }, + { + name: "empty kek_path", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "kek_path is required", + }, + { + name: "empty wrapped_keyset", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + }, + wantErrMsg: "wrapped_keyset is required", + }, + { + name: "zero primary_key_id", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 0}, + }, + wantErrMsg: "keyset_info.primary_key_id is required", + }, + { + name: "hmac_keyset_without_hmac_info", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + WrappedHMACKeyset: "vault:v1:hmac-encrypted", + HMACKeysetInfo: KeysetInfo{PrimaryKeyID: 0}, + }, + wantErrMsg: "hmac_keyset_info.primary_key_id is required when wrapped_hmac_keyset is provided", + }, + { + name: "valid keyset with hmac", + keyset: OrganizationKeyset{ + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:encrypted", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + WrappedHMACKeyset: "vault:v1:hmac-encrypted", + HMACKeysetInfo: KeysetInfo{PrimaryKeyID: 2}, + }, + wantErrMsg: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.keyset.Validate() + + if tt.wantErrMsg == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrMsg) + } + }) + } +} + +func TestOrganizationKeyset_SafeView(t *testing.T) { + t.Parallel() + + now := time.Now().UTC() + keyset := OrganizationKeyset{ + TenantID: "tenant-a", + OrganizationID: "org-a", + KEKPath: "transit/keys/test", + WrappedKeyset: "vault:v1:secret-dek-material", + WrappedHMACKeyset: "vault:v1:secret-hmac-material", + KeysetInfo: KeysetInfo{PrimaryKeyID: 1}, + HMACKeysetInfo: KeysetInfo{PrimaryKeyID: 2}, + LegacyKeyImported: true, + LegacyHMACKeyImported: true, + Revision: 5, + CreatedAt: now, + } + + safe := keyset.SafeView() + + // Verify wrapped keysets are redacted + assert.Equal(t, "[REDACTED]", safe.WrappedKeyset) + assert.Equal(t, "[REDACTED]", safe.WrappedHMACKeyset) + + // Verify other fields are preserved + assert.Equal(t, keyset.TenantID, safe.TenantID) + assert.Equal(t, keyset.OrganizationID, safe.OrganizationID) + assert.Equal(t, keyset.KEKPath, safe.KEKPath) + assert.Equal(t, keyset.KeysetInfo, safe.KeysetInfo) + assert.Equal(t, keyset.HMACKeysetInfo, safe.HMACKeysetInfo) + assert.Equal(t, keyset.LegacyKeyImported, safe.LegacyKeyImported) + assert.Equal(t, keyset.LegacyHMACKeyImported, safe.LegacyHMACKeyImported) + assert.Equal(t, keyset.Revision, safe.Revision) + assert.Equal(t, keyset.CreatedAt, safe.CreatedAt) +} diff --git a/pkg/mmodel/organization_registry.go b/pkg/mmodel/organization_registry.go new file mode 100644 index 000000000..c9ffd929a --- /dev/null +++ b/pkg/mmodel/organization_registry.go @@ -0,0 +1,144 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package mmodel + +import ( + "fmt" + "strings" + "time" + + "github.com/LerianStudio/midaz/v3/pkg/constant" +) + +type RegistryStatus string + +type ProtectionModel string + +const ( + RegistryStatusLegacy RegistryStatus = "legacy" + RegistryStatusPendingMigration RegistryStatus = "pending_migration" + RegistryStatusActive RegistryStatus = "active" + RegistryStatusPartiallyMigrated RegistryStatus = "partially_migrated" + RegistryStatusMigrationComplete RegistryStatus = "migration_complete" + RegistryStatusFailed RegistryStatus = "failed" + RegistryStatusBlocked RegistryStatus = "blocked" +) + +const ( + ProtectionModelLegacy ProtectionModel = "legacy" + ProtectionModelEnvelope ProtectionModel = "envelope" +) + +// Re-export errors from constant package for backward compatibility. +// Callers should migrate to using constant.ErrRegistry* directly. +var ( + ErrRegistryRevisionConflict = constant.ErrRegistryRevisionConflict + ErrRegistryNotFound = constant.ErrRegistryNotFound + ErrRegistryAlreadyExists = constant.ErrRegistryAlreadyExists +) + +// OrganizationRegistryRecord tracks the encryption state of an organization. +type OrganizationRegistryRecord struct { + TenantID string + OrganizationID string + Status RegistryStatus + ProtectionModel ProtectionModel + CurrentVersion int + ReadableVersions []int + Revision int64 + LegacyReadable bool + CreatedAt time.Time + UpdatedAt time.Time + CreatedBy string + UpdatedBy string + LastTransitionReason string +} + +// NewOrganizationRegistryRecord creates a new registry record with initial state. +func NewOrganizationRegistryRecord(tenantID, organizationID, actor, reason string) (*OrganizationRegistryRecord, error) { + tenantID = strings.TrimSpace(tenantID) + organizationID = strings.TrimSpace(organizationID) + actor = strings.TrimSpace(actor) + reason = strings.TrimSpace(reason) + + if tenantID == "" { + return nil, fmt.Errorf("tenant_id is required") + } + + if organizationID == "" { + return nil, fmt.Errorf("organization_id is required") + } + + if actor == "" { + return nil, fmt.Errorf("actor is required") + } + + if reason == "" { + return nil, fmt.Errorf("reason is required") + } + + now := time.Now().UTC() + + return &OrganizationRegistryRecord{ + TenantID: tenantID, + OrganizationID: organizationID, + Status: RegistryStatusPendingMigration, + ProtectionModel: ProtectionModelLegacy, + Revision: 1, + LegacyReadable: true, + CreatedAt: now, + UpdatedAt: now, + CreatedBy: actor, + UpdatedBy: actor, + LastTransitionReason: reason, + }, nil +} + +// Activate transitions the organization to envelope encryption mode. +// Only records in pending_migration status can be activated. +func (r *OrganizationRegistryRecord) Activate(expectedRevision int64, actor, reason string) error { + if r.Status != RegistryStatusPendingMigration { + return fmt.Errorf("%w: cannot activate from status %s, only pending_migration can be activated", + constant.ErrRegistryInvalidTransition, r.Status) + } + + if r.Revision != expectedRevision { + return ErrRegistryRevisionConflict + } + + r.Status = RegistryStatusActive + r.ProtectionModel = ProtectionModelEnvelope + r.CurrentVersion = 1 + r.ReadableVersions = []int{1} + r.Revision++ + r.UpdatedAt = time.Now().UTC() + r.UpdatedBy = strings.TrimSpace(actor) + r.LastTransitionReason = strings.TrimSpace(reason) + + return nil +} + +// UsesEnvelopeMode returns true if organization uses envelope encryption. +func (r *OrganizationRegistryRecord) UsesEnvelopeMode() bool { + return r.ProtectionModel == ProtectionModelEnvelope +} + +// CanReadLegacy returns true if legacy-encrypted data can still be read. +func (r *OrganizationRegistryRecord) CanReadLegacy() bool { + return r.LegacyReadable +} + +// CurrentWriteKeysetVersion returns the keyset version for new encryptions. +func (r *OrganizationRegistryRecord) CurrentWriteKeysetVersion() int { + return r.CurrentVersion +} + +// ReadableKeysetVersions returns all keyset versions that can be decrypted. +func (r *OrganizationRegistryRecord) ReadableKeysetVersions() []int { + versions := make([]int, len(r.ReadableVersions)) + copy(versions, r.ReadableVersions) + + return versions +} diff --git a/pkg/mmodel/organization_registry_test.go b/pkg/mmodel/organization_registry_test.go new file mode 100644 index 000000000..577421954 --- /dev/null +++ b/pkg/mmodel/organization_registry_test.go @@ -0,0 +1,206 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package mmodel + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewOrganizationRegistryRecord(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + + require.NoError(t, err) + require.NotNil(t, record) + assert.Equal(t, "tenant-a", record.TenantID) + assert.Equal(t, "org-a", record.OrganizationID) + assert.Equal(t, RegistryStatusPendingMigration, record.Status) + assert.Equal(t, ProtectionModelLegacy, record.ProtectionModel) + assert.Equal(t, int64(1), record.Revision) + assert.True(t, record.LegacyReadable) + assert.Equal(t, "system", record.CreatedBy) + assert.Equal(t, "system", record.UpdatedBy) + assert.Equal(t, "initial setup", record.LastTransitionReason) +} + +func TestNewOrganizationRegistryRecord_Validation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tenantID string + organizationID string + actor string + reason string + wantErrMsg string + }{ + { + name: "empty tenant", + tenantID: "", + organizationID: "org-a", + actor: "system", + reason: "test", + wantErrMsg: "tenant_id is required", + }, + { + name: "empty organization", + tenantID: "tenant-a", + organizationID: "", + actor: "system", + reason: "test", + wantErrMsg: "organization_id is required", + }, + { + name: "empty actor", + tenantID: "tenant-a", + organizationID: "org-a", + actor: "", + reason: "test", + wantErrMsg: "actor is required", + }, + { + name: "empty reason", + tenantID: "tenant-a", + organizationID: "org-a", + actor: "system", + reason: "", + wantErrMsg: "reason is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord(tt.tenantID, tt.organizationID, tt.actor, tt.reason) + + require.Error(t, err) + assert.Nil(t, record) + assert.Contains(t, err.Error(), tt.wantErrMsg) + }) + } +} + +func TestOrganizationRegistryRecord_Activate(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + err = record.Activate(1, "admin", "keyset provisioned") + + require.NoError(t, err) + assert.Equal(t, RegistryStatusActive, record.Status) + assert.Equal(t, ProtectionModelEnvelope, record.ProtectionModel) + assert.Equal(t, 1, record.CurrentVersion) + assert.Equal(t, []int{1}, record.ReadableVersions) + assert.Equal(t, int64(2), record.Revision) + assert.Equal(t, "admin", record.UpdatedBy) + assert.Equal(t, "keyset provisioned", record.LastTransitionReason) +} + +func TestOrganizationRegistryRecord_ActivateRevisionConflict(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + // Try to activate with wrong expected revision + err = record.Activate(99, "admin", "keyset provisioned") + + require.Error(t, err) + assert.ErrorIs(t, err, ErrRegistryRevisionConflict) + // Status should not change + assert.Equal(t, RegistryStatusPendingMigration, record.Status) +} + +func TestOrganizationRegistryRecord_ActivateInvalidTransition(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status RegistryStatus + }{ + {name: "from active", status: RegistryStatusActive}, + {name: "from legacy", status: RegistryStatusLegacy}, + {name: "from partially_migrated", status: RegistryStatusPartiallyMigrated}, + {name: "from migration_complete", status: RegistryStatusMigrationComplete}, + {name: "from failed", status: RegistryStatusFailed}, + {name: "from blocked", status: RegistryStatusBlocked}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + record := &OrganizationRegistryRecord{ + TenantID: "tenant-a", + OrganizationID: "org-a", + Status: tt.status, + Revision: 1, + } + + err := record.Activate(1, "admin", "keyset provisioned") + + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot activate from status") + assert.Contains(t, err.Error(), string(tt.status)) + // Status should not change + assert.Equal(t, tt.status, record.Status) + }) + } +} + +func TestOrganizationRegistryRecord_UsesEnvelopeMode(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + // Initially uses legacy mode + assert.False(t, record.UsesEnvelopeMode()) + + // After activation uses envelope mode + err = record.Activate(1, "admin", "activated") + require.NoError(t, err) + assert.True(t, record.UsesEnvelopeMode()) +} + +func TestOrganizationRegistryRecord_CanReadLegacy(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + // Initially can read legacy + assert.True(t, record.CanReadLegacy()) + + // After activation can still read legacy (during migration) + err = record.Activate(1, "admin", "activated") + require.NoError(t, err) + assert.True(t, record.CanReadLegacy()) +} + +func TestOrganizationRegistryRecord_ReadableKeysetVersions(t *testing.T) { + t.Parallel() + + record, err := NewOrganizationRegistryRecord("tenant-a", "org-a", "system", "initial setup") + require.NoError(t, err) + + err = record.Activate(1, "admin", "activated") + require.NoError(t, err) + + versions := record.ReadableKeysetVersions() + + assert.Equal(t, []int{1}, versions) + + // Verify returned slice is a copy + versions[0] = 999 + assert.Equal(t, []int{1}, record.ReadableVersions) +}