1
+ from cog import BasePredictor , Input , Path , BaseModel
2
+ import os
3
+ import cv2
4
+ import time
5
+ import shutil
6
+ import subprocess
7
+ import numpy as np
8
+ from PIL import Image
9
+ import tempfile
10
+ from transformers import AutoModelForCausalLM , AutoTokenizer
11
+ from mmengine .visualization import Visualizer
12
+ from typing import Optional
13
+ from third_parts import VideoReader
14
+
15
+ MODEL_CACHE = "checkpoints"
16
+ # MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-4B/model.tar"
17
+ MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-8B/model.tar"
18
+ # MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-26B/model.tar"
19
+
20
+ class Output (BaseModel ):
21
+ masked_video : Optional [Path ]
22
+ response : str
23
+
24
+ def download_weights (url , dest ):
25
+ start = time .time ()
26
+ print ("downloading url: " , url )
27
+ print ("downloading to: " , dest )
28
+ subprocess .check_call (["pget" , "-xf" , url , dest ], close_fds = False )
29
+ print ("downloading took: " , time .time () - start )
30
+
31
+ def read_video (video_path , video_interval ):
32
+ # First verify the video can be opened
33
+ cap = cv2 .VideoCapture (str (video_path ))
34
+ if not cap .isOpened ():
35
+ raise ValueError (f"Failed to open video file: { video_path } " )
36
+ cap .release ()
37
+
38
+ # Read frames using VideoReader
39
+ vid_frames = VideoReader (video_path )[::video_interval ]
40
+ if len (vid_frames ) == 0 :
41
+ raise ValueError (f"No frames could be read from video: { video_path } " )
42
+
43
+ temp_dir = tempfile .mkdtemp ()
44
+ os .makedirs (temp_dir , exist_ok = True )
45
+ image_paths = []
46
+ processed_frames = []
47
+
48
+ for frame_idx , frame_image in enumerate (vid_frames ):
49
+ if frame_image is None :
50
+ continue
51
+
52
+ # Convert BGR to RGB
53
+ frame_image = frame_image [..., ::- 1 ] # BGR to RGB
54
+ frame_image = Image .fromarray (frame_image )
55
+ processed_frames .append (frame_image )
56
+
57
+ image_path = os .path .join (temp_dir , f"frame_{ frame_idx :04d} .jpg" )
58
+ frame_image .save (image_path , format = "JPEG" )
59
+ image_paths .append (image_path )
60
+
61
+ if not processed_frames :
62
+ raise ValueError ("No valid frames were processed from the video" )
63
+
64
+ return processed_frames , image_paths
65
+
66
+ def visualize (pred_mask , image_path , work_dir ):
67
+ visualizer = Visualizer ()
68
+ img = cv2 .imread (image_path )
69
+ visualizer .set_image (img )
70
+ visualizer .draw_binary_masks (pred_mask , colors = 'g' , alphas = 0.4 )
71
+ visual_result = visualizer .get_image ()
72
+
73
+ output_path = os .path .join (work_dir , os .path .basename (image_path ))
74
+ cv2 .imwrite (output_path , visual_result )
75
+ return output_path
76
+
77
+ class Predictor (BasePredictor ):
78
+ def setup (self ) -> None :
79
+ """Load the model into memory to make running multiple predictions efficient"""
80
+ os .environ ["TRANSFORMERS_OFFLINE" ] = "1"
81
+
82
+ # Download weights if they don't exist
83
+ if not os .path .exists (MODEL_CACHE ):
84
+ download_weights (MODEL_URL , MODEL_CACHE )
85
+
86
+ # Load model and tokenizer
87
+ self .model = AutoModelForCausalLM .from_pretrained (
88
+ MODEL_CACHE ,
89
+ torch_dtype = "auto" ,
90
+ device_map = "cuda:0" ,
91
+ trust_remote_code = True ,
92
+ ).eval ().cuda ()
93
+
94
+ self .tokenizer = AutoTokenizer .from_pretrained (
95
+ MODEL_CACHE ,
96
+ trust_remote_code = True ,
97
+ )
98
+
99
+ def predict (
100
+ self ,
101
+ video : Path = Input (description = "Input video for segmentation" ),
102
+ instruction : str = Input (description = "Text instruction for the model" ),
103
+ frame_interval : int = Input (description = "Frame interval for processing" , default = 6 , ge = 1 , le = 30 ),
104
+ ) -> Output :
105
+ """Run a single prediction on the model"""
106
+ # clean up past runs remove /tmp/output folder
107
+ if os .path .exists ("/tmp/output" ):
108
+ shutil .rmtree ("/tmp/output" )
109
+
110
+ os .makedirs ("/tmp/output" )
111
+
112
+ # Process video frames
113
+ vid_frames , image_paths = read_video (str (video ), frame_interval )
114
+
115
+ # Get video properties for output
116
+ cap = cv2 .VideoCapture (str (video ))
117
+ if not cap .isOpened ():
118
+ raise ValueError ("Failed to open video file" )
119
+ original_fps = cap .get (cv2 .CAP_PROP_FPS )
120
+ if original_fps == 0 :
121
+ original_fps = 30.0 # Default to 30fps if unable to read
122
+ new_fps = original_fps / frame_interval if frame_interval > 1 else original_fps
123
+ cap .release ()
124
+
125
+ # Prepare the input
126
+ question = f"<image>{ instruction } "
127
+ result = self .model .predict_forward (
128
+ video = vid_frames ,
129
+ text = question ,
130
+ tokenizer = self .tokenizer ,
131
+ )
132
+ prediction = result ['prediction' ]
133
+
134
+ output_video_path = None
135
+ masked_video_path = None
136
+
137
+ if '[SEG]' in prediction :
138
+ _seg_idx = 0
139
+ pred_masks = result ['prediction_masks' ][_seg_idx ]
140
+ seg_frames = []
141
+ masked_only_frames = []
142
+
143
+ temp_dir = tempfile .mkdtemp ()
144
+ os .makedirs (temp_dir , exist_ok = True )
145
+
146
+ # Process each frame
147
+ for frame_idx in range (len (vid_frames )):
148
+ pred_mask = pred_masks [frame_idx ]
149
+
150
+ # Create visualized frame with segmentation overlay
151
+ seg_frame = visualize (pred_mask , image_paths [frame_idx ], temp_dir )
152
+ seg_frames .append (seg_frame )
153
+
154
+ # Create binary mask frame
155
+ binary_mask = (pred_mask .astype ('uint8' ) * 255 )
156
+ binary_mask_path = os .path .join (temp_dir , f"binary_mask_{ frame_idx } .png" )
157
+ cv2 .imwrite (binary_mask_path , binary_mask )
158
+ masked_only_frames .append (binary_mask_path )
159
+
160
+ # Read first frame for dimensions
161
+ frame = cv2 .imread (seg_frames [0 ])
162
+ height , width , layers = frame .shape
163
+
164
+ # Create output video files
165
+ masked_video_path = "/tmp/output/masked_video.mp4"
166
+ temp_masked_path = "/tmp/output/temp_masked.avi"
167
+
168
+ # Define video writer using a more basic codec
169
+ fourcc = cv2 .VideoWriter_fourcc (* 'MJPG' )
170
+ masked_video_writer = cv2 .VideoWriter (temp_masked_path , fourcc , new_fps , (width , height ), isColor = False )
171
+
172
+ # Write frames to video
173
+ for mask_frame_path in masked_only_frames :
174
+ mask_frame = cv2 .imread (mask_frame_path , cv2 .IMREAD_GRAYSCALE )
175
+ masked_video_writer .write (mask_frame )
176
+
177
+ # Release video writer
178
+ masked_video_writer .release ()
179
+
180
+ # Convert to web-compatible MP4 using ffmpeg
181
+ subprocess .run ([
182
+ 'ffmpeg' , '-i' , temp_masked_path , '-c:v' , 'libx264' ,
183
+ '-preset' , 'fast' , '-pix_fmt' , 'yuv420p' , masked_video_path
184
+ ], check = True )
185
+
186
+ # Clean up temporary file
187
+ os .remove (temp_masked_path )
188
+
189
+ return Output (
190
+ masked_video = Path (masked_video_path ) if masked_video_path else None ,
191
+ response = str (prediction )
192
+ )
0 commit comments