diff --git a/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml index f73413a2..e85e7874 100644 --- a/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml +++ b/label_studio_ml/examples/segment_anything_2_video/docker-compose.yml @@ -24,9 +24,9 @@ services: # specify device - DEVICE=cuda # or 'cpu' (coming soon) # SAM2 model config - - MODEL_CONFIG=sam2_hiera_l.yaml + - MODEL_CONFIG=configs/sam2.1/sam2.1_hiera_l.yaml # SAM2 checkpoint - - MODEL_CHECKPOINT=sam2_hiera_large.pt + - MODEL_CHECKPOINT=sam2.1_hiera_large.pt # Specify the Label Studio URL and API key to access # uploaded, local storage and cloud storage files. diff --git a/label_studio_ml/examples/segment_anything_2_video/model.py b/label_studio_ml/examples/segment_anything_2_video/model.py index 2af5b404..e2986b74 100644 --- a/label_studio_ml/examples/segment_anything_2_video/model.py +++ b/label_studio_ml/examples/segment_anything_2_video/model.py @@ -20,8 +20,8 @@ DEVICE = os.getenv('DEVICE', 'cuda') SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2') -MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2_hiera_l.yaml') -MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2_hiera_large.pt') +MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2.1_hiera_l.yaml') +MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_large.pt') MAX_FRAMES_TO_TRACK = int(os.getenv('MAX_FRAMES_TO_TRACK', 10)) if DEVICE == 'cuda': @@ -73,8 +73,9 @@ def split_frames(self, video_path, temp_dir, start_frame=0, end_frame=100): # Read a frame from the video success, frame = video.read() if frame_count < start_frame: + frame_count += 1 continue - if frame_count + start_frame >= end_frame: + if frame_count >= end_frame - 1: break # If frame is read correctly, success is True @@ -217,6 +218,10 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - """ Returns the predicted mask for a smart keypoint that has been placed.""" from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video') + + if not context or not context.get('result'): + # if there is no context, no interaction has happened yet + return ModelResponse(predictions=[]) task = tasks[0] task_id = task['id'] @@ -273,7 +278,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - _, out_obj_ids, out_mask_logits = predictor.add_new_points( inference_state=inference_state, - frame_idx=prompt['frame_idx'], + frame_idx=prompt['frame_idx'] - first_frame_idx, obj_id=obj_ids[prompt['obj_id']], points=prompt['points'], labels=prompt['labels'] @@ -281,14 +286,15 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - sequence = [] - debug_dir = './debug-frames' - os.makedirs(debug_dir, exist_ok=True) + #debug_dir = './debug-frames' + #os.makedirs(debug_dir, exist_ok=True) logger.info(f'Propagating in video from frame {last_frame_idx} to {last_frame_idx + frames_to_track}') + rel_last = last_frame_idx - first_frame_idx for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( inference_state=inference_state, - start_frame_idx=last_frame_idx, - max_frame_num_to_track=frames_to_track + start_frame_idx=rel_last, + max_frame_num_to_track=rel_last + frames_to_track ): real_frame_idx = out_frame_idx + first_frame_idx for i, out_obj_id in enumerate(out_obj_ids):