Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update base_tracker.py #139

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pip install -r requirements.txt

# Run the Track-Anything gradio demo.
python app.py --device cuda:0

# If your platform AppleM2 use
python app.py --device mps
# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
```

Expand Down
45 changes: 26 additions & 19 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def get_frames_from_video(video_input, video_state):
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
print(video_info)
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
gr.update(visible=True),\
gr.update(visible=True), gr.update(visible=True), \
Expand Down Expand Up @@ -336,18 +337,24 @@ def generate_video_from_frames(frames, output_path, fps=30):
output_path (str): The path to save the generated video.
fps (int, optional): The frame rate of the output video. Defaults to 30.
"""
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
# video.write(frame)

height, width, layers = frames[0].shape
print(f"Video width: {width}, height: {height}")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
print(output_path)
for frame in frames:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video.write(frame)
# zhaifang add
'''
height, width, layers = frames[0].shape
print(f"Video width: {width}, height: {height}")
# video.release()
frames = torch.from_numpy(np.asarray(frames))
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
'''
return output_path


Expand Down Expand Up @@ -377,8 +384,8 @@ def generate_video_from_frames(frames, output_path, fps=30):
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
args.port = 12212
args.device = "cuda:3"
args.port = 7860
args.device = "mps"
# args.mask_save = True

# initialize sam, xmem, e2fgvi models
Expand Down Expand Up @@ -428,8 +435,8 @@ def generate_video_from_frames(frames, output_path, fps=30):

# for user video input
with gr.Column():
with gr.Row(scale=0.4):
video_input = gr.Video(autosize=True)
with gr.Row():#scale=0.4
video_input = gr.Video()#autosize=True
with gr.Column():
video_info = gr.Textbox(label="Video Info")
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
Expand All @@ -454,16 +461,17 @@ def generate_video_from_frames(frames, output_path, fps=30):
interactive=True,
visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160)
#clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160)
clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False)
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False)
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)

with gr.Column():
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
video_output = gr.Video(autosize=True, visible=False).style(height=360)
video_output = gr.Video(autoplay=True, visible=False)
with gr.Row():
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
inpaint_video_predict_button = gr.Button(value="Inpainting", visible=False)
Expand Down Expand Up @@ -583,20 +591,19 @@ def generate_video_from_frames(frames, output_path, fps=30):
clear_button_click.click(
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state, run_status],
outputs = [template_frame,click_state, run_status]
)
# set example
gr.Markdown("## Examples")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
"test-sample2.mp4","test-sample13.mp4"]],
"test-sample2.mp4","test-sample13.mp4", "RGB_video.mp4"]],
fn=run_example,
inputs=[
video_input
],
outputs=[video_input],
# cache_examples=True,
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
# iface.launch(debug=True, enable_queue=True)

iface.launch(debug=True, server_port=args.port, server_name="127.0.0.1",max_threads=1,share=True)
20 changes: 20 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from setuptools import setup, find_packages

setup(
name="tracker",
version="0.2.1",
packages=find_packages(),
install_requires=[],
author="zhaifang",
author_email="[email protected]",
description="xmem tracking for 3 sensor short long-term memory",
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url="[email protected]:bingxinhu/Track-Anything.git",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.11',
)
4 changes: 2 additions & 2 deletions tools/base_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class BaseSegmenter:
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
def __init__(self, SAM_checkpoint, model_type, device='mps'):
"""
device: model device
SAM_checkpoint: path of SAM checkpoint
Expand Down Expand Up @@ -85,7 +85,7 @@ def predict(self, prompts, mode, multimask=True):
# initialise BaseSegmenter
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
device = "cuda:4"
device = "mps"
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)

# image embedding (once embedded, multiple prompts can be applied)
Expand Down
11 changes: 6 additions & 5 deletions track_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def generator(self, images: list, template_mask:np.ndarray):

def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--device', type=str, default="mps")
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
parser.add_argument('--mask_save', default=False)
parser.add_argument('--mask_save', default=True)
args = parser.parse_args()

if args.debug:
Expand All @@ -78,7 +78,7 @@ def parse_augment():
logits = None
painted_images = None
images = []
image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg'))
image = np.array(PIL.Image.open('./img/dogs.jpg'))
args = parse_augment()
# images.append(np.ones((20,20,3)).astype('uint8'))
# images.append(np.ones((20,20,3)).astype('uint8'))
Expand All @@ -87,10 +87,11 @@ def parse_augment():

mask = np.zeros_like(image)[:,:,0]
mask[0,0]= 1
trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args)
trackany = TrackingAnything('./checkpoints/sam_vit_h_4b8939.pth','./checkpoints/XMem-s012.pth', './checkpoints/E2FGVI-HQ-CVPR22.pth', args)
masks, logits ,painted_images= trackany.generator(images, mask)







Empty file added tracker/__init__.py
Empty file.
Binary file added tracker/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added tracker/__pycache__/base_tracker.cpython-311.pyc
Binary file not shown.
24 changes: 18 additions & 6 deletions tracker/base_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import torch
import yaml
import torch.nn.functional as F
from tracker.inference.inference_core import InferenceCore
from tracker.model.network import XMem
from inference.inference_core import InferenceCore

from tracker.util.mask_mapper import MaskMapper
from torchvision import transforms
from tracker.util.range_transform import im_normalization
Expand All @@ -25,6 +26,11 @@ def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) ->
device: model device
xmem_checkpoint: checkpoint of XMem model
"""
if device is None:
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load configurations
with open("tracker/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
Expand Down Expand Up @@ -103,7 +109,7 @@ def track(self, frame, first_frame_annotation=None):

# print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')

return final_mask, final_mask, painted_image
return final_mask, probs, painted_image

@torch.no_grad()
def sam_refinement(self, frame, logits, ti):
Expand All @@ -126,8 +132,11 @@ def sam_refinement(self, frame, logits, ti):
def clear_memory(self):
self.tracker.clear_memory()
self.mapper.clear_labels()
torch.cuda.empty_cache()

if self.device == "cuda":
torch.cuda.empty_cache()
if self.device == "mps":
torch.mps.empty_cache()


## how to use:
## 1/3) prepare device and xmem_checkpoint
Expand Down Expand Up @@ -155,7 +164,7 @@ def clear_memory(self):
# how to use
# ------------------------------------------------------------------------------------
# 1/4: set checkpoint and device
device = 'cuda:2'
device = 'mps'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
# SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
# model_type = 'vit_h'
Expand All @@ -179,7 +188,10 @@ def clear_memory(self):
# ----------------------------------------------
# end
# ----------------------------------------------
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
if device == "cuda":
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
if device == "mps":
print(f'max memory allocated: {torch.mps.driver_allocated_memory()/(2**20)} MB')
# set saving path
save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
if not os.path.exists(save_path):
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
7 changes: 4 additions & 3 deletions tracker/inference/inference_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from inference.memory_manager import MemoryManager
from model.network import XMem
from model.aggregate import aggregate

from tracker.inference.memory_manager import MemoryManager

from tracker.model.aggregate import aggregate
from tracker.model.network import XMem
from tracker.util.tensor_util import pad_divide_by, unpad


Expand Down
4 changes: 2 additions & 2 deletions tracker/inference/memory_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import warnings

from inference.kv_memory_store import KeyValueMemoryStore
from model.memory_util import *
from tracker.model.memory_util import *
from tracker.inference.kv_memory_store import KeyValueMemoryStore


class MemoryManager:
Expand Down
Binary file added tracker/model/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file added tracker/model/__pycache__/aggregate.cpython-311.pyc
Binary file not shown.
Binary file added tracker/model/__pycache__/cbam.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tracker/model/__pycache__/modules.cpython-311.pyc
Binary file not shown.
Binary file added tracker/model/__pycache__/network.cpython-311.pyc
Binary file not shown.
Binary file added tracker/model/__pycache__/resnet.cpython-311.pyc
Binary file not shown.
8 changes: 5 additions & 3 deletions tracker/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import torch.nn as nn
import torch.nn.functional as F

from model.group_modules import *
from model import resnet
from model.cbam import CBAM
from tracker.model import resnet
from tracker.model.cbam import CBAM
from tracker.model.group_modules import GConv2D, GroupResBlock, MainToGroupDistributor, downsample_groups, upsample_groups




class FeatureFusionBlock(nn.Module):
Expand Down
7 changes: 4 additions & 3 deletions tracker/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import torch
import torch.nn as nn
from tracker.model.aggregate import aggregate
from tracker.model.memory_util import get_affinity, readout
from tracker.model.modules import Decoder, KeyEncoder, KeyProjection, ValueEncoder


from model.aggregate import aggregate
from model.modules import *
from model.memory_util import *


class XMem(nn.Module):
Expand Down
Binary file added tracker/util/__pycache__/__init__.cpython-311.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.