diff --git a/common/expression/weight.go b/common/expression/weight.go new file mode 100644 index 000000000..cd3293993 --- /dev/null +++ b/common/expression/weight.go @@ -0,0 +1,123 @@ +// Copyright 2026 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "math" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/gorse-io/gorse/common/log" + "go.uber.org/zap" +) + +// FeedbackWeightExpression wraps compiled weight expressions by feedback type. +type FeedbackWeightExpression struct { + programs map[string]*vm.Program +} + +func env(value float64) map[string]any { + return map[string]any{ + "Value": value, + "abs": math.Abs, + "ceil": math.Ceil, + "floor": math.Floor, + "round": math.Round, + "sqrt": math.Sqrt, + "cbrt": math.Cbrt, + "log": math.Log, + "log2": math.Log2, + "log10": math.Log10, + "log1p": math.Log1p, + "exp": math.Exp, + "exp2": math.Exp2, + "expm1": math.Expm1, + "pow": math.Pow, + "sin": math.Sin, + "cos": math.Cos, + "tan": math.Tan, + "asin": math.Asin, + "acos": math.Acos, + "atan": math.Atan, + "sinh": math.Sinh, + "cosh": math.Cosh, + "tanh": math.Tanh, + "max": math.Max, + "min": math.Min, + } +} + +// NewFeedbackWeightExpression compiles weight expressions and wraps them by feedback type. +func NewFeedbackWeightExpression(feedbackWeight map[string]string) (*FeedbackWeightExpression, error) { + programs := make(map[string]*vm.Program, len(feedbackWeight)) + for feedbackType, exprStr := range feedbackWeight { + program, err := expr.Compile(exprStr, expr.Env(env(0.0))) + if err != nil { + return nil, err + } + programs[feedbackType] = program + } + return &FeedbackWeightExpression{programs: programs}, nil +} + +// Evaluate evaluates the weight for the given feedback type and value. +// If there is no expression for the feedback type, the default weight 1.0 is returned. +func (weightExpr *FeedbackWeightExpression) Evaluate(feedbackType string, value float64) float32 { + program, ok := weightExpr.programs[feedbackType] + if !ok { + return 1.0 + } + result, err := expr.Run(program, env(value)) + if err != nil { + log.Logger().Error("failed to evaluate weight expression", + zap.String("feedback_type", feedbackType), + zap.Float64("value", value), + zap.Error(err)) + return 1.0 + } + return ToFloat32(result) +} + +// ToFloat32 converts various numeric types to float32. +func ToFloat32(v any) float32 { + switch val := v.(type) { + case float32: + return val + case float64: + return float32(val) + case int: + return float32(val) + case int8: + return float32(val) + case int16: + return float32(val) + case int32: + return float32(val) + case int64: + return float32(val) + case uint: + return float32(val) + case uint8: + return float32(val) + case uint16: + return float32(val) + case uint32: + return float32(val) + case uint64: + return float32(val) + default: + return 1.0 + } +} diff --git a/common/expression/weight_test.go b/common/expression/weight_test.go new file mode 100644 index 000000000..ad7b446a7 --- /dev/null +++ b/common/expression/weight_test.go @@ -0,0 +1,101 @@ +// Copyright 2026 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewFeedbackWeightExpression(t *testing.T) { + t.Run("valid expressions", func(t *testing.T) { + weightExpr, err := NewFeedbackWeightExpression(map[string]string{ + "click": "1", + "rating": "Value * 2", + }) + require.NoError(t, err) + require.NotNil(t, weightExpr) + assert.Len(t, weightExpr.programs, 2) + }) + + t.Run("invalid expression", func(t *testing.T) { + weightExpr, err := NewFeedbackWeightExpression(map[string]string{ + "click": "invalid++", + }) + assert.Error(t, err) + assert.Nil(t, weightExpr) + }) +} + +func TestFeedbackWeightExpression_Evaluate(t *testing.T) { + weightExpr, err := NewFeedbackWeightExpression(map[string]string{ + "click": "1", + "rating": "Value * 2", + "view_time": "log1p(Value)", + }) + require.NoError(t, err) + + t.Run("constant expression", func(t *testing.T) { + got := weightExpr.Evaluate("click", 123) + assert.Equal(t, float32(1), got) + }) + + t.Run("value expression", func(t *testing.T) { + got := weightExpr.Evaluate("rating", 3.5) + assert.Equal(t, float32(7), got) + }) + + t.Run("math expression", func(t *testing.T) { + got := weightExpr.Evaluate("view_time", 99) + assert.InDelta(t, float32(math.Log1p(99)), got, 0.001) + }) + + t.Run("missing feedback type returns default", func(t *testing.T) { + got := weightExpr.Evaluate("purchase", 5) + assert.Equal(t, float32(1), got) + }) +} + +func TestToFloat32(t *testing.T) { + tests := []struct { + name string + input any + want float32 + }{ + {"float32", float32(1.5), float32(1.5)}, + {"float64", float64(2.5), float32(2.5)}, + {"int", int(3), float32(3)}, + {"int8", int8(4), float32(4)}, + {"int16", int16(5), float32(5)}, + {"int32", int32(6), float32(6)}, + {"int64", int64(7), float32(7)}, + {"uint", uint(8), float32(8)}, + {"uint8", uint8(9), float32(9)}, + {"uint16", uint16(10), float32(10)}, + {"uint32", uint32(11), float32(11)}, + {"uint64", uint64(12), float32(12)}, + {"unsupported type", "invalid", float32(1.0)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ToFloat32(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/config/config.go b/config/config.go index 8a2ea6512..6a49360aa 100644 --- a/config/config.go +++ b/config/config.go @@ -245,6 +245,7 @@ type DataSourceConfig struct { PositiveFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"positive_feedback_types"` // positive feedback type NegativeFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"negative_feedback_types"` // negative feedback type (highest priority) ReadFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"read_feedback_types"` // feedback type for read event + FeedbackWeight map[string]string `mapstructure:"feedback_weight"` // feedback weight expressions PositiveFeedbackTTL uint `mapstructure:"positive_feedback_ttl" validate:"gte=0"` // time-to-live of positive feedbacks ItemTTL uint `mapstructure:"item_ttl" validate:"gte=0"` // item-to-live of items } diff --git a/dataset/dataset.go b/dataset/dataset.go index a985116ae..90cf3266a 100644 --- a/dataset/dataset.go +++ b/dataset/dataset.go @@ -72,6 +72,7 @@ type CTRSplit interface { Get(i int) ([]int32, []float32, [][]float32, float32) GetItemEmbeddingDim() []int GetItemEmbeddingIndex() *Index + GetWeight(i int) float32 } type Dataset struct { diff --git a/logics/recommend.go b/logics/recommend.go index a542375d5..dd91ed4d1 100644 --- a/logics/recommend.go +++ b/logics/recommend.go @@ -43,6 +43,7 @@ type Recommender struct { config config.RecommendConfig cacheClient cache.Database dataClient data.Database + weightExpr *expression.FeedbackWeightExpression online bool coldstart bool @@ -55,6 +56,10 @@ type Recommender struct { type RecommenderFunc func(ctx context.Context) ([]cache.Score, string, error) func NewRecommender(config config.RecommendConfig, cacheClient cache.Database, dataClient data.Database, online bool, userId string, categories []string) (*Recommender, error) { + weightExpr, err := expression.NewFeedbackWeightExpression(config.DataSource.FeedbackWeight) + if err != nil { + return nil, errors.Trace(err) + } // Load user feedback userFeedback, err := dataClient.GetUserFeedback(context.Background(), userId, new(time.Now())) if err != nil { @@ -78,6 +83,7 @@ func NewRecommender(config config.RecommendConfig, cacheClient cache.Database, d config: config, cacheClient: cacheClient, dataClient: dataClient, + weightExpr: weightExpr, userId: userId, userFeedback: userFeedback, online: online, @@ -249,11 +255,12 @@ func (r *Recommender) recommendItemToItem(name string) RecommenderFunc { } } } - // collect scores + // collect scores with weighted aggregation scores := make(map[string]float64) categories := make(map[string][]string) digests := mapset.NewSet[string]() for _, feedback := range userFeedback { + fbWeight := float64(r.weightExpr.Evaluate(feedback.FeedbackType, feedback.Value)) similarItems, err := r.cacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key(name, feedback.ItemId), r.categories, 0, r.config.CacheSize) if err != nil { return nil, "", errors.Trace(err) @@ -264,7 +271,8 @@ func (r *Recommender) recommendItemToItem(name string) RecommenderFunc { } for _, item := range similarItems { if !r.excludeSet.Contains(item.Id) { - scores[item.Id] += item.Score + // weighted score aggregation + scores[item.Id] += item.Score * fbWeight categories[item.Id] = item.Categories digests.Add(digest) } diff --git a/model/ctr/data.go b/model/ctr/data.go index 0501564e3..b318fccb2 100644 --- a/model/ctr/data.go +++ b/model/ctr/data.go @@ -151,6 +151,8 @@ type Dataset struct { ItemEmbeddingIndex *dataset.Index PositiveCount int NegativeCount int + // Weight support + Weights []float32 // Computed weight for each sample (set by tasks) } // CountUsers returns the number of users. @@ -254,6 +256,15 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) { return indices, values, embedding, dataset.Target[i] } +// GetWeight returns the weight for the i-th sample. +// Returns 1.0 if no weight is set (default behavior). +func (dataset *Dataset) GetWeight(i int) float32 { + if dataset.Weights != nil && i < len(dataset.Weights) { + return dataset.Weights[i] + } + return 1.0 +} + // LoadLibFMFile loads libFM format file. func LoadLibFMFile(path string) (features [][]lo.Tuple2[int32, float32], targets []float32, maxLabel int32, err error) { // open file @@ -356,6 +367,10 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) { testSet.ContextLabels = append(testSet.ContextLabels, dataset.ContextLabels[i]) } testSet.Target = append(testSet.Target, dataset.Target[i]) + + if dataset.Weights != nil { + testSet.Weights = append(testSet.Weights, dataset.Weights[i]) + } if dataset.Target[i] > 0 { testSet.PositiveCount++ } else { @@ -369,6 +384,10 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) { trainSet.ContextLabels = append(trainSet.ContextLabels, dataset.ContextLabels[i]) } trainSet.Target = append(trainSet.Target, dataset.Target[i]) + + if dataset.Weights != nil { + trainSet.Weights = append(trainSet.Weights, dataset.Weights[i]) + } if dataset.Target[i] > 0 { trainSet.PositiveCount++ } else { diff --git a/model/ctr/data_test.go b/model/ctr/data_test.go index e168bbcc4..52f36e6f8 100644 --- a/model/ctr/data_test.go +++ b/model/ctr/data_test.go @@ -190,3 +190,26 @@ func TestDataset_Split(t *testing.T) { assert.Equal(t, 3, test.PositiveCount) assert.Equal(t, 3, test.NegativeCount) } + +func TestDataset_GetWeight(t *testing.T) { + t.Run("no weights returns 1.0", func(t *testing.T) { + dataset := &Dataset{} + assert.Equal(t, float32(1.0), dataset.GetWeight(0)) + }) + + t.Run("with weights", func(t *testing.T) { + dataset := &Dataset{ + Weights: []float32{1.0, 2.0, 3.0}, + } + assert.Equal(t, float32(1.0), dataset.GetWeight(0)) + assert.Equal(t, float32(2.0), dataset.GetWeight(1)) + assert.Equal(t, float32(3.0), dataset.GetWeight(2)) + }) + + t.Run("out of range returns 1.0", func(t *testing.T) { + dataset := &Dataset{ + Weights: []float32{1.0}, + } + assert.Equal(t, float32(1.0), dataset.GetWeight(100)) + }) +} diff --git a/model/ctr/fm.go b/model/ctr/fm.go index d974900bf..3a1990d3a 100644 --- a/model/ctr/fm.go +++ b/model/ctr/fm.go @@ -323,6 +323,7 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf var x []lo.Tuple2[[]int32, []float32] var e [][][]float32 var y []float32 + var w []float32 for i := 0; i < trainSet.Count(); i++ { indices, values, embeddings, target := trainSet.Get(i) // Apply scalers to numerical features @@ -336,8 +337,10 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf x = append(x, lo.Tuple2[[]int32, []float32]{A: indices, B: scaledValues}) e = append(e, embeddings) y = append(y, target) + w = append(w, trainSet.GetWeight(i)) } indices, values, embeddings, target := fm.convertToTensors(x, e, y) + weights := nn.NewTensor(w, len(w)) var optimizer nn.Optimizer switch fm.optimizer { @@ -368,8 +371,9 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf batchEmbedding[k] = embeddings[k].Slice(i, j) } batchTarget := target.Slice(i, j) + batchWeights := weights.Slice(i, j) batchOutput := fm.Forward(batchIndices, batchValues, batchEmbedding, config.Jobs) - batchLoss := nn.BCEWithLogits(batchTarget, batchOutput, nil) + batchLoss := nn.BCEWithLogits(batchTarget, batchOutput, batchWeights) cost += batchLoss.Data()[0] optimizer.ZeroGrad() batchLoss.Backward()