-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize shonan using minimum spanning tree #777
base: master
Are you sure you want to change the base?
Changes from all commits
f936169
1dea8a6
1a1de1f
d530283
91c4cc4
63eeccd
f5f61d5
9770142
c3a6589
8e0bafe
ebfff90
e6d200a
83c354b
c655512
fbdf99a
3a1ec06
429177a
beef33b
c01f254
4d6c3a5
42b3e70
8848d76
980b94d
f55f2b3
337c6db
dde24e8
725aedf
35b3a7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
name: gtsfm-v2 | ||
channels: | ||
# for priority order, we prefer pytorch as the highest priority as it supplies | ||
# latest stable packages for numerous deep learning based methods. conda-forge | ||
# supplies higher versions of packages like opencv compared to the defaults | ||
# channel. | ||
- pytorch | ||
- conda-forge | ||
dependencies: | ||
# python essentials | ||
- python | ||
- pip | ||
# formatting and dev environment | ||
- black | ||
- coverage | ||
- mypy | ||
- pylint | ||
- pytest | ||
- flake8 | ||
- isort | ||
# dask and related | ||
- dask # same as dask[complete] pip distribution | ||
- asyncssh | ||
- python-graphviz | ||
# core functionality and APIs | ||
- matplotlib | ||
- networkx | ||
- numpy | ||
- nodejs | ||
- pandas | ||
- pillow | ||
- scikit-learn | ||
- seaborn | ||
- scipy | ||
- hydra-core | ||
- gtsam | ||
# 3rd party algorithms for different modules | ||
- cpuonly # replacement of cudatoolkit for cpu only machines | ||
- pytorch | ||
- torchvision | ||
- kornia | ||
- pycolmap | ||
- opencv | ||
# io | ||
- h5py | ||
- plotly | ||
- tabulate | ||
- simplejson | ||
- open3d | ||
- colour | ||
- pydot | ||
- trimesh | ||
# testing | ||
- parameterized | ||
# - pip: | ||
# - pydegensac |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would spanning_tree be a better name for this file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I indented to keep all rotations related util functions here. I feel this is not as generic to be named a spanning tree right now because the args are rotations and not a generic type. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Utility functions for rotations. | ||
|
||
Authors: Ayush Baid | ||
""" | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from gtsam import Rot3 | ||
|
||
|
||
def random_rotation(angle_scale_factor: float = 0.1) -> Rot3: | ||
"""Sample a random rotation by generating a sample from the 4d unit sphere.""" | ||
q = np.random.rand(4) | ||
# make unit-length quaternion | ||
q /= np.linalg.norm(q) | ||
qw, qx, qy, qz = q | ||
R = Rot3(qw, qx, qy, qz) | ||
axis, angle = R.axisAngle() | ||
angle = angle * angle_scale_factor | ||
return Rot3.AxisAngle(axis.point3(), angle) | ||
|
||
|
||
def initialize_global_rotations_using_mst( | ||
num_images: int, i2Ri1_dict: Dict[Tuple[int, int], Rot3], edge_weights: Dict[Tuple[int, int], int] | ||
) -> List[Rot3]: | ||
"""Initializes rotations using minimum spanning tree (weighted by number of correspondences). | ||
|
||
Args: | ||
num_images: Number of images in the scene. | ||
i2Ri1_dict: Dictionary of relative rotations (i1, i2): i2Ri1. | ||
edge_weights: Weight of the edges (i1, i2). All edges in i2Ri1 must have an edge weight. | ||
|
||
Returns: | ||
Global rotations wRi initialized using an MST. Randomly initialized if we have a forest. | ||
""" | ||
# Create a graph from the relative rotations dictionary. | ||
graph = nx.Graph() | ||
for i1, i2 in i2Ri1_dict.keys(): | ||
graph.add_edge(i1, i2, weight=edge_weights[(i1, i2)]) | ||
|
||
if not nx.is_connected(graph): | ||
raise ValueError("Relative rotation graph is not connected") | ||
|
||
# Compute the Minimum Spanning Tree (MST) | ||
mst = nx.minimum_spanning_tree(graph) | ||
|
||
# MST graph. | ||
G = nx.Graph() | ||
G.add_edges_from(mst.edges) | ||
|
||
wRi_list: List[Rot3] = [Rot3()] * num_images | ||
# Choose origin node. | ||
origin_node = list(G.nodes)[0] | ||
wRi_list[origin_node] = Rot3() | ||
|
||
# Ignore 0th node, as we already set its global pose as the origin | ||
for dst_node in list(G.nodes)[1:]: | ||
# Determine the path to this node from the origin. ordered from [origin_node,...,dst_node] | ||
path = nx.shortest_path(G, source=origin_node, target=dst_node) | ||
|
||
# Chain relative rotations w.r.t. origin node. Initialize as identity Rot3 w.r.t origin node `i1`. | ||
wRi1 = Rot3() | ||
for i1, i2 in zip(path[:-1], path[1:]): | ||
# NOTE: i1, i2 may not be in sorted order here. May need to reverse ordering. | ||
if i1 < i2: | ||
i1Ri2 = i2Ri1_dict[(i1, i2)].inverse() | ||
else: | ||
i1Ri2 = i2Ri1_dict[(i2, i1)] | ||
Comment on lines
+67
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not guaranteed, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm why not? |
||
# Path order is (origin -> ... -> i1 -> i2 -> ... -> dst_node). Set `i2` to be new `i1`. | ||
wRi1 = wRi1 * i1Ri2 | ||
|
||
wRi_list[dst_node] = wRi1 | ||
|
||
return wRi_list |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
"""Unit tests for rotation utils. | ||
Authors: Ayush Baid | ||
""" | ||
|
||
import unittest | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
from gtsam import Rot3 | ||
|
||
import gtsfm.utils.geometry_comparisons as geometry_comparisons | ||
import gtsfm.utils.rotation as rotation_util | ||
import tests.data.sample_poses as sample_poses | ||
|
||
ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2 | ||
|
||
|
||
RELATIVE_ROTATION_DICT = Dict[Tuple[int, int], Rot3] | ||
|
||
|
||
def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]: | ||
"""Return data for a scenario with 5 camera poses, with ordering that follows their connectivity. | ||
Accordingly, we specify i1 < i2 for all edges (i1,i2). | ||
Graph topology: | ||
| 2 | 3 | ||
o-- ... o-- | ||
. . | ||
. . | ||
| | | | ||
o-- ... --o --o | ||
0 1 4 | ||
Returns: | ||
Tuple of mapping from image index pair to relative rotations, and expected global rotation angles. | ||
""" | ||
# Expected angles. | ||
wRi_list_euler_deg_expected = np.array([0, 90, 0, 0, 90]) | ||
|
||
# Ground truth 3d rotations for 5 ordered poses (0,1,2,3,4) | ||
wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected] | ||
|
||
edges = [(0, 1), (1, 2), (2, 3), (3, 4)] | ||
i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt, edges=edges) | ||
|
||
return i2Ri1_dict, wRi_list_euler_deg_expected | ||
|
||
|
||
def _get_mixed_order_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]: | ||
"""Return data for a scenario with 5 camera poses, with ordering that does NOT follow their connectivity. | ||
Below, we do NOT specify i1 < i2 for all edges (i1,i2). | ||
Graph topology: | ||
| 3 | 0 | ||
o-- ... o-- | ||
. . | ||
. . | ||
| | | | ||
o-- ... --o --o | ||
4 1 2 | ||
""" | ||
# Expected angles. | ||
wRi_list_euler_deg_expected = np.array([0, 90, 90, 0, 0]) | ||
|
||
# Ground truth 2d rotations for 5 ordered poses (0,1,2,3,4) | ||
wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected] | ||
|
||
edges = [(1, 4), (1, 3), (0, 3), (0, 2)] | ||
i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt=wRi_list_gt, edges=edges) | ||
|
||
return i2Ri1_dict, wRi_list_euler_deg_expected | ||
|
||
|
||
def _create_synthetic_relative_pose_measurements( | ||
wRi_list_gt: List[Rot3], edges: List[Tuple[int, int]] | ||
) -> Dict[Tuple[int, int], Rot3]: | ||
"""Generate synthetic relative rotation measurements, from ground truth global rotations. | ||
Args: | ||
wRi_list_gt: List of (3,3) rotation matrices. | ||
edges: Edges as pairs of image indices. | ||
Returns: | ||
Relative rotation measurements. | ||
""" | ||
i2Ri1_dict = {} | ||
for i1, i2 in edges: | ||
wRi2 = wRi_list_gt[i2] | ||
wRi1 = wRi_list_gt[i1] | ||
i2Ri1_dict[(i1, i2)] = wRi2.inverse() * wRi1 | ||
|
||
return i2Ri1_dict | ||
|
||
|
||
def _wrap_angles(angles: np.ndarray) -> np.ndarray: | ||
r"""Map angle (in degrees) from domain [-∞, ∞] to [0, 360). | ||
Args: | ||
angles: Array of shape (N,) representing angles (in degrees) in any interval. | ||
Returns: | ||
Array of shape (N,) representing the angles (in degrees) mapped to the interval [0, 360]. | ||
""" | ||
# Reduce the angle | ||
angles = angles % 360 | ||
|
||
# Force it to be the positive remainder, so that 0 <= angle < 360 | ||
angles = (angles + 360) % 360 | ||
return angles | ||
|
||
|
||
class TestRotationUtil(unittest.TestCase): | ||
def test_mst_initialization(self): | ||
"""Test for 4 poses in a circle, with a pose connected to all others.""" | ||
i2Ri1_dict, wRi_expected = sample_poses.convert_data_for_rotation_averaging( | ||
sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES | ||
) | ||
|
||
wRi_computed = rotation_util.initialize_global_rotations_using_mst( | ||
len(wRi_expected), | ||
i2Ri1_dict, | ||
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, | ||
) | ||
self.assertTrue( | ||
geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, ROTATION_ANGLE_ERROR_THRESHOLD_DEG) | ||
) | ||
|
||
def test_greedily_construct_st_ordered_chain(self) -> None: | ||
"""Ensures that we can greedily construct a Spanning Tree for an ordered chain.""" | ||
|
||
i2Ri1_dict, wRi_list_euler_deg_expected = _get_ordered_chain_pose_data() | ||
|
||
num_images = 5 | ||
wRi_list_computed = rotation_util.initialize_global_rotations_using_mst( | ||
num_images, | ||
i2Ri1_dict, | ||
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, | ||
) | ||
|
||
wRi_list_euler_deg_est = [np.rad2deg(wRi.roll()) for wRi in wRi_list_computed] | ||
assert np.allclose(wRi_list_euler_deg_est, wRi_list_euler_deg_expected) | ||
|
||
def test_greedily_construct_st_mixed_order_chain(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ayushbaid I added a test that fails with the current implementation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to change the lookup logic? Or do you have a fix in mind? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a fix -- traversing the shortest path in the MST between the origin and destination, and chaining the poses together |
||
"""Ensures that we can greedily construct a Spanning Tree for an unordered chain.""" | ||
i2Ri1_dict, wRi_list_euler_deg_expected = _get_mixed_order_chain_pose_data() | ||
|
||
num_images = 5 | ||
wRi_list_computed = rotation_util.initialize_global_rotations_using_mst( | ||
num_images, | ||
i2Ri1_dict, | ||
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, | ||
) | ||
|
||
wRi_list_euler_deg_est = np.array([np.rad2deg(wRi.roll()) for wRi in wRi_list_computed]) | ||
|
||
# Make sure both lists of angles start at 0 deg. | ||
wRi_list_euler_deg_est -= wRi_list_euler_deg_est[0] | ||
wRi_list_euler_deg_expected -= wRi_list_euler_deg_expected[0] | ||
|
||
assert np.allclose(_wrap_angles(wRi_list_euler_deg_est), _wrap_angles(wRi_list_euler_deg_expected)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be removed