-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathrun_mae_vis.py
287 lines (251 loc) · 10.3 KB
/
run_mae_vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# -*- coding: utf-8 -*-
import argparse
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import os
from pathlib import Path
# import datetime
import numpy as np
# import time
import torch
import torch.backends.cudnn as cudnn
from decord import VideoReader, cpu
from einops import rearrange
from petrel_client.client import Client
from PIL import Image
from timm.data import create_transform
from timm.data.constants import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_INCEPTION_MEAN,
IMAGENET_INCEPTION_STD,
)
from timm.models import create_model
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
import modeling_pretrain
import utils
from datasets import DataAugmentationForMAE
from kinetics import VideoClsDataset
from mae import VideoMAE
from masking_generator import (
RandomMaskingGenerator,
TemporalCenteringProgressiveMaskingGenerator,
TemporalConsistencyMaskingGenerator,
TemporalProgressiveMaskingGenerator,
)
from transforms import *
class DataAugmentationForMAE(object):
def __init__(self, args):
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
self.input_mean = [0.485, 0.456, 0.406]
self.input_std = [0.229, 0.224, 0.225]
div = True
roll = False
normalize = GroupNormalize(self.input_mean, self.input_std)
self.train_augmentation = GroupCenterCrop(args.input_size)
# self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66])
self.transform = transforms.Compose([
# GroupScale((240,320)),
self.train_augmentation,
Stack(roll=roll),
ToTorchFormatTensor(div=div),
normalize,
])
if args.mask_type == 'random':
self.masked_position_generator = RandomMaskingGenerator(
args.window_size, args.mask_ratio)
elif args.mask_type == 't_consist':
self.masked_position_generator = TemporalConsistencyMaskingGenerator(
args.window_size, args.mask_ratio)
def __call__(self, images):
process_data, _ = self.transform(images)
return process_data, self.masked_position_generator()
def __repr__(self):
repr = "(DataAugmentationForBEiT,\n"
repr += " transform = %s,\n" % str(self.transform)
repr += " Masked position generator = %s,\n" % str(
self.masked_position_generator)
repr += ")"
return repr
def get_args():
parser = argparse.ArgumentParser('MAE visualization reconstruction script',
add_help=False)
parser.add_argument('img_path', type=str, help='input image path')
parser.add_argument('save_path', type=str, help='save image path')
parser.add_argument('model_path',
type=str,
help='checkpoint path of model')
parser.add_argument(
'--mask_type',
default='random',
choices=['random', 't_consist', 't_progressive', 't_center_prog'],
type=str,
help='masked strategy of visual tokens/patches')
parser.add_argument('--num_frames', type=int, default=16)
parser.add_argument('--sampling_rate', type=int, default=4)
parser.add_argument('--decoder_depth',
default=4,
type=int,
help='depth of decoder')
parser.add_argument('--input_size',
default=224,
type=int,
help='images input size for backbone')
parser.add_argument('--device',
default='cuda:0',
help='device to use for training / testing')
parser.add_argument('--imagenet_default_mean_and_std',
default=True,
action='store_true')
parser.add_argument(
'--mask_ratio',
default=0.75,
type=float,
help='ratio of the visual tokens/patches need be masked')
# Model parameters
parser.add_argument('--model',
default='pretrain_mae_base_patch16_224',
type=str,
metavar='MODEL',
help='Name of model to vis')
parser.add_argument('--drop_path',
type=float,
default=0.0,
metavar='PCT',
help='Drop path rate (default: 0.1)')
return parser.parse_args()
def get_model(args):
print(f"Creating model: {args.model}")
model = create_model(args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
decoder_depth=args.decoder_depth)
return model
def main(args):
print(args)
device = torch.device(args.device)
cudnn.benchmark = True
model = get_model(args)
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.num_frames // 2, args.input_size // patch_size[0],
args.input_size // patch_size[1])
args.patch_size = patch_size
model.to(device)
checkpoint = torch.load(args.model_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()
if args.save_path:
Path(args.save_path).mkdir(parents=True, exist_ok=True)
tmp = np.arange(0, 32, 2) + 60
frame_id_list = tmp.tolist()
if args.img_path.startswith("s3:"):
client = Client()
video_bytes = client.get(args.img_path)
vr = VideoReader(memoryview(video_bytes), mc=True, ctx=cpu(0))
else:
with open(args.img_path, 'rb') as f:
vr = VideoReader(f, ctx=cpu(0))
duration = len(vr)
new_length = 1
new_step = 1
skip_length = new_length * new_step
video_data = vr.get_batch(frame_id_list).asnumpy()
print(video_data.shape)
img = [
Image.fromarray(video_data[vid, :, :, :]).convert('RGB')
for vid, _ in enumerate(frame_id_list)
]
transforms = DataAugmentationForMAE(args)
img, bool_masked_pos = transforms((img, None)) # T*C,H,W
# print(img.shape)
img = img.view((args.num_frames, 3) + img.size()[-2:]).transpose(
0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W
# img = img.view(( -1 , args.num_frames) + img.size()[-2:])
bool_masked_pos = torch.from_numpy(bool_masked_pos)
with torch.no_grad():
# img = img[None, :]
# bool_masked_pos = bool_masked_pos[None, :]
img = img.unsqueeze(0)
print(img.shape)
bool_masked_pos = bool_masked_pos.unsqueeze(0)
img = img.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(
device, non_blocking=True).flatten(1).to(torch.bool)
outputs = model(img, bool_masked_pos)
#save original img
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None,
None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None,
None, None]
# unnorm_images = images * std + mean # in [0, 1]
print(img.shape)
ori_img = img * std + mean # in [0, 1]
imgs = [
ToPILImage()(ori_img[0, :, vid, :, :].cpu())
for vid, _ in enumerate(frame_id_list)
]
for id, im in enumerate(imgs):
im.save(f"{args.save_path}/ori_img{id}.jpg")
img_squeeze = rearrange(
ori_img,
'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c',
p0=2,
p1=patch_size[0],
p2=patch_size[0])
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
img_patch[bool_masked_pos] = outputs
#make mask
mask = torch.ones_like(img_patch)
mask[bool_masked_pos] = 0
mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
mask = rearrange(mask,
'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ',
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14)
#save reconstruction img
rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
# Notice: To visualize the reconstruction image, we add the predict and the original mean and var of each patch. Issue #40
rec_img = rec_img * (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() +
1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
rec_img = rearrange(
rec_img,
'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)',
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14)
imgs = [
ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996))
for vid, _ in enumerate(frame_id_list)
]
# imgs = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clip(0,0.996)) for vid, _ in enumerate(frame_id_list) ]
for id, im in enumerate(imgs):
im.save(f"{args.save_path}/rec_img{id}.jpg")
#save random mask img
img_mask = rec_img * mask
imgs = [
ToPILImage()(img_mask[0, :, vid, :, :].cpu())
for vid, _ in enumerate(frame_id_list)
]
for id, im in enumerate(imgs):
im.save(f"{args.save_path}/mask_img{id}.jpg")
if __name__ == '__main__':
opts = get_args()
main(opts)