Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2026 Lerian Studio. All rights reserved.
// Use of this source code is governed by the Elastic License 2.0
// that can be found in the LICENSE file.

package encryption

import "sync"

// indexState tracks whether indexes have been successfully created for a specific database/collection pair.
type indexState struct {
mu sync.Mutex
done bool
}

// indexTracker manages per-database index creation state.
// In multi-tenant mode, each tenant database needs its own indexes.
// This tracker ensures indexes are created exactly once per database, with retry on failure.
type indexTracker struct {
states sync.Map // key: "dbName:collection" -> *indexState
}

// ensureOnce executes fn exactly once per key, but only marks as done on success.
// If fn returns an error, subsequent calls will retry.
func (t *indexTracker) ensureOnce(key string, fn func() error) error {
v, _ := t.states.LoadOrStore(key, &indexState{})
state := v.(*indexState)

state.mu.Lock()
defer state.mu.Unlock()

if state.done {
return nil
}

if err := fn(); err != nil {
return err
}

state.done = true

return nil
}

// reset clears the state for a specific key. Used in integration tests to ensure fresh state
// when each test runs with a new MongoDB container.
//
//nolint:unused // used in *_integration_test.go files (build tag: integration)
func (t *indexTracker) reset(key string) {
t.states.Delete(key)
}

// globalIndexTracker is shared across all encryption repository instances.
// This ensures indexes are created once per database even if multiple repository instances exist.
var globalIndexTracker = &indexTracker{}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright (c) 2026 Lerian Studio. All rights reserved.
// Use of this source code is governed by the Elastic License 2.0
// that can be found in the LICENSE file.

package encryption

import (
"errors"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestIndexTracker_EnsureOnce_ExecutesOnceOnSuccess(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32

for i := 0; i < 3; i++ {
err := tracker.ensureOnce("test-db:test-collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
}

assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "function should be called exactly once on success")
}

func TestIndexTracker_EnsureOnce_RetriesOnFailure(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32
expectedErr := errors.New("index creation failed")

// First call fails
err := tracker.ensureOnce("test-db:test-collection", func() error {
atomic.AddInt32(&callCount, 1)
return expectedErr
})
require.Error(t, err)
assert.Equal(t, expectedErr, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&callCount))

// Second call should retry (not skipped)
err = tracker.ensureOnce("test-db:test-collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil // Now succeeds
})
require.NoError(t, err)
assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be retried after failure")

// Third call should be skipped (already succeeded)
err = tracker.ensureOnce("test-db:test-collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should not be called again after success")
}

func TestIndexTracker_EnsureOnce_DifferentKeys(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCountA, callCountB int32

// Key A
err := tracker.ensureOnce("tenant-a:collection", func() error {
atomic.AddInt32(&callCountA, 1)
return nil
})
require.NoError(t, err)

// Key B (different tenant)
err = tracker.ensureOnce("tenant-b:collection", func() error {
atomic.AddInt32(&callCountB, 1)
return nil
})
require.NoError(t, err)

// Key A again (should be skipped)
err = tracker.ensureOnce("tenant-a:collection", func() error {
atomic.AddInt32(&callCountA, 1)
return nil
})
require.NoError(t, err)

assert.Equal(t, int32(1), atomic.LoadInt32(&callCountA), "tenant A should be called once")
assert.Equal(t, int32(1), atomic.LoadInt32(&callCountB), "tenant B should be called once")
}

func TestIndexTracker_EnsureOnce_ConcurrentAccess(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32
const goroutines = 100

var wg sync.WaitGroup
wg.Add(goroutines)

for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()

_ = tracker.ensureOnce("concurrent-db:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
}()
}

wg.Wait()

assert.Equal(t, int32(1), atomic.LoadInt32(&callCount), "function should be called exactly once even with concurrent access")
}

func TestIndexTracker_EnsureOnce_ConcurrentAccessWithFailure(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32
var failureCount int32
const goroutines = 50

// First, make the first call fail
err := tracker.ensureOnce("concurrent-fail-db:collection", func() error {
atomic.AddInt32(&callCount, 1)
return errors.New("initial failure")
})
require.Error(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&callCount))

// Now run concurrent retries - only one should succeed in executing
var wg sync.WaitGroup
wg.Add(goroutines)

for i := 0; i < goroutines; i++ {
go func() {
defer wg.Done()

err := tracker.ensureOnce("concurrent-fail-db:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
if err != nil {
atomic.AddInt32(&failureCount, 1)
}
}()
}

wg.Wait()

// Due to mutex, only one goroutine should have actually executed
assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be retried exactly once after initial failure")
assert.Equal(t, int32(0), atomic.LoadInt32(&failureCount), "all concurrent calls should succeed once retry succeeds")
}

func TestIndexTracker_Reset(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32

// First call
err := tracker.ensureOnce("reset-test:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&callCount))

// Reset the key
tracker.reset("reset-test:collection")

// Call again - should execute
err = tracker.ensureOnce("reset-test:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
assert.Equal(t, int32(2), atomic.LoadInt32(&callCount), "function should be called again after reset")
}

func TestIndexTracker_MultipleFailuresBeforeSuccess(t *testing.T) {
t.Parallel()

tracker := &indexTracker{}
var callCount int32

// Simulate multiple failures before success
for i := 0; i < 5; i++ {
_ = tracker.ensureOnce("multi-fail:collection", func() error {
atomic.AddInt32(&callCount, 1)
return errors.New("temporary failure")
})
}
assert.Equal(t, int32(5), atomic.LoadInt32(&callCount), "function should be called on each retry")

// Now succeed
err := tracker.ensureOnce("multi-fail:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
assert.Equal(t, int32(6), atomic.LoadInt32(&callCount))

// Subsequent calls should be skipped
err = tracker.ensureOnce("multi-fail:collection", func() error {
atomic.AddInt32(&callCount, 1)
return nil
})
require.NoError(t, err)
assert.Equal(t, int32(6), atomic.LoadInt32(&callCount), "function should not be called after success")
}
119 changes: 119 additions & 0 deletions components/crm/internal/adapters/mongodb/encryption/keyset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2026 Lerian Studio. All rights reserved.
// Use of this source code is governed by the Elastic License 2.0
// that can be found in the LICENSE file.

package encryption

import (
"time"

"github.com/LerianStudio/midaz/v3/pkg/mmodel"
)

// KeysetMongoDBModel is the MongoDB representation of OrganizationKeyset.
type KeysetMongoDBModel struct {
TenantID string `bson:"tenant_id,omitempty"`
OrganizationID string `bson:"organization_id"`
KEKPath string `bson:"kek_path"`
WrappedKeyset string `bson:"wrapped_keyset"`
KeysetInfo KeysetInfoModel `bson:"keyset_info"`
LegacyKeyImported bool `bson:"legacy_key_imported"`
WrappedHMACKeyset string `bson:"wrapped_hmac_keyset,omitempty"`
HMACKeysetInfo KeysetInfoModel `bson:"hmac_keyset_info,omitempty"`
LegacyHMACKeyImported bool `bson:"legacy_hmac_key_imported"`
Revision int64 `bson:"revision"`
CreatedAt time.Time `bson:"created_at"`
RotatedAt *time.Time `bson:"rotated_at,omitempty"`
}

// KeysetInfoModel is the MongoDB representation of KeysetInfo.
type KeysetInfoModel struct {
PrimaryKeyID uint32 `bson:"primary_key_id"`
Keys []KeyInfoModel `bson:"keys"`
}

// KeyInfoModel is the MongoDB representation of KeyInfo.
type KeyInfoModel struct {
KeyID uint32 `bson:"key_id"`
Status string `bson:"status"`
Type string `bson:"type"`
IsPrimary bool `bson:"is_primary"`
}

// KeysetFromEntity converts a domain OrganizationKeyset to MongoDB model.
func KeysetFromEntity(k *mmodel.OrganizationKeyset) *KeysetMongoDBModel {
if k == nil {
return nil
}

return &KeysetMongoDBModel{
TenantID: k.TenantID,
OrganizationID: k.OrganizationID,
KEKPath: k.KEKPath,
WrappedKeyset: k.WrappedKeyset,
KeysetInfo: keysetInfoFromEntity(k.KeysetInfo),
LegacyKeyImported: k.LegacyKeyImported,
WrappedHMACKeyset: k.WrappedHMACKeyset,
HMACKeysetInfo: keysetInfoFromEntity(k.HMACKeysetInfo),
LegacyHMACKeyImported: k.LegacyHMACKeyImported,
Revision: k.Revision,
CreatedAt: k.CreatedAt,
RotatedAt: k.RotatedAt,
}
}

// ToEntity converts the MongoDB model to a domain OrganizationKeyset.
func (m *KeysetMongoDBModel) ToEntity() *mmodel.OrganizationKeyset {
if m == nil {
return nil
}

return &mmodel.OrganizationKeyset{
TenantID: m.TenantID,
OrganizationID: m.OrganizationID,
KEKPath: m.KEKPath,
WrappedKeyset: m.WrappedKeyset,
KeysetInfo: m.KeysetInfo.toEntity(),
LegacyKeyImported: m.LegacyKeyImported,
WrappedHMACKeyset: m.WrappedHMACKeyset,
HMACKeysetInfo: m.HMACKeysetInfo.toEntity(),
LegacyHMACKeyImported: m.LegacyHMACKeyImported,
Revision: m.Revision,
CreatedAt: m.CreatedAt,
RotatedAt: m.RotatedAt,
}
}

func keysetInfoFromEntity(info mmodel.KeysetInfo) KeysetInfoModel {
keys := make([]KeyInfoModel, len(info.Keys))
for i, k := range info.Keys {
keys[i] = KeyInfoModel{
KeyID: k.KeyID,
Status: k.Status,
Type: k.Type,
IsPrimary: k.IsPrimary,
}
}

return KeysetInfoModel{
PrimaryKeyID: info.PrimaryKeyID,
Keys: keys,
}
}

func (m *KeysetInfoModel) toEntity() mmodel.KeysetInfo {
keys := make([]mmodel.KeyInfo, len(m.Keys))
for i, k := range m.Keys {
keys[i] = mmodel.KeyInfo{
KeyID: k.KeyID,
Status: k.Status,
Type: k.Type,
IsPrimary: k.IsPrimary,
}
}

return mmodel.KeysetInfo{
PrimaryKeyID: m.PrimaryKeyID,
Keys: keys,
}
}
Loading