diff --git a/fiftyone/core/models.py b/fiftyone/core/models.py index f6f557ed1bf..5be128b62d2 100644 --- a/fiftyone/core/models.py +++ b/fiftyone/core/models.py @@ -485,8 +485,17 @@ def _apply_image_model_data_loader( warning="Async failure labeling batches", ) ) + output_processor = getattr( + model, "_output_processor", lambda output, *_, **__: output + ) + context.enter_context(fou.SetAttributes(model, _output_processor=None)) - def save_batch(sample_batch, labels_batch): + def save_batch(sample_batch, output, image_sizes): + labels_batch = output_processor( + output, + image_sizes, + confidence_thresh=model.config.confidence_thresh, + ) with _handle_batch_error(skip_failures, sample_batch): for sample, labels in zip(sample_batch, labels_batch): if filename_maker is not None: @@ -507,14 +516,15 @@ def save_batch(sample_batch, labels_batch): if isinstance(imgs, Exception): raise imgs + image_sizes = imgs.pop("fo_image_size", [(None, None)]) if needs_samples: - labels_batch = model.predict_all( - imgs, samples=sample_batch - ) + output = model.predict_all(imgs, samples=sample_batch) else: - labels_batch = model.predict_all(imgs) + output = model.predict_all(imgs) - submit(save_batch, sample_batch, labels_batch) + submit( + save_batch, sample_batch, output, image_sizes=image_sizes + ) pb.update(len(sample_batch))