diff --git a/fiftyone/operators/_types/pipeline.py b/fiftyone/operators/_types/pipeline.py
index 97c8a169b2..46d35f27fd 100644
--- a/fiftyone/operators/_types/pipeline.py
+++ b/fiftyone/operators/_types/pipeline.py
@@ -96,28 +96,35 @@ def stage(
return stage
@classmethod
- def from_json(cls, json_list):
- """Loads the pipeline from a list of JSON/python dicts.
+ def from_json(cls, json_dict):
+ """Loads the pipeline from a JSON/python dict.
- Ex., [
- {"operator_uri": "@voxel51/test/blah", "name": "my_stage", ...},
- ...,
- ]
+ Ex., {
+ "stages": [
+ {"operator_uri": "@voxel51/test/blah", "name": "my_stage"},
+ ...,
+ ]
+ }
Args:
- json_list: a list of JSON / python dicts
+ json_dict: a JSON / python dict representation of the pipeline
"""
- stages = [PipelineStage(**stage) for stage in json_list]
+ stages = [
+ PipelineStage(**stage) for stage in json_dict.get("stages") or []
+ ]
return cls(stages=stages)
def to_json(self):
- """Converts the pipeline to list of JSON/python dicts.
+ """Converts this pipeline to JSON/python dict representation
+
+ Ex., {
+ "stages": [
+ {"operator_uri": "@voxel51/test/blah", "name": "my_stage"},
+ ...,
+ ]
+ }
- Ex., [
- {"operator_uri": "@voxel51/test/blah", "name": "my_stage", ...},
- ...,
- ]
Returns:
- list of JSON / python dicts
+ JSON / python dict representation of the pipeline
"""
- return [stage.to_json() for stage in self.stages]
+ return dataclasses.asdict(self)
diff --git a/fiftyone/server/decorators.py b/fiftyone/server/decorators.py
index f85a1a9781..7a836b13ad 100644
--- a/fiftyone/server/decorators.py
+++ b/fiftyone/server/decorators.py
@@ -6,38 +6,23 @@
|
"""
-from json import JSONEncoder
import traceback
import typing as t
import logging
-from bson import json_util
-import numpy as np
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.responses import JSONResponse, Response
from starlette.requests import Request
from fiftyone.core.utils import run_sync_task
-
-
-class Encoder(JSONEncoder):
- """Custom JSON encoder that handles numpy types."""
-
- def default(self, o):
- if isinstance(o, np.floating):
- return float(o)
-
- if isinstance(o, np.integer):
- return int(o)
-
- return JSONEncoder.default(self, o)
+from fiftyone.server import utils
async def create_response(response: dict):
"""Creates a JSON response from the given dictionary."""
return Response(
- await run_sync_task(lambda: json_util.dumps(response, cls=Encoder)),
+ await run_sync_task(lambda: utils.json.dumps(response)),
headers={"Content-Type": "application/json"},
)
@@ -52,7 +37,7 @@ async def wrapper(
try:
body = await request.body()
payload = body.decode("utf-8")
- data = json_util.loads(payload) if payload else {}
+ data = utils.json.loads(payload)
response = await func(endpoint, request, data, *args)
if isinstance(response, Response):
return response
diff --git a/fiftyone/server/routes/sample.py b/fiftyone/server/routes/sample.py
index d0d6fde89d..0dac859cc0 100644
--- a/fiftyone/server/routes/sample.py
+++ b/fiftyone/server/routes/sample.py
@@ -6,7 +6,10 @@
|
"""
+import base64
+import datetime
import logging
+from typing import Any, List, Union
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
@@ -14,56 +17,180 @@
import fiftyone as fo
import fiftyone.core.odm.utils as fou
-from typing import List
-from fiftyone.server.utils.jsonpatch import parse
-from fiftyone.server.utils import transform_json
-from fiftyone.server.decorators import route
-from typing import Any
+
+from fiftyone.server import decorators, utils
logger = logging.getLogger(__name__)
-def get_sample(dataset_id: str, sample_id: str) -> fo.Sample:
- """Retrieves a sample from a dataset.
+def get_if_last_modified_at(
+ request: Request,
+) -> Union[datetime.datetime, None]:
+ """Parses the If-Match header from the request, if present, and returns
+ the last modified date.
Args:
- dataset_id: the dataset ID
- sample_id: the sample ID
+ request: The request
+
+ Raises:
+ HTTPException: If the If-Match header could not be parsed into a
+ valid date
Returns:
- the sample
+ The last modified date, or None if the header is not present
+ """
+
+ if_last_modified_at: Union[str, datetime.datetime, None] = None
+ if request.headers.get("If-Match"):
+ if_match, _ = utils.http.ETag.parse(request.headers["If-Match"])
+
+ # As ETag - Currently this is just a based64 encode string of
+ # last_modified_at
+ try:
+ if_last_modified_at = datetime.datetime.fromisoformat(
+ base64.b64decode(if_match.encode("utf-8")).decode("utf-8")
+ )
+ except Exception:
+ ...
+
+ # As ISO date
+ try:
+ if_last_modified_at = datetime.datetime.fromisoformat(if_match)
+ except Exception:
+ ...
+
+ # As Unix timestamp
+ try:
+ if_last_modified_at = datetime.datetime.fromtimestamp(
+ float(if_match)
+ )
+ except Exception:
+ ...
+
+ if if_last_modified_at is None:
+ raise HTTPException(
+ status_code=400, detail="Invalid If-Match header"
+ )
+ return if_last_modified_at
+
+
+def get_sample(
+ dataset_id: str,
+ sample_id: str,
+ if_last_modified_at: Union[datetime.datetime, None],
+) -> fo.Sample:
+ """Retrieves a sample from a dataset.
+
+ Args:
+ dataset_id: The ID of the dataset
+ sample_id: The ID of the sample
+ if_last_modified_at: The if last modified date, if it exists
Raises:
- HTTPException: if the dataset or sample is not found
+ HTTPException: If the dataset or sample is not found or the if last
+ modified date is present and does not match the sample
+
+ Returns:
+ The sample
"""
+
try:
dataset = fou.load_dataset(id=dataset_id)
- except ValueError:
+ except ValueError as err:
raise HTTPException(
status_code=404,
detail=f"Dataset '{dataset_id}' not found",
- )
+ ) from err
try:
sample = dataset[sample_id]
- except KeyError:
+ except KeyError as err:
raise HTTPException(
status_code=404,
detail=f"Sample '{sample_id}' not found in dataset '{dataset_id}'",
- )
+ ) from err
+
+ # Fail early, if very out-of-date
+ if if_last_modified_at is not None:
+ if sample.last_modified_at != if_last_modified_at:
+ raise HTTPException(
+ status_code=412, detail="If-Match condition failed"
+ )
return sample
+def generate_sample_etag(sample: fo.Sample) -> str:
+ """Generates an ETag for a sample based on its last modified date.
+
+ Args:
+ sample: The sample
+ Returns:
+ The ETag
+ """
+ value = base64.b64encode(
+ sample.last_modified_at.isoformat().encode("utf-8")
+ ).decode("utf-8")
+
+ return utils.http.ETag.create(value)
+
+
+def save_sample(
+ sample: fo.Sample, if_last_modified_at: Union[datetime.datetime, None]
+) -> str:
+ """Saves a sample to the database.
+
+ Args:
+ sample: The sample to save
+ if_last_modified_at: The if last modified date, if it exists
+
+ Returns:
+ The ETag of the saved sample
+ """
+
+ if if_last_modified_at is not None:
+ d = sample.to_mongo_dict(include_id=True)
+ d["last_modified_at"] = datetime.datetime.now(datetime.timezone.utc)
+
+ # pylint:disable-next=protected-access
+ update_result = sample.dataset._sample_collection.replace_one(
+ {
+ # pylint:disable-next=protected-access
+ "_id": sample._id,
+ "last_modified_at": {"$eq": if_last_modified_at},
+ },
+ d,
+ )
+
+ if update_result.matched_count == 0:
+ raise HTTPException(
+ status_code=412, detail="If-Match condition failed"
+ )
+ else:
+ sample.save()
+
+ # Ensure last_modified_at reflects persisted state before computing
+ # ETag
+ try:
+ sample.reload(hard=True)
+ except Exception:
+ # best-effort; still return response
+ ...
+
+ return generate_sample_etag(sample)
+
+
def handle_json_patch(target: Any, patch_list: List[dict]) -> Any:
"""Applies a list of JSON patch operations to a target object."""
try:
- patches = parse(patch_list, transform_fn=transform_json)
- except Exception as e:
+ patches = utils.json.parse_jsonpatch(
+ patch_list, transform_fn=utils.json.deserialize
+ )
+ except Exception as err:
raise HTTPException(
status_code=400,
- detail=f"Failed to parse patches due to: {e}",
- )
+ detail=f"Failed to parse patches due to: {err}",
+ ) from err
errors = {}
for i, p in enumerate(patches):
@@ -82,14 +209,17 @@ def handle_json_patch(target: Any, patch_list: List[dict]) -> Any:
class Sample(HTTPEndpoint):
- @route
+ """Sample endpoints."""
+
+ @decorators.route
async def patch(self, request: Request, data: dict) -> dict:
"""Applies a list of field updates to a sample.
See: https://datatracker.ietf.org/doc/html/rfc6902
Args:
- request: Starlette request with dataset_id and sample_id in path params
+ request: Starlette request with dataset_id and sample_id in path
+ params
data: A dict mapping field names to values.
Returns:
@@ -104,21 +234,26 @@ async def patch(self, request: Request, data: dict) -> dict:
dataset_id,
)
- sample = get_sample(dataset_id, sample_id)
+ if_last_modified_at = get_if_last_modified_at(request)
+
+ sample = get_sample(dataset_id, sample_id, if_last_modified_at)
content_type = request.headers.get("Content-Type", "")
ctype = content_type.split(";", 1)[0].strip().lower()
if ctype == "application/json":
- result = self._handle_patch(sample, data)
+ self._handle_patch(sample, data)
elif ctype == "application/json-patch+json":
- result = handle_json_patch(sample, data)
+ handle_json_patch(sample, data)
else:
raise HTTPException(
- status_code=415,
- detail=f"Unsupported Content-Type '{ctype}'",
+ status_code=415, detail=f"Unsupported Content-Type '{ctype}'"
)
- sample.save()
- return result.to_dict(include_private=True)
+
+ etag = save_sample(sample, if_last_modified_at)
+
+ return utils.json.JSONResponse(
+ utils.json.serialize(sample), headers={"ETag": etag}
+ )
def _handle_patch(self, sample: fo.Sample, data: dict) -> dict:
errors = {}
@@ -128,27 +263,27 @@ def _handle_patch(self, sample: fo.Sample, data: dict) -> dict:
sample.clear_field(field_name)
continue
- sample[field_name] = transform_json(value)
+ sample[field_name] = utils.json.deserialize(value)
except Exception as e:
errors[field_name] = str(e)
if errors:
- raise HTTPException(
- status_code=400,
- detail=errors,
- )
+ raise HTTPException(status_code=400, detail=errors)
return sample
class SampleField(HTTPEndpoint):
- @route
+ """Sample field endpoints."""
+
+ @decorators.route
async def patch(self, request: Request, data: dict) -> dict:
"""Applies a list of field updates to a sample field in a list by id.
See: https://datatracker.ietf.org/doc/html/rfc6902
Args:
- request: Starlette request with dataset_id and sample_id in path params
+ request: Starlette request with dataset_id and sample_id in path
+ params
data: patch of type op, path, value.
Returns:
@@ -160,39 +295,52 @@ async def patch(self, request: Request, data: dict) -> dict:
field_id = request.path_params["field_id"]
logger.info(
- "Received patch request for field %s with ID %s on sample %s in dataset %s",
+ (
+ "Received patch request for field %s with ID %s on sample %s "
+ "in dataset %s"
+ ),
path,
field_id,
sample_id,
dataset_id,
)
- sample = get_sample(dataset_id, sample_id)
+ if_last_modified_at = get_if_last_modified_at(request)
+
+ sample = get_sample(dataset_id, sample_id, if_last_modified_at)
try:
field_list = sample.get_field(path)
- except Exception as e:
+ except Exception as err:
raise HTTPException(
status_code=404,
detail=f"Field '{path}' not found in sample '{sample_id}'",
- )
+ ) from err
if not isinstance(field_list, list):
raise HTTPException(
- status_code=400,
- detail=f"Field '{path}' is not a list",
+ status_code=400, detail=f"Field '{path}' is not a list"
)
field = next((f for f in field_list if f.id == field_id), None)
if field is None:
raise HTTPException(
status_code=404,
- detail=f"Field with id '{field_id}' not found in field '{path}'",
+ detail=(
+ f"Field with id '{field_id}' not found in field "
+ f"'{path}'"
+ ),
)
- result = handle_json_patch(field, data)
- sample.save()
- return result.to_dict()
+ handle_json_patch(field, data)
+
+ etag = save_sample(sample, if_last_modified_at)
+
+ updated_field = next((f for f in field_list if f.id == field_id), None)
+
+ return utils.json.JSONResponse(
+ utils.json.serialize(updated_field), headers={"ETag": etag}
+ )
SampleRoutes = [
diff --git a/fiftyone/server/utils/__init__.py b/fiftyone/server/utils/__init__.py
index d27c4e45ad..67e8ae9992 100644
--- a/fiftyone/server/utils/__init__.py
+++ b/fiftyone/server/utils/__init__.py
@@ -15,9 +15,7 @@
import fiftyone.core.dataset as fod
import fiftyone.core.fields as fof
-from fiftyone.server.utils.json_transform import (
- transform as transform_json,
-) # auto-register resource types
+from fiftyone.server.utils import http, json
_cache = cachetools.TTLCache(maxsize=10, ttl=900) # ttl in seconds
diff --git a/fiftyone/server/utils/http.py b/fiftyone/server/utils/http.py
new file mode 100644
index 0000000000..62f0ca0165
--- /dev/null
+++ b/fiftyone/server/utils/http.py
@@ -0,0 +1,41 @@
+"""
+HTTP utils
+
+| Copyright 2017-2025, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+from typing import Any
+
+
+class ETag:
+ """Utility class for creating and parsing ETag strings."""
+
+ @staticmethod
+ def create(value: Any, is_weak: bool = False) -> str:
+ """Creates an ETag string from the given value."""
+ # Wrap in quotes if not already quoted
+ if not (value.startswith('"') and value.endswith('"')):
+ value = f'"{value}"'
+
+ # Add weak prefix if necessary
+ if is_weak:
+ return f'W/"{value}"'
+
+ return value
+
+ @staticmethod
+ def parse(etag: str) -> tuple[str, bool]:
+ """Parses an ETag string into its value and whether it is weak."""
+
+ is_weak = False
+ if etag.startswith("W/"):
+ is_weak = True
+ etag = etag[2:] # Remove "W/" prefix
+
+ # Remove surrounding quotes (ETags are typically quoted)
+ if etag.startswith('"') and etag.endswith('"'):
+ etag = etag[1:-1]
+
+ return etag, is_weak
diff --git a/fiftyone/server/utils/json/__init__.py b/fiftyone/server/utils/json/__init__.py
new file mode 100644
index 0000000000..6a16a2a8e3
--- /dev/null
+++ b/fiftyone/server/utils/json/__init__.py
@@ -0,0 +1,32 @@
+"""
+
+| Copyright 2017-2025, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+from typing import Any, Union
+from bson import json_util
+
+from starlette.responses import JSONResponse as StarletteJSONResponse
+
+from fiftyone.server.utils.json.encoder import Encoder
+from fiftyone.server.utils.json.jsonpatch import parse as parse_jsonpatch
+from fiftyone.server.utils.json.serialization import deserialize, serialize
+
+
+def dumps(obj: Any) -> str:
+ """Serializes an object to a JSON-formatted string."""
+ return json_util.dumps(obj, cls=Encoder)
+
+
+def loads(s: Union[str, bytes, bytearray, None]) -> Any:
+ """Deserializes a JSON-formatted string to a Python object."""
+ return json_util.loads(s) if s else {}
+
+
+class JSONResponse(StarletteJSONResponse):
+ """Custom JSON response that uses the custom Encoder."""
+
+ def render(self, content: Any) -> bytes:
+ return dumps(content).encode("utf-8")
diff --git a/fiftyone/server/utils/json/encoder.py b/fiftyone/server/utils/json/encoder.py
new file mode 100644
index 0000000000..ef60fb37c4
--- /dev/null
+++ b/fiftyone/server/utils/json/encoder.py
@@ -0,0 +1,45 @@
+"""
+
+| Copyright 2017-2025, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+from typing import Any, Union
+import json
+
+from bson import json_util
+import numpy as np
+from starlette.responses import JSONResponse as StarletteJSONResponse
+
+
+class Encoder(json.JSONEncoder):
+ """Custom JSON encoder that handles numpy types."""
+
+ def default(self, o):
+ """Override the default method to handle numpy types."""
+
+ if isinstance(o, np.floating):
+ return float(o)
+
+ if isinstance(o, np.integer):
+ return int(o)
+
+ return json.JSONEncoder.default(self, o)
+
+
+def dumps(obj: Any) -> str:
+ """Serializes an object to a JSON-formatted string."""
+ return json_util.dumps(obj, cls=Encoder)
+
+
+def loads(s: Union[str, bytes, bytearray, None]) -> Any:
+ """Deserializes a JSON-formatted string to a Python object."""
+ return json_util.loads(s) if s else {}
+
+
+class JSONResponse(StarletteJSONResponse):
+ """Custom JSON response that uses the custom Encoder."""
+
+ def render(self, content: Any) -> bytes:
+ return dumps(content).encode("utf-8")
diff --git a/fiftyone/server/utils/jsonpatch/__init__.py b/fiftyone/server/utils/json/jsonpatch/__init__.py
similarity index 94%
rename from fiftyone/server/utils/jsonpatch/__init__.py
rename to fiftyone/server/utils/json/jsonpatch/__init__.py
index 4348933e19..4999cefcc7 100644
--- a/fiftyone/server/utils/jsonpatch/__init__.py
+++ b/fiftyone/server/utils/json/jsonpatch/__init__.py
@@ -8,7 +8,8 @@
from typing import Any, Callable, Iterable, Optional, Union
-from fiftyone.server.utils.jsonpatch.methods import (
+
+from fiftyone.server.utils.json.jsonpatch.methods import (
add,
copy,
move,
@@ -16,7 +17,8 @@
replace,
test,
)
-from fiftyone.server.utils.jsonpatch.patch import (
+
+from fiftyone.server.utils.json.jsonpatch.patch import (
Patch,
Operation,
Add,
@@ -27,6 +29,7 @@
Test,
)
+
__PATCH_MAP = {
Operation.ADD: Add,
Operation.COPY: Copy,
diff --git a/fiftyone/server/utils/jsonpatch/methods.py b/fiftyone/server/utils/json/jsonpatch/methods.py
similarity index 100%
rename from fiftyone/server/utils/jsonpatch/methods.py
rename to fiftyone/server/utils/json/jsonpatch/methods.py
diff --git a/fiftyone/server/utils/jsonpatch/patch.py b/fiftyone/server/utils/json/jsonpatch/patch.py
similarity index 89%
rename from fiftyone/server/utils/jsonpatch/patch.py
rename to fiftyone/server/utils/json/jsonpatch/patch.py
index f43a32caff..2cc930e3f9 100644
--- a/fiftyone/server/utils/jsonpatch/patch.py
+++ b/fiftyone/server/utils/json/jsonpatch/patch.py
@@ -12,7 +12,7 @@
from typing import Any, Generic, TypeVar, Union
-from fiftyone.server.utils.jsonpatch import methods
+from fiftyone.server.utils.json.jsonpatch import methods
T = TypeVar("T")
V = TypeVar("V")
@@ -37,12 +37,14 @@ class Patch(abc.ABC):
op: Operation
- def __init_subclass__(cls):
+ def __init_subclass__(cls, **kwargs):
if not inspect.isabstract(cls) and not isinstance(
getattr(cls, "op", None), Operation
):
raise TypeError("Subclass must define 'op' class variable")
+ super().__init_subclass__(**kwargs)
+
def __init__(self, path: str):
self._pointer = methods.to_json_pointer(path)
@@ -67,10 +69,10 @@ def apply(self, src: Any) -> Any:
"""
-class PatchWithValue(Patch, Generic[T], abc.ABC):
+class PatchWithValue(Patch, abc.ABC, Generic[V]):
"""A JSON Patch operation that requires a value."""
- def __init__(self, path: str, value: T):
+ def __init__(self, path: str, value: V):
super().__init__(path)
self.value = value
@@ -88,7 +90,7 @@ def from_(self) -> str:
return self._from_pointer.path
-class Add(PatchWithValue):
+class Add(PatchWithValue[V], Generic[V]):
"""Helper class for JSON Patch "add" operation."""
op = Operation.ADD
@@ -124,7 +126,7 @@ def apply(self, src: T) -> T:
return methods.remove(src, self._pointer)
-class Replace(PatchWithValue):
+class Replace(PatchWithValue[V], Generic[V]):
"""Helper class for JSON Patch "replace" operation."""
op = Operation.REPLACE
@@ -133,7 +135,7 @@ def apply(self, src: T) -> T:
return methods.replace(src, self._pointer, self.value)
-class Test(PatchWithValue):
+class Test(PatchWithValue[V], Generic[V]):
"""Helper class for JSON Patch "test" operation."""
op = Operation.TEST
diff --git a/fiftyone/server/utils/json/serialization.py b/fiftyone/server/utils/json/serialization.py
new file mode 100644
index 0000000000..a3b1428447
--- /dev/null
+++ b/fiftyone/server/utils/json/serialization.py
@@ -0,0 +1,70 @@
+"""JSON serialization
+
+| Copyright 2017-2025, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+from typing import Any
+
+import fiftyone.core.labels as fol
+import fiftyone.core.sample as fos
+
+
+def deserialize(value: Any) -> Any:
+ """Deserializes a value into an a known type.
+
+ Args:
+ value: The value to deserialize
+
+ Returns:
+ The deserialized value if able to deserialize, otherwise the input
+ value.
+ """
+
+ if isinstance(value, dict):
+ if cls_name := value.get("_cls"):
+ cls = next(
+ (
+ cls
+ for cls in (
+ fol.Classification,
+ fol.Classifications,
+ fol.Detection,
+ fol.Detections,
+ fol.Polyline,
+ fol.Polylines,
+ )
+ if cls.__name__ == cls_name
+ ),
+ None,
+ )
+
+ if cls is None:
+ raise ValueError(
+ f"No deserializer registered for class '{cls_name}'"
+ )
+
+ return cls.from_dict(value)
+
+ return value
+
+
+def serialize(value: Any) -> Any:
+ """Serializes an value
+
+ Args:
+ value: The value to serialize
+
+ Returns:
+ The serialized value if able to serialize, otherwise the input value.
+ """
+
+ cls = type(value)
+ if cls == fos.Sample:
+ return value.to_dict(include_private=True)
+
+ if hasattr(value, "to_dict"):
+ return value.to_dict()
+
+ return value
diff --git a/fiftyone/server/utils/json_transform/__init__.py b/fiftyone/server/utils/json_transform/__init__.py
deleted file mode 100644
index aff7a01df3..0000000000
--- a/fiftyone/server/utils/json_transform/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-"""
-FiftyOne Server utils json transform.
-
-| Copyright 2017-2025, Voxel51, Inc.
-| `voxel51.com `_
-|
-"""
-
-import fiftyone.server.utils.json_transform.types # auto-register resource types
-from fiftyone.server.utils.json_transform.transform import transform
diff --git a/fiftyone/server/utils/json_transform/transform.py b/fiftyone/server/utils/json_transform/transform.py
deleted file mode 100644
index a258effc53..0000000000
--- a/fiftyone/server/utils/json_transform/transform.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""Transform a json value.
-
-| Copyright 2017-2025, Voxel51, Inc.
-| `voxel51.com `_
-|
-"""
-from typing import Any, Callable, Type, TypeVar
-
-T = TypeVar("T")
-
-REGISTRY: dict[Type[T], Callable[[dict], T]] = {}
-
-
-def register(
- cls: Type[T], # pylint: disable=redefined-builtin
-) -> Callable[[Callable[[dict], T]], Callable[[dict], T]]:
- """Register a validator function for a resource type.
-
- Args:
- cls Type[T]: The resource type
-
- Returns:
- Callable[[Callable[[dict], T]], Callable[[dict], T]]: A decorator
- that registers the decorated function as a validator for the given
- resource type.
- """
-
- def inner(fn: Callable[[dict], T]) -> Callable[[dict], T]:
- if not callable(fn):
- raise TypeError("fn must be callable")
-
- if cls in REGISTRY:
- raise ValueError(
- f"Resource type '{cls.__name__}' validator already registered"
- )
-
- REGISTRY[cls] = fn
-
- return fn
-
- return inner
-
-
-def transform(
- value: Any,
-) -> Any:
- """Transforms a patch value if there is a registered transform method.
- Args:
- value (Any): The patch value optionally containing "_cls" key.
-
- Returns:
- Any: The transformed value or the original value if no transform is found.
- """
- if not isinstance(value, dict):
- return value
-
- func = None
- cls_name = value.get("_cls")
- if cls_name:
- func = next(
- (fn for cls, fn in REGISTRY.items() if cls.__name__ == cls_name),
- None,
- )
- if not func:
- raise ValueError(f"No transform registered for class '{cls_name}'")
- return func(value) if func else value
diff --git a/fiftyone/server/utils/json_transform/types.py b/fiftyone/server/utils/json_transform/types.py
deleted file mode 100644
index d1dc08ada3..0000000000
--- a/fiftyone/server/utils/json_transform/types.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""Json types registery.
-
-| Copyright 2017-2025, Voxel51, Inc.
-| `voxel51.com `_
-|
-"""
-from fiftyone.server.utils.json_transform.transform import register
-import fiftyone.core.labels as fol
-
-
-@register(fol.Classification)
-def transform_classification(value: dict) -> fol.Classification:
- return fol.Classification.from_dict(value)
-
-
-@register(fol.Classifications)
-def transform_classifications(value: dict) -> fol.Classifications:
- return fol.Classifications.from_dict(value)
-
-
-@register(fol.Detection)
-def transform_detection(value: dict) -> fol.Detection:
- return fol.Detection.from_dict(value)
-
-
-@register(fol.Detections)
-def transform_detections(value: dict) -> fol.Detections:
- return fol.Detections.from_dict(value)
-
-
-@register(fol.Polyline)
-def transform_polyline(value: dict) -> fol.Polyline:
- return fol.Polyline.from_dict(value)
-
-
-@register(fol.Polylines)
-def transform_polylines(value: dict) -> fol.Polylines:
- return fol.Polylines.from_dict(value)
diff --git a/tests/unittests/factory/delegated_operation_doc_tests.py b/tests/unittests/factory/delegated_operation_doc_tests.py
index 972102a59b..cd347743be 100644
--- a/tests/unittests/factory/delegated_operation_doc_tests.py
+++ b/tests/unittests/factory/delegated_operation_doc_tests.py
@@ -69,20 +69,7 @@ def test_serialize_pipeline(self):
]
)
out = op_doc.to_pymongo()
- assert out["pipeline"] == [
- {
- "name": "one",
- "operator_uri": "@test/op1",
- "num_distributed_tasks": None,
- "params": None,
- },
- {
- "name": "two",
- "operator_uri": "@test/op2",
- "num_distributed_tasks": None,
- "params": None,
- },
- ]
+ assert out["pipeline"] == op_doc.pipeline.to_json()
op_doc2 = repos.DelegatedOperationDocument()
op_doc2.from_pymongo(out)
assert op_doc2.pipeline == op_doc.pipeline
diff --git a/tests/unittests/operators/types_tests.py b/tests/unittests/operators/types_tests.py
new file mode 100644
index 0000000000..305deedc7a
--- /dev/null
+++ b/tests/unittests/operators/types_tests.py
@@ -0,0 +1,95 @@
+"""
+FiftyOne operator type tests.
+
+| Copyright 2017-2025, Voxel51, Inc.
+| `voxel51.com `_
+|
+"""
+
+import unittest
+
+import bson
+
+import fiftyone as fo
+import fiftyone.operators as foo
+from fiftyone.operators import types
+
+
+class TestPipelineType(unittest.TestCase):
+ def test_pipeline_type(self):
+ pipeline = types.Pipeline()
+ self.assertListEqual(pipeline.stages, [])
+
+ pipeline = types.Pipeline(stages=[])
+ self.assertListEqual(pipeline.stages, [])
+
+ stage1 = types.PipelineStage(operator_uri="my/uri")
+ stage2 = types.PipelineStage(
+ operator_uri="my/uri2",
+ name="stage2",
+ num_distributed_tasks=5,
+ params={"foo": "bar"},
+ )
+ pipeline = types.Pipeline(stages=[stage1, stage2])
+ self.assertListEqual(pipeline.stages, [stage1, stage2])
+
+ pipeline = types.Pipeline()
+ pipeline.stage(stage1.operator_uri)
+ pipeline.stage(
+ stage2.operator_uri,
+ stage2.name,
+ stage2.num_distributed_tasks,
+ stage2.params,
+ )
+ self.assertListEqual(pipeline.stages, [stage1, stage2])
+
+ def test_serialize(self):
+ pipeline = types.Pipeline(
+ stages=[
+ types.PipelineStage(operator_uri="my/uri"),
+ types.PipelineStage(
+ operator_uri="my/uri2",
+ name="stage2",
+ num_distributed_tasks=5,
+ params={"foo": "bar"},
+ ),
+ ]
+ )
+ dict_rep = pipeline.to_json()
+ self.assertDictEqual(
+ dict_rep,
+ {
+ "stages": [
+ {
+ "operator_uri": "my/uri",
+ "name": None,
+ "num_distributed_tasks": None,
+ "params": None,
+ },
+ {
+ "operator_uri": "my/uri2",
+ "name": "stage2",
+ "num_distributed_tasks": 5,
+ "params": {"foo": "bar"},
+ },
+ ],
+ },
+ )
+ new_obj = types.Pipeline.from_json(dict_rep)
+ self.assertEqual(new_obj, pipeline)
+
+ def test_validation(self):
+ with self.assertRaises(ValueError):
+ types.PipelineStage(operator_uri=None)
+
+ with self.assertRaises(ValueError):
+ types.PipelineStage(operator_uri="my/uri", num_distributed_tasks=0)
+
+ with self.assertRaises(ValueError):
+ types.PipelineStage(
+ operator_uri="my/uri", num_distributed_tasks=-5
+ )
+
+ pipe = types.Pipeline()
+ with self.assertRaises(ValueError):
+ pipe.stage("my/uri", num_distributed_tasks=-5)
diff --git a/tests/unittests/sample_route_tests.py b/tests/unittests/sample_route_tests.py
index 38bde0dfea..595243d6b6 100644
--- a/tests/unittests/sample_route_tests.py
+++ b/tests/unittests/sample_route_tests.py
@@ -5,39 +5,81 @@
| `voxel51.com `_
|
"""
+
# pylint: disable=no-value-for-parameter
-import unittest
from unittest.mock import MagicMock, AsyncMock
import json
-import fiftyone as fo
-import fiftyone.core.labels as fol
from bson import ObjectId, json_util
+import pytest
from starlette.exceptions import HTTPException
from starlette.responses import Response
+
+import fiftyone as fo
+import fiftyone.core.labels as fol
import fiftyone.server.routes.sample as fors
-class SampleRouteTests(unittest.IsolatedAsyncioTestCase):
- def setUp(self):
- """Sets up a persistent dataset with a sample for each test."""
- self.mutator = fors.Sample(
- scope={"type": "http"},
- receive=AsyncMock(),
- send=AsyncMock(),
- )
- self.dataset = fo.Dataset()
- self.dataset.persistent = True
- self.dataset_id = self.dataset._doc.id
+@pytest.fixture(name="dataset")
+def fixture_dataset():
+ """Creates a persistent dataset for testing."""
+ dataset = fo.Dataset()
+ dataset.persistent = True
+
+ try:
+ yield dataset
+ finally:
+ if fo.dataset_exists(dataset.name):
+ fo.delete_dataset(dataset.name)
+
+
+@pytest.fixture(name="dataset_id")
+def fixture_dataset_id(dataset):
+ """Returns the ID of the dataset."""
+ # pylint: disable-next=protected-access
+ return dataset._doc.id
+
+
+@pytest.fixture(name="if_match", params=[None, "etag", "isodate", "timestamp"])
+def fixture_if_match(request, sample):
+ """Provides different database connections."""
+ if_match_type = request.param
+
+ if if_match_type is None:
+ return None
+ if if_match_type == "etag":
+ return fors.generate_sample_etag(sample)
+
+ if if_match_type == "isodate":
+ return sample.last_modified_at.isoformat()
+
+ if if_match_type == "timestamp":
+ return str(sample.last_modified_at.timestamp())
+
+ raise ValueError(f"Unknown connection type: {if_match_type}")
+
+
+def json_payload(payload: dict) -> bytes:
+ """Converts a dictionary to a JSON payload."""
+ return json_util.dumps(payload).encode("utf-8")
+
+
+class TestSampleRoutes:
+ """Tests for sample routes"""
+
+ INITIAL_DETECTION_ID = ObjectId()
+
+ @pytest.fixture(name="sample")
+ def fixture_sample(self, dataset):
+ """Creates a persistent dataset for testing."""
sample = fo.Sample(filepath="/tmp/test_sample.jpg", tags=["initial"])
- self.initial_detection_id = ObjectId()
sample["ground_truth"] = fol.Detections(
detections=[
fol.Detection(
- id=self.initial_detection_id,
+ id=self.INITIAL_DETECTION_ID,
label="cat",
bounding_box=[0.1, 0.1, 0.2, 0.2],
)
@@ -45,29 +87,37 @@ def setUp(self):
)
sample["primitive_field"] = "initial_value"
- self.dataset.add_sample(sample)
- self.sample = sample
+ dataset.add_sample(sample)
+
+ return sample
- def tearDown(self):
- """Deletes the dataset after each test."""
- if self.dataset and fo.dataset_exists(self.dataset.name):
- fo.delete_dataset(self.dataset.name)
+ @pytest.fixture(name="mutator")
+ def test_mutator(self):
+ """Returns the Sample route mutator."""
+ return fors.Sample(
+ scope={"type": "http"}, receive=AsyncMock(), send=AsyncMock()
+ )
- def _create_mock_request(self, payload, content_type="application/json"):
+ @pytest.fixture(name="mock_request")
+ def fixture_mock_request(self, dataset_id, sample, if_match):
"""Helper to create a mock request object."""
mock_request = MagicMock()
mock_request.path_params = {
- "dataset_id": self.dataset_id,
- "sample_id": str(self.sample.id),
+ "dataset_id": dataset_id,
+ "sample_id": str(sample.id),
}
- mock_request.headers = {"Content-Type": content_type}
- mock_request.body = AsyncMock(
- return_value=json_util.dumps(payload).encode("utf-8")
- )
+ mock_request.headers = {"Content-Type": "application/json"}
+
+ if if_match is not None:
+ mock_request.headers["If-Match"] = if_match
+
+ mock_request.body = AsyncMock(return_value=json_payload({}))
+
return mock_request
- async def test_update_detection(self):
+ @pytest.mark.asyncio
+ async def test_update_detection(self, mutator, mock_request, sample):
"""
Tests updating an existing detection
"""
@@ -80,7 +130,7 @@ async def test_update_detection(self):
"detections": [
{
"_cls": "Detection",
- "id": str(self.initial_detection_id),
+ "id": str(self.INITIAL_DETECTION_ID),
"label": label,
"bounding_box": bounding_box, # updated
"confidence": confidence,
@@ -91,40 +141,49 @@ async def test_update_detection(self):
"tags": None,
}
- response = await self.mutator.patch(
- self._create_mock_request(patch_payload)
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+
+ sample.reload()
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
)
+
response_dict = json.loads(response.body)
- self.assertIsInstance(response, Response)
- self.assertEqual(response.status_code, 200)
+
+ assert isinstance(response, Response)
+ assert response.status_code == 200
+
# Assertions on the response
- self.assertIsInstance(response_dict, dict)
- sample = fo.Sample.from_dict(response_dict)
- self.assertEqual(
- sample.ground_truth.detections[0].id,
- str(self.initial_detection_id),
+ assert isinstance(response_dict, dict)
+ updated_sample = fo.Sample.from_dict(response_dict)
+ assert updated_sample.ground_truth.detections[0].id == str(
+ self.INITIAL_DETECTION_ID
)
- self.assertEqual(
- sample.ground_truth.detections[0].bounding_box, bounding_box
- )
- self.assertEqual(sample.ground_truth.detections[0].label, label)
- # Verify changes in the dataset by reloading the sample
- self.sample.reload()
+ assert (
+ updated_sample.ground_truth.detections[0].bounding_box
+ == bounding_box
+ )
+ assert updated_sample.ground_truth.detections[0].label == label
# Verify UPDATE
- updated_detection = self.sample.ground_truth.detections[0]
- self.assertEqual(updated_detection.id, str(self.initial_detection_id))
- self.assertEqual(updated_detection.bounding_box[0], 0.15)
- self.assertEqual(updated_detection.confidence, 0.99)
+ updated_detection = sample.ground_truth.detections[0]
+ assert updated_detection.id == str(self.INITIAL_DETECTION_ID)
+ assert updated_detection.bounding_box[0] == 0.15
+ assert updated_detection.confidence == 0.99
# Verify CREATE (Primitive)
- self.assertEqual(self.sample.reviewer, "John Doe")
+ assert sample.reviewer == "John Doe"
# Verify DELETE
- self.assertEqual(self.sample.tags, [])
+ assert sample.tags == []
- async def test_add_detection(self):
+ @pytest.mark.asyncio
+ async def test_add_detection(self, mutator, mock_request, sample):
"""
Tests adding a new detection
"""
@@ -143,17 +202,26 @@ async def test_add_detection(self):
],
},
}
+ mock_request.body.return_value = json_payload(patch_payload)
- response = await self.mutator.patch(
- self._create_mock_request(patch_payload)
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+
+ sample.reload()
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
)
+
response_dict = json.loads(response.body)
- self.assertIsInstance(response_dict, dict)
- updated_detection = self.sample.ground_truth_2.detections[0]
- self.assertEqual(updated_detection.bounding_box, bounding_box)
- self.assertEqual(updated_detection.confidence, confidence)
+ assert isinstance(response_dict, dict)
- async def test_add_classification(self):
+ updated_detection = sample.ground_truth_2.detections[0]
+ assert updated_detection.bounding_box == bounding_box
+ assert updated_detection.confidence == confidence
+
+ @pytest.mark.asyncio
+ async def test_add_classification(self, mutator, mock_request, sample):
"""
Tests adding a new classification
"""
@@ -166,57 +234,85 @@ async def test_add_classification(self):
"confidence": confidence,
},
}
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
- response = await self.mutator.patch(
- self._create_mock_request(patch_payload)
+ sample.reload()
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
)
+
response_dict = json.loads(response.body)
- self.assertIsInstance(response_dict, dict)
- updated_detection = self.sample.weather
- self.assertEqual(updated_detection.label, label)
- self.assertEqual(updated_detection.confidence, confidence)
+ assert isinstance(response_dict, dict)
+ updated_detection = sample.weather
+ assert updated_detection.label == label
+ assert updated_detection.confidence == confidence
- async def test_dataset_not_found(self):
- """Tests that a 404 HTTPException is raised for a non-existent dataset."""
- mock_request = MagicMock()
- mock_request.path_params = {
- "dataset_id": "non-existent-dataset",
- "sample_id": str(self.sample.id),
- }
+ @pytest.mark.asyncio
+ async def test_dataset_not_found(self, mutator, mock_request):
+ """Tests that a 404 HTTPException is raised for a non-existent
+ dataset."""
- mock_request.body = AsyncMock(
- return_value=json_util.dumps({}).encode("utf-8")
- )
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(mock_request)
+ mock_request.path_params["dataset_id"] = "non-existent-dataset"
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail, "Dataset 'non-existent-dataset' not found"
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 404
+ assert (
+ exc_info.value.detail == "Dataset 'non-existent-dataset' not found"
)
- async def test_sample_not_found(self):
- """Tests that a 404 HTTPException is raised for a non-existent sample."""
+ @pytest.mark.asyncio
+ async def test_sample_not_found(self, mutator, mock_request, dataset_id):
+ """Tests that a 404 HTTPException is raised for a non-existent
+ sample."""
bad_id = str(ObjectId())
- mock_request = MagicMock()
- mock_request.path_params = {
- "dataset_id": self.dataset_id,
- "sample_id": bad_id,
- }
- mock_request.body = AsyncMock(
- return_value=json_util.dumps({}).encode("utf-8")
+ mock_request.path_params["sample_id"] = bad_id
+
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 404
+ assert (
+ exc_info.value.detail
+ == f"Sample '{bad_id}' not found in dataset '{dataset_id}'"
)
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(mock_request)
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail,
- f"Sample '{bad_id}' not found in dataset '{self.dataset_id}'",
+ @pytest.mark.asyncio
+ async def test_if_match_header_failure(
+ self, mutator, mock_request, sample, if_match
+ ):
+ """Tests that a 412 HTTPException is raised for an invalid If-Match."""
+ if if_match is None:
+ pytest.skip("Fixture returned None, skipping this test.")
+
+ sample["primitive_field"] = "new_value"
+ sample.save()
+
+ mock_request.body.return_value = json_payload(
+ {"primitive_field": "newer_value"}
)
- async def test_unsupported_label_class(self):
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 412
+
+ @pytest.mark.asyncio
+ async def test_unsupported_label_class(
+ self, mutator, mock_request, sample
+ ):
"""Tests that an HTTPException is raised for an unknown _cls value."""
patch_payload = {
"bad_label": {
@@ -224,20 +320,25 @@ async def test_unsupported_label_class(self):
"label": "invalid",
}
}
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(self._create_mock_request(patch_payload))
- self.assertEqual(cm.exception.status_code, 400)
- self.assertEqual(
- cm.exception.detail["bad_label"],
- "No transform registered for class 'NonExistentLabelType'",
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail["bad_label"] == (
+ "No deserializer registered for class 'NonExistentLabelType'"
)
# Verify the sample was not modified
- self.sample.reload()
- self.assertFalse(self.sample.has_field("bad_label"))
+ sample.reload()
+ assert sample.has_field("bad_label") is False
- async def test_malformed_label_data(self):
+ @pytest.mark.asyncio
+ async def test_malformed_label_data(self, mutator, mock_request, sample):
"""
Tests that an HTTPException is raised when label data is malformed and
cannot be deserialized by from_dict.
@@ -250,44 +351,56 @@ async def test_malformed_label_data(self):
}
}
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(self._create_mock_request(patch_payload))
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
- self.assertEqual(cm.exception.status_code, 400)
- response_dict = cm.exception.detail
+ assert exc_info.value.status_code == 400
+ response_dict = exc_info.value.detail
- self.assertIn(
- "Invalid data to create a `Detections` instance.",
- response_dict["ground_truth"],
+ assert (
+ "Invalid data to create a `Detections` instance."
+ in response_dict["ground_truth"]
)
# Verify the original field was not overwritten
- self.sample.reload()
- self.assertEqual(len(self.sample.ground_truth.detections), 1)
- self.assertEqual(
- self.sample.ground_truth.detections[0].id,
- str(self.initial_detection_id),
+ sample.reload()
+ assert len(sample.ground_truth.detections) == 1
+ assert sample.ground_truth.detections[0].id == str(
+ self.INITIAL_DETECTION_ID
)
- async def test_patch_replace_primitive_field(self):
+ @pytest.mark.asyncio
+ async def test_patch_rplc_primitive(self, mutator, mock_request, sample):
"""Tests 'replace' on a primitive field with json-patch."""
new_value = "updated_value"
patch_payload = [
{"op": "replace", "path": "/primitive_field", "value": new_value}
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+ sample.reload()
- response = await self.mutator.patch(mock_request)
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
+ )
response_dict = json.loads(response.body)
- self.assertEqual(response_dict["primitive_field"], new_value)
+ assert response_dict["primitive_field"] == new_value
- self.sample.reload()
- self.assertEqual(self.sample.primitive_field, new_value)
+ assert sample.primitive_field == new_value
- async def test_patch_replace_nested_label_attribute(self):
+ @pytest.mark.asyncio
+ async def test_patch_rplc_nest_label_attr(
+ self, mutator, mock_request, sample
+ ):
"""Tests 'replace' on a nested attribute of a label with json-patch."""
new_label = "dog"
patch_payload = [
@@ -297,17 +410,25 @@ async def test_patch_replace_nested_label_attribute(self):
"value": new_label,
}
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
- await self.mutator.patch(mock_request)
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
- self.sample.reload()
- self.assertEqual(
- self.sample.ground_truth.detections[0].label, new_label
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+
+ sample.reload()
+
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
)
- async def test_patch_add_detection_to_list(self):
+ assert sample.ground_truth.detections[0].label == new_label
+
+ @pytest.mark.asyncio
+ async def test_patch_add_detect_to_list(
+ self, mutator, mock_request, sample
+ ):
"""Tests 'add' to a list of labels, testing the transform function."""
new_detection = {
"_cls": "Detection",
@@ -321,111 +442,129 @@ async def test_patch_add_detection_to_list(self):
"value": new_detection,
}
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
- await self.mutator.patch(mock_request)
+ #####
+ response = await mutator.patch(mock_request)
+ #####
- self.sample.reload()
- self.assertEqual(len(self.sample.ground_truth.detections), 2)
- self.assertIsInstance(
- self.sample.ground_truth.detections[1], fol.Detection
+ sample.reload()
+
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
)
- self.assertEqual(self.sample.ground_truth.detections[1].label, "dog")
- async def test_patch_remove_detection_from_list(self):
+ assert len(sample.ground_truth.detections) == 2
+ assert isinstance(sample.ground_truth.detections[1], fol.Detection)
+ assert sample.ground_truth.detections[1].label == "dog"
+
+ @pytest.mark.asyncio
+ async def test_patch_rmv_detect_list(self, mutator, mock_request, sample):
"""Tests 'remove' from a list of labels."""
- self.assertEqual(len(self.sample.ground_truth.detections), 1)
+ assert len(sample.ground_truth.detections) == 1
patch_payload = [
{"op": "remove", "path": "/ground_truth/detections/0"}
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
- await self.mutator.patch(mock_request)
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+
+ sample.reload()
+
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
+ )
- self.sample.reload()
- self.assertEqual(len(self.sample.ground_truth.detections), 0)
+ assert len(sample.ground_truth.detections) == 0
- async def test_patch_multiple_operations(self):
+ @pytest.mark.asyncio
+ async def test_patch_multiple_operations(
+ self, mutator, mock_request, sample
+ ):
"""Tests a patch request with multiple operations."""
patch_payload = [
{"op": "replace", "path": "/primitive_field", "value": "multi-op"},
{"op": "remove", "path": "/ground_truth/detections/0"},
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
- await self.mutator.patch(mock_request)
+ sample.reload()
+
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
+ )
- self.sample.reload()
- self.assertEqual(self.sample.primitive_field, "multi-op")
- self.assertEqual(len(self.sample.ground_truth.detections), 0)
+ assert sample.primitive_field == "multi-op"
+ assert len(sample.ground_truth.detections) == 0
- async def test_patch_invalid_path(self):
+ @pytest.mark.asyncio
+ async def test_patch_invalid_path(self, mutator, mock_request):
"""Tests that a 400 is raised for an invalid path."""
patch_payload = [
{"op": "replace", "path": "/non_existent_field", "value": "test"}
]
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(mock_request)
+ with pytest.raises(HTTPException) as exc_info:
+ ######
+ await mutator.patch(mock_request)
+ ######
- self.assertEqual(cm.exception.status_code, 400)
- self.assertIn(str(patch_payload[0]), cm.exception.detail)
+ assert exc_info.value.status_code == 400
+ assert str(patch_payload[0]) in exc_info.value.detail
- async def test_patch_invalid_format(self):
+ @pytest.mark.asyncio
+ async def test_patch_invalid_format(self, mutator, mock_request):
"""Tests that a 400 is raised for a malformed patch operation."""
patch_payload = [
{"path": "/primitive_field", "value": "test"}
] # missing 'op'
- mock_request = self._create_mock_request(
- patch_payload, content_type="application/json-patch+json"
- )
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(mock_request)
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
- self.assertEqual(cm.exception.status_code, 400)
- self.assertIn(
- "Failed to parse patches due to",
- cm.exception.detail,
- )
+ with pytest.raises(HTTPException) as exc_info:
+ ######
+ await mutator.patch(mock_request)
+ ######
+ assert exc_info.value.status_code == 400
+ assert "Failed to parse patches due to" in exc_info.value.detail
-class SampleFieldRouteTests(unittest.IsolatedAsyncioTestCase):
- def setUp(self):
- """Sets up a persistent dataset with a sample for each test."""
- self.mutator = fors.SampleField(
- scope={"type": "http"},
- receive=AsyncMock(),
- send=AsyncMock(),
- )
- self.dataset = fo.Dataset()
- self.dataset.persistent = True
- self.dataset_id = self.dataset._doc.id
+class TestSampleFieldRoute:
+ """Tests for sample field routes"""
+
+ DETECTION_ID_1 = ObjectId()
+ DETECTION_ID_2 = ObjectId()
+
+ @pytest.fixture(name="sample")
+ def fixture_sample(self, dataset):
+ """Creates a persistent dataset for testing."""
sample = fo.Sample(filepath="/tmp/test_sample_field.jpg")
- self.detection_id_1 = ObjectId()
- self.detection_id_2 = ObjectId()
sample["ground_truth"] = fol.Detections(
detections=[
fol.Detection(
- id=self.detection_id_1,
+ id=self.DETECTION_ID_1,
label="cat",
bounding_box=[0.1, 0.1, 0.2, 0.2],
confidence=0.9,
),
fol.Detection(
- id=self.detection_id_2,
+ id=self.DETECTION_ID_2,
label="dog",
bounding_box=[0.4, 0.4, 0.3, 0.3],
confidence=0.8,
@@ -434,160 +573,198 @@ def setUp(self):
)
sample["scalar_field"] = "not a list"
- self.dataset.add_sample(sample)
- self.sample = sample
+ dataset.add_sample(sample)
+
+ return sample
- def tearDown(self):
- """Deletes the dataset after each test."""
- if self.dataset and fo.dataset_exists(self.dataset.name):
- fo.delete_dataset(self.dataset.name)
+ @pytest.fixture(name="mutator")
+ def test_mutator(self):
+ """Returns the Sample fields route mutator."""
+ return fors.SampleField(
+ scope={"type": "http"},
+ receive=AsyncMock(),
+ send=AsyncMock(),
+ )
- def _create_mock_request(self, payload, field_path, field_id):
- """Helper to create a mock request object for SampleField."""
+ @pytest.fixture(name="mock_request")
+ def fixture_mock_request(self, dataset_id, sample, if_match):
+ """Helper to create a mock request object."""
+ mock_request = MagicMock()
mock_request = MagicMock()
mock_request.path_params = {
- "dataset_id": self.dataset_id,
- "sample_id": str(self.sample.id),
- "field_path": field_path,
- "field_id": str(field_id),
+ "dataset_id": dataset_id,
+ "sample_id": str(sample.id),
+ "field_path": "ground_truth.detections",
+ "field_id": str(self.DETECTION_ID_1),
}
mock_request.headers = {"Content-Type": "application/json"}
- mock_request.body = AsyncMock(
- return_value=json_util.dumps(payload).encode("utf-8")
- )
+ if if_match is not None:
+ mock_request.headers["If-Match"] = if_match
+
+ mock_request.body = AsyncMock(return_value=json_payload({}))
+
return mock_request
- async def test_update_label_in_list(self):
+ @pytest.mark.asyncio
+ async def test_update_label_in_list(self, mutator, mock_request, sample):
"""Tests updating a label within a list field."""
new_label = "person"
patch_payload = [
{"op": "replace", "path": "/label", "value": new_label}
]
- field_path = "ground_truth.detections"
- field_id = self.detection_id_1
- request = self._create_mock_request(
- patch_payload, field_path, field_id
- )
- response = await self.mutator.patch(request)
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ #####
+ response = await mutator.patch(mock_request)
+ #####
+ sample.reload()
+
response_dict = json.loads(response.body)
- self.assertIsInstance(response, Response)
- self.assertEqual(response.status_code, 200)
+ assert isinstance(response, Response)
+ assert response.status_code == 200
+ assert response.headers.get("ETag") == fors.generate_sample_etag(
+ sample
+ )
+
# check response body
- self.assertEqual(response_dict["label"], new_label)
- self.assertEqual(response_dict["_id"]["$oid"], str(field_id))
+ assert response_dict["label"] == new_label
+ assert response_dict["_id"]["$oid"] == str(self.DETECTION_ID_1)
# check database state
- self.sample.reload()
- detection1 = self.sample.ground_truth.detections[0]
- detection2 = self.sample.ground_truth.detections[1]
-
- self.assertEqual(detection1.id, str(field_id))
- self.assertEqual(detection1.label, new_label)
- self.assertEqual(
- detection2.id, str(self.detection_id_2)
- ) # ensure other item is not modified
- self.assertEqual(detection2.label, "dog")
-
- async def test_dataset_not_found(self):
+ detection1 = sample.ground_truth.detections[0]
+ detection2 = sample.ground_truth.detections[1]
+
+ assert detection1.id == str(self.DETECTION_ID_1)
+ assert detection1.label == new_label
+ # ensure other item is not modified
+ assert detection2.id == str(self.DETECTION_ID_2)
+ assert detection2.label == "dog"
+
+ @pytest.mark.asyncio
+ async def test_dataset_not_found(self, mutator, mock_request):
"""Tests that a 404 is raised for a non-existent dataset."""
- request = self._create_mock_request(
- [], "ground_truth.detections", self.detection_id_1
- )
- request.path_params["dataset_id"] = "non-existent-dataset"
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ mock_request.path_params["dataset_id"] = "non-existent-dataset"
+
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail, "Dataset 'non-existent-dataset' not found"
+ assert exc_info.value.status_code == 404
+ assert (
+ exc_info.value.detail == "Dataset 'non-existent-dataset' not found"
)
- async def test_sample_not_found(self):
+ @pytest.mark.asyncio
+ async def test_sample_not_found(self, mutator, mock_request, dataset_id):
"""Tests that a 404 is raised for a non-existent sample."""
bad_id = str(ObjectId())
- request = self._create_mock_request(
- [], "ground_truth.detections", self.detection_id_1
- )
- request.path_params["sample_id"] = bad_id
+ mock_request.path_params["sample_id"] = bad_id
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail,
- f"Sample '{bad_id}' not found in dataset '{self.dataset_id}'",
+ assert exc_info.value.status_code == 404
+ assert (
+ exc_info.value.detail
+ == f"Sample '{bad_id}' not found in dataset '{dataset_id}'"
)
- async def test_field_path_not_found(self):
+ @pytest.mark.asyncio
+ async def test_if_match_header_failure(
+ self, mutator, mock_request, sample, if_match
+ ):
+ """Tests that a 412 HTTPException is raised for an invalid If-Match."""
+ if if_match is None:
+ pytest.skip("Fixture returned None, skipping this test.")
+
+ # Update the sample to change its last_modified_at
+ sample["primitive_field"] = "new_value"
+ sample.save()
+
+ patch_payload = [{"op": "replace", "path": "/label", "value": "fish"}]
+
+ mock_request.body.return_value = json_payload(patch_payload)
+
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 412
+
+ @pytest.mark.asyncio
+ async def test_field_path_not_found(self, mutator, mock_request, sample):
"""Tests that a 404 is raised for a non-existent field path."""
bad_path = "non_existent.path"
- request = self._create_mock_request([], bad_path, self.detection_id_1)
+ mock_request.path_params["field_path"] = bad_path
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail,
- f"Field '{bad_path}' not found in sample '{self.sample.id}'",
+ assert exc_info.value.status_code == 404
+ assert (
+ exc_info.value.detail
+ == f"Field '{bad_path}' not found in sample '{sample.id}'"
)
- async def test_field_is_not_a_list(self):
- """Tests that a 400 is raised if the field path does not point to a list."""
+ @pytest.mark.asyncio
+ async def test_field_is_not_a_list(self, mutator, mock_request):
+ """Tests that a 400 is raised if the field path does not point to a
+ list."""
field_path = "scalar_field"
- request = self._create_mock_request(
- [], field_path, self.detection_id_1
- )
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ mock_request.path_params["field_path"] = field_path
- self.assertEqual(cm.exception.status_code, 400)
- self.assertEqual(
- cm.exception.detail,
- f"Field '{field_path}' is not a list",
- )
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == f"Field '{field_path}' is not a list"
- async def test_field_id_not_found_in_list(self):
+ @pytest.mark.asyncio
+ async def test_field_id_not_found_in_list(self, mutator, mock_request):
"""Tests that a 404 is raised if the field ID is not in the list."""
bad_id = str(ObjectId())
- field_path = "ground_truth.detections"
- request = self._create_mock_request([], field_path, bad_id)
+ mock_request.path_params["field_id"] = bad_id
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ with pytest.raises(HTTPException) as exc_info:
+ #####
+ await mutator.patch(mock_request)
+ #####
- self.assertEqual(cm.exception.status_code, 404)
- self.assertEqual(
- cm.exception.detail,
- f"Field with id '{bad_id}' not found in field '{field_path}'",
+ assert exc_info.value.status_code == 404
+ assert exc_info.value.detail == (
+ f"Field with id '{bad_id}' not found in field "
+ f"'{mock_request.path_params['field_path']}'"
)
- async def test_invalid_patch_operation(self):
+ @pytest.mark.asyncio
+ async def test_invalid_patch_operation(self, mutator, mock_request):
"""Tests that a 400 is raised for an invalid patch operation."""
patch_payload = [
{"op": "replace", "path": "/non_existent_attr", "value": "test"}
]
- field_path = "ground_truth.detections"
- field_id = self.detection_id_1
- request = self._create_mock_request(
- patch_payload, field_path, field_id
- )
+ mock_request.body.return_value = json_payload(patch_payload)
+ mock_request.headers["Content-Type"] = "application/json-patch+json"
- with self.assertRaises(HTTPException) as cm:
- await self.mutator.patch(request)
+ with pytest.raises(HTTPException) as exc_info:
+ ###
+ await mutator.patch(mock_request)
+ ###
- self.assertEqual(cm.exception.status_code, 400)
- self.assertIn(str(patch_payload[0]), cm.exception.detail)
- self.assertIn(
- "non_existent_attr", cm.exception.detail[str(patch_payload[0])]
+ assert exc_info.value.status_code == 400
+ assert str(patch_payload[0]) in exc_info.value.detail
+ assert (
+ "non_existent_attr" in exc_info.value.detail[str(patch_payload[0])]
)
-
-
-if __name__ == "__main__":
- unittest.main(verbosity=2)
diff --git a/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py b/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py
index 6ef9b5d7a0..1dc3d46458 100644
--- a/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py
+++ b/tests/unittests/server/utils/jsonpatch/test_json_patch_patch.py
@@ -9,7 +9,7 @@
import pytest
-from fiftyone.server.utils.jsonpatch import methods, patch
+from fiftyone.server.utils.json.jsonpatch import methods, patch
@pytest.mark.parametrize(
diff --git a/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py b/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py
index c94033009e..d09e876ec0 100644
--- a/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py
+++ b/tests/unittests/server/utils/jsonpatch/test_jsonpatch.py
@@ -9,7 +9,7 @@
import pytest
-from fiftyone.server.utils import jsonpatch
+from fiftyone.server.utils.json import jsonpatch
class TestParse:
diff --git a/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py b/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py
index ea2d65e878..676bac036b 100644
--- a/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py
+++ b/tests/unittests/server/utils/jsonpatch/test_jsonpatch_methods.py
@@ -11,8 +11,8 @@
import jsonpointer
import pytest
-from fiftyone.server.utils.jsonpatch import methods
-from fiftyone.server.utils.jsonpatch.methods import (
+from fiftyone.server.utils.json.jsonpatch import methods
+from fiftyone.server.utils.json.jsonpatch.methods import (
get,
add,
remove,