-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquaternion_averaging.py
60 lines (39 loc) · 2.32 KB
/
quaternion_averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Mert Asim Karaoglu, 2020
# Heavily inspired by Dr. Tolga Birdal's work: https://github.com/tolgabirdal/averaging_quaternions
# Based on
# Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman.
# "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30,
# no. 4 (2007): 1193-1197.
import torch
def quaternion_average(a: torch.Tensor) -> torch.Tensor:
r"""Quaternion average based on Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman. "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30, no. 4 (2007): 1193-1197.
Args:
a: N x 4 tensor each row representing a different data point, assumed to represent a unit-quaternion vector; i.e. [x, y, z, w]
Returns:
torch.Tensor: N x 4 tensor each row representing a different data point, represents a unit-quaternion vector; i.e. [x, y, z, w]
"""
# handle the antipodal configuration
a[a[:, 3] < 0] = -1 * a[a[:, 3] < 0]
a = a.view(-1, 4, 1)
eigen_values, eigen_vectors = torch.matmul(a, a.transpose(1, 2)).mean(dim=0).eig(True)
out = eigen_vectors[:, eigen_values.argmax(0)[0]].view(1, 4)
# handle the antipodal configuration
out[out[:, 3] < 0] = -1 * out[out[:, 3] < 0]
return out
def weighted_quaternion_average(a: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
r"""Weighted quaternion average based on Markley, F. Landis, Yang Cheng, John Lucas Crassidis, and Yaakov Oshman. "Averaging quaternions." Journal of Guidance, Control, and Dynamics 30, no. 4 (2007): 1193-1197.
Args:
a: N x 4 tensor each row representing a different data point, assumed to represent a unit-quaternion vector; i.e. [x, y, z, w]
w: N x 1 tensor each row representing a different float for weight
Returns:
torch.Tensor: N x 4 tensor each row representing a different data point, represents a unit-quaternion vector; i.e. [x, y, z, w]
"""
# handle the antipodal configuration
a[a[:, 3] < 0] = -1 * a[a[:, 3] < 0]
a = a.view(-1, 4, 1)
eigen_values, eigen_vectors = torch.matmul(a.mul(w.view(-1, 1, 1)), a.transpose(1, 2)).sum(dim=0).div(w.sum()).eig(
True)
out = eigen_vectors[:, eigen_values.argmax(0)[0]].view(1, 4)
# handle the antipodal configuration
out[out[:, 3] < 0] = -1 * out[out[:, 3] < 0]
return out