Skip to content

Commit 3912fb9

Browse files
Add Lambda Weights for PairWiseLoss and ListMLELoss
1 parent a94d9fc commit 3912fb9

File tree

2 files changed

+636
-0
lines changed

2 files changed

+636
-0
lines changed
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
import abc
2+
from typing import Optional, Dict, Any, Callable
3+
from keras_rs.src import types
4+
5+
import keras
6+
from keras import ops
7+
8+
def check_tensor_shapes(tensors):
9+
"""Checks the tensor shapes to be compatible."""
10+
if not tensors:
11+
return
12+
shapes = [ops.shape(ops.convert_to_tensor(tensor)) for tensor in tensors]
13+
14+
# Checking the tensors should have rank 2
15+
for i, shape in enumerate(shapes):
16+
if len(shape) != 2:
17+
raise ValueError(f"Tensor {i} must have rank 2, got rank {len(shape)}")
18+
19+
# Checking the tensor shapes are equal
20+
reference_shape = shapes[0]
21+
for i, shape in enumerate(shapes[1:], 1):
22+
if not ops.all(ops.equal(shape, reference_shape)):
23+
raise ValueError(f"Tensor {i} shape {shape} incompatible with reference shape {reference_shape}")
24+
25+
def apply_pairwise_op(
26+
x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]) -> types.Tensor:
27+
return op(
28+
ops.expand_dims(x, axis=-1),
29+
ops.expand_dims(x, axis=-2),)
30+
31+
def is_label_valid(labels):
32+
"""Returns a boolen tensor, indicating whether the labels are valid."""
33+
labels = ops.convert_to_tensor(labels)
34+
return ops.greater_equal(labels, 0.)
35+
36+
def get_valid_pairs_and_clean_labels(labels):
37+
"""Returns a boolean Tensor for valid pairs and cleaned labels."""
38+
labels = ops.convert_to_tensor(labels)
39+
40+
# Check that labels has rank 2
41+
labels_shape = ops.shape(labels)
42+
if len(labels_shape) != 2:
43+
raise ValueError(f"Expected labels to have rank 2, but got rank {len(labels_shape)}")
44+
45+
is_valid = is_label_valid(labels)
46+
47+
valid_pairs = apply_pairwise_op(is_valid, ops.logical_and)
48+
labels = ops.where(is_valid, labels, ops.zeros_like(labels))
49+
return valid_pairs, labels
50+
51+
def _get_shuffle_indices(shape, mask=None, shuffle_ties=False, seed=None):
52+
53+
# Produces random values when ties are to be shuffled, otherwise all zeros.
54+
if shuffle_ties:
55+
shuffle_values = keras.random.uniform(shape, seed=seed)
56+
else:
57+
shuffle_values = ops.zeros(shape, dtype="float32")
58+
59+
# Given shuffle_values are consistently within [0, 1), we can safely augment
60+
# entries corresponding to mask=False by 2.0. This ensures their placement
61+
# at the end during the argsort operation.
62+
if mask is not None:
63+
shuffle_values = ops.where(mask, shuffle_values, shuffle_values + 2.0)
64+
65+
# Determines indices by performing an argsort on the shuffle values.
66+
return ops.argsort(shuffle_values, True)
67+
68+
def sort_by_scores(scores, features_list, topn=None):
69+
scores = ops.cast(scores, "float32")
70+
71+
# Check that scores has rank 2
72+
scores_shape = ops.shape(scores)
73+
if len(scores_shape) != 2:
74+
raise ValueError(f"Expected scores to have rank 2, but got rank {len(scores_shape)}")
75+
76+
batch_size = ops.shape(scores)[0]
77+
list_size = ops.shape(scores)[1]
78+
79+
if topn is None:
80+
topn = list_size
81+
topn = ops.minimum(topn, list_size)
82+
83+
# Get top-k indices
84+
_, indices = ops.top_k(scores, topn, sorted=True) # [B, topn]
85+
86+
# Now gather features using manual indexing
87+
sorted_features = []
88+
for feat in features_list:
89+
# feat: [B, list_size]
90+
batch_indices = ops.arange(batch_size)[:, None] # [B, 1]
91+
batch_indices = ops.repeat(batch_indices, topn, axis=1) # [B, topn]
92+
gather_indices = ops.stack([batch_indices, indices], axis=-1) # [B, topn, 2]
93+
94+
# Reshape to flat indexing
95+
feat_flat = ops.reshape(feat, [-1])
96+
batch_indices_flat = ops.reshape(gather_indices[:, :, 0], [-1])
97+
list_indices_flat = ops.reshape(gather_indices[:, :, 1], [-1])
98+
flat_index = batch_indices_flat * list_size + list_indices_flat
99+
100+
gathered = ops.take(feat_flat, flat_index)
101+
gathered = ops.reshape(gathered, [batch_size, topn])
102+
sorted_features.append(gathered)
103+
104+
return sorted_features
105+
106+
def inverse_max_dcg(labels,
107+
gain_fn=lambda labels: ops.power(2.0, labels) - 1.,
108+
rank_discount_fn=lambda rank: 1. / ops.log1p(rank),
109+
topn=None):
110+
ideal_sorted_labels, = sort_by_scores(labels, [labels], topn=topn)
111+
rank = ops.arange(ops.shape(ideal_sorted_labels)[1]) + 1 # shape: (list_size,)
112+
rank = ops.cast(rank, dtype="float32")
113+
114+
# Fix broadcasting: shape (1, list_size)
115+
discount = ops.expand_dims(rank_discount_fn(rank), axis=0)
116+
117+
# Shape now compatible: (batch_size, list_size)
118+
discounted_gain = gain_fn(ideal_sorted_labels) * discount
119+
120+
discounted_gain = ops.sum(discounted_gain, axis=1, keepdims=True)
121+
return ops.where(
122+
ops.greater(discounted_gain, 0.),
123+
1. / discounted_gain,
124+
ops.zeros_like(discounted_gain)
125+
)
126+
127+
def log2_inverse(ranks):
128+
ranks_float = ops.cast(ranks, dtype="float32")
129+
return 1.0 / (ops.log(ranks_float + 1.0) / ops.log(2.0))
130+
131+
class LambdaWeight(abc.ABC):
132+
133+
"""This interface is for ranking metric optimization using weights within
134+
the LambdaLoss framework (https://ai.google/research/pubs/pub47258).
135+
Implementations of this interface provide concrete lambda weight models
136+
that can be used with standard losses like logistic loss and softmax loss.
137+
138+
This implementation is compatible with TensorFlow, JAX, and PyTorch,
139+
operating across these backends through the unified Keras 3 API
140+
"""
141+
142+
@abc.abstractmethod
143+
def pair_weights(self, labels, ranks):
144+
"""
145+
Returns pairwise weights for ranking loss.
146+
147+
Args:
148+
labels: Tensor of shape [batch_size, list_size] with relevance labels
149+
ranks: Tensor of shape [batch_size, list_size] with current ranks (1-based)
150+
151+
Returns:
152+
A tensor that can weight example pairs with shape
153+
[batch_size, list_size, list_size].
154+
"""
155+
raise NotImplementedError('Calling an abstract method.')
156+
157+
#@abc.abstractmethod
158+
def individual_weights(self, labels, ranks):
159+
"""Returns the weight tensor for individual examples.
160+
161+
Args:
162+
labels: A dense tensor of labels with shape [batch_size, list_size].
163+
ranks: A dense tensor of ranks with the same shape as `labels` that are
164+
sorted by logits.
165+
166+
Returns:
167+
A tensor that can weight individual examples with shape [batch_size, list_size].
168+
"""
169+
return labels
170+
171+
class LabelDiffLambdaWeight(LambdaWeight):
172+
"""A simple LambdaWeight to compute the pair label difference."""
173+
174+
def pair_weights(self, labels, ranks):
175+
"""Returns the absolute label difference for each pair."""
176+
return ops.abs(apply_pairwise_op(labels, ops.subtract))
177+
178+
class AbstractDCGLambdaWeight(LambdaWeight):
179+
"""Abstract LambdaWeight for Discounted Cumulative Gain (DCG) metric."""
180+
181+
def __init__(self,
182+
topn=None,
183+
gain_fn=lambda label: label,
184+
rank_discount_fn=lambda rank: 1. / rank,
185+
normalized=False):
186+
"""Initializer.
187+
188+
Ranks are 1-based, not 0-based.
189+
190+
Args:
191+
topn: (int) The topn for the DCG metric.
192+
gain_fn: (function) Transforms labels.
193+
rank_discount_fn: (function) The rank discount function.
194+
normalized: (bool) If True, normalize weight by the max DCG.
195+
"""
196+
self._topn = topn
197+
self._gain_fn = gain_fn
198+
self._rank_discount_fn = rank_discount_fn
199+
self._normalized = normalized
200+
201+
@abc.abstractmethod
202+
def pair_rank_discount(self, ranks, topn):
203+
"""Computes the rank-based discount for a pair.
204+
205+
Args:
206+
ranks: A 2D `Tensor` for the 1-based ranks.
207+
topn: A scalar `Tensor` for the topn cutoff.
208+
209+
Returns:
210+
A pairwise weights `Tensor` based on the `rank_discount_fn`.
211+
"""
212+
raise NotImplementedError('Calling an abstract method.')
213+
214+
def pair_weights(self, labels, ranks):
215+
"""See `_LambdaWeight`."""
216+
check_tensor_shapes([labels, ranks])
217+
valid_pair, labels = get_valid_pairs_and_clean_labels(labels)
218+
gain = self._gain_fn(labels)
219+
if self._normalized:
220+
gain *= inverse_max_dcg(
221+
labels,
222+
gain_fn=self._gain_fn,
223+
rank_discount_fn=self._rank_discount_fn,
224+
topn=self._topn)
225+
pair_gain = apply_pairwise_op(gain, ops.subtract)
226+
pair_gain *= ops.cast(valid_pair, dtype="float32")
227+
228+
list_size = ops.shape(labels)[1]
229+
topn = self._topn or list_size
230+
pair_weight = ops.abs(pair_gain) * self.pair_rank_discount(ranks, topn)
231+
pair_weight *= ops.cast(ops.shape(labels)[1], dtype="float32")
232+
return pair_weight
233+
234+
def individual_weights(self, labels, ranks):
235+
check_tensor_shapes([labels, ranks])
236+
labels = ops.convert_to_tensor(labels)
237+
labels = ops.where(
238+
is_label_valid(labels), labels, ops.zeros_like(labels))
239+
gain = self._gain_fn(labels)
240+
if self._normalized:
241+
gain *= inverse_max_dcg(
242+
labels,
243+
gain_fn=self._gain_fn,
244+
rank_discount_fn=self._rank_discount_fn,
245+
topn=self._topn)
246+
rank_discount = self._rank_discount_fn(ops.cast(ranks, dtype="float32"))
247+
return gain * rank_discount
248+
249+
class DCGLambdaWeight(AbstractDCGLambdaWeight):
250+
"""LambdaWeight for Discounted Cumulative Gain metric."""
251+
252+
def __init__(self,
253+
topn=None,
254+
gain_fn=lambda label: label,
255+
rank_discount_fn=lambda rank: 1. / rank,
256+
normalized=False,
257+
smooth_fraction=0.):
258+
"""Initializer.
259+
260+
Ranks are 1-based, not 0-based. Given rank i and j, there are two types of
261+
pair weights:
262+
u = |rank_discount_fn(|i-j|) - rank_discount_fn(|i-j| + 1)|
263+
v = |rank_discount_fn(i) - rank_discount_fn(j)|
264+
where u is the newly introduced one in LambdaLoss paper
265+
(https://ai.google/research/pubs/pub47258) and v is the original one in the
266+
LambdaMART paper "From RankNet to LambdaRank to LambdaMART: An Overview".
267+
The final pair weight contribution of ranks is
268+
(1-smooth_fraction) * u + smooth_fraction * v.
269+
270+
Args:
271+
topn: (int) The topn for the DCG metric.
272+
gain_fn: (function) Transforms labels.
273+
rank_discount_fn: (function) The rank discount function.
274+
normalized: (bool) If True, normalize weight by the max DCG.
275+
smooth_fraction: (float) parameter to control the contribution from
276+
LambdaMART.
277+
"""
278+
super().__init__(topn, gain_fn, rank_discount_fn, normalized)
279+
if not 0. <= smooth_fraction <= 1.:
280+
raise ValueError('smooth_fraction %s should be in range [0, 1].' %
281+
smooth_fraction)
282+
self._smooth_fraction = smooth_fraction
283+
284+
def pair_rank_discount(self, ranks, topn):
285+
"""See `_LambdaWeight`."""
286+
287+
def _discount_for_relative_rank_diff():
288+
"""Rank-based discount in the LambdaLoss paper."""
289+
pair_valid_rank = apply_pairwise_op(ops.less_equal(ranks, topn), ops.logical_or
290+
)
291+
rank_diff = ops.cast(
292+
ops.abs(apply_pairwise_op(ranks, ops.subtract)), dtype="float32")
293+
pair_discount = ops.where(
294+
ops.logical_and(ops.greater(rank_diff, 0), pair_valid_rank),
295+
ops.abs(
296+
self._rank_discount_fn(rank_diff) -
297+
self._rank_discount_fn(rank_diff + 1)), ops.zeros_like(rank_diff))
298+
return pair_discount
299+
300+
def _discount_for_absolute_rank():
301+
"""Standard discount in the LambdaMART paper."""
302+
# When the rank discount is (1 / rank) for example, the discount is
303+
# |1 / r_i - 1 / r_j|. When i or j > topn, the discount becomes 0.
304+
rank_discount = ops.where(
305+
ops.greater(ranks, topn),
306+
ops.zeros_like(ops.cast(ranks, dtype="float32")),
307+
self._rank_discount_fn(ops.cast(ranks, dtype="float32")))
308+
pair_discount = ops.abs(apply_pairwise_op(rank_discount, ops.subtract))
309+
return pair_discount
310+
311+
u = _discount_for_relative_rank_diff()
312+
v = _discount_for_absolute_rank()
313+
pair_discount = (1. - self._smooth_fraction) * u + self._smooth_fraction * v
314+
pair_mask = apply_pairwise_op(ops.less_equal(ranks, topn), ops.logical_or)
315+
return pair_discount * ops.cast(pair_mask, dtype="float32")
316+
317+
class ListMLELambdaWeight(LambdaWeight):
318+
"""
319+
Lambda weights for ListMLE (List Maximum Likelihood Estimation) loss.
320+
321+
ListMLE optimizes the probability of generating the correct ranking order.
322+
It uses position-based discounting to emphasize top positions more.
323+
"""
324+
325+
def __init__(self, rank_discount_fn: Optional[Callable] = None):
326+
"""
327+
Initialize ListMLE lambda weights.
328+
329+
Args:
330+
rank_discount_fn: Function that takes ranks and returns discount weights.
331+
Default is logarithmic discount (1/log2(rank+1)).
332+
"""
333+
self.rank_discount_fn = rank_discount_fn or log2_inverse
334+
335+
336+
def _validate_inputs(self, labels, ranks):
337+
"""Validate input tensors have correct shapes and types."""
338+
labels = ops.convert_to_tensor(labels)
339+
ranks = ops.convert_to_tensor(ranks)
340+
341+
if labels.shape != ranks.shape:
342+
raise ValueError(f"Labels shape {labels.shape} must match ranks shape {ranks.shape}")
343+
344+
# Ensure ranks are 1-based (minimum value should be 1)
345+
min_rank = ops.min(ranks)
346+
if min_rank < 1:
347+
raise ValueError(f"Ranks must be 1-based (minimum value is {min_rank})")
348+
349+
return labels, ranks
350+
351+
def pair_weights(self, labels, ranks):
352+
"""
353+
ListMLE doesn't use pairwise weights as it's a listwise method.
354+
Returns None to indicate this method is not applicable.
355+
"""
356+
return None
357+
358+
def individual_weights(self, labels, ranks):
359+
"""
360+
Calculate individual weights for ListMLE loss.
361+
362+
The weights are computed as rank discounts applied uniformly across all items.
363+
This emphasizes top positions more than lower positions.
364+
365+
Args:
366+
labels: Tensor [batch_size, list_size] with relevance labels
367+
ranks: Tensor [batch_size, list_size] with current ranks (1-based)
368+
369+
Returns:
370+
Tensor [batch_size, list_size] with position discount weights
371+
"""
372+
labels, ranks = self._validate_inputs(labels, ranks)
373+
374+
# Apply rank discount function
375+
rank_discount = self.rank_discount_fn(ops.cast(ranks, dtype="float32"))
376+
377+
# Return uniform base weights scaled by rank discount
378+
base_weights = ops.ones_like(labels, dtype="float32")
379+
return base_weights * rank_discount

0 commit comments

Comments
 (0)