Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 114 additions & 22 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from datetime import datetime
from operator import itemgetter
Expand All @@ -17,6 +18,7 @@
import os
import random
import string
import threading
import timeit
import warnings

Expand Down Expand Up @@ -77,6 +79,32 @@ def registrar(func):
aggregation = _make_registrar()


class DummyFuture:
def __init__(self, *, value=None, exception=None):
self.value = value
self.exception = exception

def result(self):
if self.exception:
raise self.exception
return self.value


class DummyExecutor:
def __enter__(self):
return self

def __exit__(self, *args):
pass

def submit(self, fn, *args, **kwargs):
try:
result = fn(*args, **kwargs)
return DummyFuture(value=result)
except Exception as e:
return DummyFuture(exception=e)


class SaveContext(object):
"""Context that saves samples from a collection according to a configurable
batching strategy.
Expand Down Expand Up @@ -106,6 +134,7 @@ def __init__(
sample_collection,
batch_size=None,
batching_strategy=None,
async_writes=False,
):
batch_size, batching_strategy = fou.parse_batching_strategy(
batch_size=batch_size, batching_strategy=batching_strategy
Expand All @@ -130,6 +159,18 @@ def __init__(
self._encoding_ratio = 1.0
self._last_time = None

self.samples_lock = threading.Lock()
self.frames_lock = threading.Lock()
self.batch_ids_lock = threading.Lock()
self.reloading_lock = threading.Lock()

self.executor = (
ThreadPoolExecutor(max_workers=1)
if async_writes
else DummyExecutor()
)
self.futures = []

def __enter__(self):
if self._batching_strategy == "static":
self._curr_batch_size = 0
Expand All @@ -138,11 +179,32 @@ def __enter__(self):
elif self._batching_strategy == "latency":
self._last_time = timeit.default_timer()

self.executor.__enter__()
return self

def __exit__(self, *args):
self._save_batch()

error = None
try:
# Loop-drain self.futures so any submissions triggered by
# super().__exit__() are awaited.
while self.futures:
futures = self.futures
self.futures = []
for future in futures:
try:
future.result()
except Exception as e:
if error is None:
error = e
self.futures.clear()
finally:
self.executor.__exit__(*args)

if error and (not args or args[0] is not None):
raise error

def save(self, sample):
"""Registers the sample for saving in the next batch.

Expand All @@ -160,16 +222,20 @@ def save(self, sample):
updated = sample_ops or frame_ops

if sample_ops:
self._sample_ops.extend(sample_ops)
with self.samples_lock:
self._sample_ops.extend(sample_ops)

if frame_ops:
self._frame_ops.extend(frame_ops)
with self.frames_lock:
self._frame_ops.extend(frame_ops)

if updated and self._is_generated:
self._batch_ids.append(sample.id)
with self.batch_ids_lock:
self._batch_ids.append(sample.id)

if updated and isinstance(sample, fosa.SampleView):
self._reload_parents.append(sample)
with self.reloading_lock:
self._reload_parents.append(sample)

if self._batching_strategy == "static":
self._curr_batch_size += 1
Expand Down Expand Up @@ -198,24 +264,43 @@ def save(self, sample):
self._save_batch()
self._last_time = timeit.default_timer()

def _save_batch(self):
def _do_save_batch(self):
encoded_size = -1
if self._sample_ops:
res = foo.bulk_write(
self._sample_ops,
self._sample_coll,
ordered=False,
batcher=False,
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)
self._sample_ops.clear()
with self.samples_lock:
sample_ops = self._sample_ops.copy()
self._sample_ops.clear()
try:
res = foo.bulk_write(
sample_ops,
self._sample_coll,
ordered=False,
batcher=False,
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)
except Exception:
# requeue to avoid data loss
with self.samples_lock:
self._sample_ops.extend(sample_ops)
raise

if self._frame_ops:
res = foo.bulk_write(
self._frame_ops, self._frame_coll, ordered=False, batcher=False
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)
self._frame_ops.clear()
with self.frames_lock:
frame_ops = self._frame_ops.copy()
self._frame_ops.clear()
try:
res = foo.bulk_write(
frame_ops,
self._frame_coll,
ordered=False,
batcher=False,
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)
except Exception as e:
# requeue to avoid data loss
with self.frames_lock:
self._frame_ops.extend(frame_ops)
raise

self._encoding_ratio = (
self._curr_batch_size_bytes / encoded_size
Expand All @@ -224,14 +309,21 @@ def _save_batch(self):
)

if self._batch_ids and self._is_generated:
self.sample_collection._sync_source(ids=self._batch_ids)
self._batch_ids.clear()
with self.batch_ids_lock:
batch_ids = self._batch_ids.copy()
self._batch_ids.clear()
self.sample_collection._sync_source(ids=batch_ids)

if self._reload_parents:
for sample in self._reload_parents:
with self.reloading_lock:
reload_parents = self._reload_parents.copy()
self._reload_parents.clear()
for sample in reload_parents:
sample._reload_parents()

self._reload_parents.clear()
def _save_batch(self):
future = self.executor.submit(self._do_save_batch)
self.futures.append(future)


class SampleCollection(object):
Expand Down
14 changes: 12 additions & 2 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,12 @@ def _apply_image_model_data_loader(

with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.SaveContext(
samples,
async_writes=True,
)
)

def save_batch(sample_batch, labels_batch):
with _handle_batch_error(skip_failures, sample_batch):
Expand Down Expand Up @@ -1210,7 +1215,12 @@ def _compute_image_embeddings_data_loader(
with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
if embeddings_field is not None:
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.SaveContext(
samples,
async_writes=True,
)
)
else:
ctx = None

Expand Down
Loading