This repository was archived by the owner on May 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVideoPreprocessor.py
More file actions
158 lines (126 loc) · 5.61 KB
/
VideoPreprocessor.py
File metadata and controls
158 lines (126 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gc
import os
import pathlib
from typing import List, Any, Dict
import torch
import ffmpeg
from ..utils.load_video import v2
from ..utils import find_video_stream
__all__ = ["VideoPreprocessor"]
class VideoPreprocessor(object):
# Add more if necessary
__filters: Dict[str, Dict[str, Any]] = {
"fps": {"fps": 15, "round": "up"},
"scale": {"w": 320, "h": 320, "sws_flags": "neighbor"},
"crop": {"out_w": 224, "out_h": 224, "exact": 1, "keep_aspect": 1},
}
def __init__(self,
fpath: str,
save_root: str,
dataset_name: str,
device: str,
num_segments: int = 32,
num_frames: int = 30,
filters: Dict[str, Dict[str, Any]] = None,
) -> None:
super(VideoPreprocessor, self).__init__()
self.__fpath: str = fpath
self.__spath: str = self._make_spath(dataset_name, save_root)
self.__device: str = device
self.__num_segments: int = num_segments
self.__num_frames: int = num_frames
self.__filters = filters or self.__filters
self.__is_labeled: bool = self._is_labeled(dataset_name)
os.makedirs((pathlib.Path(self.__spath)).parent, exist_ok=True)
@property
def fpath(self) -> str:
return self.__fpath
@property
def spath(self) -> str:
return self.__spath
@property
def is_labeled(self) -> bool:
return self.__is_labeled
@is_labeled.setter
def is_labeled(self, value: bool) -> None:
self.__is_labeled = value
def _is_labeled(self, ds_name: str) -> bool:
flag: bool = False
path_components: List[str] = self.__fpath.split(os.sep)
if path_components[path_components.index(ds_name) + 1] == "labeled":
flag = not flag
return flag
def _make_spath(self, dataset_name: str, save_root: str) -> str:
ds_name_idx = self.__fpath.split(os.sep).index(dataset_name)
path_components: List[str] = self.__fpath.split(os.sep)
path_components.insert(ds_name_idx, save_root)
spath: str = f"{os.sep}".join(path_components)
return spath
def stage_one(self, run_async: bool = False) -> None:
"""
Stage one includes:
a/ Resampling video with specified fps
b/ Rescale frame
c/ Central crop frame
d/ Save video stream as output
"""
if not os.path.exists(self.__spath):
if self.__is_labeled:
self.__filters.pop("fps", None)
probe_info: Dict[str, Any] = ffmpeg.probe(self.fpath)
stream = find_video_stream(probe_info["streams"])
try:
stream = ffmpeg.input(self.__fpath, hwaccel="cuda")[stream] if self.__device == "cuda" else \
ffmpeg.input(self.__fpath)[stream]
for filter_name, kwargs in self.__filters.items():
stream = stream.filter(filter_name, **kwargs)
stream = stream.output(self.__spath, pix_fmt="rgb24", loglevel="error")
stream = stream.overwrite_output()
stream.run_async() if run_async else stream.run()
except ffmpeg.Error as e:
print(f"File: {self.__fpath} get {e}\n so ignore it")
if os.path.exists(self.__spath):
os.remove(self.__spath)
def stage_two(self,
del_prev_result: bool = False
) -> None:
"""
Stage two includes:
a/ Split video into segments
b/ Temporal sampling by interpolation (up/ down scaling)
c/ Save video stream as .pt file
and
"""
ext: str = pathlib.Path(self.__spath).suffix
if not os.path.isfile(self.__spath.replace(ext, ".pt")):
video: torch.Tensor = v2(self.__spath, device=self.__device) # [T,H,W,C] in cpu device
total_frames: int = video.shape[0]
# Add one to step due to API make it minus 1
seg_start_idx: torch.Tensor = torch.linspace(
0, total_frames, self.__num_segments+1
).clamp(0, total_frames).int()
save_tensor: None | torch.Tensor = None
if not self.__is_labeled:
for i in range(0, len(seg_start_idx)-1):
start, end = seg_start_idx[i].item(), seg_start_idx[i + 1].item()
indices: torch.Tensor = torch.arange(start, end, device=video.device, dtype=torch.int32)
inter_mode = "nearest-exact" if indices.shape[0] > self.__num_frames else "trilinear"
frames: torch.Tensor = torch.index_select(video, 0, indices)
frames = frames.to("cpu").permute(-1, 0, 1, 2).unsqueeze(0).to(self.__device)
frames = torch.nn.functional.interpolate(
frames.type(torch.float32) if inter_mode == "trilinear" else frames,
(self.__num_frames, *frames.shape[-2:]),
mode=inter_mode
).to("cpu").type(torch.uint8)
save_tensor = frames if save_tensor is None else torch.vstack((save_tensor, frames))
else:
save_tensor = video
torch.save(save_tensor, self.__spath.replace(ext, ".pt"))
torch.serialization.add_safe_globals([save_tensor])
if del_prev_result:
os.remove(self.__spath)
del video
gc.collect()
torch.cuda.empty_cache()
torch.serialization.clear_safe_globals()
return None