Skip to content

Commit 7dde566

Browse files
authored
INTPYTHON-752 Update reciprocal_rank_stage (#13)
1 parent c53c0ab commit 7dde566

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

pymongo_search_utils/pipeline.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,38 +102,42 @@ def combine_pipelines(
102102

103103

104104
def reciprocal_rank_stage(
105-
score_field: str, penalty: float = 0, **kwargs: Any
105+
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
106106
) -> list[dict[str, Any]]:
107-
"""Stage adds Reciprocal Rank Fusion weighting.
107+
"""
108+
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.
108109
109-
First, it pushes documents retrieved from previous stage
110-
into a temporary sub-document. It then unwinds to establish
111-
the rank to each and applies the penalty.
110+
First, it groups documents into an array, assigns rank by array index,
111+
and then computes a weighted RRF score.
112112
113113
Args:
114-
score_field: A unique string to identify the search being ranked
115-
penalty: A non-negative float.
116-
extra_fields: Any fields other than text_field that one wishes to keep.
114+
score_field: A unique string to identify the search being ranked.
115+
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
116+
weight: A float multiplier for this source's importance.
117+
**kwargs: Ignored; allows future extensions or passthrough args.
117118
118119
Returns:
119-
RRF score
120+
Aggregation pipeline stage for weighted RRF scoring.
120121
"""
121122

122-
rrf_pipeline = [
123+
return [
123124
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
124125
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
125126
{
126127
"$addFields": {
127-
f"docs.{score_field}": {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
128+
f"docs.{score_field}": {
129+
"$multiply": [
130+
weight,
131+
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
132+
]
133+
},
128134
"docs.rank": "$rank",
129135
"_id": "$docs._id",
130136
}
131137
},
132138
{"$replaceRoot": {"newRoot": "$docs"}},
133139
]
134140

135-
return rrf_pipeline # type: ignore[return-value]
136-
137141

138142
def final_hybrid_stage(scores_fields: list[str], limit: int, **kwargs: Any) -> list[dict[str, Any]]:
139143
"""Sum weighted scores, sort, and apply limit.

tests/test_pipeline.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def test_basic_reciprocal_rank(self):
196196
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
197197
{
198198
"$addFields": {
199-
"docs.text_score": {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]},
199+
"docs.text_score": {
200+
"$multiply": [1, {"$divide": [1.0, {"$add": ["$rank", 0, 1]}]}]
201+
},
200202
"docs.rank": "$rank",
201203
"_id": "$docs._id",
202204
}
@@ -210,7 +212,7 @@ def test_reciprocal_rank_with_penalty(self):
210212
result = reciprocal_rank_stage(score_field="vector_score", penalty=60)
211213

212214
add_fields_stage = result[2]["$addFields"]
213-
divide_expr = add_fields_stage["docs.vector_score"]["$divide"]
215+
divide_expr = add_fields_stage["docs.vector_score"]["$multiply"][1]["$divide"]
214216
add_expr = divide_expr[1]["$add"]
215217

216218
assert add_expr == ["$rank", 60, 1]
@@ -225,7 +227,11 @@ def test_reciprocal_rank_with_kwargs(self):
225227
result = reciprocal_rank_stage(score_field="test_score", penalty=10, extra_param="ignored")
226228

227229
assert len(result) == 4
228-
assert result[2]["$addFields"]["docs.test_score"]["$divide"][1]["$add"] == ["$rank", 10, 1]
230+
assert result[2]["$addFields"]["docs.test_score"]["$multiply"][1]["$divide"][1]["$add"] == [
231+
"$rank",
232+
10,
233+
1,
234+
]
229235

230236

231237
class TestFinalHybridStage:

0 commit comments

Comments
 (0)