-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathvisualize_attn_util.py
156 lines (136 loc) · 5.49 KB
/
visualize_attn_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#################################################
### Functions used to compute and visualize the learnt Space-Time attention in TimeSformer ###
#################################################
from pathlib import Path
from timesformer.models.vit import *
from timesformer.datasets import utils as utils
from timesformer.config.defaults import get_cfg
from einops import rearrange, repeat, reduce
import cv2
from google.colab.patches import cv2_imshow
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import matplotlib.pyplot as plt
DEFAULT_MEAN = [0.45, 0.45, 0.45]
DEFAULT_STD = [0.225, 0.225, 0.225]
# convert video path to input tensor for model
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(DEFAULT_MEAN,DEFAULT_STD),
transforms.Resize(224),
transforms.CenterCrop(224),
])
# convert the video path to input for cv2_imshow()
transform_plot = transforms.Compose([
lambda p: cv2.imread(str(p),cv2.IMREAD_COLOR),
transforms.ToTensor(),
transforms.Resize(224),
transforms.CenterCrop(224),
lambda x: rearrange(x*255, 'c h w -> h w c').numpy()
])
def get_frames(path_to_video, num_frames=8):
"return a list of paths to the frames of sampled from the video"
path_to_frames = list(path_to_video.iterdir())
path_to_frames.sort(key=lambda f: int(f.with_suffix('').name[-6:]))
assert num_frames <= len(path_to_frames), "num_frames can't exceed the number of frames extracted from videos"
if len(path_to_frames) == num_frames:
return(path_to_frames)
else:
video_length = len(path_to_frames)
seg_size = float(video_length - 1) / num_frames
seq = []
for i in range(num_frames):
start = int(np.round(seg_size * i))
end = int(np.round(seg_size * (i + 1)))
seq.append((start + end) // 2)
path_to_frames_new = [path_to_frames[p] for p in seq]
return(path_to_frames_new)
def create_video_input(path_to_video):
"create the input tensor for TimeSformer model"
path_to_frames = get_frames(path_to_video)
frames = [transform(cv2.imread(str(p), cv2.IMREAD_COLOR)) for p in path_to_frames]
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, 't c h w -> c t h w')
frames = frames.unsqueeze(dim=0)
return(frames)
def show_mask_on_image(img, mask):
img = np.float32(img) / 255
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return np.uint8(255 * cam)
def create_masks(masks_in, np_imgs):
masks = []
for mask, img in zip(masks_in, np_imgs):
mask= cv2.resize(mask, (img.shape[1], img.shape[0]))
mask = show_mask_on_image(img, mask)
masks.append(mask)
return(masks)
def combine_divided_attention(attn_t, attn_s):
## time attention
# average time attention weights across heads
attn_t = attn_t.mean(dim = 1)
# add cls_token to attn_t as an identity matrix since it only attends to itself
I = torch.eye(attn_t.size(-1)).unsqueeze(0)
attn_t = torch.cat([I,attn_t], 0)
# adding identity matrix to account for skipped connection
attn_t = attn_t + torch.eye(attn_t.size(-1))[None,...]
# renormalize
attn_t = attn_t / attn_t.sum(-1)[...,None]
## space attention
# average across heads
attn_s = attn_s.mean(dim = 1)
# adding residual and renormalize
attn_s = attn_s + torch.eye(attn_s.size(-1))[None,...]
attn_s = attn_s / attn_s.sum(-1)[...,None]
## combine the space and time attention
attn_ts = einsum('tpk, ktq -> ptkq', attn_s, attn_t)
## average the cls_token attention across the frames
# splice out the attention for cls_token
attn_cls = attn_ts[0,:,:,:]
# average the cls_token attention and repeat across the frames
attn_cls_a = attn_cls.mean(dim=0)
attn_cls_a = repeat(attn_cls_a, 'p t -> j p t', j = 8)
# add it back
attn_ts = torch.cat([attn_cls_a.unsqueeze(0),attn_ts[1:,:,:,:]],0)
return(attn_ts)
class DividedAttentionRollout():
def __init__(self, model, **kwargs):
self.model = model
self.hooks = []
def get_attn_t(self, module, input, output):
self.time_attentions.append(output.detach().cpu())
def get_attn_s(self, module, input, output):
self.space_attentions.append(output.detach().cpu())
def remove_hooks(self):
for h in self.hooks: h.remove()
def __call__(self, path_to_video):
input_tensor = create_video_input(path_to_video)
self.model.zero_grad()
self.time_attentions = []
self.space_attentions = []
self.attentions = []
for name, m in model.named_modules():
if 'temporal_attn.attn_drop' in name:
self.hooks.append(m.register_forward_hook(self.get_attn_t))
elif 'attn.attn_drop' in name:
self.hooks.append(m.register_forward_hook(self.get_attn_s))
preds = self.model(input_tensor)
for h in self.hooks: h.remove()
for attn_t,attn_s in zip(self.time_attentions, self.space_attentions):
self.attentions.append(combine_divided_attention(attn_t,attn_s))
p,t = self.attentions[0].shape[0], self.attentions[0].shape[1]
result = torch.eye(p*t)
for attention in self.attentions:
attention = rearrange(attention, 'p1 t1 p2 t2 -> (p1 t1) (p2 t2)')
result = torch.matmul(attention, result)
mask = rearrange(result, '(p1 t1) (p2 t2) -> p1 t1 p2 t2', p1 = p, p2=p)
mask = mask.mean(dim=1)
mask = mask[0,1:,:]
width = int(mask.size(0)**0.5)
mask = rearrange(mask, '(h w) t -> h w t', w = width).numpy()
mask = mask / np.max(mask)
return(mask)