Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import random
import string
import threading
import timeit
import warnings

Expand Down Expand Up @@ -234,6 +235,156 @@ def _save_batch(self):
self._reload_parents.clear()


class AsyncSaveContext(SaveContext):
def __init__(self, *args, executor=None, **kwargs):
if executor is None:
raise ValueError("executor must be specified")
super().__init__(*args, **kwargs)
self.executor = executor
self.futures = []

self.samples_lock = threading.Lock()
self.frames_lock = threading.Lock()
self.batch_ids_lock = threading.Lock()
self.reloading_lock = threading.Lock()

def __enter__(self):
super().__enter__()
self.executor.__enter__()
return self

def __exit__(self, *args):
super().__exit__(*args)

error = None
try:
# Loop-drain self.futures so any submissions triggered by
# super().__exit__() are awaited.
while self.futures:
futures = self.futures
self.futures = []
for future in futures:
try:
future.result()
except Exception as e:
if error is None:
error = e
self.futures.clear()
finally:
self.executor.__exit__(*args)

if error:
raise error

def save(self, sample):
"""Registers the sample for saving in the next batch.

Args:
sample: a :class:`fiftyone.core.sample.Sample` or
:class:`fiftyone.core.sample.SampleView`
"""
if sample._in_db and sample._dataset is not self._dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is all of this just copied from SaveContext but with the locks? Feels like a lot of shared logic/similar code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. What are your thoughts on resolving that? Put the locks in SaveContext?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the locks in SaveContext?

Yes that seems like like the best path. Both to reduce code duplication and also because per my comments here, I see this either as (1) the way that SaveContext always works, or (2) a behavior I can get via fo.SaveContext(..., async_writes=True) and similar.

Note that I am making some small tweaks to SaveContext in #4773, so it would be great to get that merged first 😄

^reminder about this as well. More reason not to duplicate implementation of SaveContext and to get that PR merged first

raise ValueError(
"Dataset context '%s' cannot save sample from dataset '%s'"
% (self._dataset.name, sample._dataset.name)
)

sample_ops, frame_ops = sample._save(deferred=True)
updated = sample_ops or frame_ops

if sample_ops:
with self.samples_lock:
self._sample_ops.extend(sample_ops)

if frame_ops:
with self.frames_lock:
self._frame_ops.extend(frame_ops)

if updated and self._is_generated:
with self.batch_ids_lock:
self._batch_ids.append(sample.id)

if updated and isinstance(sample, fosa.SampleView):
with self.reloading_lock:
self._reload_parents.append(sample)

if self._batching_strategy == "static":
self._curr_batch_size += 1
if self._curr_batch_size >= self.batch_size:
self._save_batch()
self._curr_batch_size = 0
elif self._batching_strategy == "size":
if sample_ops:
self._curr_batch_size_bytes += sum(
len(str(op)) for op in sample_ops
)

if frame_ops:
self._curr_batch_size_bytes += sum(
len(str(op)) for op in frame_ops
)

if (
self._curr_batch_size_bytes
>= self.batch_size * self._encoding_ratio
):
self._save_batch()
self._curr_batch_size_bytes = 0
elif self._batching_strategy == "latency":
if timeit.default_timer() - self._last_time >= self.batch_size:
self._save_batch()
self._last_time = timeit.default_timer()

def _do_save_batch(self):
encoded_size = -1
if self._sample_ops:
with self.samples_lock:
sample_ops = self._sample_ops.copy()
self._sample_ops.clear()
res = foo.bulk_write(
sample_ops,
self._sample_coll,
ordered=False,
batcher=False,
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)

if self._frame_ops:
with self.frames_lock:
frame_ops = self._frame_ops.copy()
self._frame_ops.clear()
res = foo.bulk_write(
frame_ops,
self._frame_coll,
ordered=False,
batcher=False,
)[0]
encoded_size += res.bulk_api_result.get("nBytes", 0)

self._encoding_ratio = (
self._curr_batch_size_bytes / encoded_size
if encoded_size > 0 and self._curr_batch_size_bytes
else 1.0
)

if self._batch_ids and self._is_generated:
with self.batch_ids_lock:
batch_ids = self._batch_ids.copy()
self._batch_ids.clear()
self.sample_collection._sync_source(ids=batch_ids)

if self._reload_parents:
with self.reloading_lock:
reload_parents = self._reload_parents.copy()
self._reload_parents.clear()
for sample in reload_parents:
sample._reload_parents()

def _save_batch(self):
future = self.executor.submit(self._do_save_batch)
self.futures.append(future)


class SampleCollection(object):
"""Abstract class representing an ordered collection of
:class:`fiftyone.core.sample.Sample` instances in a
Expand Down
13 changes: 11 additions & 2 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import contextlib
import inspect
import logging
from concurrent.futures import ThreadPoolExecutor

import numpy as np

Expand Down Expand Up @@ -477,7 +478,11 @@ def _apply_image_model_data_loader(

with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.AsyncSaveContext(
samples, executor=ThreadPoolExecutor(max_workers=1)
)
)

def save_batch(sample_batch, labels_batch):
with _handle_batch_error(skip_failures, sample_batch):
Expand Down Expand Up @@ -1210,7 +1215,11 @@ def _compute_image_embeddings_data_loader(
with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
if embeddings_field is not None:
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.AsyncSaveContext(
samples, executor=ThreadPoolExecutor(max_workers=1)
)
)
else:
ctx = None

Expand Down
Loading