Skip to content

Commit b4ae619

Browse files
committed
Merge branch 'main' of github.com:magic-research/Sa2VA
2 parents 1a3e544 + c490564 commit b4ae619

File tree

4 files changed

+323
-1
lines changed

4 files changed

+323
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos
22

3-
[\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]](https://arxiv.org/abs/2501.04001) [\[🤗 HuggingFace\]](https://huggingface.co/collections/ByteDance/sa2va-model-zoo-677e3084d71b5f108d00e093) [\[🎥 Introduction\]]() [\[🧑‍💻 GitHub\]](https://github.com/magic-research/Sa2VA) [\[Gradio Demo (Ours internal: Sa2VA-4B)\]](https://5512470799b6b35fbc.gradio.live/) [\[Gradio Demo (By HuggingFace Offical)\]](https://huggingface.co/spaces/fffiloni/Sa2VA-simple-demo)
3+
[\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]](https://arxiv.org/abs/2501.04001) [\[🤗 HuggingFace\]](https://huggingface.co/collections/ByteDance/sa2va-model-zoo-677e3084d71b5f108d00e093) [\[🎥 Introduction\]]() [\[🧑‍💻 GitHub\]](https://github.com/magic-research/Sa2VA) [\[Gradio Demo (Ours internal: Sa2VA-4B)\]](https://5512470799b6b35fbc.gradio.live/) [\[Gradio Demo (By HuggingFace Offical)\]](https://huggingface.co/spaces/fffiloni/Sa2VA-simple-demo) [\[🤖 Replicate Demo\]](https://replicate.com/bytedance)
44

55

66
[**Haobo Yuan**](https://yuanhaobo.me/)<sup>1*</sup> · [**Xiangtai Li**](https://lxtgh.github.io/)<sup>2*&dagger;</sup> · [**Tao Zhang**](https://zhang-tao-whu.github.io/)<sup>2,3*</sup> · [**Zilong Huang**](http://speedinghzl.github.io/)<sup>2</sup> · [**Shilin Xu**](https://xushilin1.github.io/)<sup>4</sup> ·[**Shunping Ji**](https://scholar.google.com/citations?user=FjoRmF4AAAAJ&hl=en)<sup>3</sup> ·[**Yunhai Tong**](https://scholar.google.com/citations?user=T4gqdPkAAAAJ&hl=zh-CN)<sup>4</sup> ·

demo/cog.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
build:
2+
gpu: true
3+
cuda: "12.4"
4+
python_version: "3.10"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
- "ffmpeg"
9+
python_packages:
10+
- "torch==2.4.0"
11+
- "torchvision"
12+
- "transformers==4.42.3"
13+
- "opencv-python-headless<4.10"
14+
- "peft<0.14.0"
15+
- "timm==1.0.9"
16+
- "einops==0.8.0"
17+
- "sentencepiece==0.2.0"
18+
- "mmengine<1"
19+
- "accelerate"
20+
- "numpy<2"
21+
22+
run:
23+
- FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation
24+
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
25+
26+
predict: "predict.py:Predictor"

demo/predict-img.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from cog import BasePredictor, Input, Path, BaseModel
2+
import os
3+
import cv2
4+
import time
5+
import subprocess
6+
import numpy as np
7+
from PIL import Image
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
from mmengine.visualization import Visualizer
10+
from typing import Optional
11+
12+
MODEL_CACHE = "checkpoints"
13+
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-4B/model.tar"
14+
MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-8B/model.tar"
15+
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-26B/model.tar"
16+
17+
class Output(BaseModel):
18+
img: Optional[Path]
19+
response: str
20+
21+
def download_weights(url, dest):
22+
start = time.time()
23+
print("downloading url: ", url)
24+
print("downloading to: ", dest)
25+
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
26+
print("downloading took: ", time.time() - start)
27+
28+
class Predictor(BasePredictor):
29+
def setup(self) -> None:
30+
"""Load the model into memory to make running multiple predictions efficient"""
31+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
32+
33+
# Download weights if they don't exist
34+
if not os.path.exists(MODEL_CACHE):
35+
download_weights(MODEL_URL, MODEL_CACHE)
36+
37+
# Load model and tokenizer
38+
self.model = AutoModelForCausalLM.from_pretrained(
39+
MODEL_CACHE,
40+
torch_dtype="auto",
41+
device_map="cuda:0",
42+
trust_remote_code=True,
43+
).eval().cuda()
44+
45+
self.tokenizer = AutoTokenizer.from_pretrained(
46+
MODEL_CACHE,
47+
trust_remote_code=True,
48+
)
49+
50+
def predict(
51+
self,
52+
image: Path = Input(description="Input image for segmentation"),
53+
instruction: str = Input(description="Text instruction for the model"),
54+
) -> Output:
55+
"""Run a single prediction on the model"""
56+
# Prepare the image
57+
image = Image.open(str(image)).convert('RGB')
58+
59+
# Prepare the input
60+
text_prompts = f"<image>{instruction}"
61+
input_dict = {
62+
'image': image,
63+
'text': text_prompts,
64+
'past_text': '',
65+
'mask_prompts': None,
66+
'tokenizer': self.tokenizer,
67+
}
68+
69+
# Get model prediction
70+
return_dict = self.model.predict_forward(**input_dict)
71+
answer = return_dict["prediction"]
72+
73+
# Handle segmentation if present
74+
output_path = None
75+
if '[SEG]' in answer:
76+
pred_masks = return_dict["prediction_masks"][0]
77+
78+
# Ensure mask is in the correct format
79+
if isinstance(pred_masks, np.ndarray):
80+
binary_mask = (pred_masks > 0.5).astype('uint8') * 255
81+
else:
82+
binary_mask = (pred_masks.cpu().numpy() > 0.5).astype('uint8') * 255
83+
84+
# Ensure mask has valid dimensions
85+
if binary_mask.ndim == 2:
86+
height, width = binary_mask.shape
87+
elif binary_mask.ndim == 3:
88+
# If we have a 3D array, take the first channel
89+
binary_mask = binary_mask[0] if binary_mask.shape[0] == 1 else binary_mask[:, :, 0]
90+
height, width = binary_mask.shape
91+
else:
92+
return Output(img=None, response=str(answer))
93+
94+
# Check if dimensions are valid and mask is not empty
95+
if width > 0 and height > 0 and np.any(binary_mask):
96+
# Create output directory if it doesn't exist
97+
os.makedirs("/tmp", exist_ok=True)
98+
99+
# Save the binary mask
100+
output_path = "/tmp/output.png"
101+
if cv2.imwrite(output_path, binary_mask):
102+
return Output(img=Path(output_path), response=str(answer))
103+
104+
return Output(img=None, response=str(answer))

demo/predict-vid.py

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from cog import BasePredictor, Input, Path, BaseModel
2+
import os
3+
import cv2
4+
import time
5+
import shutil
6+
import subprocess
7+
import numpy as np
8+
from PIL import Image
9+
import tempfile
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
from mmengine.visualization import Visualizer
12+
from typing import Optional
13+
from third_parts import VideoReader
14+
15+
MODEL_CACHE = "checkpoints"
16+
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-4B/model.tar"
17+
MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-8B/model.tar"
18+
# MODEL_URL = "https://weights.replicate.delivery/default/ByteDance/Sa2VA-26B/model.tar"
19+
20+
class Output(BaseModel):
21+
masked_video: Optional[Path]
22+
response: str
23+
24+
def download_weights(url, dest):
25+
start = time.time()
26+
print("downloading url: ", url)
27+
print("downloading to: ", dest)
28+
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
29+
print("downloading took: ", time.time() - start)
30+
31+
def read_video(video_path, video_interval):
32+
# First verify the video can be opened
33+
cap = cv2.VideoCapture(str(video_path))
34+
if not cap.isOpened():
35+
raise ValueError(f"Failed to open video file: {video_path}")
36+
cap.release()
37+
38+
# Read frames using VideoReader
39+
vid_frames = VideoReader(video_path)[::video_interval]
40+
if len(vid_frames) == 0:
41+
raise ValueError(f"No frames could be read from video: {video_path}")
42+
43+
temp_dir = tempfile.mkdtemp()
44+
os.makedirs(temp_dir, exist_ok=True)
45+
image_paths = []
46+
processed_frames = []
47+
48+
for frame_idx, frame_image in enumerate(vid_frames):
49+
if frame_image is None:
50+
continue
51+
52+
# Convert BGR to RGB
53+
frame_image = frame_image[..., ::-1] # BGR to RGB
54+
frame_image = Image.fromarray(frame_image)
55+
processed_frames.append(frame_image)
56+
57+
image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg")
58+
frame_image.save(image_path, format="JPEG")
59+
image_paths.append(image_path)
60+
61+
if not processed_frames:
62+
raise ValueError("No valid frames were processed from the video")
63+
64+
return processed_frames, image_paths
65+
66+
def visualize(pred_mask, image_path, work_dir):
67+
visualizer = Visualizer()
68+
img = cv2.imread(image_path)
69+
visualizer.set_image(img)
70+
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
71+
visual_result = visualizer.get_image()
72+
73+
output_path = os.path.join(work_dir, os.path.basename(image_path))
74+
cv2.imwrite(output_path, visual_result)
75+
return output_path
76+
77+
class Predictor(BasePredictor):
78+
def setup(self) -> None:
79+
"""Load the model into memory to make running multiple predictions efficient"""
80+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
81+
82+
# Download weights if they don't exist
83+
if not os.path.exists(MODEL_CACHE):
84+
download_weights(MODEL_URL, MODEL_CACHE)
85+
86+
# Load model and tokenizer
87+
self.model = AutoModelForCausalLM.from_pretrained(
88+
MODEL_CACHE,
89+
torch_dtype="auto",
90+
device_map="cuda:0",
91+
trust_remote_code=True,
92+
).eval().cuda()
93+
94+
self.tokenizer = AutoTokenizer.from_pretrained(
95+
MODEL_CACHE,
96+
trust_remote_code=True,
97+
)
98+
99+
def predict(
100+
self,
101+
video: Path = Input(description="Input video for segmentation"),
102+
instruction: str = Input(description="Text instruction for the model"),
103+
frame_interval: int = Input(description="Frame interval for processing", default=6, ge=1, le=30),
104+
) -> Output:
105+
"""Run a single prediction on the model"""
106+
# clean up past runs remove /tmp/output folder
107+
if os.path.exists("/tmp/output"):
108+
shutil.rmtree("/tmp/output")
109+
110+
os.makedirs("/tmp/output")
111+
112+
# Process video frames
113+
vid_frames, image_paths = read_video(str(video), frame_interval)
114+
115+
# Get video properties for output
116+
cap = cv2.VideoCapture(str(video))
117+
if not cap.isOpened():
118+
raise ValueError("Failed to open video file")
119+
original_fps = cap.get(cv2.CAP_PROP_FPS)
120+
if original_fps == 0:
121+
original_fps = 30.0 # Default to 30fps if unable to read
122+
new_fps = original_fps / frame_interval if frame_interval > 1 else original_fps
123+
cap.release()
124+
125+
# Prepare the input
126+
question = f"<image>{instruction}"
127+
result = self.model.predict_forward(
128+
video=vid_frames,
129+
text=question,
130+
tokenizer=self.tokenizer,
131+
)
132+
prediction = result['prediction']
133+
134+
output_video_path = None
135+
masked_video_path = None
136+
137+
if '[SEG]' in prediction:
138+
_seg_idx = 0
139+
pred_masks = result['prediction_masks'][_seg_idx]
140+
seg_frames = []
141+
masked_only_frames = []
142+
143+
temp_dir = tempfile.mkdtemp()
144+
os.makedirs(temp_dir, exist_ok=True)
145+
146+
# Process each frame
147+
for frame_idx in range(len(vid_frames)):
148+
pred_mask = pred_masks[frame_idx]
149+
150+
# Create visualized frame with segmentation overlay
151+
seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir)
152+
seg_frames.append(seg_frame)
153+
154+
# Create binary mask frame
155+
binary_mask = (pred_mask.astype('uint8') * 255)
156+
binary_mask_path = os.path.join(temp_dir, f"binary_mask_{frame_idx}.png")
157+
cv2.imwrite(binary_mask_path, binary_mask)
158+
masked_only_frames.append(binary_mask_path)
159+
160+
# Read first frame for dimensions
161+
frame = cv2.imread(seg_frames[0])
162+
height, width, layers = frame.shape
163+
164+
# Create output video files
165+
masked_video_path = "/tmp/output/masked_video.mp4"
166+
temp_masked_path = "/tmp/output/temp_masked.avi"
167+
168+
# Define video writer using a more basic codec
169+
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
170+
masked_video_writer = cv2.VideoWriter(temp_masked_path, fourcc, new_fps, (width, height), isColor=False)
171+
172+
# Write frames to video
173+
for mask_frame_path in masked_only_frames:
174+
mask_frame = cv2.imread(mask_frame_path, cv2.IMREAD_GRAYSCALE)
175+
masked_video_writer.write(mask_frame)
176+
177+
# Release video writer
178+
masked_video_writer.release()
179+
180+
# Convert to web-compatible MP4 using ffmpeg
181+
subprocess.run([
182+
'ffmpeg', '-i', temp_masked_path, '-c:v', 'libx264',
183+
'-preset', 'fast', '-pix_fmt', 'yuv420p', masked_video_path
184+
], check=True)
185+
186+
# Clean up temporary file
187+
os.remove(temp_masked_path)
188+
189+
return Output(
190+
masked_video=Path(masked_video_path) if masked_video_path else None,
191+
response=str(prediction)
192+
)

0 commit comments

Comments
 (0)