diff --git a/internal/fs/inode/file.go b/internal/fs/inode/file.go index 93571df59f..54015e33d7 100644 --- a/internal/fs/inode/file.go +++ b/internal/fs/inode/file.go @@ -517,6 +517,9 @@ func (f *FileInode) Destroy() (err error) { } else if f.content != nil { f.content.Destroy() } + if f.mrdInstance != nil { + f.mrdInstance.Destroy() + } return } diff --git a/internal/fs/inode/file_test.go b/internal/fs/inode/file_test.go index 0c04600de6..68a2b3820f 100644 --- a/internal/fs/inode/file_test.go +++ b/internal/fs/inode/file_test.go @@ -33,6 +33,7 @@ import ( "github.com/googlecloudplatform/gcsfuse/v3/internal/storage/gcs" "github.com/googlecloudplatform/gcsfuse/v3/internal/storage/storageutil" "github.com/googlecloudplatform/gcsfuse/v3/internal/util" + "github.com/googlecloudplatform/gcsfuse/v3/metrics" "github.com/googlecloudplatform/gcsfuse/v3/tools/integration_tests/util/setup" "github.com/jacobsa/fuse/fuseops" "github.com/jacobsa/syncutil" @@ -545,6 +546,25 @@ func (t *FileTest) TestTruncateNegative() { assert.False(t.T(), gcsSynced) } +func (t *FileTest) TestDestroy_MrdInstanceDestroyed() { + // Manually initialize MRD pool since FileInode.Read doesn't use it directly. + mi := t.in.GetMRDInstance() + require.NotNil(t.T(), mi) + // Perform a read on MrdInstance to trigger pool creation. + buf := make([]byte, 1) + _, err := mi.Read(t.ctx, buf, 0, metrics.NewNoopMetrics()) + require.NoError(t.T(), err) + // Verify pool is initialized. + assert.Greater(t.T(), int(mi.Size()), 0) + + // Destroy the inode. + err = t.in.Destroy() + + require.NoError(t.T(), err) + // Verify MRD instance is destroyed (pool closed and set to nil). + assert.Equal(t.T(), uint64(0), mi.Size()) +} + func (t *FileTest) TestWriteThenSync() { testcases := []struct { name string diff --git a/internal/gcsx/mrd_instance.go b/internal/gcsx/mrd_instance.go index 68309eb21e..5c71bea15a 100644 --- a/internal/gcsx/mrd_instance.go +++ b/internal/gcsx/mrd_instance.go @@ -204,17 +204,39 @@ func (mi *MrdInstance) RecreateMRD() error { return nil } -// Destroy closes all MRD instances in the pool and releases associated resources. -func (mi *MrdInstance) Destroy() { +// closePool closes all MRD instances in the pool and releases associated resources. +func (mi *MrdInstance) closePool() { mi.poolMu.Lock() defer mi.poolMu.Unlock() if mi.mrdPool != nil { - // Delete the instance. + // Close the pool. mi.mrdPool.Close() mi.mrdPool = nil } } +// Destroy completely destroys the MrdInstance, cleaning up +// its resources and ensuring it is removed from the cache. This should be +// called when the owning inode is destroyed. +func (mi *MrdInstance) Destroy() { + mi.refCountMu.Lock() + defer mi.refCountMu.Unlock() + + // If it's in use, this indicates a potential lifecycle mismatch between the + // inode and its readers. + if mi.refCount > 0 { + logger.Warnf("MrdInstance::Destroy called on an instance with refCount %d", mi.refCount) + } + + // Remove from cache. + if mi.mrdCache != nil { + mi.mrdCache.Erase(getKey(mi.inodeId)) + } + + // Close the pool. + mi.closePool() +} + // getKey generates a unique key for the MrdInstance based on its inode ID. func getKey(id fuseops.InodeID) string { return strconv.FormatUint(uint64(id), 10) @@ -237,26 +259,24 @@ func (mi *MrdInstance) IncrementRefCount() { } } -// destroyEvictedCacheEntries is a helper function to destroy evicted MrdInstance objects. -// It handles type assertion and ensures that only truly inactive instances are destroyed. -// This function should not be called when refCountMu is held. -func destroyEvictedCacheEntries(evictedValues []lru.ValueType) { - for _, instance := range evictedValues { - mrdInstance, ok := instance.(*MrdInstance) - if !ok { - logger.Errorf("destroyEvictedCacheEntries: Invalid value type, expected *MrdInstance, got %T", instance) - } else { - // Check if the instance was resurrected. - mrdInstance.refCountMu.Lock() - if mrdInstance.refCount > 0 { - mrdInstance.refCountMu.Unlock() - continue - } - // Safe to destroy. Hold refCountMu to prevent concurrent resurrection. - mrdInstance.Destroy() - mrdInstance.refCountMu.Unlock() - } +// handleEviction handles the cleanup of the MrdInstance when it is evicted from the cache. +// Race protection: MrdInstance could be reopened (refCount>0) or re-added to cache before eviction. +func (mi *MrdInstance) handleEviction() { + mi.refCountMu.Lock() + defer mi.refCountMu.Unlock() + + // Check if mrdInstance was reopened (refCount>0) - must skip eviction. + if mi.refCount > 0 { + return + } + + // Check if mrdInstance was re-added to cache (refCount went 0→1→0 in between eviction and closure.) + // Lock order: refCountMu -> cache.mu (consistent with Increment/DecrementRefCount) + if mi.mrdCache != nil && mi.mrdCache.LookUpWithoutChangingOrder(getKey(mi.inodeId)) == mi { + return } + + mi.closePool() } // DecrementRefCount decreases the reference count. When the count drops to zero, the @@ -285,8 +305,8 @@ func (mi *MrdInstance) DecrementRefCount() { if err != nil { logger.Errorf("MrdInstance::DecrementRefCount: Failed to insert MrdInstance for object (%s) into cache, destroying immediately: %v", mi.object.Name, err) // The instance could not be inserted into the cache. Since the refCount is 0, - // we must destroy it now to prevent it from being leaked. - mi.Destroy() + // we must close it now to prevent it from being leaked. + mi.closePool() return } logger.Tracef("MrdInstance::DecrementRefCount: MrdInstance for object (%s) added to cache", mi.object.Name) @@ -298,7 +318,14 @@ func (mi *MrdInstance) DecrementRefCount() { // Evict outside all locks. mi.refCountMu.Unlock() - destroyEvictedCacheEntries(evictedValues) + for _, instance := range evictedValues { + mrdInstance, ok := instance.(*MrdInstance) + if !ok { + logger.Errorf("MrdInstance::DecrementRefCount: Invalid value type, expected *MrdInstance, got %T", instance) + } else { + mrdInstance.handleEviction() + } + } // Reacquire the lock ensuring safe defer's Unlock. mi.refCountMu.Lock() } diff --git a/internal/gcsx/mrd_instance_test.go b/internal/gcsx/mrd_instance_test.go index 780d519504..06d31a9d4d 100644 --- a/internal/gcsx/mrd_instance_test.go +++ b/internal/gcsx/mrd_instance_test.go @@ -15,14 +15,18 @@ package gcsx import ( + "bytes" "context" "fmt" + "os" "strconv" "testing" "time" "github.com/googlecloudplatform/gcsfuse/v3/cfg" "github.com/googlecloudplatform/gcsfuse/v3/internal/cache/lru" + "github.com/googlecloudplatform/gcsfuse/v3/internal/logger" + "github.com/googlecloudplatform/gcsfuse/v3/internal/storage" "github.com/googlecloudplatform/gcsfuse/v3/internal/storage/fake" "github.com/googlecloudplatform/gcsfuse/v3/internal/storage/gcs" @@ -327,38 +331,6 @@ func (t *MrdInstanceTest) TestDecrementRefCount_Eviction() { assert.NotNil(t.T(), t.cache.LookUpWithoutChangingOrder("other2")) } -func (t *MrdInstanceTest) TestDestroyEvictedCacheEntries() { - // 1. Instance to be destroyed - mi1 := NewMrdInstance(t.object, t.bucket, t.cache, 1, t.config) - fakeMRD1 := fake.NewFakeMultiRangeDownloader(t.object, nil) - t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD1, nil).Once() - buf := make([]byte, 1) - _, err := mi1.Read(context.Background(), buf, 0, metrics.NewNoopMetrics()) - assert.NoError(t.T(), err) - assert.NotNil(t.T(), mi1.mrdPool) - // 2. Instance that is resurrected (refCount > 0) - mi2 := NewMrdInstance(t.object, t.bucket, t.cache, 2, t.config) - fakeMRD2 := fake.NewFakeMultiRangeDownloader(t.object, nil) - t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD2, nil).Once() - _, err = mi2.Read(context.Background(), buf, 0, metrics.NewNoopMetrics()) - assert.NoError(t.T(), err) - assert.NotNil(t.T(), mi2.mrdPool) - mi2.refCount = 1 - // Entries to be evicted - evicted := []lru.ValueType{mi1, mi2} - - destroyEvictedCacheEntries(evicted) - - // Verify mi1 is destroyed - mi1.poolMu.RLock() - assert.Nil(t.T(), mi1.mrdPool) - mi1.poolMu.RUnlock() - // Verify mi2 is NOT destroyed - mi2.poolMu.RLock() - assert.NotNil(t.T(), mi2.mrdPool) - mi2.poolMu.RUnlock() -} - func (t *MrdInstanceTest) TestGetKey() { testCases := []struct { inodeID fuseops.InodeID @@ -410,3 +382,163 @@ func (t *MrdInstanceTest) TestEnsureMRDPool_Failure() { assert.Nil(t.T(), t.mrdInstance.mrdPool) assert.Contains(t.T(), err.Error(), "init error") } + +func (t *MrdInstanceTest) TestSize() { + assert.Equal(t.T(), uint64(0), t.mrdInstance.Size()) + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + buf := make([]byte, 1) + _, err := t.mrdInstance.Read(context.Background(), buf, 0, metrics.NewNoopMetrics()) + assert.NoError(t.T(), err) + + poolSize := t.mrdInstance.Size() + + // Pool size is 1 based on SetupTest config (PoolSize: 1) + assert.Equal(t.T(), uint64(1), poolSize) + t.mrdInstance.Destroy() + assert.Equal(t.T(), uint64(0), t.mrdInstance.Size()) +} + +func (t *MrdInstanceTest) TestDestroy_RemovesFromCache() { + // Manually insert into cache + key := getKey(t.inodeID) + _, err := t.cache.Insert(key, t.mrdInstance) + assert.NoError(t.T(), err) + assert.NotNil(t.T(), t.cache.LookUpWithoutChangingOrder(key)) + + t.mrdInstance.Destroy() + + assert.Nil(t.T(), t.cache.LookUpWithoutChangingOrder(key)) +} + +func (t *MrdInstanceTest) TestDestroy_WithRefCount() { + t.mrdInstance.refCount = 1 + // Should log warning but proceed to destroy + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + err := t.mrdInstance.ensureMRDPool() + assert.NoError(t.T(), err) + assert.NotNil(t.T(), t.mrdInstance.mrdPool) + // Capture logs to verify error message + var buf bytes.Buffer + logger.SetOutput(&buf) + defer logger.SetOutput(os.Stdout) + + t.mrdInstance.Destroy() + + assert.Nil(t.T(), t.mrdInstance.mrdPool) + assert.Contains(t.T(), buf.String(), "MrdInstance::Destroy called on an instance with refCount 1") +} + +func (t *MrdInstanceTest) TestDecrementRefCount_Negative() { + t.mrdInstance.refCount = 0 + // Capture logs to verify error message + var buf bytes.Buffer + logger.SetOutput(&buf) + defer logger.SetOutput(os.Stdout) + + // Should log error and return, not panic. RefCount should remain 0. + t.mrdInstance.DecrementRefCount() + + assert.Equal(t.T(), int64(0), t.mrdInstance.refCount) + assert.Contains(t.T(), buf.String(), "MrdInstance::DecrementRefCount: Refcount cannot be negative") +} + +func (t *MrdInstanceTest) TestDecrementRefCount_CacheInsertFailure() { + // Create a cache with capacity 1 + smallCache := lru.NewCache(1) + // Create instance with pool size 2 (so Size() returns 2). + config := &cfg.Config{Mrd: cfg.MrdConfig{PoolSize: 2}} + mi := NewMrdInstance(t.object, t.bucket, smallCache, t.inodeID, config) + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil) + // Initialize pool. + err := mi.ensureMRDPool() + assert.NoError(t.T(), err) + mi.refCount = 1 + // Capture logs to verify error message + var buf bytes.Buffer + logger.SetOutput(&buf) + defer logger.SetOutput(os.Stdout) + + // This should fail to insert into cache (Size 2 > Cap 1) and should close the pool instantly. + mi.DecrementRefCount() + + assert.Equal(t.T(), int64(0), mi.refCount) + mi.poolMu.RLock() + assert.Nil(t.T(), mi.mrdPool) + mi.poolMu.RUnlock() + assert.Contains(t.T(), buf.String(), "Failed to insert MrdInstance") +} + +func (t *MrdInstanceTest) TestClosePool() { + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + err := t.mrdInstance.ensureMRDPool() + assert.NoError(t.T(), err) + assert.NotNil(t.T(), t.mrdInstance.mrdPool) + + t.mrdInstance.closePool() + + t.mrdInstance.poolMu.RLock() + assert.Nil(t.T(), t.mrdInstance.mrdPool) + t.mrdInstance.poolMu.RUnlock() +} + +func (t *MrdInstanceTest) TestHandleEviction_Resurrected() { + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + // Initialize pool. + err := t.mrdInstance.ensureMRDPool() + assert.NoError(t.T(), err) + // Simulate resurrection (refCount > 0). + t.mrdInstance.refCount = 1 + + t.mrdInstance.handleEviction() + + // Pool should still exist because refCount > 0. + t.mrdInstance.poolMu.RLock() + assert.NotNil(t.T(), t.mrdInstance.mrdPool) + t.mrdInstance.poolMu.RUnlock() +} + +func (t *MrdInstanceTest) TestHandleEviction_ReAddedToCache() { + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + // Initialize pool. + err := t.mrdInstance.ensureMRDPool() + assert.NoError(t.T(), err) + // Add to cache to simulate it being re-added concurrently. + key := getKey(t.inodeID) + _, err = t.cache.Insert(key, t.mrdInstance) + assert.NoError(t.T(), err) + // refCount is 0, but it is in the cache. + t.mrdInstance.refCount = 0 + + t.mrdInstance.handleEviction() + + // Pool should still exist because it's in the cache. + t.mrdInstance.poolMu.RLock() + assert.NotNil(t.T(), t.mrdInstance.mrdPool) + t.mrdInstance.poolMu.RUnlock() +} + +func (t *MrdInstanceTest) TestHandleEviction_SafeToClose() { + fakeMRD := fake.NewFakeMultiRangeDownloader(t.object, nil) + t.bucket.On("NewMultiRangeDownloader", mock.Anything, mock.Anything).Return(fakeMRD, nil).Once() + // Initialize pool. + err := t.mrdInstance.ensureMRDPool() + assert.NoError(t.T(), err) + // Ensure not in cache. + key := getKey(t.inodeID) + t.cache.Erase(key) + // refCount is 0. + t.mrdInstance.refCount = 0 + + t.mrdInstance.handleEviction() + + // Pool should be closed (nil). + t.mrdInstance.poolMu.RLock() + assert.Nil(t.T(), t.mrdInstance.mrdPool) + t.mrdInstance.poolMu.RUnlock() +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index ed5a1e810c..c70cb4037e 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -229,6 +229,13 @@ func Fatal(format string, v ...any) { os.Exit(1) } +// SetOutput sets the output destination for the default logger. +// This is primarily used for testing. +func SetOutput(w io.Writer) { + defaultLogger = slog.New(defaultLoggerFactory.createJsonOrTextHandler(w, programLevel, "")) + slog.SetDefault(defaultLogger) +} + type loggerFactory struct { // If nil, log to stdout or stderr. Otherwise, log to this file. file *os.File