|
3 | 3 | from django.conf import settings
|
4 | 4 | from django.db import NotSupportedError
|
5 | 5 | from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField
|
6 |
| -from django.db.models.expressions import F, Func, Value |
| 6 | +from django.db.models.expressions import Func |
7 | 7 | from django.db.models.functions import JSONArray
|
8 | 8 | from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf
|
9 | 9 | from django.db.models.functions.datetime import (
|
|
38 | 38 | Trim,
|
39 | 39 | Upper,
|
40 | 40 | )
|
41 |
| -from django.utils.deconstruct import deconstructible |
42 | 41 |
|
43 |
| -from .query_utils import process_lhs, process_rhs |
| 42 | +from .query_utils import process_lhs |
44 | 43 |
|
45 | 44 | MONGO_OPERATORS = {
|
46 | 45 | Ceil: "ceil",
|
@@ -269,28 +268,160 @@ def trunc_time(self, compiler, connection):
|
269 | 268 | }
|
270 | 269 |
|
271 | 270 |
|
272 |
| -@deconstructible(path="django_mongodb_backend.functions.SearchScore") |
273 |
| -class SearchScore(Expression): |
274 |
| - def __init__(self, path, value, operation="equals", **kwargs): |
275 |
| - self.extra_params = kwargs |
276 |
| - self.lhs = path if hasattr(path, "resolve_expression") else F(path) |
277 |
| - if not isinstance(value, str): |
278 |
| - # TODO HANDLE VALUES LIKE Value("some string") |
279 |
| - raise ValueError("STRING NEEDED") |
280 |
| - self.rhs = Value(value) |
281 |
| - self.operation = operation |
| 271 | +class SearchExpression(Expression): |
| 272 | + optional_arguments = [] |
| 273 | + |
| 274 | + def __init__(self, *args, score=None, **kwargs): |
| 275 | + self.score = score |
| 276 | + # Support positional arguments first |
| 277 | + if args and len(args) > len(self.expected_arguments) + len(self.optional_arguments): |
| 278 | + raise ValueError( |
| 279 | + f"Too many positional arguments: expected {len(self.expected_arguments)}" |
| 280 | + ) |
| 281 | + # TODO: REFACTOR. |
| 282 | + for arg_name, arg_type in self.expected_arguments: |
| 283 | + if args: |
| 284 | + value = args.pop(0) |
| 285 | + if arg_name in kwargs: |
| 286 | + raise ValueError( |
| 287 | + f"Argument '{arg_name}' was provided both positionally and as keyword" |
| 288 | + ) |
| 289 | + elif arg_name in kwargs: |
| 290 | + value = kwargs.pop(arg_name) |
| 291 | + else: |
| 292 | + raise ValueError(f"Missing required argument '{arg_name}'") |
| 293 | + if not isinstance(value, arg_type): |
| 294 | + raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") |
| 295 | + setattr(self, arg_name, value) |
| 296 | + |
| 297 | + for arg_name, arg_type in self.optional_arguments: |
| 298 | + if args: |
| 299 | + if arg_name in kwargs: |
| 300 | + raise ValueError( |
| 301 | + f"Argument '{arg_name}' was provided both positionally and as keyword" |
| 302 | + ) |
| 303 | + value = args.pop(0) |
| 304 | + elif arg_name in kwargs: |
| 305 | + value = kwargs.pop(arg_name) |
| 306 | + else: |
| 307 | + value = None |
| 308 | + if value is not None and not isinstance(value, arg_type): |
| 309 | + raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") |
| 310 | + setattr(self, arg_name, value) |
| 311 | + if kwargs: |
| 312 | + raise ValueError(f"Unexpected keyword arguments: {list(kwargs.keys())}") |
282 | 313 | super().__init__(output_field=FloatField())
|
283 | 314 |
|
| 315 | + def get_source_expressions(self): |
| 316 | + return [] |
| 317 | + |
| 318 | + def __str__(self): |
| 319 | + args = ", ".join(map(str, self.get_source_expressions())) |
| 320 | + return f"{self.search_type}({args})" |
| 321 | + |
284 | 322 | def __repr__(self):
|
285 |
| - return f"search {self.field} = {self.value} | {self.extra_params}" |
| 323 | + return str(self) |
| 324 | + |
| 325 | + def as_sql(self, compiler, connection): |
| 326 | + return "", [] |
| 327 | + |
| 328 | + def _get_query_index(self, field, compiler): |
| 329 | + for search_indexes in compiler.collection.list_search_indexes(): |
| 330 | + mappings = search_indexes["latestDefinition"]["mappings"] |
| 331 | + if mappings["dynamic"] or field in mappings["fields"]: |
| 332 | + return search_indexes["name"] |
| 333 | + return "default" |
286 | 334 |
|
287 | 335 | def as_mql(self, compiler, connection):
|
288 |
| - lhs = process_lhs(self, compiler, connection) |
289 |
| - rhs = process_rhs(self, compiler, connection) |
290 |
| - return {"$search": {self.operation: {"path": lhs[:1], "query": rhs, **self.extra_params}}} |
| 336 | + params = {} |
| 337 | + for arg_name, _ in self.expected_arguments: |
| 338 | + params[arg_name] = getattr(self, arg_name) |
| 339 | + if self.score: |
| 340 | + params["score"] = self.score.as_mql(compiler, connection) |
| 341 | + index = self._get_query_index(params.get("path"), compiler) |
| 342 | + return {"$search": {self.search_type: params, "index": index}} |
| 343 | + |
| 344 | + |
| 345 | +class SearchAutocomplete(SearchExpression): |
| 346 | + search_type = "autocomplete" |
| 347 | + expected_arguments = [("path", str), ("query", str)] |
| 348 | + |
| 349 | + |
| 350 | +class SearchEquals(SearchExpression): |
| 351 | + search_type = "equals" |
| 352 | + expected_arguments = [("path", str), ("value", str)] |
| 353 | + |
| 354 | + |
| 355 | +class SearchExists(SearchExpression): |
| 356 | + search_type = "equals" |
| 357 | + expected_arguments = [("path", str)] |
| 358 | + |
| 359 | + |
| 360 | +class SearchIn(SearchExpression): |
| 361 | + search_type = "equals" |
| 362 | + expected_arguments = [("path", str), ("value", str | list)] |
| 363 | + |
| 364 | + |
| 365 | +class SearchPhrase(SearchExpression): |
| 366 | + search_type = "equals" |
| 367 | + expected_arguments = [("path", str), ("value", str | list)] |
| 368 | + optional_arguments = [("slop", int), ("synonyms", str)] |
| 369 | + |
| 370 | + |
| 371 | +""" |
| 372 | +IT IS BEING REFACTORED |
| 373 | +class SearchOperator(SearchExpression): |
| 374 | + _operation_params = { |
| 375 | + "autocomplete": ("path", {"query"}), |
| 376 | + "equals": ("path", {"value"}), |
| 377 | + "exists": ("path", {}), |
| 378 | + "in": ("path", {"value"}), |
| 379 | + "phrase": ("path", {"query"}), |
| 380 | + "queryString": ("defaultPath", {"query"}), |
| 381 | + "range": ("path", {("lt", "lte"), ("gt", "gte")}), |
| 382 | + "regex": ("path", {"query"}), |
| 383 | + "text": ("path", {"query"}), |
| 384 | + "wildcard": ("path", {"query"}), |
| 385 | + "geoShape": ("path", {"query", "relation", "geometry"}), |
| 386 | + "geoWithin": ("path", {("box", "circle", "geometry")}), |
| 387 | + "moreLikeThis": (None, {"like"}), |
| 388 | + "near": ("path", {"origin", "pivot"}), |
| 389 | + } |
| 390 | +
|
| 391 | + def __init__(self, operation, **kwargs): |
| 392 | + self.lhs = path if path is None or hasattr(path, "resolve_expression") else F(path) |
| 393 | + self.operation = operation |
| 394 | + self.lhs_field, needed_params = self._operation_params[self.operation] |
| 395 | + rhs_values = {} |
| 396 | + for param in needed_params: |
| 397 | + if isinstance(param, str): |
| 398 | + rhs_values[param] = kwargs.pop(param) |
| 399 | + else: |
| 400 | + for key in param: |
| 401 | + if key in kwargs: |
| 402 | + rhs_values[param] = kwargs.pop(key) |
| 403 | + break |
| 404 | + else: |
| 405 | + raise ValueError(f"Not found either {', '.join(param)}") |
| 406 | +
|
| 407 | + self.rhs_values = rhs_values |
| 408 | + self.extra_params = kwargs |
| 409 | + super().__init__(output_field=FloatField()) |
| 410 | +
|
| 411 | + def as_mql(self, compiler, connection): |
| 412 | + params = {**self.rhs_values, **self.extra_params} |
| 413 | + if self.lhs: |
| 414 | + lhs_mql = process_lhs(self, compiler, connection) |
| 415 | + params[self.lhs_field] = lhs_mql[1:] |
| 416 | + index = self._get_query_index(compiler, connection) |
| 417 | + return {"$search": {self.operation: params, "index": index}} |
| 418 | +
|
| 419 | + def get_source_expressions(self): |
| 420 | + return [self.lhs, self.rhs, self.extra_params] |
291 | 421 |
|
292 | 422 | def as_sql(self, compiler, connection):
|
293 | 423 | return "", []
|
| 424 | +""" |
294 | 425 |
|
295 | 426 |
|
296 | 427 | def register_functions():
|
|
0 commit comments