From 4269038f8475da17afbc4f28f1ea929697f6d2a8 Mon Sep 17 00:00:00 2001 From: Josey Date: Wed, 15 Oct 2025 12:25:05 -0400 Subject: [PATCH 1/2] Add `If-Match` header to sample patch endpoints * initial commit * don't send etag back on 412 easy way to cheat the system * remove unused * refactor keeping match closer to db * address PR comments and cleanup save code and refactor tests to account for if-match * check for Etag on success * update last_modified at on if match save * tweaks * use ETag.create --- fiftyone/server/decorators.py | 21 +- fiftyone/server/routes/sample.py | 238 +++++- fiftyone/server/utils/__init__.py | 4 +- fiftyone/server/utils/http.py | 41 + fiftyone/server/utils/json/__init__.py | 32 + fiftyone/server/utils/json/encoder.py | 45 + .../utils/{ => json}/jsonpatch/__init__.py | 7 +- .../utils/{ => json}/jsonpatch/methods.py | 0 .../utils/{ => json}/jsonpatch/patch.py | 16 +- fiftyone/server/utils/json/serialization.py | 70 ++ .../server/utils/json_transform/__init__.py | 10 - .../server/utils/json_transform/transform.py | 66 -- fiftyone/server/utils/json_transform/types.py | 38 - tests/unittests/sample_route_tests.py | 771 +++++++++++------- .../utils/jsonpatch/test_json_patch_patch.py | 2 +- .../server/utils/jsonpatch/test_jsonpatch.py | 2 +- .../utils/jsonpatch/test_jsonpatch_methods.py | 4 +- 17 files changed, 877 insertions(+), 490 deletions(-) create mode 100644 fiftyone/server/utils/http.py create mode 100644 fiftyone/server/utils/json/__init__.py create mode 100644 fiftyone/server/utils/json/encoder.py rename fiftyone/server/utils/{ => json}/jsonpatch/__init__.py (94%) rename fiftyone/server/utils/{ => json}/jsonpatch/methods.py (100%) rename fiftyone/server/utils/{ => json}/jsonpatch/patch.py (89%) create mode 100644 fiftyone/server/utils/json/serialization.py delete mode 100644 fiftyone/server/utils/json_transform/__init__.py delete mode 100644 fiftyone/server/utils/json_transform/transform.py delete mode 100644 fiftyone/server/utils/json_transform/types.py diff --git a/fiftyone/server/decorators.py b/fiftyone/server/decorators.py index f85a1a9781..7a836b13ad 100644 --- a/fiftyone/server/decorators.py +++ b/fiftyone/server/decorators.py @@ -6,38 +6,23 @@ | """ -from json import JSONEncoder import traceback import typing as t import logging -from bson import json_util -import numpy as np from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException from starlette.responses import JSONResponse, Response from starlette.requests import Request from fiftyone.core.utils import run_sync_task - - -class Encoder(JSONEncoder): - """Custom JSON encoder that handles numpy types.""" - - def default(self, o): - if isinstance(o, np.floating): - return float(o) - - if isinstance(o, np.integer): - return int(o) - - return JSONEncoder.default(self, o) +from fiftyone.server import utils async def create_response(response: dict): """Creates a JSON response from the given dictionary.""" return Response( - await run_sync_task(lambda: json_util.dumps(response, cls=Encoder)), + await run_sync_task(lambda: utils.json.dumps(response)), headers={"Content-Type": "application/json"}, ) @@ -52,7 +37,7 @@ async def wrapper( try: body = await request.body() payload = body.decode("utf-8") - data = json_util.loads(payload) if payload else {} + data = utils.json.loads(payload) response = await func(endpoint, request, data, *args) if isinstance(response, Response): return response diff --git a/fiftyone/server/routes/sample.py b/fiftyone/server/routes/sample.py index d0d6fde89d..0dac859cc0 100644 --- a/fiftyone/server/routes/sample.py +++ b/fiftyone/server/routes/sample.py @@ -6,7 +6,10 @@ | """ +import base64 +import datetime import logging +from typing import Any, List, Union from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException @@ -14,56 +17,180 @@ import fiftyone as fo import fiftyone.core.odm.utils as fou -from typing import List -from fiftyone.server.utils.jsonpatch import parse -from fiftyone.server.utils import transform_json -from fiftyone.server.decorators import route -from typing import Any + +from fiftyone.server import decorators, utils logger = logging.getLogger(__name__) -def get_sample(dataset_id: str, sample_id: str) -> fo.Sample: - """Retrieves a sample from a dataset. +def get_if_last_modified_at( + request: Request, +) -> Union[datetime.datetime, None]: + """Parses the If-Match header from the request, if present, and returns + the last modified date. Args: - dataset_id: the dataset ID - sample_id: the sample ID + request: The request + + Raises: + HTTPException: If the If-Match header could not be parsed into a + valid date Returns: - the sample + The last modified date, or None if the header is not present + """ + + if_last_modified_at: Union[str, datetime.datetime, None] = None + if request.headers.get("If-Match"): + if_match, _ = utils.http.ETag.parse(request.headers["If-Match"]) + + # As ETag - Currently this is just a based64 encode string of + # last_modified_at + try: + if_last_modified_at = datetime.datetime.fromisoformat( + base64.b64decode(if_match.encode("utf-8")).decode("utf-8") + ) + except Exception: + ... + + # As ISO date + try: + if_last_modified_at = datetime.datetime.fromisoformat(if_match) + except Exception: + ... + + # As Unix timestamp + try: + if_last_modified_at = datetime.datetime.fromtimestamp( + float(if_match) + ) + except Exception: + ... + + if if_last_modified_at is None: + raise HTTPException( + status_code=400, detail="Invalid If-Match header" + ) + return if_last_modified_at + + +def get_sample( + dataset_id: str, + sample_id: str, + if_last_modified_at: Union[datetime.datetime, None], +) -> fo.Sample: + """Retrieves a sample from a dataset. + + Args: + dataset_id: The ID of the dataset + sample_id: The ID of the sample + if_last_modified_at: The if last modified date, if it exists Raises: - HTTPException: if the dataset or sample is not found + HTTPException: If the dataset or sample is not found or the if last + modified date is present and does not match the sample + + Returns: + The sample """ + try: dataset = fou.load_dataset(id=dataset_id) - except ValueError: + except ValueError as err: raise HTTPException( status_code=404, detail=f"Dataset '{dataset_id}' not found", - ) + ) from err try: sample = dataset[sample_id] - except KeyError: + except KeyError as err: raise HTTPException( status_code=404, detail=f"Sample '{sample_id}' not found in dataset '{dataset_id}'", - ) + ) from err + + # Fail early, if very out-of-date + if if_last_modified_at is not None: + if sample.last_modified_at != if_last_modified_at: + raise HTTPException( + status_code=412, detail="If-Match condition failed" + ) return sample +def generate_sample_etag(sample: fo.Sample) -> str: + """Generates an ETag for a sample based on its last modified date. + + Args: + sample: The sample + Returns: + The ETag + """ + value = base64.b64encode( + sample.last_modified_at.isoformat().encode("utf-8") + ).decode("utf-8") + + return utils.http.ETag.create(value) + + +def save_sample( + sample: fo.Sample, if_last_modified_at: Union[datetime.datetime, None] +) -> str: + """Saves a sample to the database. + + Args: + sample: The sample to save + if_last_modified_at: The if last modified date, if it exists + + Returns: + The ETag of the saved sample + """ + + if if_last_modified_at is not None: + d = sample.to_mongo_dict(include_id=True) + d["last_modified_at"] = datetime.datetime.now(datetime.timezone.utc) + + # pylint:disable-next=protected-access + update_result = sample.dataset._sample_collection.replace_one( + { + # pylint:disable-next=protected-access + "_id": sample._id, + "last_modified_at": {"$eq": if_last_modified_at}, + }, + d, + ) + + if update_result.matched_count == 0: + raise HTTPException( + status_code=412, detail="If-Match condition failed" + ) + else: + sample.save() + + # Ensure last_modified_at reflects persisted state before computing + # ETag + try: + sample.reload(hard=True) + except Exception: + # best-effort; still return response + ... + + return generate_sample_etag(sample) + + def handle_json_patch(target: Any, patch_list: List[dict]) -> Any: """Applies a list of JSON patch operations to a target object.""" try: - patches = parse(patch_list, transform_fn=transform_json) - except Exception as e: + patches = utils.json.parse_jsonpatch( + patch_list, transform_fn=utils.json.deserialize + ) + except Exception as err: raise HTTPException( status_code=400, - detail=f"Failed to parse patches due to: {e}", - ) + detail=f"Failed to parse patches due to: {err}", + ) from err errors = {} for i, p in enumerate(patches): @@ -82,14 +209,17 @@ def handle_json_patch(target: Any, patch_list: List[dict]) -> Any: class Sample(HTTPEndpoint): - @route + """Sample endpoints.""" + + @decorators.route async def patch(self, request: Request, data: dict) -> dict: """Applies a list of field updates to a sample. See: https://datatracker.ietf.org/doc/html/rfc6902 Args: - request: Starlette request with dataset_id and sample_id in path params + request: Starlette request with dataset_id and sample_id in path + params data: A dict mapping field names to values. Returns: @@ -104,21 +234,26 @@ async def patch(self, request: Request, data: dict) -> dict: dataset_id, ) - sample = get_sample(dataset_id, sample_id) + if_last_modified_at = get_if_last_modified_at(request) + + sample = get_sample(dataset_id, sample_id, if_last_modified_at) content_type = request.headers.get("Content-Type", "") ctype = content_type.split(";", 1)[0].strip().lower() if ctype == "application/json": - result = self._handle_patch(sample, data) + self._handle_patch(sample, data) elif ctype == "application/json-patch+json": - result = handle_json_patch(sample, data) + handle_json_patch(sample, data) else: raise HTTPException( - status_code=415, - detail=f"Unsupported Content-Type '{ctype}'", + status_code=415, detail=f"Unsupported Content-Type '{ctype}'" ) - sample.save() - return result.to_dict(include_private=True) + + etag = save_sample(sample, if_last_modified_at) + + return utils.json.JSONResponse( + utils.json.serialize(sample), headers={"ETag": etag} + ) def _handle_patch(self, sample: fo.Sample, data: dict) -> dict: errors = {} @@ -128,27 +263,27 @@ def _handle_patch(self, sample: fo.Sample, data: dict) -> dict: sample.clear_field(field_name) continue - sample[field_name] = transform_json(value) + sample[field_name] = utils.json.deserialize(value) except Exception as e: errors[field_name] = str(e) if errors: - raise HTTPException( - status_code=400, - detail=errors, - ) + raise HTTPException(status_code=400, detail=errors) return sample class SampleField(HTTPEndpoint): - @route + """Sample field endpoints.""" + + @decorators.route async def patch(self, request: Request, data: dict) -> dict: """Applies a list of field updates to a sample field in a list by id. See: https://datatracker.ietf.org/doc/html/rfc6902 Args: - request: Starlette request with dataset_id and sample_id in path params + request: Starlette request with dataset_id and sample_id in path + params data: patch of type op, path, value. Returns: @@ -160,39 +295,52 @@ async def patch(self, request: Request, data: dict) -> dict: field_id = request.path_params["field_id"] logger.info( - "Received patch request for field %s with ID %s on sample %s in dataset %s", + ( + "Received patch request for field %s with ID %s on sample %s " + "in dataset %s" + ), path, field_id, sample_id, dataset_id, ) - sample = get_sample(dataset_id, sample_id) + if_last_modified_at = get_if_last_modified_at(request) + + sample = get_sample(dataset_id, sample_id, if_last_modified_at) try: field_list = sample.get_field(path) - except Exception as e: + except Exception as err: raise HTTPException( status_code=404, detail=f"Field '{path}' not found in sample '{sample_id}'", - ) + ) from err if not isinstance(field_list, list): raise HTTPException( - status_code=400, - detail=f"Field '{path}' is not a list", + status_code=400, detail=f"Field '{path}' is not a list" ) field = next((f for f in field_list if f.id == field_id), None) if field is None: raise HTTPException( status_code=404, - detail=f"Field with id '{field_id}' not found in field '{path}'", + detail=( + f"Field with id '{field_id}' not found in field " + f"'{path}'" + ), ) - result = handle_json_patch(field, data) - sample.save() - return result.to_dict() + handle_json_patch(field, data) + + etag = save_sample(sample, if_last_modified_at) + + updated_field = next((f for f in field_list if f.id == field_id), None) + + return utils.json.JSONResponse( + utils.json.serialize(updated_field), headers={"ETag": etag} + ) SampleRoutes = [ diff --git a/fiftyone/server/utils/__init__.py b/fiftyone/server/utils/__init__.py index d27c4e45ad..67e8ae9992 100644 --- a/fiftyone/server/utils/__init__.py +++ b/fiftyone/server/utils/__init__.py @@ -15,9 +15,7 @@ import fiftyone.core.dataset as fod import fiftyone.core.fields as fof -from fiftyone.server.utils.json_transform import ( - transform as transform_json, -) # auto-register resource types +from fiftyone.server.utils import http, json _cache = cachetools.TTLCache(maxsize=10, ttl=900) # ttl in seconds diff --git a/fiftyone/server/utils/http.py b/fiftyone/server/utils/http.py new file mode 100644 index 0000000000..62f0ca0165 --- /dev/null +++ b/fiftyone/server/utils/http.py @@ -0,0 +1,41 @@ +""" +HTTP utils + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +from typing import Any + + +class ETag: + """Utility class for creating and parsing ETag strings.""" + + @staticmethod + def create(value: Any, is_weak: bool = False) -> str: + """Creates an ETag string from the given value.""" + # Wrap in quotes if not already quoted + if not (value.startswith('"') and value.endswith('"')): + value = f'"{value}"' + + # Add weak prefix if necessary + if is_weak: + return f'W/"{value}"' + + return value + + @staticmethod + def parse(etag: str) -> tuple[str, bool]: + """Parses an ETag string into its value and whether it is weak.""" + + is_weak = False + if etag.startswith("W/"): + is_weak = True + etag = etag[2:] # Remove "W/" prefix + + # Remove surrounding quotes (ETags are typically quoted) + if etag.startswith('"') and etag.endswith('"'): + etag = etag[1:-1] + + return etag, is_weak diff --git a/fiftyone/server/utils/json/__init__.py b/fiftyone/server/utils/json/__init__.py new file mode 100644 index 0000000000..6a16a2a8e3 --- /dev/null +++ b/fiftyone/server/utils/json/__init__.py @@ -0,0 +1,32 @@ +""" + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +from typing import Any, Union +from bson import json_util + +from starlette.responses import JSONResponse as StarletteJSONResponse + +from fiftyone.server.utils.json.encoder import Encoder +from fiftyone.server.utils.json.jsonpatch import parse as parse_jsonpatch +from fiftyone.server.utils.json.serialization import deserialize, serialize + + +def dumps(obj: Any) -> str: + """Serializes an object to a JSON-formatted string.""" + return json_util.dumps(obj, cls=Encoder) + + +def loads(s: Union[str, bytes, bytearray, None]) -> Any: + """Deserializes a JSON-formatted string to a Python object.""" + return json_util.loads(s) if s else {} + + +class JSONResponse(StarletteJSONResponse): + """Custom JSON response that uses the custom Encoder.""" + + def render(self, content: Any) -> bytes: + return dumps(content).encode("utf-8") diff --git a/fiftyone/server/utils/json/encoder.py b/fiftyone/server/utils/json/encoder.py new file mode 100644 index 0000000000..ef60fb37c4 --- /dev/null +++ b/fiftyone/server/utils/json/encoder.py @@ -0,0 +1,45 @@ +""" + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +from typing import Any, Union +import json + +from bson import json_util +import numpy as np +from starlette.responses import JSONResponse as StarletteJSONResponse + + +class Encoder(json.JSONEncoder): + """Custom JSON encoder that handles numpy types.""" + + def default(self, o): + """Override the default method to handle numpy types.""" + + if isinstance(o, np.floating): + return float(o) + + if isinstance(o, np.integer): + return int(o) + + return json.JSONEncoder.default(self, o) + + +def dumps(obj: Any) -> str: + """Serializes an object to a JSON-formatted string.""" + return json_util.dumps(obj, cls=Encoder) + + +def loads(s: Union[str, bytes, bytearray, None]) -> Any: + """Deserializes a JSON-formatted string to a Python object.""" + return json_util.loads(s) if s else {} + + +class JSONResponse(StarletteJSONResponse): + """Custom JSON response that uses the custom Encoder.""" + + def render(self, content: Any) -> bytes: + return dumps(content).encode("utf-8") diff --git a/fiftyone/server/utils/jsonpatch/__init__.py b/fiftyone/server/utils/json/jsonpatch/__init__.py similarity index 94% rename from fiftyone/server/utils/jsonpatch/__init__.py rename to fiftyone/server/utils/json/jsonpatch/__init__.py index 4348933e19..4999cefcc7 100644 --- a/fiftyone/server/utils/jsonpatch/__init__.py +++ b/fiftyone/server/utils/json/jsonpatch/__init__.py @@ -8,7 +8,8 @@ from typing import Any, Callable, Iterable, Optional, Union -from fiftyone.server.utils.jsonpatch.methods import ( + +from fiftyone.server.utils.json.jsonpatch.methods import ( add, copy, move, @@ -16,7 +17,8 @@ replace, test, ) -from fiftyone.server.utils.jsonpatch.patch import ( + +from fiftyone.server.utils.json.jsonpatch.patch import ( Patch, Operation, Add, @@ -27,6 +29,7 @@ Test, ) + __PATCH_MAP = { Operation.ADD: Add, Operation.COPY: Copy, diff --git a/fiftyone/server/utils/jsonpatch/methods.py b/fiftyone/server/utils/json/jsonpatch/methods.py similarity index 100% rename from fiftyone/server/utils/jsonpatch/methods.py rename to fiftyone/server/utils/json/jsonpatch/methods.py diff --git a/fiftyone/server/utils/jsonpatch/patch.py b/fiftyone/server/utils/json/jsonpatch/patch.py similarity index 89% rename from fiftyone/server/utils/jsonpatch/patch.py rename to fiftyone/server/utils/json/jsonpatch/patch.py index f43a32caff..2cc930e3f9 100644 --- a/fiftyone/server/utils/jsonpatch/patch.py +++ b/fiftyone/server/utils/json/jsonpatch/patch.py @@ -12,7 +12,7 @@ from typing import Any, Generic, TypeVar, Union -from fiftyone.server.utils.jsonpatch import methods +from fiftyone.server.utils.json.jsonpatch import methods T = TypeVar("T") V = TypeVar("V") @@ -37,12 +37,14 @@ class Patch(abc.ABC): op: Operation - def __init_subclass__(cls): + def __init_subclass__(cls, **kwargs): if not inspect.isabstract(cls) and not isinstance( getattr(cls, "op", None), Operation ): raise TypeError("Subclass must define 'op' class variable") + super().__init_subclass__(**kwargs) + def __init__(self, path: str): self._pointer = methods.to_json_pointer(path) @@ -67,10 +69,10 @@ def apply(self, src: Any) -> Any: """ -class PatchWithValue(Patch, Generic[T], abc.ABC): +class PatchWithValue(Patch, abc.ABC, Generic[V]): """A JSON Patch operation that requires a value.""" - def __init__(self, path: str, value: T): + def __init__(self, path: str, value: V): super().__init__(path) self.value = value @@ -88,7 +90,7 @@ def from_(self) -> str: return self._from_pointer.path -class Add(PatchWithValue): +class Add(PatchWithValue[V], Generic[V]): """Helper class for JSON Patch "add" operation.""" op = Operation.ADD @@ -124,7 +126,7 @@ def apply(self, src: T) -> T: return methods.remove(src, self._pointer) -class Replace(PatchWithValue): +class Replace(PatchWithValue[V], Generic[V]): """Helper class for JSON Patch "replace" operation.""" op = Operation.REPLACE @@ -133,7 +135,7 @@ def apply(self, src: T) -> T: return methods.replace(src, self._pointer, self.value) -class Test(PatchWithValue): +class Test(PatchWithValue[V], Generic[V]): """Helper class for JSON Patch "test" operation.""" op = Operation.TEST diff --git a/fiftyone/server/utils/json/serialization.py b/fiftyone/server/utils/json/serialization.py new file mode 100644 index 0000000000..a3b1428447 --- /dev/null +++ b/fiftyone/server/utils/json/serialization.py @@ -0,0 +1,70 @@ +"""JSON serialization + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +from typing import Any + +import fiftyone.core.labels as fol +import fiftyone.core.sample as fos + + +def deserialize(value: Any) -> Any: + """Deserializes a value into an a known type. + + Args: + value: The value to deserialize + + Returns: + The deserialized value if able to deserialize, otherwise the input + value. + """ + + if isinstance(value, dict): + if cls_name := value.get("_cls"): + cls = next( + ( + cls + for cls in ( + fol.Classification, + fol.Classifications, + fol.Detection, + fol.Detections, + fol.Polyline, + fol.Polylines, + ) + if cls.__name__ == cls_name + ), + None, + ) + + if cls is None: + raise ValueError( + f"No deserializer registered for class '{cls_name}'" + ) + + return cls.from_dict(value) + + return value + + +def serialize(value: Any) -> Any: + """Serializes an value + + Args: + value: The value to serialize + + Returns: + The serialized value if able to serialize, otherwise the input value. + """ + + cls = type(value) + if cls == fos.Sample: + return value.to_dict(include_private=True) + + if hasattr(value, "to_dict"): + return value.to_dict() + + return value diff --git a/fiftyone/server/utils/json_transform/__init__.py b/fiftyone/server/utils/json_transform/__init__.py deleted file mode 100644 index aff7a01df3..0000000000 --- a/fiftyone/server/utils/json_transform/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -FiftyOne Server utils json transform. - -| Copyright 2017-2025, Voxel51, Inc. -| `voxel51.com `_ -| -""" - -import fiftyone.server.utils.json_transform.types # auto-register resource types -from fiftyone.server.utils.json_transform.transform import transform diff --git a/fiftyone/server/utils/json_transform/transform.py b/fiftyone/server/utils/json_transform/transform.py deleted file mode 100644 index a258effc53..0000000000 --- a/fiftyone/server/utils/json_transform/transform.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Transform a json value. - -| Copyright 2017-2025, Voxel51, Inc. -| `voxel51.com `_ -| -""" -from typing import Any, Callable, Type, TypeVar - -T = TypeVar("T") - -REGISTRY: dict[Type[T], Callable[[dict], T]] = {} - - -def register( - cls: Type[T], # pylint: disable=redefined-builtin -) -> Callable[[Callable[[dict], T]], Callable[[dict], T]]: - """Register a validator function for a resource type. - - Args: - cls Type[T]: The resource type - - Returns: - Callable[[Callable[[dict], T]], Callable[[dict], T]]: A decorator - that registers the decorated function as a validator for the given - resource type. - """ - - def inner(fn: Callable[[dict], T]) -> Callable[[dict], T]: - if not callable(fn): - raise TypeError("fn must be callable") - - if cls in REGISTRY: - raise ValueError( - f"Resource type '{cls.__name__}' validator already registered" - ) - - REGISTRY[cls] = fn - - return fn - - return inner - - -def transform( - value: Any, -) -> Any: - """Transforms a patch value if there is a registered transform method. - Args: - value (Any): The patch value optionally containing "_cls" key. - - Returns: - Any: The transformed value or the original value if no transform is found. - """ - if not isinstance(value, dict): - return value - - func = None - cls_name = value.get("_cls") - if cls_name: - func = next( - (fn for cls, fn in REGISTRY.items() if cls.__name__ == cls_name), - None, - ) - if not func: - raise ValueError(f"No transform registered for class '{cls_name}'") - return func(value) if func else value diff --git a/fiftyone/server/utils/json_transform/types.py b/fiftyone/server/utils/json_transform/types.py deleted file mode 100644 index d1dc08ada3..0000000000 --- a/fiftyone/server/utils/json_transform/types.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Json types registery. - -| Copyright 2017-2025, Voxel51, Inc. -| `voxel51.com `_ -| -""" -from fiftyone.server.utils.json_transform.transform import register -import fiftyone.core.labels as fol - - -@register(fol.Classification) -def transform_classification(value: dict) -> fol.Classification: - return fol.Classification.from_dict(value) - - -@register(fol.Classifications) -def transform_classifications(value: dict) -> fol.Classifications: - return fol.Classifications.from_dict(value) - - -@register(fol.Detection) -def transform_detection(value: dict) -> fol.Detection: - return fol.Detection.from_dict(value) - - -@register(fol.Detections) -def transform_detections(value: dict) -> fol.Detections: - return fol.Detections.from_dict(value) - - -@register(fol.Polyline) -def transform_polyline(value: dict) -> fol.Polyline: - return fol.Polyline.from_dict(value) - - -@register(fol.Polylines) -def transform_polylines(value: dict) -> fol.Polylines: - return fol.Polylines.from_dict(value) diff --git a/tests/unittests/sample_route_tests.py b/tests/unittests/sample_route_tests.py index 38bde0dfea..595243d6b6 100644 --- a/tests/unittests/sample_route_tests.py +++ b/tests/unittests/sample_route_tests.py @@ -5,39 +5,81 @@ | `voxel51.com `_ | """ + # pylint: disable=no-value-for-parameter -import unittest from unittest.mock import MagicMock, AsyncMock import json -import fiftyone as fo -import fiftyone.core.labels as fol from bson import ObjectId, json_util +import pytest from starlette.exceptions import HTTPException from starlette.responses import Response + +import fiftyone as fo +import fiftyone.core.labels as fol import fiftyone.server.routes.sample as fors -class SampleRouteTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Sets up a persistent dataset with a sample for each test.""" - self.mutator = fors.Sample( - scope={"type": "http"}, - receive=AsyncMock(), - send=AsyncMock(), - ) - self.dataset = fo.Dataset() - self.dataset.persistent = True - self.dataset_id = self.dataset._doc.id +@pytest.fixture(name="dataset") +def fixture_dataset(): + """Creates a persistent dataset for testing.""" + dataset = fo.Dataset() + dataset.persistent = True + + try: + yield dataset + finally: + if fo.dataset_exists(dataset.name): + fo.delete_dataset(dataset.name) + + +@pytest.fixture(name="dataset_id") +def fixture_dataset_id(dataset): + """Returns the ID of the dataset.""" + # pylint: disable-next=protected-access + return dataset._doc.id + + +@pytest.fixture(name="if_match", params=[None, "etag", "isodate", "timestamp"]) +def fixture_if_match(request, sample): + """Provides different database connections.""" + if_match_type = request.param + + if if_match_type is None: + return None + if if_match_type == "etag": + return fors.generate_sample_etag(sample) + + if if_match_type == "isodate": + return sample.last_modified_at.isoformat() + + if if_match_type == "timestamp": + return str(sample.last_modified_at.timestamp()) + + raise ValueError(f"Unknown connection type: {if_match_type}") + + +def json_payload(payload: dict) -> bytes: + """Converts a dictionary to a JSON payload.""" + return json_util.dumps(payload).encode("utf-8") + + +class TestSampleRoutes: + """Tests for sample routes""" + + INITIAL_DETECTION_ID = ObjectId() + + @pytest.fixture(name="sample") + def fixture_sample(self, dataset): + """Creates a persistent dataset for testing.""" sample = fo.Sample(filepath="/tmp/test_sample.jpg", tags=["initial"]) - self.initial_detection_id = ObjectId() sample["ground_truth"] = fol.Detections( detections=[ fol.Detection( - id=self.initial_detection_id, + id=self.INITIAL_DETECTION_ID, label="cat", bounding_box=[0.1, 0.1, 0.2, 0.2], ) @@ -45,29 +87,37 @@ def setUp(self): ) sample["primitive_field"] = "initial_value" - self.dataset.add_sample(sample) - self.sample = sample + dataset.add_sample(sample) + + return sample - def tearDown(self): - """Deletes the dataset after each test.""" - if self.dataset and fo.dataset_exists(self.dataset.name): - fo.delete_dataset(self.dataset.name) + @pytest.fixture(name="mutator") + def test_mutator(self): + """Returns the Sample route mutator.""" + return fors.Sample( + scope={"type": "http"}, receive=AsyncMock(), send=AsyncMock() + ) - def _create_mock_request(self, payload, content_type="application/json"): + @pytest.fixture(name="mock_request") + def fixture_mock_request(self, dataset_id, sample, if_match): """Helper to create a mock request object.""" mock_request = MagicMock() mock_request.path_params = { - "dataset_id": self.dataset_id, - "sample_id": str(self.sample.id), + "dataset_id": dataset_id, + "sample_id": str(sample.id), } - mock_request.headers = {"Content-Type": content_type} - mock_request.body = AsyncMock( - return_value=json_util.dumps(payload).encode("utf-8") - ) + mock_request.headers = {"Content-Type": "application/json"} + + if if_match is not None: + mock_request.headers["If-Match"] = if_match + + mock_request.body = AsyncMock(return_value=json_payload({})) + return mock_request - async def test_update_detection(self): + @pytest.mark.asyncio + async def test_update_detection(self, mutator, mock_request, sample): """ Tests updating an existing detection """ @@ -80,7 +130,7 @@ async def test_update_detection(self): "detections": [ { "_cls": "Detection", - "id": str(self.initial_detection_id), + "id": str(self.INITIAL_DETECTION_ID), "label": label, "bounding_box": bounding_box, # updated "confidence": confidence, @@ -91,40 +141,49 @@ async def test_update_detection(self): "tags": None, } - response = await self.mutator.patch( - self._create_mock_request(patch_payload) + mock_request.body.return_value = json_payload(patch_payload) + + ##### + response = await mutator.patch(mock_request) + ##### + + sample.reload() + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample ) + response_dict = json.loads(response.body) - self.assertIsInstance(response, Response) - self.assertEqual(response.status_code, 200) + + assert isinstance(response, Response) + assert response.status_code == 200 + # Assertions on the response - self.assertIsInstance(response_dict, dict) - sample = fo.Sample.from_dict(response_dict) - self.assertEqual( - sample.ground_truth.detections[0].id, - str(self.initial_detection_id), + assert isinstance(response_dict, dict) + updated_sample = fo.Sample.from_dict(response_dict) + assert updated_sample.ground_truth.detections[0].id == str( + self.INITIAL_DETECTION_ID ) - self.assertEqual( - sample.ground_truth.detections[0].bounding_box, bounding_box - ) - self.assertEqual(sample.ground_truth.detections[0].label, label) - # Verify changes in the dataset by reloading the sample - self.sample.reload() + assert ( + updated_sample.ground_truth.detections[0].bounding_box + == bounding_box + ) + assert updated_sample.ground_truth.detections[0].label == label # Verify UPDATE - updated_detection = self.sample.ground_truth.detections[0] - self.assertEqual(updated_detection.id, str(self.initial_detection_id)) - self.assertEqual(updated_detection.bounding_box[0], 0.15) - self.assertEqual(updated_detection.confidence, 0.99) + updated_detection = sample.ground_truth.detections[0] + assert updated_detection.id == str(self.INITIAL_DETECTION_ID) + assert updated_detection.bounding_box[0] == 0.15 + assert updated_detection.confidence == 0.99 # Verify CREATE (Primitive) - self.assertEqual(self.sample.reviewer, "John Doe") + assert sample.reviewer == "John Doe" # Verify DELETE - self.assertEqual(self.sample.tags, []) + assert sample.tags == [] - async def test_add_detection(self): + @pytest.mark.asyncio + async def test_add_detection(self, mutator, mock_request, sample): """ Tests adding a new detection """ @@ -143,17 +202,26 @@ async def test_add_detection(self): ], }, } + mock_request.body.return_value = json_payload(patch_payload) - response = await self.mutator.patch( - self._create_mock_request(patch_payload) + ##### + response = await mutator.patch(mock_request) + ##### + + sample.reload() + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample ) + response_dict = json.loads(response.body) - self.assertIsInstance(response_dict, dict) - updated_detection = self.sample.ground_truth_2.detections[0] - self.assertEqual(updated_detection.bounding_box, bounding_box) - self.assertEqual(updated_detection.confidence, confidence) + assert isinstance(response_dict, dict) - async def test_add_classification(self): + updated_detection = sample.ground_truth_2.detections[0] + assert updated_detection.bounding_box == bounding_box + assert updated_detection.confidence == confidence + + @pytest.mark.asyncio + async def test_add_classification(self, mutator, mock_request, sample): """ Tests adding a new classification """ @@ -166,57 +234,85 @@ async def test_add_classification(self): "confidence": confidence, }, } + mock_request.body.return_value = json_payload(patch_payload) + + ##### + response = await mutator.patch(mock_request) + ##### - response = await self.mutator.patch( - self._create_mock_request(patch_payload) + sample.reload() + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample ) + response_dict = json.loads(response.body) - self.assertIsInstance(response_dict, dict) - updated_detection = self.sample.weather - self.assertEqual(updated_detection.label, label) - self.assertEqual(updated_detection.confidence, confidence) + assert isinstance(response_dict, dict) + updated_detection = sample.weather + assert updated_detection.label == label + assert updated_detection.confidence == confidence - async def test_dataset_not_found(self): - """Tests that a 404 HTTPException is raised for a non-existent dataset.""" - mock_request = MagicMock() - mock_request.path_params = { - "dataset_id": "non-existent-dataset", - "sample_id": str(self.sample.id), - } + @pytest.mark.asyncio + async def test_dataset_not_found(self, mutator, mock_request): + """Tests that a 404 HTTPException is raised for a non-existent + dataset.""" - mock_request.body = AsyncMock( - return_value=json_util.dumps({}).encode("utf-8") - ) - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(mock_request) + mock_request.path_params["dataset_id"] = "non-existent-dataset" - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, "Dataset 'non-existent-dataset' not found" + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 404 + assert ( + exc_info.value.detail == "Dataset 'non-existent-dataset' not found" ) - async def test_sample_not_found(self): - """Tests that a 404 HTTPException is raised for a non-existent sample.""" + @pytest.mark.asyncio + async def test_sample_not_found(self, mutator, mock_request, dataset_id): + """Tests that a 404 HTTPException is raised for a non-existent + sample.""" bad_id = str(ObjectId()) - mock_request = MagicMock() - mock_request.path_params = { - "dataset_id": self.dataset_id, - "sample_id": bad_id, - } - mock_request.body = AsyncMock( - return_value=json_util.dumps({}).encode("utf-8") + mock_request.path_params["sample_id"] = bad_id + + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 404 + assert ( + exc_info.value.detail + == f"Sample '{bad_id}' not found in dataset '{dataset_id}'" ) - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(mock_request) - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, - f"Sample '{bad_id}' not found in dataset '{self.dataset_id}'", + @pytest.mark.asyncio + async def test_if_match_header_failure( + self, mutator, mock_request, sample, if_match + ): + """Tests that a 412 HTTPException is raised for an invalid If-Match.""" + if if_match is None: + pytest.skip("Fixture returned None, skipping this test.") + + sample["primitive_field"] = "new_value" + sample.save() + + mock_request.body.return_value = json_payload( + {"primitive_field": "newer_value"} ) - async def test_unsupported_label_class(self): + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 412 + + @pytest.mark.asyncio + async def test_unsupported_label_class( + self, mutator, mock_request, sample + ): """Tests that an HTTPException is raised for an unknown _cls value.""" patch_payload = { "bad_label": { @@ -224,20 +320,25 @@ async def test_unsupported_label_class(self): "label": "invalid", } } - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(self._create_mock_request(patch_payload)) - self.assertEqual(cm.exception.status_code, 400) - self.assertEqual( - cm.exception.detail["bad_label"], - "No transform registered for class 'NonExistentLabelType'", + mock_request.body.return_value = json_payload(patch_payload) + + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail["bad_label"] == ( + "No deserializer registered for class 'NonExistentLabelType'" ) # Verify the sample was not modified - self.sample.reload() - self.assertFalse(self.sample.has_field("bad_label")) + sample.reload() + assert sample.has_field("bad_label") is False - async def test_malformed_label_data(self): + @pytest.mark.asyncio + async def test_malformed_label_data(self, mutator, mock_request, sample): """ Tests that an HTTPException is raised when label data is malformed and cannot be deserialized by from_dict. @@ -250,44 +351,56 @@ async def test_malformed_label_data(self): } } - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(self._create_mock_request(patch_payload)) + mock_request.body.return_value = json_payload(patch_payload) + + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### - self.assertEqual(cm.exception.status_code, 400) - response_dict = cm.exception.detail + assert exc_info.value.status_code == 400 + response_dict = exc_info.value.detail - self.assertIn( - "Invalid data to create a `Detections` instance.", - response_dict["ground_truth"], + assert ( + "Invalid data to create a `Detections` instance." + in response_dict["ground_truth"] ) # Verify the original field was not overwritten - self.sample.reload() - self.assertEqual(len(self.sample.ground_truth.detections), 1) - self.assertEqual( - self.sample.ground_truth.detections[0].id, - str(self.initial_detection_id), + sample.reload() + assert len(sample.ground_truth.detections) == 1 + assert sample.ground_truth.detections[0].id == str( + self.INITIAL_DETECTION_ID ) - async def test_patch_replace_primitive_field(self): + @pytest.mark.asyncio + async def test_patch_rplc_primitive(self, mutator, mock_request, sample): """Tests 'replace' on a primitive field with json-patch.""" new_value = "updated_value" patch_payload = [ {"op": "replace", "path": "/primitive_field", "value": new_value} ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" + + ##### + response = await mutator.patch(mock_request) + ##### + sample.reload() - response = await self.mutator.patch(mock_request) + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample + ) response_dict = json.loads(response.body) - self.assertEqual(response_dict["primitive_field"], new_value) + assert response_dict["primitive_field"] == new_value - self.sample.reload() - self.assertEqual(self.sample.primitive_field, new_value) + assert sample.primitive_field == new_value - async def test_patch_replace_nested_label_attribute(self): + @pytest.mark.asyncio + async def test_patch_rplc_nest_label_attr( + self, mutator, mock_request, sample + ): """Tests 'replace' on a nested attribute of a label with json-patch.""" new_label = "dog" patch_payload = [ @@ -297,17 +410,25 @@ async def test_patch_replace_nested_label_attribute(self): "value": new_label, } ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) - await self.mutator.patch(mock_request) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" - self.sample.reload() - self.assertEqual( - self.sample.ground_truth.detections[0].label, new_label + ##### + response = await mutator.patch(mock_request) + ##### + + sample.reload() + + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample ) - async def test_patch_add_detection_to_list(self): + assert sample.ground_truth.detections[0].label == new_label + + @pytest.mark.asyncio + async def test_patch_add_detect_to_list( + self, mutator, mock_request, sample + ): """Tests 'add' to a list of labels, testing the transform function.""" new_detection = { "_cls": "Detection", @@ -321,111 +442,129 @@ async def test_patch_add_detection_to_list(self): "value": new_detection, } ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" - await self.mutator.patch(mock_request) + ##### + response = await mutator.patch(mock_request) + ##### - self.sample.reload() - self.assertEqual(len(self.sample.ground_truth.detections), 2) - self.assertIsInstance( - self.sample.ground_truth.detections[1], fol.Detection + sample.reload() + + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample ) - self.assertEqual(self.sample.ground_truth.detections[1].label, "dog") - async def test_patch_remove_detection_from_list(self): + assert len(sample.ground_truth.detections) == 2 + assert isinstance(sample.ground_truth.detections[1], fol.Detection) + assert sample.ground_truth.detections[1].label == "dog" + + @pytest.mark.asyncio + async def test_patch_rmv_detect_list(self, mutator, mock_request, sample): """Tests 'remove' from a list of labels.""" - self.assertEqual(len(self.sample.ground_truth.detections), 1) + assert len(sample.ground_truth.detections) == 1 patch_payload = [ {"op": "remove", "path": "/ground_truth/detections/0"} ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) - await self.mutator.patch(mock_request) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" + + ##### + response = await mutator.patch(mock_request) + ##### + + sample.reload() + + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample + ) - self.sample.reload() - self.assertEqual(len(self.sample.ground_truth.detections), 0) + assert len(sample.ground_truth.detections) == 0 - async def test_patch_multiple_operations(self): + @pytest.mark.asyncio + async def test_patch_multiple_operations( + self, mutator, mock_request, sample + ): """Tests a patch request with multiple operations.""" patch_payload = [ {"op": "replace", "path": "/primitive_field", "value": "multi-op"}, {"op": "remove", "path": "/ground_truth/detections/0"}, ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" + + ##### + response = await mutator.patch(mock_request) + ##### - await self.mutator.patch(mock_request) + sample.reload() + + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample + ) - self.sample.reload() - self.assertEqual(self.sample.primitive_field, "multi-op") - self.assertEqual(len(self.sample.ground_truth.detections), 0) + assert sample.primitive_field == "multi-op" + assert len(sample.ground_truth.detections) == 0 - async def test_patch_invalid_path(self): + @pytest.mark.asyncio + async def test_patch_invalid_path(self, mutator, mock_request): """Tests that a 400 is raised for an invalid path.""" patch_payload = [ {"op": "replace", "path": "/non_existent_field", "value": "test"} ] - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(mock_request) + with pytest.raises(HTTPException) as exc_info: + ###### + await mutator.patch(mock_request) + ###### - self.assertEqual(cm.exception.status_code, 400) - self.assertIn(str(patch_payload[0]), cm.exception.detail) + assert exc_info.value.status_code == 400 + assert str(patch_payload[0]) in exc_info.value.detail - async def test_patch_invalid_format(self): + @pytest.mark.asyncio + async def test_patch_invalid_format(self, mutator, mock_request): """Tests that a 400 is raised for a malformed patch operation.""" patch_payload = [ {"path": "/primitive_field", "value": "test"} ] # missing 'op' - mock_request = self._create_mock_request( - patch_payload, content_type="application/json-patch+json" - ) - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(mock_request) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" - self.assertEqual(cm.exception.status_code, 400) - self.assertIn( - "Failed to parse patches due to", - cm.exception.detail, - ) + with pytest.raises(HTTPException) as exc_info: + ###### + await mutator.patch(mock_request) + ###### + assert exc_info.value.status_code == 400 + assert "Failed to parse patches due to" in exc_info.value.detail -class SampleFieldRouteTests(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Sets up a persistent dataset with a sample for each test.""" - self.mutator = fors.SampleField( - scope={"type": "http"}, - receive=AsyncMock(), - send=AsyncMock(), - ) - self.dataset = fo.Dataset() - self.dataset.persistent = True - self.dataset_id = self.dataset._doc.id +class TestSampleFieldRoute: + """Tests for sample field routes""" + + DETECTION_ID_1 = ObjectId() + DETECTION_ID_2 = ObjectId() + + @pytest.fixture(name="sample") + def fixture_sample(self, dataset): + """Creates a persistent dataset for testing.""" sample = fo.Sample(filepath="/tmp/test_sample_field.jpg") - self.detection_id_1 = ObjectId() - self.detection_id_2 = ObjectId() sample["ground_truth"] = fol.Detections( detections=[ fol.Detection( - id=self.detection_id_1, + id=self.DETECTION_ID_1, label="cat", bounding_box=[0.1, 0.1, 0.2, 0.2], confidence=0.9, ), fol.Detection( - id=self.detection_id_2, + id=self.DETECTION_ID_2, label="dog", bounding_box=[0.4, 0.4, 0.3, 0.3], confidence=0.8, @@ -434,160 +573,198 @@ def setUp(self): ) sample["scalar_field"] = "not a list" - self.dataset.add_sample(sample) - self.sample = sample + dataset.add_sample(sample) + + return sample - def tearDown(self): - """Deletes the dataset after each test.""" - if self.dataset and fo.dataset_exists(self.dataset.name): - fo.delete_dataset(self.dataset.name) + @pytest.fixture(name="mutator") + def test_mutator(self): + """Returns the Sample fields route mutator.""" + return fors.SampleField( + scope={"type": "http"}, + receive=AsyncMock(), + send=AsyncMock(), + ) - def _create_mock_request(self, payload, field_path, field_id): - """Helper to create a mock request object for SampleField.""" + @pytest.fixture(name="mock_request") + def fixture_mock_request(self, dataset_id, sample, if_match): + """Helper to create a mock request object.""" + mock_request = MagicMock() mock_request = MagicMock() mock_request.path_params = { - "dataset_id": self.dataset_id, - "sample_id": str(self.sample.id), - "field_path": field_path, - "field_id": str(field_id), + "dataset_id": dataset_id, + "sample_id": str(sample.id), + "field_path": "ground_truth.detections", + "field_id": str(self.DETECTION_ID_1), } mock_request.headers = {"Content-Type": "application/json"} - mock_request.body = AsyncMock( - return_value=json_util.dumps(payload).encode("utf-8") - ) + if if_match is not None: + mock_request.headers["If-Match"] = if_match + + mock_request.body = AsyncMock(return_value=json_payload({})) + return mock_request - async def test_update_label_in_list(self): + @pytest.mark.asyncio + async def test_update_label_in_list(self, mutator, mock_request, sample): """Tests updating a label within a list field.""" new_label = "person" patch_payload = [ {"op": "replace", "path": "/label", "value": new_label} ] - field_path = "ground_truth.detections" - field_id = self.detection_id_1 - request = self._create_mock_request( - patch_payload, field_path, field_id - ) - response = await self.mutator.patch(request) + mock_request.body.return_value = json_payload(patch_payload) + + ##### + response = await mutator.patch(mock_request) + ##### + sample.reload() + response_dict = json.loads(response.body) - self.assertIsInstance(response, Response) - self.assertEqual(response.status_code, 200) + assert isinstance(response, Response) + assert response.status_code == 200 + assert response.headers.get("ETag") == fors.generate_sample_etag( + sample + ) + # check response body - self.assertEqual(response_dict["label"], new_label) - self.assertEqual(response_dict["_id"]["$oid"], str(field_id)) + assert response_dict["label"] == new_label + assert response_dict["_id"]["$oid"] == str(self.DETECTION_ID_1) # check database state - self.sample.reload() - detection1 = self.sample.ground_truth.detections[0] - detection2 = self.sample.ground_truth.detections[1] - - self.assertEqual(detection1.id, str(field_id)) - self.assertEqual(detection1.label, new_label) - self.assertEqual( - detection2.id, str(self.detection_id_2) - ) # ensure other item is not modified - self.assertEqual(detection2.label, "dog") - - async def test_dataset_not_found(self): + detection1 = sample.ground_truth.detections[0] + detection2 = sample.ground_truth.detections[1] + + assert detection1.id == str(self.DETECTION_ID_1) + assert detection1.label == new_label + # ensure other item is not modified + assert detection2.id == str(self.DETECTION_ID_2) + assert detection2.label == "dog" + + @pytest.mark.asyncio + async def test_dataset_not_found(self, mutator, mock_request): """Tests that a 404 is raised for a non-existent dataset.""" - request = self._create_mock_request( - [], "ground_truth.detections", self.detection_id_1 - ) - request.path_params["dataset_id"] = "non-existent-dataset" - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + mock_request.path_params["dataset_id"] = "non-existent-dataset" + + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, "Dataset 'non-existent-dataset' not found" + assert exc_info.value.status_code == 404 + assert ( + exc_info.value.detail == "Dataset 'non-existent-dataset' not found" ) - async def test_sample_not_found(self): + @pytest.mark.asyncio + async def test_sample_not_found(self, mutator, mock_request, dataset_id): """Tests that a 404 is raised for a non-existent sample.""" bad_id = str(ObjectId()) - request = self._create_mock_request( - [], "ground_truth.detections", self.detection_id_1 - ) - request.path_params["sample_id"] = bad_id + mock_request.path_params["sample_id"] = bad_id - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, - f"Sample '{bad_id}' not found in dataset '{self.dataset_id}'", + assert exc_info.value.status_code == 404 + assert ( + exc_info.value.detail + == f"Sample '{bad_id}' not found in dataset '{dataset_id}'" ) - async def test_field_path_not_found(self): + @pytest.mark.asyncio + async def test_if_match_header_failure( + self, mutator, mock_request, sample, if_match + ): + """Tests that a 412 HTTPException is raised for an invalid If-Match.""" + if if_match is None: + pytest.skip("Fixture returned None, skipping this test.") + + # Update the sample to change its last_modified_at + sample["primitive_field"] = "new_value" + sample.save() + + patch_payload = [{"op": "replace", "path": "/label", "value": "fish"}] + + mock_request.body.return_value = json_payload(patch_payload) + + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 412 + + @pytest.mark.asyncio + async def test_field_path_not_found(self, mutator, mock_request, sample): """Tests that a 404 is raised for a non-existent field path.""" bad_path = "non_existent.path" - request = self._create_mock_request([], bad_path, self.detection_id_1) + mock_request.path_params["field_path"] = bad_path - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, - f"Field '{bad_path}' not found in sample '{self.sample.id}'", + assert exc_info.value.status_code == 404 + assert ( + exc_info.value.detail + == f"Field '{bad_path}' not found in sample '{sample.id}'" ) - async def test_field_is_not_a_list(self): - """Tests that a 400 is raised if the field path does not point to a list.""" + @pytest.mark.asyncio + async def test_field_is_not_a_list(self, mutator, mock_request): + """Tests that a 400 is raised if the field path does not point to a + list.""" field_path = "scalar_field" - request = self._create_mock_request( - [], field_path, self.detection_id_1 - ) - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + mock_request.path_params["field_path"] = field_path - self.assertEqual(cm.exception.status_code, 400) - self.assertEqual( - cm.exception.detail, - f"Field '{field_path}' is not a list", - ) + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == f"Field '{field_path}' is not a list" - async def test_field_id_not_found_in_list(self): + @pytest.mark.asyncio + async def test_field_id_not_found_in_list(self, mutator, mock_request): """Tests that a 404 is raised if the field ID is not in the list.""" bad_id = str(ObjectId()) - field_path = "ground_truth.detections" - request = self._create_mock_request([], field_path, bad_id) + mock_request.path_params["field_id"] = bad_id - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + with pytest.raises(HTTPException) as exc_info: + ##### + await mutator.patch(mock_request) + ##### - self.assertEqual(cm.exception.status_code, 404) - self.assertEqual( - cm.exception.detail, - f"Field with id '{bad_id}' not found in field '{field_path}'", + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == ( + f"Field with id '{bad_id}' not found in field " + f"'{mock_request.path_params['field_path']}'" ) - async def test_invalid_patch_operation(self): + @pytest.mark.asyncio + async def test_invalid_patch_operation(self, mutator, mock_request): """Tests that a 400 is raised for an invalid patch operation.""" patch_payload = [ {"op": "replace", "path": "/non_existent_attr", "value": "test"} ] - field_path = "ground_truth.detections" - field_id = self.detection_id_1 - request = self._create_mock_request( - patch_payload, field_path, field_id - ) + mock_request.body.return_value = json_payload(patch_payload) + mock_request.headers["Content-Type"] = "application/json-patch+json" - with self.assertRaises(HTTPException) as cm: - await self.mutator.patch(request) + with pytest.raises(HTTPException) as exc_info: + ### + await mutator.patch(mock_request) + ### - self.assertEqual(cm.exception.status_code, 400) - self.assertIn(str(patch_payload[0]), cm.exception.detail) - self.assertIn( - "non_existent_attr", cm.exception.detail[str(patch_payload[0])] + assert exc_info.value.status_code == 400 + assert str(patch_payload[0]) in exc_info.value.detail + assert ( + "non_existent_attr" in exc_info.value.detail[str(patch_payload[0])] ) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py b/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py index 6ef9b5d7a0..1dc3d46458 100644 --- a/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py +++ b/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py @@ -9,7 +9,7 @@ import pytest -from fiftyone.server.utils.jsonpatch import methods, patch +from fiftyone.server.utils.json.jsonpatch import methods, patch @pytest.mark.parametrize( diff --git a/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py b/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py index c94033009e..d09e876ec0 100644 --- a/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py +++ b/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py @@ -9,7 +9,7 @@ import pytest -from fiftyone.server.utils import jsonpatch +from fiftyone.server.utils.json import jsonpatch class TestParse: diff --git a/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py b/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py index ea2d65e878..676bac036b 100644 --- a/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py +++ b/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py @@ -11,8 +11,8 @@ import jsonpointer import pytest -from fiftyone.server.utils.jsonpatch import methods -from fiftyone.server.utils.jsonpatch.methods import ( +from fiftyone.server.utils.json.jsonpatch import methods +from fiftyone.server.utils.json.jsonpatch.methods import ( get, add, remove, From 1a712cf9d5cadc9ec95db4aba3bc5fda632d44c3 Mon Sep 17 00:00:00 2001 From: Stuart Date: Wed, 15 Oct 2025 16:06:47 -0400 Subject: [PATCH 2/2] FOEPD-1417 fix pipeline type from/to json (#6418) * fix pipeline type from/to json; add tests * use dataclass func instead for to_json * fix test --- fiftyone/operators/_types/pipeline.py | 37 +++++--- .../factory/delegated_operation_doc_tests.py | 15 +-- tests/unittests/operators/types_tests.py | 95 +++++++++++++++++++ 3 files changed, 118 insertions(+), 29 deletions(-) create mode 100644 tests/unittests/operators/types_tests.py diff --git a/fiftyone/operators/_types/pipeline.py b/fiftyone/operators/_types/pipeline.py index 97c8a169b2..46d35f27fd 100644 --- a/fiftyone/operators/_types/pipeline.py +++ b/fiftyone/operators/_types/pipeline.py @@ -96,28 +96,35 @@ def stage( return stage @classmethod - def from_json(cls, json_list): - """Loads the pipeline from a list of JSON/python dicts. + def from_json(cls, json_dict): + """Loads the pipeline from a JSON/python dict. - Ex., [ - {"operator_uri": "@voxel51/test/blah", "name": "my_stage", ...}, - ..., - ] + Ex., { + "stages": [ + {"operator_uri": "@voxel51/test/blah", "name": "my_stage"}, + ..., + ] + } Args: - json_list: a list of JSON / python dicts + json_dict: a JSON / python dict representation of the pipeline """ - stages = [PipelineStage(**stage) for stage in json_list] + stages = [ + PipelineStage(**stage) for stage in json_dict.get("stages") or [] + ] return cls(stages=stages) def to_json(self): - """Converts the pipeline to list of JSON/python dicts. + """Converts this pipeline to JSON/python dict representation + + Ex., { + "stages": [ + {"operator_uri": "@voxel51/test/blah", "name": "my_stage"}, + ..., + ] + } - Ex., [ - {"operator_uri": "@voxel51/test/blah", "name": "my_stage", ...}, - ..., - ] Returns: - list of JSON / python dicts + JSON / python dict representation of the pipeline """ - return [stage.to_json() for stage in self.stages] + return dataclasses.asdict(self) diff --git a/tests/unittests/factory/delegated_operation_doc_tests.py b/tests/unittests/factory/delegated_operation_doc_tests.py index 972102a59b..cd347743be 100644 --- a/tests/unittests/factory/delegated_operation_doc_tests.py +++ b/tests/unittests/factory/delegated_operation_doc_tests.py @@ -69,20 +69,7 @@ def test_serialize_pipeline(self): ] ) out = op_doc.to_pymongo() - assert out["pipeline"] == [ - { - "name": "one", - "operator_uri": "@test/op1", - "num_distributed_tasks": None, - "params": None, - }, - { - "name": "two", - "operator_uri": "@test/op2", - "num_distributed_tasks": None, - "params": None, - }, - ] + assert out["pipeline"] == op_doc.pipeline.to_json() op_doc2 = repos.DelegatedOperationDocument() op_doc2.from_pymongo(out) assert op_doc2.pipeline == op_doc.pipeline diff --git a/tests/unittests/operators/types_tests.py b/tests/unittests/operators/types_tests.py new file mode 100644 index 0000000000..305deedc7a --- /dev/null +++ b/tests/unittests/operators/types_tests.py @@ -0,0 +1,95 @@ +""" +FiftyOne operator type tests. + +| Copyright 2017-2025, Voxel51, Inc. +| `voxel51.com `_ +| +""" + +import unittest + +import bson + +import fiftyone as fo +import fiftyone.operators as foo +from fiftyone.operators import types + + +class TestPipelineType(unittest.TestCase): + def test_pipeline_type(self): + pipeline = types.Pipeline() + self.assertListEqual(pipeline.stages, []) + + pipeline = types.Pipeline(stages=[]) + self.assertListEqual(pipeline.stages, []) + + stage1 = types.PipelineStage(operator_uri="my/uri") + stage2 = types.PipelineStage( + operator_uri="my/uri2", + name="stage2", + num_distributed_tasks=5, + params={"foo": "bar"}, + ) + pipeline = types.Pipeline(stages=[stage1, stage2]) + self.assertListEqual(pipeline.stages, [stage1, stage2]) + + pipeline = types.Pipeline() + pipeline.stage(stage1.operator_uri) + pipeline.stage( + stage2.operator_uri, + stage2.name, + stage2.num_distributed_tasks, + stage2.params, + ) + self.assertListEqual(pipeline.stages, [stage1, stage2]) + + def test_serialize(self): + pipeline = types.Pipeline( + stages=[ + types.PipelineStage(operator_uri="my/uri"), + types.PipelineStage( + operator_uri="my/uri2", + name="stage2", + num_distributed_tasks=5, + params={"foo": "bar"}, + ), + ] + ) + dict_rep = pipeline.to_json() + self.assertDictEqual( + dict_rep, + { + "stages": [ + { + "operator_uri": "my/uri", + "name": None, + "num_distributed_tasks": None, + "params": None, + }, + { + "operator_uri": "my/uri2", + "name": "stage2", + "num_distributed_tasks": 5, + "params": {"foo": "bar"}, + }, + ], + }, + ) + new_obj = types.Pipeline.from_json(dict_rep) + self.assertEqual(new_obj, pipeline) + + def test_validation(self): + with self.assertRaises(ValueError): + types.PipelineStage(operator_uri=None) + + with self.assertRaises(ValueError): + types.PipelineStage(operator_uri="my/uri", num_distributed_tasks=0) + + with self.assertRaises(ValueError): + types.PipelineStage( + operator_uri="my/uri", num_distributed_tasks=-5 + ) + + pipe = types.Pipeline() + with self.assertRaises(ValueError): + pipe.stage("my/uri", num_distributed_tasks=-5)