diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index c28a7181b5..ff0521732b 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -7,6 +7,7 @@ """ from collections import defaultdict, deque +from concurrent.futures import ThreadPoolExecutor from copy import copy from datetime import datetime from operator import itemgetter @@ -17,6 +18,7 @@ import os import random import string +import threading import timeit import warnings @@ -77,6 +79,32 @@ def registrar(func): aggregation = _make_registrar() +class DummyFuture: + def __init__(self, *, value=None, exception=None): + self.value = value + self.exception = exception + + def result(self): + if self.exception: + raise self.exception + return self.value + + +class DummyExecutor: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def submit(self, fn, *args, **kwargs): + try: + result = fn(*args, **kwargs) + return DummyFuture(value=result) + except Exception as e: + return DummyFuture(exception=e) + + class SaveContext(object): """Context that saves samples from a collection according to a configurable batching strategy. @@ -106,6 +134,7 @@ def __init__( sample_collection, batch_size=None, batching_strategy=None, + async_writes=False, ): batch_size, batching_strategy = fou.parse_batching_strategy( batch_size=batch_size, batching_strategy=batching_strategy @@ -130,6 +159,18 @@ def __init__( self._encoding_ratio = 1.0 self._last_time = None + self.samples_lock = threading.Lock() + self.frames_lock = threading.Lock() + self.batch_ids_lock = threading.Lock() + self.reloading_lock = threading.Lock() + + self.executor = ( + ThreadPoolExecutor(max_workers=1) + if async_writes + else DummyExecutor() + ) + self.futures = [] + def __enter__(self): if self._batching_strategy == "static": self._curr_batch_size = 0 @@ -138,11 +179,32 @@ def __enter__(self): elif self._batching_strategy == "latency": self._last_time = timeit.default_timer() + self.executor.__enter__() return self def __exit__(self, *args): self._save_batch() + error = None + try: + # Loop-drain self.futures so any submissions triggered by + # super().__exit__() are awaited. + while self.futures: + futures = self.futures + self.futures = [] + for future in futures: + try: + future.result() + except Exception as e: + if error is None: + error = e + self.futures.clear() + finally: + self.executor.__exit__(*args) + + if error and (not args or args[0] is not None): + raise error + def save(self, sample): """Registers the sample for saving in the next batch. @@ -160,16 +222,20 @@ def save(self, sample): updated = sample_ops or frame_ops if sample_ops: - self._sample_ops.extend(sample_ops) + with self.samples_lock: + self._sample_ops.extend(sample_ops) if frame_ops: - self._frame_ops.extend(frame_ops) + with self.frames_lock: + self._frame_ops.extend(frame_ops) if updated and self._is_generated: - self._batch_ids.append(sample.id) + with self.batch_ids_lock: + self._batch_ids.append(sample.id) if updated and isinstance(sample, fosa.SampleView): - self._reload_parents.append(sample) + with self.reloading_lock: + self._reload_parents.append(sample) if self._batching_strategy == "static": self._curr_batch_size += 1 @@ -198,24 +264,43 @@ def save(self, sample): self._save_batch() self._last_time = timeit.default_timer() - def _save_batch(self): + def _do_save_batch(self): encoded_size = -1 if self._sample_ops: - res = foo.bulk_write( - self._sample_ops, - self._sample_coll, - ordered=False, - batcher=False, - )[0] - encoded_size += res.bulk_api_result.get("nBytes", 0) - self._sample_ops.clear() + with self.samples_lock: + sample_ops = self._sample_ops.copy() + self._sample_ops.clear() + try: + res = foo.bulk_write( + sample_ops, + self._sample_coll, + ordered=False, + batcher=False, + )[0] + encoded_size += res.bulk_api_result.get("nBytes", 0) + except Exception: + # requeue to avoid data loss + with self.samples_lock: + self._sample_ops.extend(sample_ops) + raise if self._frame_ops: - res = foo.bulk_write( - self._frame_ops, self._frame_coll, ordered=False, batcher=False - )[0] - encoded_size += res.bulk_api_result.get("nBytes", 0) - self._frame_ops.clear() + with self.frames_lock: + frame_ops = self._frame_ops.copy() + self._frame_ops.clear() + try: + res = foo.bulk_write( + frame_ops, + self._frame_coll, + ordered=False, + batcher=False, + )[0] + encoded_size += res.bulk_api_result.get("nBytes", 0) + except Exception as e: + # requeue to avoid data loss + with self.frames_lock: + self._frame_ops.extend(frame_ops) + raise self._encoding_ratio = ( self._curr_batch_size_bytes / encoded_size @@ -224,14 +309,21 @@ def _save_batch(self): ) if self._batch_ids and self._is_generated: - self.sample_collection._sync_source(ids=self._batch_ids) - self._batch_ids.clear() + with self.batch_ids_lock: + batch_ids = self._batch_ids.copy() + self._batch_ids.clear() + self.sample_collection._sync_source(ids=batch_ids) if self._reload_parents: - for sample in self._reload_parents: + with self.reloading_lock: + reload_parents = self._reload_parents.copy() + self._reload_parents.clear() + for sample in reload_parents: sample._reload_parents() - self._reload_parents.clear() + def _save_batch(self): + future = self.executor.submit(self._do_save_batch) + self.futures.append(future) class SampleCollection(object): diff --git a/fiftyone/core/models.py b/fiftyone/core/models.py index ec2aeaea62..f925103562 100644 --- a/fiftyone/core/models.py +++ b/fiftyone/core/models.py @@ -477,7 +477,12 @@ def _apply_image_model_data_loader( with contextlib.ExitStack() as context: pb = context.enter_context(fou.ProgressBar(samples, progress=progress)) - ctx = context.enter_context(foc.SaveContext(samples)) + ctx = context.enter_context( + foc.SaveContext( + samples, + async_writes=True, + ) + ) def save_batch(sample_batch, labels_batch): with _handle_batch_error(skip_failures, sample_batch): @@ -1210,7 +1215,12 @@ def _compute_image_embeddings_data_loader( with contextlib.ExitStack() as context: pb = context.enter_context(fou.ProgressBar(samples, progress=progress)) if embeddings_field is not None: - ctx = context.enter_context(foc.SaveContext(samples)) + ctx = context.enter_context( + foc.SaveContext( + samples, + async_writes=True, + ) + ) else: ctx = None