Skip to content

Commit 6e348e9

Browse files
committed
Limit number of background tasks
1 parent 96d4bf2 commit 6e348e9

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

fiftyone/core/models.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,6 @@ def _apply_image_model_data_loader(
478478
with contextlib.ExitStack() as context:
479479
pb = context.enter_context(fou.ProgressBar(samples, progress=progress))
480480
ctx = context.enter_context(foc.SaveContext(samples))
481-
submit = context.enter_context(
482-
fou.async_executor(
483-
max_workers=1,
484-
skip_failures=skip_failures,
485-
warning="Async failure labeling batches",
486-
)
487-
)
488481

489482
def save_batch(sample_batch, labels_batch):
490483
with _handle_batch_error(skip_failures, sample_batch):
@@ -499,9 +492,15 @@ def save_batch(sample_batch, labels_batch):
499492
)
500493
ctx.save(sample)
501494

502-
for sample_batch, imgs in zip(
503-
fou.iter_batches(samples, batch_size),
504-
data_loader,
495+
for submit, (sample_batch, imgs) in fou.async_iterator(
496+
zip(
497+
fou.iter_batches(samples, batch_size),
498+
data_loader,
499+
),
500+
limit=10,
501+
max_workers=1,
502+
skip_failures=skip_failures,
503+
warning="Async failure labeling batches",
505504
):
506505
with _handle_batch_error(skip_failures, sample_batch):
507506
if isinstance(imgs, Exception):
@@ -1221,24 +1220,22 @@ def _compute_image_embeddings_data_loader(
12211220
else:
12221221
ctx = None
12231222

1224-
submit = context.enter_context(
1225-
fou.async_executor(
1226-
max_workers=1,
1227-
skip_failures=skip_failures,
1228-
warning="Async failure saving embeddings",
1229-
)
1230-
)
1231-
12321223
def save_batch(sample_batch, embeddings_batch):
12331224
with _handle_batch_error(skip_failures, sample_batch):
12341225
for sample, embedding in zip(sample_batch, embeddings_batch):
12351226
sample[embeddings_field] = embedding
12361227
if ctx:
12371228
ctx.save(sample)
12381229

1239-
for sample_batch, imgs in zip(
1240-
fou.iter_batches(samples, batch_size),
1241-
data_loader,
1230+
for submit, (sample_batch, imgs) in fou.async_iterator(
1231+
zip(
1232+
fou.iter_batches(samples, batch_size),
1233+
data_loader,
1234+
),
1235+
limit=10,
1236+
max_workers=1,
1237+
skip_failures=skip_failures,
1238+
warning="Async failure saving embeddings",
12421239
):
12431240
embeddings_batch = [None] * len(sample_batch)
12441241

fiftyone/core/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3222,3 +3222,60 @@ def submit(*args, **kwargs):
32223222
if not skip_failures:
32233223
raise e
32243224
logger.warning(warning, exc_info=True)
3225+
3226+
3227+
def async_iterator(
3228+
iterator,
3229+
*,
3230+
limit,
3231+
max_workers,
3232+
skip_failures=False,
3233+
warning="Async failure",
3234+
):
3235+
"""
3236+
Wraps an iterable with a thread pool executor and provides a function to add
3237+
tasks to a background queue. When the background queue has more than
3238+
max number of allowed tasks, iteration pauses until a task completes.
3239+
3240+
Example::
3241+
3242+
for submit, item in async_iterator(iterator, limit=4, max_workers=4):
3243+
submit(task(item))
3244+
3245+
Args:
3246+
iterator: an iterable to consume
3247+
limit: the maximum number of background tasks to allow before
3248+
pausing iteration
3249+
max_workers: the maximum number of workers to use
3250+
skip_failures (False): whether to skip exceptions raised by tasks
3251+
warning ("Async failure"): the warning message to log if a task
3252+
raises an exception and ``skip_failures == True``
3253+
3254+
Raises:
3255+
Exception: if a task raises an exception and ``skip_failures == False``
3256+
"""
3257+
3258+
def wait_for_task(future):
3259+
try:
3260+
yield future.result()
3261+
except Exception as e:
3262+
if not skip_failures:
3263+
raise e
3264+
logger.warning(warning, exc_info=True)
3265+
3266+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
3267+
futures = []
3268+
3269+
def submit(*args, **kwargs):
3270+
future = executor.submit(*args, **kwargs)
3271+
futures.append(future)
3272+
return future
3273+
3274+
for item in iterator:
3275+
yield submit, item
3276+
3277+
while len(futures) > limit:
3278+
wait_for_task(futures.pop(0))
3279+
3280+
for future in futures:
3281+
wait_for_task(future)

0 commit comments

Comments
 (0)