diff --git a/go.mod b/go.mod index e302eeefd..1c63155fe 100644 --- a/go.mod +++ b/go.mod @@ -191,6 +191,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.21 // indirect + github.com/maypok86/otter/v2 v2.3.0 // indirect github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect github.com/minio/crc64nvme v1.1.1 // indirect diff --git a/go.sum b/go.sum index 47ea55925..ae68bcb42 100644 --- a/go.sum +++ b/go.sum @@ -677,6 +677,8 @@ github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpe github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/maypok86/otter/v2 v2.3.0 h1:8H8AVVFUSzJwIegKwv1uF5aGitTY+AIrtktg7OcLs8w= +github.com/maypok86/otter/v2 v2.3.0/go.mod h1:XgIdlpmL6jYz882/CAx1E4C1ukfgDKSaw4mWq59+7l8= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= diff --git a/worker/pipeline.go b/worker/pipeline.go index 191c73df7..92290c07a 100644 --- a/worker/pipeline.go +++ b/worker/pipeline.go @@ -17,7 +17,6 @@ package worker import ( "context" "strings" - "sync" "time" mapset "github.com/deckarep/golang-set/v2" @@ -35,6 +34,8 @@ import ( "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" + + "github.com/maypok86/otter/v2" ) type Pipeline struct { @@ -51,7 +52,13 @@ type Pipeline struct { func (p *Pipeline) Recommend(ctx context.Context, users []data.User, progress func(completed, throughput int)) { startRecommendTime := time.Now() - itemCache := NewItemCache(p.DataClient) + // Get total item count to determine cache size + itemCount, err := p.DataClient.CountItems(ctx) + if err != nil { + log.Logger().Error("failed to count items for cache size", zap.Error(err)) + itemCount = 1024 * 10 // fallback to default + } + itemCache := NewItemCache(p.DataClient, itemCount) log.Logger().Info("ranking recommendation", zap.Int("n_working_users", len(users)), zap.Int("n_jobs", p.Jobs), @@ -574,24 +581,33 @@ func (p *Pipeline) applyReplacementDecay( return updated } -// ItemCache is alias of map[string]data.Item. +// ItemCache is a cache for items using W-TinyLFU eviction policy. type ItemCache struct { Client data.Database - Data sync.Map + Data *otter.Cache[string, *data.Item] } -// NewItemCache creates a new ItemCache. -func NewItemCache(client data.Database) *ItemCache { +// NewItemCache creates a new ItemCache with W-TinyLFU eviction. +// The cache size is calculated as max(1024, itemCount * 10%). +func NewItemCache(client data.Database, itemCount int) *ItemCache { + // Calculate cache size: 10% of total items, but at least 1024 + size := itemCount / 10 + if size < 1024 { + size = 1024 + } + cache := otter.Must(&otter.Options[string, *data.Item]{ + MaximumSize: size, + }) return &ItemCache{ Client: client, - Data: sync.Map{}, + Data: cache, } } func (c *ItemCache) GetSlice(ctx context.Context, itemIds []string) ([]*data.Item, error) { requests := make([]string, 0, len(itemIds)) for _, itemId := range itemIds { - if _, exist := c.Data.Load(itemId); !exist { + if _, ok := c.Data.GetIfPresent(itemId); !ok { requests = append(requests, itemId) } } @@ -600,12 +616,11 @@ func (c *ItemCache) GetSlice(ctx context.Context, itemIds []string) ([]*data.Ite return nil, errors.Trace(err) } for _, item := range response { - c.Data.Store(item.ItemId, &item) + c.Data.Set(item.ItemId, &item) } items := make([]*data.Item, 0, len(itemIds)) for _, itemId := range itemIds { - if val, exist := c.Data.Load(itemId); exist { - item := val.(*data.Item) + if item, ok := c.Data.GetIfPresent(itemId); ok { if !item.IsHidden { items = append(items, item) } diff --git a/worker/pipeline_test.go b/worker/pipeline_test.go index d2e6b4467..dcd21fd13 100644 --- a/worker/pipeline_test.go +++ b/worker/pipeline_test.go @@ -51,14 +51,14 @@ func (suite *PipelineTestSuite) TearDownSuite() { } func (suite *PipelineTestSuite) TestGetSlice() { - c := NewItemCache(suite.dataClient) + c := NewItemCache(suite.dataClient, 10000) items, err := c.GetSlice(suite.T().Context(), []string{"1", "2", "3", "4", "5", "6"}) suite.NoError(err) suite.Equal(5, len(items)) } func (suite *PipelineTestSuite) TestGetMap() { - c := NewItemCache(suite.dataClient) + c := NewItemCache(suite.dataClient, 10000) items, err := c.GetMap(suite.T().Context(), []string{"1", "2", "3", "4", "5", "6"}) suite.NoError(err) suite.Equal(5, len(items)) diff --git a/worker/worker_test.go b/worker/worker_test.go index ca5184b0a..5066acf2c 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -725,7 +725,7 @@ func (suite *WorkerTestSuite) TestRankByClickTroughRate() { }) suite.NoError(err) // rank items - itemCache := NewItemCache(suite.DataClient) + itemCache := NewItemCache(suite.DataClient, 10000) result, err := suite.rankByClickTroughRate(ctx, new(mockFactorizationMachine), &data.User{UserId: "1"}, []cache.Score{{Id: "1"}, {Id: "2"}, {Id: "3"}, {Id: "4"}, {Id: "5"}}, itemCache, time.Now()) suite.NoError(err) @@ -762,7 +762,7 @@ func (suite *WorkerTestSuite) TestRankByLLM() { "{{user.UserId}}", "{{item.ItemId}}") suite.NoError(err) - itemCache := NewItemCache(suite.DataClient) + itemCache := NewItemCache(suite.DataClient, 10000) recommendTime := time.Now() result, err := suite.rankByLLM(ctx, nil, ranker, &data.User{UserId: "u1"}, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "like", UserId: "u1", ItemId: "4"}},