Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
638daec
docs: add AFM weighted training development plan
zhenghaoz Mar 25, 2026
4f8c529
docs: update AFM weighted training plan with expression-based weight …
zhenghaoz Mar 25, 2026
35f11ca
docs: use expr-lang/expr for weight expression parsing
zhenghaoz Mar 25, 2026
b2f4500
feat: add feedback_weight config for weighted training
zhenghaoz Mar 25, 2026
393bacf
docs: mark Phase 2 (config extension) as completed
zhenghaoz Mar 25, 2026
3d36124
docs: refactor architecture - model layer should not be aware of Weig…
zhenghaoz Mar 25, 2026
00ec1b0
docs: simplify design - inline expression parsing in Dataset construc…
zhenghaoz Mar 25, 2026
b5539ae
feat: add weight expression parsing with expr-lang/expr
zhenghaoz Mar 25, 2026
3dd1ac9
docs: mark Phase 1 (expression parsing) as completed
zhenghaoz Mar 25, 2026
de46ed1
feat: add weight support to Dataset
zhenghaoz Mar 25, 2026
f63b128
docs: mark Phase 3 (data layer extension) as completed
zhenghaoz Mar 25, 2026
d6f5de8
feat: add weighted training support to AFM
zhenghaoz Mar 25, 2026
e2eea0a
docs: mark Phase 4 (model layer extension) as partially completed
zhenghaoz Mar 25, 2026
870fcbc
test: add weighted training test for AFM
zhenghaoz Mar 25, 2026
4374895
docs: mark Phase 5 (testing) as completed
zhenghaoz Mar 25, 2026
996f7e2
fix: correct import order in model/ctr tests
zhenghaoz Mar 26, 2026
49310a5
fix: use tagged switch instead of if-else (staticcheck QF1003)
zhenghaoz Mar 26, 2026
6200165
feat: add weighted feedback support for item-to-item recommendation
zhenghaoz Mar 28, 2026
7fa9e5c
refactor: move weight functions to common/weight package
zhenghaoz Mar 28, 2026
fffd30b
test: remove TestFactorizationMachines_WeightedTraining
zhenghaoz Mar 28, 2026
86a64bf
refactor: move weight functions from model/ctr to logics package
zhenghaoz Mar 28, 2026
7b60168
refactor: simplify ctr.Dataset to only keep SampleWeights
zhenghaoz Mar 28, 2026
65d9243
refactor: rename SampleWeights to Weights in ctr.Dataset
zhenghaoz Mar 28, 2026
06643d7
refactor: update weight test tolerance type and clean up unused impor…
zhenghaoz Apr 13, 2026
d0fe7a6
feat: add FeedbackWeightExpression and related tests for dynamic weig…
zhenghaoz Apr 13, 2026
26a6a39
feat: integrate FeedbackWeightExpression into Recommender and simplif…
zhenghaoz Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions common/expression/weight.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright 2026 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"math"

"github.com/expr-lang/expr"
"github.com/expr-lang/expr/vm"
"github.com/gorse-io/gorse/common/log"
"go.uber.org/zap"
)

// FeedbackWeightExpression wraps compiled weight expressions by feedback type.
type FeedbackWeightExpression struct {
programs map[string]*vm.Program
}

func env(value float64) map[string]any {
return map[string]any{
"Value": value,
"abs": math.Abs,
"ceil": math.Ceil,
"floor": math.Floor,
"round": math.Round,
"sqrt": math.Sqrt,
"cbrt": math.Cbrt,
"log": math.Log,
"log2": math.Log2,
"log10": math.Log10,
"log1p": math.Log1p,
"exp": math.Exp,
"exp2": math.Exp2,
"expm1": math.Expm1,
"pow": math.Pow,
"sin": math.Sin,
"cos": math.Cos,
"tan": math.Tan,
"asin": math.Asin,
"acos": math.Acos,
"atan": math.Atan,
"sinh": math.Sinh,
"cosh": math.Cosh,
"tanh": math.Tanh,
"max": math.Max,
"min": math.Min,
}
}

// NewFeedbackWeightExpression compiles weight expressions and wraps them by feedback type.
func NewFeedbackWeightExpression(feedbackWeight map[string]string) (*FeedbackWeightExpression, error) {
programs := make(map[string]*vm.Program, len(feedbackWeight))
for feedbackType, exprStr := range feedbackWeight {
program, err := expr.Compile(exprStr, expr.Env(env(0.0)))
if err != nil {
return nil, err
}
programs[feedbackType] = program
}
return &FeedbackWeightExpression{programs: programs}, nil
}

// Evaluate evaluates the weight for the given feedback type and value.
// If there is no expression for the feedback type, the default weight 1.0 is returned.
func (weightExpr *FeedbackWeightExpression) Evaluate(feedbackType string, value float64) float32 {
program, ok := weightExpr.programs[feedbackType]
if !ok {
return 1.0
}
result, err := expr.Run(program, env(value))
if err != nil {
log.Logger().Error("failed to evaluate weight expression",
zap.String("feedback_type", feedbackType),
zap.Float64("value", value),
zap.Error(err))
return 1.0
}
return ToFloat32(result)
}

// ToFloat32 converts various numeric types to float32.
func ToFloat32(v any) float32 {
switch val := v.(type) {
case float32:
return val
case float64:
return float32(val)
case int:
return float32(val)
case int8:
return float32(val)
case int16:
return float32(val)
case int32:
return float32(val)
case int64:
return float32(val)
case uint:
return float32(val)
case uint8:
return float32(val)
case uint16:
return float32(val)
case uint32:
return float32(val)
case uint64:
return float32(val)
default:
return 1.0
}
}
101 changes: 101 additions & 0 deletions common/expression/weight_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2026 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"math"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewFeedbackWeightExpression(t *testing.T) {
t.Run("valid expressions", func(t *testing.T) {
weightExpr, err := NewFeedbackWeightExpression(map[string]string{
"click": "1",
"rating": "Value * 2",
})
require.NoError(t, err)
require.NotNil(t, weightExpr)
assert.Len(t, weightExpr.programs, 2)
})

t.Run("invalid expression", func(t *testing.T) {
weightExpr, err := NewFeedbackWeightExpression(map[string]string{
"click": "invalid++",
})
assert.Error(t, err)
assert.Nil(t, weightExpr)
})
}

func TestFeedbackWeightExpression_Evaluate(t *testing.T) {
weightExpr, err := NewFeedbackWeightExpression(map[string]string{
"click": "1",
"rating": "Value * 2",
"view_time": "log1p(Value)",
})
require.NoError(t, err)

t.Run("constant expression", func(t *testing.T) {
got := weightExpr.Evaluate("click", 123)
assert.Equal(t, float32(1), got)
})

t.Run("value expression", func(t *testing.T) {
got := weightExpr.Evaluate("rating", 3.5)
assert.Equal(t, float32(7), got)
})

t.Run("math expression", func(t *testing.T) {
got := weightExpr.Evaluate("view_time", 99)
assert.InDelta(t, float32(math.Log1p(99)), got, 0.001)
})

t.Run("missing feedback type returns default", func(t *testing.T) {
got := weightExpr.Evaluate("purchase", 5)
assert.Equal(t, float32(1), got)
})
}

func TestToFloat32(t *testing.T) {
tests := []struct {
name string
input any
want float32
}{
{"float32", float32(1.5), float32(1.5)},
{"float64", float64(2.5), float32(2.5)},
{"int", int(3), float32(3)},
{"int8", int8(4), float32(4)},
{"int16", int16(5), float32(5)},
{"int32", int32(6), float32(6)},
{"int64", int64(7), float32(7)},
{"uint", uint(8), float32(8)},
{"uint8", uint8(9), float32(9)},
{"uint16", uint16(10), float32(10)},
{"uint32", uint32(11), float32(11)},
{"uint64", uint64(12), float32(12)},
{"unsupported type", "invalid", float32(1.0)},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ToFloat32(tt.input)
assert.Equal(t, tt.want, got)
})
}
}
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ type DataSourceConfig struct {
PositiveFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"positive_feedback_types"` // positive feedback type
NegativeFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"negative_feedback_types"` // negative feedback type (highest priority)
ReadFeedbackTypes []expression.FeedbackTypeExpression `mapstructure:"read_feedback_types"` // feedback type for read event
FeedbackWeight map[string]string `mapstructure:"feedback_weight"` // feedback weight expressions
PositiveFeedbackTTL uint `mapstructure:"positive_feedback_ttl" validate:"gte=0"` // time-to-live of positive feedbacks
ItemTTL uint `mapstructure:"item_ttl" validate:"gte=0"` // item-to-live of items
}
Expand Down
1 change: 1 addition & 0 deletions dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type CTRSplit interface {
Get(i int) ([]int32, []float32, [][]float32, float32)
GetItemEmbeddingDim() []int
GetItemEmbeddingIndex() *Index
GetWeight(i int) float32
}

type Dataset struct {
Expand Down
12 changes: 10 additions & 2 deletions logics/recommend.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type Recommender struct {
config config.RecommendConfig
cacheClient cache.Database
dataClient data.Database
weightExpr *expression.FeedbackWeightExpression

online bool
coldstart bool
Expand All @@ -55,6 +56,10 @@ type Recommender struct {
type RecommenderFunc func(ctx context.Context) ([]cache.Score, string, error)

func NewRecommender(config config.RecommendConfig, cacheClient cache.Database, dataClient data.Database, online bool, userId string, categories []string) (*Recommender, error) {
weightExpr, err := expression.NewFeedbackWeightExpression(config.DataSource.FeedbackWeight)
if err != nil {
return nil, errors.Trace(err)
}
// Load user feedback
userFeedback, err := dataClient.GetUserFeedback(context.Background(), userId, new(time.Now()))
if err != nil {
Expand All @@ -78,6 +83,7 @@ func NewRecommender(config config.RecommendConfig, cacheClient cache.Database, d
config: config,
cacheClient: cacheClient,
dataClient: dataClient,
weightExpr: weightExpr,
userId: userId,
userFeedback: userFeedback,
online: online,
Expand Down Expand Up @@ -249,11 +255,12 @@ func (r *Recommender) recommendItemToItem(name string) RecommenderFunc {
}
}
}
// collect scores
// collect scores with weighted aggregation
scores := make(map[string]float64)
categories := make(map[string][]string)
digests := mapset.NewSet[string]()
for _, feedback := range userFeedback {
fbWeight := float64(r.weightExpr.Evaluate(feedback.FeedbackType, feedback.Value))
similarItems, err := r.cacheClient.SearchScores(ctx, cache.ItemToItem, cache.Key(name, feedback.ItemId), r.categories, 0, r.config.CacheSize)
if err != nil {
return nil, "", errors.Trace(err)
Expand All @@ -264,7 +271,8 @@ func (r *Recommender) recommendItemToItem(name string) RecommenderFunc {
}
for _, item := range similarItems {
if !r.excludeSet.Contains(item.Id) {
scores[item.Id] += item.Score
// weighted score aggregation
scores[item.Id] += item.Score * fbWeight
categories[item.Id] = item.Categories
digests.Add(digest)
}
Expand Down
19 changes: 19 additions & 0 deletions model/ctr/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ type Dataset struct {
ItemEmbeddingIndex *dataset.Index
PositiveCount int
NegativeCount int
// Weight support
Weights []float32 // Computed weight for each sample (set by tasks)
}

// CountUsers returns the number of users.
Expand Down Expand Up @@ -254,6 +256,15 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) {
return indices, values, embedding, dataset.Target[i]
}

// GetWeight returns the weight for the i-th sample.
// Returns 1.0 if no weight is set (default behavior).
func (dataset *Dataset) GetWeight(i int) float32 {
if dataset.Weights != nil && i < len(dataset.Weights) {
return dataset.Weights[i]
}
return 1.0
}

// LoadLibFMFile loads libFM format file.
func LoadLibFMFile(path string) (features [][]lo.Tuple2[int32, float32], targets []float32, maxLabel int32, err error) {
// open file
Expand Down Expand Up @@ -356,6 +367,10 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
testSet.ContextLabels = append(testSet.ContextLabels, dataset.ContextLabels[i])
}
testSet.Target = append(testSet.Target, dataset.Target[i])

if dataset.Weights != nil {
testSet.Weights = append(testSet.Weights, dataset.Weights[i])
}
if dataset.Target[i] > 0 {
testSet.PositiveCount++
} else {
Expand All @@ -369,6 +384,10 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
trainSet.ContextLabels = append(trainSet.ContextLabels, dataset.ContextLabels[i])
}
trainSet.Target = append(trainSet.Target, dataset.Target[i])

if dataset.Weights != nil {
trainSet.Weights = append(trainSet.Weights, dataset.Weights[i])
}
if dataset.Target[i] > 0 {
trainSet.PositiveCount++
} else {
Expand Down
23 changes: 23 additions & 0 deletions model/ctr/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,26 @@ func TestDataset_Split(t *testing.T) {
assert.Equal(t, 3, test.PositiveCount)
assert.Equal(t, 3, test.NegativeCount)
}

func TestDataset_GetWeight(t *testing.T) {
t.Run("no weights returns 1.0", func(t *testing.T) {
dataset := &Dataset{}
assert.Equal(t, float32(1.0), dataset.GetWeight(0))
})

t.Run("with weights", func(t *testing.T) {
dataset := &Dataset{
Weights: []float32{1.0, 2.0, 3.0},
}
assert.Equal(t, float32(1.0), dataset.GetWeight(0))
assert.Equal(t, float32(2.0), dataset.GetWeight(1))
assert.Equal(t, float32(3.0), dataset.GetWeight(2))
})

t.Run("out of range returns 1.0", func(t *testing.T) {
dataset := &Dataset{
Weights: []float32{1.0},
}
assert.Equal(t, float32(1.0), dataset.GetWeight(100))
})
}
6 changes: 5 additions & 1 deletion model/ctr/fm.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf
var x []lo.Tuple2[[]int32, []float32]
var e [][][]float32
var y []float32
var w []float32
for i := 0; i < trainSet.Count(); i++ {
indices, values, embeddings, target := trainSet.Get(i)
// Apply scalers to numerical features
Expand All @@ -336,8 +337,10 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf
x = append(x, lo.Tuple2[[]int32, []float32]{A: indices, B: scaledValues})
e = append(e, embeddings)
y = append(y, target)
w = append(w, trainSet.GetWeight(i))
}
indices, values, embeddings, target := fm.convertToTensors(x, e, y)
weights := nn.NewTensor(w, len(w))

var optimizer nn.Optimizer
switch fm.optimizer {
Expand Down Expand Up @@ -368,8 +371,9 @@ func (fm *AFM) Fit(ctx context.Context, trainSet, testSet dataset.CTRSplit, conf
batchEmbedding[k] = embeddings[k].Slice(i, j)
}
batchTarget := target.Slice(i, j)
batchWeights := weights.Slice(i, j)
batchOutput := fm.Forward(batchIndices, batchValues, batchEmbedding, config.Jobs)
batchLoss := nn.BCEWithLogits(batchTarget, batchOutput, nil)
batchLoss := nn.BCEWithLogits(batchTarget, batchOutput, batchWeights)
cost += batchLoss.Data()[0]
optimizer.ZeroGrad()
batchLoss.Backward()
Expand Down
Loading