diff --git a/fiftyone/core/models.py b/fiftyone/core/models.py index f6f557ed1bf..0f88ddd9d3d 100644 --- a/fiftyone/core/models.py +++ b/fiftyone/core/models.py @@ -75,6 +75,7 @@ def apply_model( store_logits=False, batch_size=None, num_workers=None, + num_writer_workers=None, skip_failures=True, output_dir=None, rel_dir=None, @@ -105,6 +106,8 @@ def apply_model( batching num_workers (None): the number of workers to use when loading images. Only applicable for Torch-based models + num_writer_workers (None): the number of thread workers to use when + writing predictions. Only applicable for Torch-based models. skip_failures (True): whether to gracefully continue without raising an error if predictions cannot be generated for a sample. Only applicable to :class:`Model` instances @@ -284,6 +287,7 @@ def apply_model( filename_maker, progress, field_mapping, + num_writer_workers, ) if batch_size is not None: @@ -457,6 +461,7 @@ def _apply_image_model_data_loader( filename_maker, progress, field_mapping, + num_writer_workers, ): needs_samples = isinstance(model, SamplesMixin) @@ -480,7 +485,7 @@ def _apply_image_model_data_loader( ctx = context.enter_context(foc.SaveContext(samples)) submit = context.enter_context( fou.async_executor( - max_workers=1, + max_workers=num_writer_workers, skip_failures=skip_failures, warning="Async failure labeling batches", ) @@ -904,6 +909,7 @@ def compute_embeddings( embeddings_field=None, batch_size=None, num_workers=None, + num_writer_workers=None, skip_failures=True, progress=None, **kwargs, @@ -934,6 +940,8 @@ def compute_embeddings( batching num_workers (None): the number of workers to use when loading images. Only applicable for Torch-based models + num_writer_workers (None): the number of thread workers to use when + writing embeddings. Only applicable for Torch-based models. skip_failures (True): whether to gracefully continue without raising an error if embeddings cannot be generated for a sample. Only applicable to :class:`Model` instances @@ -1078,6 +1086,7 @@ def compute_embeddings( skip_failures, progress, field_mapping, + num_writer_workers, ) if batch_size is not None: @@ -1200,6 +1209,7 @@ def _compute_image_embeddings_data_loader( skip_failures, progress, field_mapping, + num_writer_workers, ): data_loader = _make_data_loader( samples, @@ -1224,7 +1234,7 @@ def _compute_image_embeddings_data_loader( submit = context.enter_context( fou.async_executor( - max_workers=1, + max_workers=num_writer_workers, skip_failures=skip_failures, warning="Async failure saving embeddings", ) diff --git a/fiftyone/core/utils.py b/fiftyone/core/utils.py index c3fb2d841a1..d6a48b9027e 100644 --- a/fiftyone/core/utils.py +++ b/fiftyone/core/utils.py @@ -2859,14 +2859,19 @@ def recommend_thread_pool_workers(num_workers=None): If a ``fo.config.max_thread_pool_workers`` is set, this limit is applied. Args: - num_workers (None): a suggested number of workers + num_workers (None): a suggested number of workers. If ``num_workers <= 0``, this + function returns 1. Returns: a number of workers """ + if num_workers is None: num_workers = multiprocessing.cpu_count() + if num_workers <= 0: + num_workers = 1 + if fo.config.max_thread_pool_workers is not None: num_workers = min(num_workers, fo.config.max_thread_pool_workers) @@ -3147,7 +3152,7 @@ def validate_hex_color(value): @contextmanager def async_executor( - *, max_workers, skip_failures=False, warning="Async failure" + *, max_workers=None, skip_failures=False, warning="Async failure" ): """ Context manager that provides a function for submitting tasks to a thread @@ -3160,7 +3165,8 @@ def async_executor( submit(process_item, item) Args: - max_workers: the maximum number of workers to use + max_workers (None): the maximum number of workers to use. By default, + this is determined by :func:`fiftyone.core.utils.recommend_thread_pool_workers`. skip_failures (False): whether to skip exceptions raised by tasks warning ("Async failure"): the warning message to log if a task raises an exception and ``skip_failures == True`` @@ -3168,6 +3174,12 @@ def async_executor( Raises: Exception: if a task raises an exception and ``skip_failures == False`` """ + if max_workers is None: + max_workers = ( + fo.config.default_thread_pool_workers + or recommend_thread_pool_workers(max_workers) + ) + with ThreadPoolExecutor(max_workers=max_workers) as executor: _futures = []