-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy path2. preprocess_all_to_imuposer_at_25fps.py
105 lines (83 loc) · 3.13 KB
/
2. preprocess_all_to_imuposer_at_25fps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
r"""
Resample 60fps datasets to 25fps
Smoothen with an average filter
"""
# +
import torch
from imuposer.config import Config
from imuposer import math
# -
config = Config(project_root_dir="../../")
# +
target_fps = 25
# hop = 60 // target_fps
def smooth_avg(acc=None, s=3):
nan_tensor = (torch.zeros((s // 2, acc.shape[1], acc.shape[2])) * torch.nan)
acc = torch.cat((nan_tensor, acc, nan_tensor))
tensors = []
for i in range(s):
L = acc.shape[0]
tensors.append(acc[i:L-(s-i-1)])
smoothed = torch.stack(tensors).nanmean(dim=0)
return smoothed
def _resample(tensor, target_fps):
r"""
Resample to the target fps, assumes 60fps input
"""
indices = torch.arange(0, tensor.shape[0], 60/target_fps)
start_indices = torch.floor(indices).long()
end_indices = torch.ceil(indices).long()
end_indices[end_indices >= tensor.shape[0]] = tensor.shape[0] - 1 # handling edge cases
start = tensor[start_indices]
end = tensor[end_indices]
floats = indices - start_indices
for shape_index in range(len(tensor.shape) - 1):
floats = floats.unsqueeze(1)
weights = torch.ones_like(start) * floats
torch_lerped = torch.lerp(start, end, weights)
return torch_lerped
# -
path_to_save = config.processed_imu_poser_25fps
path_to_save.mkdir(exist_ok=True, parents=True)
# 11 frames at 60 fps = 11*25/60
11*25/60
# process AMASS first
for fpath in (config.processed_imu_poser / "AMASS").iterdir():
# resample to 25 fps
joint = [_resample(x, target_fps) for x in torch.load(fpath / "joint.pt")]
pose = [math.axis_angle_to_rotation_matrix(_resample(x, target_fps).contiguous()).view(-1, 24, 3, 3) for x in torch.load(fpath / "pose.pt")]
shape = torch.load(fpath / "shape.pt")
tran = [_resample(x, target_fps) for x in torch.load(fpath / "tran.pt")]
# average filter
vacc = [smooth_avg(_resample(x, target_fps), s=5) for x in torch.load(fpath / "vacc.pt")]
vrot = [_resample(x, target_fps) for x in torch.load(fpath / "vrot.pt")]
# save the data
fdata = {
"joint": joint,
"pose": pose,
"shape": shape,
"tran": tran,
"acc": vacc,
"ori": vrot
}
torch.save(fdata, path_to_save / f"{fpath.name}.pt")
# process DIP next
for fpath in (config.processed_imu_poser / "DIP_IMU").iterdir():
# resample to 25 fps
joint = [_resample(x, target_fps) for x in torch.load(fpath / "joint.pt")]
pose = [math.axis_angle_to_rotation_matrix(_resample(x, target_fps).contiguous()).view(-1, 24, 3, 3) for x in torch.load(fpath / "pose.pt")]
shape = torch.load(fpath / "shape.pt")
tran = [_resample(x, target_fps) for x in torch.load(fpath / "tran.pt")]
# average filter
acc = [smooth_avg(_resample(x, target_fps), s=5) for x in torch.load(fpath / "accs.pt")]
rot = [_resample(x, target_fps) for x in torch.load(fpath / "oris.pt")]
# save the data
fdata = {
"joint": joint,
"pose": pose,
"shape": shape,
"tran": tran,
"acc": acc,
"ori": rot
}
torch.save(fdata, path_to_save / f"dip_{fpath.name}.pt")