Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions fiftyone/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,6 @@ def _apply_image_model_data_loader(
with contextlib.ExitStack() as context:
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
ctx = context.enter_context(foc.SaveContext(samples))
submit = context.enter_context(
fou.async_executor(
max_workers=1,
skip_failures=skip_failures,
warning="Async failure labeling batches",
)
)

def save_batch(sample_batch, labels_batch):
with _handle_batch_error(skip_failures, sample_batch):
Expand All @@ -499,9 +492,15 @@ def save_batch(sample_batch, labels_batch):
)
ctx.save(sample)

for sample_batch, imgs in zip(
fou.iter_batches(samples, batch_size),
data_loader,
for submit, (sample_batch, imgs) in fou.async_iterator(
zip(
fou.iter_batches(samples, batch_size),
data_loader,
),
limit=1,
max_workers=1,
skip_failures=skip_failures,
warning="Async failure labeling batches",
):
with _handle_batch_error(skip_failures, sample_batch):
if isinstance(imgs, Exception):
Expand Down Expand Up @@ -1221,24 +1220,22 @@ def _compute_image_embeddings_data_loader(
else:
ctx = None

submit = context.enter_context(
fou.async_executor(
max_workers=1,
skip_failures=skip_failures,
warning="Async failure saving embeddings",
)
)

def save_batch(sample_batch, embeddings_batch):
with _handle_batch_error(skip_failures, sample_batch):
for sample, embedding in zip(sample_batch, embeddings_batch):
sample[embeddings_field] = embedding
if ctx:
ctx.save(sample)

for sample_batch, imgs in zip(
fou.iter_batches(samples, batch_size),
data_loader,
for submit, (sample_batch, imgs) in fou.async_iterator(
zip(
fou.iter_batches(samples, batch_size),
data_loader,
),
limit=1,
max_workers=1,
skip_failures=skip_failures,
warning="Async failure saving embeddings",
):
embeddings_batch = [None] * len(sample_batch)

Expand Down
88 changes: 84 additions & 4 deletions fiftyone/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import atexit
from bson import json_util
from base64 import b64encode, b64decode
from collections import defaultdict
from collections import defaultdict, deque
from contextlib import contextmanager, suppress
from copy import deepcopy
from datetime import date, datetime
Expand All @@ -36,6 +36,7 @@
import sys
import shutil
import tempfile
import threading
import timeit
import types
import uuid
Expand Down Expand Up @@ -3219,19 +3220,98 @@ def async_executor(
Exception: if a task raises an exception and ``skip_failures == False``
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
_futures = []
futures = deque()
lock = threading.Lock()

def remove(future):
with lock:
try:
futures.remove(future)
except ValueError:
pass

def submit(*args, **kwargs):
future = executor.submit(*args, **kwargs)
_futures.append(future)
with lock:
futures.append(future)
future.add_done_callback(remove)
return future

yield submit

for future in _futures:
while futures:
future = futures.popleft()
try:
future.result()
except Exception as e:
if not skip_failures:
raise e
logger.warning(warning, exc_info=True)


def async_iterator(
iterator,
*,
limit,
max_workers,
skip_failures=False,
warning="Async failure",
):
"""
Wraps an iterable with a thread pool executor and provides a function to add
tasks to a background queue. When the background queue has more than
max number of allowed tasks, iteration pauses until a task completes.

Example::

for submit, item in async_iterator(iterator, limit=4, max_workers=4):
submit(task(item))

Args:
iterator: an iterable to consume
limit: the maximum number of background tasks to allow before
pausing iteration
max_workers: the maximum number of workers to use
skip_failures (False): whether to skip exceptions raised by tasks
warning ("Async failure"): the warning message to log if a task
raises an exception and ``skip_failures == True``

Raises:
Exception: if a task raises an exception and ``skip_failures == False``
"""

def wait_for_task(future):
try:
yield future.result()
except Exception as e:
if not skip_failures:
raise e
logger.warning(warning, exc_info=True)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = deque()
lock = threading.Lock()

def remove(future):
with lock:
try:
futures.remove(future)
except ValueError:
pass

def submit(*args, **kwargs):
future = executor.submit(*args, **kwargs)
with lock:
futures.append(future)
future.add_done_callback(remove)
return future

for item in iterator:
yield submit, item
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this supposed to be submit(item)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is yielding the submit function and the current item to the consumer loop, so the consumer can call submit where it wants and with the function it wants to run async.


while len(futures) > limit:
wait_for_task(futures.popleft())

while futures:
future = futures.popleft()
wait_for_task(future)
Loading