-
Notifications
You must be signed in to change notification settings - Fork 678
[NO MERGE] Ray postprocessing and writes. kinda goes zoom. #6293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ | |
foutr = fou.lazy_import("fiftyone.utils.transformers") | ||
fouu = fou.lazy_import("fiftyone.utils.ultralytics") | ||
|
||
foray = fou.lazy_import("fiftyone.core.ray.base") | ||
foray_writers = fou.lazy_import("fiftyone.core.ray.writers") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -459,7 +461,18 @@ def _apply_image_model_data_loader( | |
|
||
with contextlib.ExitStack() as context: | ||
pb = context.enter_context(fou.ProgressBar(samples, progress=progress)) | ||
ctx = context.enter_context(foc.SaveContext(samples)) | ||
output_processor = model._output_processor | ||
ctx = context.enter_context( | ||
foray.ActorPoolContext( | ||
samples, | ||
foray_writers.LabelWriter, | ||
num_workers=16, | ||
label_field=label_field, | ||
confidence_thresh=confidence_thresh, | ||
post_processor=output_processor, | ||
) | ||
) | ||
context.enter_context(fou.SetAttributes(model, _output_processor=None)) | ||
|
||
for sample_batch, imgs in zip( | ||
fou.iter_batches(samples, batch_size), | ||
|
@@ -470,22 +483,13 @@ def _apply_image_model_data_loader( | |
raise imgs | ||
|
||
if needs_samples: | ||
labels_batch = model.predict_all( | ||
ids, labels_batch = model.predict_all( | ||
imgs, samples=sample_batch | ||
) | ||
else: | ||
labels_batch = model.predict_all(imgs) | ||
|
||
for sample, labels in zip(sample_batch, labels_batch): | ||
if filename_maker is not None: | ||
_export_arrays(labels, sample.filepath, filename_maker) | ||
ids, labels_batch = model.predict_all(imgs) | ||
|
||
sample.add_labels( | ||
labels, | ||
label_field=label_field, | ||
confidence_thresh=confidence_thresh, | ||
) | ||
ctx.save(sample) | ||
ctx.submit(ids, labels_batch) | ||
Comment on lines
+486
to
+492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Breaking change in predict_all return signature The change from This needs careful coordination:
# Consider adding compatibility wrapper:
def _predict_all_compat(model, imgs, samples=None):
"""Wrapper to handle both old and new predict_all signatures."""
result = model.predict_all(imgs, samples=samples) if samples else model.predict_all(imgs)
if isinstance(result, tuple) and len(result) == 2:
return result # New format: (ids, predictions)
else:
# Old format: just predictions, extract IDs from imgs if available
ids = imgs.get("_id", None)
return ids, result 🤖 Prompt for AI Agents
|
||
|
||
except Exception as e: | ||
if not skip_failures: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,4 @@ | ||||||||||||||||||||||||||
import ray | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if not ray.is_initialized(): | ||||||||||||||||||||||||||
ray.init() | ||||||||||||||||||||||||||
Comment on lines
+3
to
+4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Ray initialization at import time can cause issues Initializing Ray at module import time is problematic for several reasons:
Consider lazy initialization or providing an explicit initialization function: import ray
-if not ray.is_initialized():
- ray.init()
+def ensure_ray_initialized(**kwargs):
+ """Initialize Ray if not already initialized.
+
+ Args:
+ **kwargs: Optional Ray initialization parameters
+ """
+ if not ray.is_initialized():
+ ray.init(**kwargs) Then call this function only when Ray features are actually needed, such as in 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,72 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import ray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import fiftyone.core.view as fov | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def serialize_samples(samples): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a good pattern because you're guaranteeing that the process will require a database connection just to resolve the file path. It would be much better to just resolve the file path directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What file paths are we talking about? Generally speaking, we want these workers to have a database connection. One of the goals is for them to interact with FO datasets in parallel to the main process. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No you don't want to give workers database connections. There is zero benefit because the data (media) is not even in the database and cannot be retrieved using the database connection There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are writer workers. They must each hold some connection or share access to a pool of connections for us to write multiple things in parallel. Frankly we don't even need a ton of them because most of what they do is sit around waiting for I/O, so a single one that's multithreaded would probably be just fine. Unrelated grievances with our multi-worker read system can be discussed elsewhere. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataset_name = samples._root_dataset.name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
stages = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
samples._serialize() if isinstance(samples, fov.DatasetView) else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return dataset_name, stages | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def deserialize_samples(serialized_samples): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import fiftyone as fo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataset_name, stages = serialized_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dataset = fo.load_dataset(dataset_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if stages is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return fov.DatasetView._build(dataset, stages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return dataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class FiftyOneActor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Class for FiftyOne Ray actors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
serialized_samples: a serialized representation of a | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:class:`fiftyone.core.collections.SampleCollection` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, serialized_samples, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__(**kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.samples = deserialize_samples(serialized_samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+33
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not forward arbitrary kwargs to object.init (will raise TypeError).
Apply this diff: def __init__(self, serialized_samples, **kwargs):
- super().__init__(**kwargs)
- self.samples = deserialize_samples(serialized_samples)
+ # Don't pass kwargs to object.__init__ (TypeError). Store for subclasses.
+ super().__init__()
+ self.samples = deserialize_samples(serialized_samples)
+ self._init_kwargs = kwargs 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class ActorPoolContext: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Context manager for a pool of Ray actors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
samples: a :class:`fiftyone.core.collections.SampleCollection` | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
actor_type: the :class:`FiftyOneActor` subclass to instantiate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for each worker | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
num_workers (int): the number of workers in the pool | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__(self, samples, actor_type, *args, num_workers=4, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.serialized_samples_ref = ray.put(serialize_samples(samples)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.num_workers = num_workers | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.actor_type = actor_type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.actors = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.actor_type.remote( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.serialized_samples_ref, *args, **kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for _ in range(self.num_workers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.pool = ray.util.ActorPool(self.actors) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __enter__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __exit__(self, *args): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Clean up refs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for actor in self.actors: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
del actor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
del self.serialized_samples_ref | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+64
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Ensure graceful teardown: drain pending results and terminate Ray actors. Deleting Python references does not stop Ray actors; processes continue until the cluster GC kills them. Drain pending results to avoid backpressure, then explicitly terminate actors. Apply this diff: def __exit__(self, *args):
- # Clean up refs
- for actor in self.actors:
- del actor
-
- del self.serialized_samples_ref
+ # Drain any pending results so ActorPool marks actors idle
+ try:
+ while hasattr(self.pool, "has_next") and self.pool.has_next():
+ self.pool.get_next_unordered()
+ except Exception:
+ pass
+ # Explicitly terminate remote actors
+ for actor in getattr(self, "actors", []):
+ try:
+ ray.kill(actor)
+ except Exception:
+ pass
+ # Release references
+ try:
+ self.actors.clear()
+ except Exception:
+ pass
+ self.pool = None
+ try:
+ del self.serialized_samples_ref
+ except Exception:
+ pass 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def submit(self, ids, payloads): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.pool.submit(lambda a, v: a.run.remote(*v), (ids, payloads)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+71
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ActorPool will stall without consuming results; opportunistically drain to keep progress.
Apply this diff to opportunistically release one completed task per submit: def submit(self, ids, payloads):
- self.pool.submit(lambda a, v: a.run.remote(*v), (ids, payloads))
+ self.pool.submit(lambda a, v: a.run.remote(*v), (ids, payloads))
+ # Opportunistically free one finished task so actors re-enter the idle pool
+ if hasattr(self.pool, "has_next") and self.pool.has_next():
+ self.pool.get_next_unordered() Additionally, add helper methods so callers can drain explicitly (outside this hunk): # Add to ActorPoolContext class (e.g., after submit)
def get_next(self):
return self.pool.get_next_unordered()
def drain(self):
while self.pool.has_next():
self.pool.get_next_unordered() 🤖 Prompt for AI Agents
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,40 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import ray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import fiftyone.core.ray.base as foray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import fiftyone.core.collections as foc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from fiftyone.core.ray.base import FiftyOneActor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@ray.remote | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class LabelWriter(FiftyOneActor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
serialized_samples, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
label_field, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
confidence_thresh=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
post_processor=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
super().__init__(serialized_samples, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.label_field = label_field | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.confidence_thresh = confidence_thresh | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.post_processor = post_processor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.ctx = foc.SaveContext(self.samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def run(self, ids, payloads): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
samples_batch = self.samples.select(ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.post_processor is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
payloads = self.post_processor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
*payloads, confidence_thresh=self.confidence_thresh | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
with self.ctx: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for sample, payload in zip(samples_batch, payloads): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sample.add_labels( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
payload, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is what actually needs to be refactored. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any optimizations to the base code are welcome. That said, notice that we are offloading multiple things here from the main process, namely:
all of these parts have some impact on final |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
label_field=self.label_field, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
confidence_thresh=self.confidence_thresh, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.ctx.save(sample) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+25
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and validation The def run(self, ids, payloads):
+ if ids is None or not ids:
+ return # Nothing to process
+
samples_batch = self.samples.select(ids)
+
+ # Validate that we got the expected number of samples
+ if len(samples_batch) != len(ids):
+ raise ValueError(
+ f"Expected {len(ids)} samples but got {len(samples_batch)}"
+ )
if self.post_processor is not None:
- payloads = self.post_processor(
- *payloads, confidence_thresh=self.confidence_thresh
- )
+ try:
+ payloads = self.post_processor(
+ *payloads, confidence_thresh=self.confidence_thresh
+ )
+ except Exception as e:
+ raise RuntimeError(f"Post-processor failed: {e}") from e
+
+ # Ensure payloads matches samples count after processing
+ if len(payloads) != len(samples_batch):
+ raise ValueError(
+ f"Payload count {len(payloads)} doesn't match sample count {len(samples_batch)}"
+ )
with self.ctx:
for sample, payload in zip(samples_batch, payloads):
- sample.add_labels(
- payload,
- label_field=self.label_field,
- confidence_thresh=self.confidence_thresh,
- )
- self.ctx.save(sample)
+ try:
+ sample.add_labels(
+ payload,
+ label_field=self.label_field,
+ confidence_thresh=self.confidence_thresh,
+ )
+ self.ctx.save(sample)
+ except Exception as e:
+ # Log error but continue processing other samples
+ print(f"Failed to save labels for sample {sample.id}: {e}")
+ # Or re-raise if you want to fail the entire batch
+ # raise 📝 Committable suggestion
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -883,7 +883,8 @@ def predict_all(self, imgs): | |
of dicts of :class:`fiftyone.core.labels.Label` instances | ||
containing the predictions | ||
""" | ||
return self._predict_all(imgs) | ||
ids = imgs.pop("_id", None) | ||
return ids, self._predict_all(imgs) | ||
|
||
def _predict_all(self, imgs): | ||
if self._preprocess and self._transforms is not None: | ||
|
@@ -921,9 +922,9 @@ def _predict_all(self, imgs): | |
|
||
if self._output_processor is None: | ||
if isinstance(output, torch.Tensor): | ||
output = output.detach().cpu().numpy() | ||
output = output.detach().cpu() | ||
|
||
return output | ||
return output, (width, height) | ||
Comment on lines
+925
to
+927
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent return type when no output processor When there's no output processor, the method returns Consider either:
|
||
|
||
if self.has_logits: | ||
self._output_processor.store_logits = self.store_logits | ||
|
@@ -1913,17 +1914,20 @@ def __getitem__(self, idx): | |
return self.__getitems__([idx])[0] | ||
|
||
def __getitems__(self, indices): | ||
_ids = [self.ids[idx] for idx in indices] | ||
if self.vectorize: | ||
batch = self._prepare_batch_vectorized(indices) | ||
else: | ||
batch = self._prepare_batch_db(indices) | ||
|
||
res = [] | ||
for d in batch: | ||
for i, d in enumerate(batch): | ||
if isinstance(d, Exception): | ||
res.append(d) | ||
else: | ||
res.append(self._get_item(d)) | ||
_processed = self._get_item(d) | ||
_processed.update({"_id": _ids[i]}) | ||
res.append(_processed) | ||
|
||
return res | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -681,7 +681,17 @@ def _predict_all(self, args): | |
) | ||
|
||
else: | ||
return output | ||
for k, v in output.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this |
||
if isinstance(v, torch.Tensor): | ||
output[k] = v.detach().cpu() | ||
elif isinstance(v, (tuple, list)): | ||
output[k] = [ | ||
i.detach().cpu() | ||
for i in v | ||
if isinstance(i, torch.Tensor) | ||
] | ||
|
||
return output, image_sizes | ||
|
||
def _forward_pass(self, args): | ||
return self._model( | ||
|
@@ -709,6 +719,10 @@ def collate_fn(batch): | |
keys = batch[0].keys() | ||
res = {} | ||
for k in keys: | ||
if not isinstance(batch[0][k], (torch.Tensor, np.ndarray)): | ||
# not a tensor, just return the list | ||
res[k] = [b[k] for b in batch] | ||
continue | ||
# Gather shapes for dimension analysis | ||
shapes = [b[k].shape for b in batch] | ||
# Find the max size in each dimension | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clearing _output_processor may cause issues
Setting
_output_processor=None
on the model modifies its internal state, which could cause problems if the model is used elsewhere or if an error occurs before it's restored.The current approach modifies the model's internal state which could lead to:
Consider passing the output processor directly to the actor pool without modifying the model.