-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
92 lines (74 loc) · 2.62 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import pickle
import torch
import random
from tqdm import tqdm
class DataSet(torch.utils.data.Dataset):
def __init__(self,
data_path: str,
label_path: str,
joints_path = None):
self.data_path = data_path
self.label_path = label_path
self.joints_path = joints_path
self.load_data()
def load_data(self):
with open(self.label_path, 'rb') as f:
self.sample_name, self.label = pickle.load(f)
# self.label = pickle.load(f)
if self.joints_path:
with open(self.joints_path, 'rb') as f:
self.missing_joints = pickle.load(f)
self.data = np.load(self.data_path)
# N C T V M
N, C, T, V, M = self.data.shape
self.size = N
def __len__(self) -> int:
return self.size
def __getitem__(self, index: int) -> tuple:
data = np.array(self.data[index])
label = self.label[index]
if self.joints_path:
joints = torch.tensor(self.missing_joints[index])
return data, label, joints
return data, label
class Feeder_semi(torch.utils.data.Dataset):
def __init__(self, data_path, label_path, label_percent=0.1):
self.data_path = data_path
self.label_path = label_path
self.label_percent = label_percent
self.load_data()
def load_data(self):
# load label
with open(self.label_path, 'rb') as f:
self.sample_name, self.label = pickle.load(f)
self.data = np.load(self.data_path)
n = len(self.label)
# Record each class sample id
class_blance = {}
for i in range(n):
if self.label[i] not in class_blance:
class_blance[self.label[i]] = [i]
else:
class_blance[self.label[i]] += [i]
final_choise = []
for c in class_blance:
c_num = len(class_blance[c])
choise = random.sample(class_blance[c], round(self.label_percent * c_num))
final_choise += choise
final_choise.sort()
self.data = self.data[final_choise]
new_sample_name = []
new_label = []
for i in final_choise:
new_sample_name.append(self.sample_name[i])
new_label.append(self.label[i])
self.sample_name = new_sample_name
self.label = new_label
def __len__(self):
return len(self.label)
def __getitem__(self, index):
# get data
data = np.array(self.data[index])
label = self.label[index]
return data, label