Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,26 @@ def _save_batch(self):
self._reload_parents.clear()


class AsyncSaveContext(SaveContext):
def __init__(self, *args, executor=None, **kwargs):
if executor is None:
raise ValueError("executor must be specified")
super().__init__(*args, **kwargs)
self.executor = executor

def __enter__(self):
super().__enter__()
self.executor.__enter__()
return self

def __exit__(self, *args):
super().__exit__(*args)
self.executor.__exit__(*args)

def _save_batch(self):
self.executor.submit(super()._save_batch)


class SampleCollection(object):
"""Abstract class representing an ordered collection of
:class:`fiftyone.core.sample.Sample` instances in a
Expand Down
13 changes: 11 additions & 2 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import contextlib
import inspect
import logging
from concurrent.futures import ThreadPoolExecutor

import numpy as np

Expand Down Expand Up @@ -477,7 +478,11 @@ def _apply_image_model_data_loader(

with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.AsyncSaveContext(
samples, executor=ThreadPoolExecutor(max_workers=1)
)
)

def save_batch(sample_batch, labels_batch):
with _handle_batch_error(skip_failures, sample_batch):
Expand Down Expand Up @@ -1210,7 +1215,11 @@ def _compute_image_embeddings_data_loader(
with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
if embeddings_field is not None:
ctx = context.enter_context(foc.SaveContext(samples))
ctx = context.enter_context(
foc.AsyncSaveContext(
samples, executor=ThreadPoolExecutor(max_workers=1)
)
)
else:
ctx = None

Expand Down
Loading