-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathpframe_dataset_torch.py
89 lines (76 loc) · 3.43 KB
/
pframe_dataset_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.nn import functional as F
import pframe_dataset_shared
import numpy as np
class FrameSequenceDataset(Dataset):
"""
Create a Dataset that yields sequences of `num_frames` frames, e.g.:
first element: ( (f11_y, f11_u, f11_v), (f12_y, f12_u, f12_v) ), # tuple for video 1, frame 1, 2
second element: ( (f12_y, f12_u, f12_v), (f13_y, f13_u, f13_v) ), # tuple for video 1, frame 2, 3
If merge_channels=True, the channels are merged into one tensor, yielding
first element: ( f11, f12 ), # for video 1, frame 1, 2
second element: ( f12, f13 ), # for video 1, frame 2, 3
Dataformat is always torch-default CHW, and dtype is float32, output is in [0, 1]
"""
def __init__(self, data_root, merge_channels=False, num_frames_per_sequence=2):
self.tuple_ps = pframe_dataset_shared.get_paths_for_frame_sequences(
data_root, num_frames_per_sequence=num_frames_per_sequence)
self.merge_channels = merge_channels
self.image_to_tensor = lambda pic: image_to_tensor(pic, normalize=True)
def __len__(self):
return len(self.tuple_ps)
def __getitem__(self, idx):
# this is a tuple of tuple, e.g.,
# ( (f21_y, f21_u, f21_v), (f22_y, f22_u, f22_v) )
frame_seq = self.tuple_ps[idx]
return tuple(self.load_frame(y_p, u_p, v_p)
for y_p, u_p, v_p in frame_seq)
def load_frame(self, y_p, u_p, v_p):
y, u, v = (self.image_to_tensor(Image.open(p)) for p in (y_p, u_p, v_p))
if not self.merge_channels:
return y, u, v
yuv = yuv_420_to_444(y, u, v)
return yuv
def yuv_420_to_444(y, u, v):
""" Convert Y, U, V, given in 420, to RGB 444. Expects CHW dataformat """
u, v = map(_upsample_nearest_neighbor, (u, v)) # upsample U, V
return torch.cat((y, u, v), dim=0) # merge
def _upsample_nearest_neighbor(t, factor=2):
""" Upsample tensor `t` by `factor`. """
return F.interpolate(t.unsqueeze(0), scale_factor=factor, mode='nearest').squeeze(0)
def image_to_tensor(pic, normalize=True):
"""
Convert a ``PIL Image`` to tensor.
Copied from torchvision.transforms.functional.to_tensor, adapted
to only support PIL inputs and normalize flag
:param pic PIL Image
:param normalize If False, return uint8, otherwise return float32 in range [0,1]
"""
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor) and normalize:
return img.float().div(255)
else:
return img