Skip to content

Commit 088af0f

Browse files
committed
Add SearchExpressions
1 parent 3b90847 commit 088af0f

File tree

4 files changed

+180
-30
lines changed

4 files changed

+180
-30
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ repos:
8181
rev: "v2.2.6"
8282
hooks:
8383
- id: codespell
84-
args: ["-L", "nin"]
84+
args: ["-L", "nin", "SearchIn", "searchin"]

django_mongodb_backend/compiler.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from django.utils.functional import cached_property
1818
from pymongo import ASCENDING, DESCENDING
1919

20-
from .functions import SearchScore
20+
from .functions import SearchExpression
2121
from .query import MongoQuery, wrap_database_errors
2222

2323

@@ -115,31 +115,31 @@ def _prepare_search_expressions_for_pipeline(self, expression, target, search_id
115115
for sub_expr in self._get_search_expressions(expression):
116116
alias = f"__search_expr.search{next(search_idx)}"
117117
replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias)
118-
return replacements, searches
118+
return replacements, list(searches.values())
119119

120120
def _prepare_search_query_for_aggregation_pipeline(self, order_by):
121121
replacements = {}
122-
searches = {}
123-
search_idx = itertools.count(start=1)
122+
searches = []
123+
annotation_group_idx = itertools.count(start=1)
124124
for target, expr in self.query.annotation_select.items():
125125
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
126-
expr, target, search_idx
126+
expr, target, annotation_group_idx
127127
)
128128
replacements.update(new_replacements)
129-
searches.update(expr_searches)
129+
searches += expr_searches
130130

131131
for expr, _ in order_by:
132132
new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline(
133-
expr, None, search_idx
133+
expr, None, annotation_group_idx
134134
)
135135
replacements.update(new_replacements)
136-
searches.update(expr_searches)
136+
searches += expr_searches
137137

138138
having_replacements, having_group = self._prepare_search_expressions_for_pipeline(
139-
self.having, None, search_idx
139+
self.having, None, annotation_group_idx
140140
)
141141
replacements.update(having_replacements)
142-
searches.update(having_group)
142+
searches += having_group
143143
return searches, replacements
144144

145145
def _prepare_annotations_for_aggregation_pipeline(self, order_by):
@@ -249,14 +249,21 @@ def _build_aggregation_pipeline(self, ids, group):
249249
pipeline.append({"$unset": "_id"})
250250
return pipeline
251251

252+
def _compound_searches_queries(self, searches):
253+
if not searches:
254+
return []
255+
if len(searches) > 1:
256+
raise ValueError("Cannot perform more than one search operation.")
257+
return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}]
258+
252259
def pre_sql_setup(self, with_col_aliases=False):
253260
extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases)
254261
searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline(
255262
order_by
256263
)
257264
group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by)
258265
all_replacements = {**search_replacements, **group_replacements}
259-
self.search_pipeline = searches
266+
self.search_pipeline = self._compound_searches_queries(searches)
260267
# query.group_by is either:
261268
# - None: no GROUP BY
262269
# - True: group by select fields
@@ -607,7 +614,7 @@ def _get_aggregate_expressions(self, expr):
607614
return self._get_all_expressions_of_type(expr, Aggregate)
608615

609616
def _get_search_expressions(self, expr):
610-
return self._get_all_expressions_of_type(expr, SearchScore)
617+
return self._get_all_expressions_of_type(expr, SearchExpression)
611618

612619
def _get_all_expressions_of_type(self, expr, target_type):
613620
stack = [expr]

django_mongodb_backend/functions.py

Lines changed: 148 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from django.conf import settings
44
from django.db import NotSupportedError
55
from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField
6-
from django.db.models.expressions import F, Func, Value
6+
from django.db.models.expressions import Func
77
from django.db.models.functions import JSONArray
88
from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
99
from django.db.models.functions.datetime import (
@@ -38,9 +38,8 @@
3838
Trim,
3939
Upper,
4040
)
41-
from django.utils.deconstruct import deconstructible
4241

43-
from .query_utils import process_lhs, process_rhs
42+
from .query_utils import process_lhs
4443

4544
MONGO_OPERATORS = {
4645
Ceil: "ceil",
@@ -269,28 +268,160 @@ def trunc_time(self, compiler, connection):
269268
}
270269

271270

272-
@deconstructible(path="django_mongodb_backend.functions.SearchScore")
273-
class SearchScore(Expression):
274-
def __init__(self, path, value, operation="equals", **kwargs):
275-
self.extra_params = kwargs
276-
self.lhs = path if hasattr(path, "resolve_expression") else F(path)
277-
if not isinstance(value, str):
278-
# TODO HANDLE VALUES LIKE Value("some string")
279-
raise ValueError("STRING NEEDED")
280-
self.rhs = Value(value)
281-
self.operation = operation
271+
class SearchExpression(Expression):
272+
optional_arguments = []
273+
274+
def __init__(self, *args, score=None, **kwargs):
275+
self.score = score
276+
# Support positional arguments first
277+
if args and len(args) > len(self.expected_arguments) + len(self.optional_arguments):
278+
raise ValueError(
279+
f"Too many positional arguments: expected {len(self.expected_arguments)}"
280+
)
281+
# TODO: REFACTOR.
282+
for arg_name, arg_type in self.expected_arguments:
283+
if args:
284+
value = args.pop(0)
285+
if arg_name in kwargs:
286+
raise ValueError(
287+
f"Argument '{arg_name}' was provided both positionally and as keyword"
288+
)
289+
elif arg_name in kwargs:
290+
value = kwargs.pop(arg_name)
291+
else:
292+
raise ValueError(f"Missing required argument '{arg_name}'")
293+
if not isinstance(value, arg_type):
294+
raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}")
295+
setattr(self, arg_name, value)
296+
297+
for arg_name, arg_type in self.optional_arguments:
298+
if args:
299+
if arg_name in kwargs:
300+
raise ValueError(
301+
f"Argument '{arg_name}' was provided both positionally and as keyword"
302+
)
303+
value = args.pop(0)
304+
elif arg_name in kwargs:
305+
value = kwargs.pop(arg_name)
306+
else:
307+
value = None
308+
if value is not None and not isinstance(value, arg_type):
309+
raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}")
310+
setattr(self, arg_name, value)
311+
if kwargs:
312+
raise ValueError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
282313
super().__init__(output_field=FloatField())
283314

315+
def get_source_expressions(self):
316+
return []
317+
318+
def __str__(self):
319+
args = ", ".join(map(str, self.get_source_expressions()))
320+
return f"{self.search_type}({args})"
321+
284322
def __repr__(self):
285-
return f"search {self.field} = {self.value} | {self.extra_params}"
323+
return str(self)
324+
325+
def as_sql(self, compiler, connection):
326+
return "", []
327+
328+
def _get_query_index(self, field, compiler):
329+
for search_indexes in compiler.collection.list_search_indexes():
330+
mappings = search_indexes["latestDefinition"]["mappings"]
331+
if mappings["dynamic"] or field in mappings["fields"]:
332+
return search_indexes["name"]
333+
return "default"
286334

287335
def as_mql(self, compiler, connection):
288-
lhs = process_lhs(self, compiler, connection)
289-
rhs = process_rhs(self, compiler, connection)
290-
return {"$search": {self.operation: {"path": lhs[:1], "query": rhs, **self.extra_params}}}
336+
params = {}
337+
for arg_name, _ in self.expected_arguments:
338+
params[arg_name] = getattr(self, arg_name)
339+
if self.score:
340+
params["score"] = self.score.as_mql(compiler, connection)
341+
index = self._get_query_index(params.get("path"), compiler)
342+
return {"$search": {self.search_type: params, "index": index}}
343+
344+
345+
class SearchAutocomplete(SearchExpression):
346+
search_type = "autocomplete"
347+
expected_arguments = [("path", str), ("query", str)]
348+
349+
350+
class SearchEquals(SearchExpression):
351+
search_type = "equals"
352+
expected_arguments = [("path", str), ("value", str)]
353+
354+
355+
class SearchExists(SearchExpression):
356+
search_type = "equals"
357+
expected_arguments = [("path", str)]
358+
359+
360+
class SearchIn(SearchExpression):
361+
search_type = "equals"
362+
expected_arguments = [("path", str), ("value", str | list)]
363+
364+
365+
class SearchPhrase(SearchExpression):
366+
search_type = "equals"
367+
expected_arguments = [("path", str), ("value", str | list)]
368+
optional_arguments = [("slop", int), ("synonyms", str)]
369+
370+
371+
"""
372+
IT IS BEING REFACTORED
373+
class SearchOperator(SearchExpression):
374+
_operation_params = {
375+
"autocomplete": ("path", {"query"}),
376+
"equals": ("path", {"value"}),
377+
"exists": ("path", {}),
378+
"in": ("path", {"value"}),
379+
"phrase": ("path", {"query"}),
380+
"queryString": ("defaultPath", {"query"}),
381+
"range": ("path", {("lt", "lte"), ("gt", "gte")}),
382+
"regex": ("path", {"query"}),
383+
"text": ("path", {"query"}),
384+
"wildcard": ("path", {"query"}),
385+
"geoShape": ("path", {"query", "relation", "geometry"}),
386+
"geoWithin": ("path", {("box", "circle", "geometry")}),
387+
"moreLikeThis": (None, {"like"}),
388+
"near": ("path", {"origin", "pivot"}),
389+
}
390+
391+
def __init__(self, operation, **kwargs):
392+
self.lhs = path if path is None or hasattr(path, "resolve_expression") else F(path)
393+
self.operation = operation
394+
self.lhs_field, needed_params = self._operation_params[self.operation]
395+
rhs_values = {}
396+
for param in needed_params:
397+
if isinstance(param, str):
398+
rhs_values[param] = kwargs.pop(param)
399+
else:
400+
for key in param:
401+
if key in kwargs:
402+
rhs_values[param] = kwargs.pop(key)
403+
break
404+
else:
405+
raise ValueError(f"Not found either {', '.join(param)}")
406+
407+
self.rhs_values = rhs_values
408+
self.extra_params = kwargs
409+
super().__init__(output_field=FloatField())
410+
411+
def as_mql(self, compiler, connection):
412+
params = {**self.rhs_values, **self.extra_params}
413+
if self.lhs:
414+
lhs_mql = process_lhs(self, compiler, connection)
415+
params[self.lhs_field] = lhs_mql[1:]
416+
index = self._get_query_index(compiler, connection)
417+
return {"$search": {self.operation: params, "index": index}}
418+
419+
def get_source_expressions(self):
420+
return [self.lhs, self.rhs, self.extra_params]
291421
292422
def as_sql(self, compiler, connection):
293423
return "", []
424+
"""
294425

295426

296427
def register_functions():

tests/queries_/test_search.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from django.test import TestCase
2+
3+
from .models import Article
4+
5+
from django_mongodb_backend.functions import SearchEquals
6+
7+
8+
class SearchTests(TestCase):
9+
def test_1(self):
10+
Article.objects.create(headline="cross", number=1, body="body")
11+
aa = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")).all()
12+
print(aa)

0 commit comments

Comments
 (0)