@@ -102,38 +102,42 @@ def combine_pipelines(
102102
103103
104104def reciprocal_rank_stage (
105- score_field : str , penalty : float = 0 , ** kwargs : Any
105+ score_field : str , penalty : float = 0 , weight : float = 1 , ** kwargs : Any
106106) -> list [dict [str , Any ]]:
107- """Stage adds Reciprocal Rank Fusion weighting.
107+ """
108+ Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.
108109
109- First, it pushes documents retrieved from previous stage
110- into a temporary sub-document. It then unwinds to establish
111- the rank to each and applies the penalty.
110+ First, it groups documents into an array, assigns rank by array index,
111+ and then computes a weighted RRF score.
112112
113113 Args:
114- score_field: A unique string to identify the search being ranked
115- penalty: A non-negative float.
116- extra_fields: Any fields other than text_field that one wishes to keep.
114+ score_field: A unique string to identify the search being ranked.
115+ penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
116+ weight: A float multiplier for this source's importance.
117+ **kwargs: Ignored; allows future extensions or passthrough args.
117118
118119 Returns:
119- RRF score
120+ Aggregation pipeline stage for weighted RRF scoring.
120121 """
121122
122- rrf_pipeline = [
123+ return [
123124 {"$group" : {"_id" : None , "docs" : {"$push" : "$$ROOT" }}},
124125 {"$unwind" : {"path" : "$docs" , "includeArrayIndex" : "rank" }},
125126 {
126127 "$addFields" : {
127- f"docs.{ score_field } " : {"$divide" : [1.0 , {"$add" : ["$rank" , penalty , 1 ]}]},
128+ f"docs.{ score_field } " : {
129+ "$multiply" : [
130+ weight ,
131+ {"$divide" : [1.0 , {"$add" : ["$rank" , penalty , 1 ]}]},
132+ ]
133+ },
128134 "docs.rank" : "$rank" ,
129135 "_id" : "$docs._id" ,
130136 }
131137 },
132138 {"$replaceRoot" : {"newRoot" : "$docs" }},
133139 ]
134140
135- return rrf_pipeline # type: ignore[return-value]
136-
137141
138142def final_hybrid_stage (scores_fields : list [str ], limit : int , ** kwargs : Any ) -> list [dict [str , Any ]]:
139143 """Sum weighted scores, sort, and apply limit.
0 commit comments