-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathdataset.py
120 lines (100 loc) · 4.03 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# encoding: utf-8
import numpy as np
import glob
import time
import cv2
import os
from torch.utils.data import Dataset
from cvtransforms import *
import torch
import glob
import re
import copy
import json
import random
import editdistance
class MyDataset(Dataset):
letters = [' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
def __init__(self, video_path, anno_path, file_list, vid_pad, txt_pad, phase):
self.anno_path = anno_path
self.vid_pad = vid_pad
self.txt_pad = txt_pad
self.phase = phase
with open(file_list, 'r') as f:
self.videos = [os.path.join(video_path, line.strip()) for line in f.readlines()]
self.data = []
for vid in self.videos:
items = vid.split(os.path.sep)
self.data.append((vid, items[-4], items[-1]))
def __getitem__(self, idx):
(vid, spk, name) = self.data[idx]
vid = self._load_vid(vid)
anno = self._load_anno(os.path.join(self.anno_path, spk, 'align', name + '.align'))
if(self.phase == 'train'):
vid = HorizontalFlip(vid)
vid = ColorNormalize(vid)
vid_len = vid.shape[0]
anno_len = anno.shape[0]
vid = self._padding(vid, self.vid_pad)
anno = self._padding(anno, self.txt_pad)
return {'vid': torch.FloatTensor(vid.transpose(3, 0, 1, 2)),
'txt': torch.LongTensor(anno),
'txt_len': anno_len,
'vid_len': vid_len}
def __len__(self):
return len(self.data)
def _load_vid(self, p):
files = os.listdir(p)
files = list(filter(lambda file: file.find('.jpg') != -1, files))
files = sorted(files, key=lambda file: int(os.path.splitext(file)[0]))
array = [cv2.imread(os.path.join(p, file)) for file in files]
array = list(filter(lambda im: not im is None, array))
array = [cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array]
array = np.stack(array, axis=0).astype(np.float32)
return array
def _load_anno(self, name):
with open(name, 'r') as f:
lines = [line.strip().split(' ') for line in f.readlines()]
txt = [line[2] for line in lines]
txt = list(filter(lambda s: not s.upper() in ['SIL', 'SP'], txt))
return MyDataset.txt2arr(' '.join(txt).upper(), 1)
def _padding(self, array, length):
array = [array[_] for _ in range(array.shape[0])]
size = array[0].shape
for i in range(length - len(array)):
array.append(np.zeros(size))
return np.stack(array, axis=0)
@staticmethod
def txt2arr(txt, start):
arr = []
for c in list(txt):
arr.append(MyDataset.letters.index(c) + start)
return np.array(arr)
@staticmethod
def arr2txt(arr, start):
txt = []
for n in arr:
if(n >= start):
txt.append(MyDataset.letters[n - start])
return ''.join(txt).strip()
@staticmethod
def ctc_arr2txt(arr, start):
pre = -1
txt = []
for n in arr:
if(pre != n and n >= start):
if(len(txt) > 0 and txt[-1] == ' ' and MyDataset.letters[n - start] == ' '):
pass
else:
txt.append(MyDataset.letters[n - start])
pre = n
return ''.join(txt).strip()
@staticmethod
def wer(predict, truth):
word_pairs = [(p[0].split(' '), p[1].split(' ')) for p in zip(predict, truth)]
wer = [1.0*editdistance.eval(p[0], p[1])/len(p[1]) for p in word_pairs]
return wer
@staticmethod
def cer(predict, truth):
cer = [1.0*editdistance.eval(p[0], p[1])/len(p[1]) for p in zip(predict, truth)]
return cer