Skip to content

Commit 3013982

Browse files
Update keras_rs/src/losses/lambda_weights.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 317b012 commit 3013982

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

keras_rs/src/losses/lambda_weights.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,10 @@ def _validate_inputs(self, labels, ranks):
351351
def pair_weights(self, labels, ranks):
352352
"""
353353
ListMLE doesn't use pairwise weights as it's a listwise method.
354-
Returns None to indicate this method is not applicable.
354+
Returns a tensor of zeros to indicate no pairwise contribution.
355355
"""
356-
return None
356+
shape = ops.shape(labels)
357+
return ops.zeros((shape[0], shape[1], shape[1]), dtype="float32")
357358

358359
def individual_weights(self, labels, ranks):
359360
"""

0 commit comments

Comments
 (0)