Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize shonan using minimum spanning tree #777

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f936169
Initialize shonan using minimum spanning tree
ayushbaid Feb 1, 2024
1dea8a6
Move out MST initialization into utility function
ayushbaid Feb 3, 2024
1a1de1f
Add typinh and docstring
ayushbaid Feb 20, 2024
d530283
Use number of inliers as edge weights for MST
ayushbaid Feb 27, 2024
91c4cc4
Remove unused import
ayushbaid Feb 29, 2024
63eeccd
Fix bugs and add docstring
ayushbaid Feb 29, 2024
f5f61d5
cleanup docstrings
Mar 11, 2024
9770142
add more tests
Mar 11, 2024
c3a6589
add different impl
Mar 11, 2024
8e0bafe
clean up notation
Mar 11, 2024
ebfff90
clean up test
Mar 11, 2024
e6d200a
python black fixes
Mar 11, 2024
83c354b
fix flake8
Mar 11, 2024
c655512
use i1 < i2 convention for pair indices
Mar 12, 2024
fbdf99a
use v_corr_idxs instead of two_view_estimation_reports to get inlier …
Mar 13, 2024
3a1ec06
python black reformat
Mar 13, 2024
429177a
fix import error
Mar 13, 2024
beef33b
Control MST init with flag plus fixes
ayushbaid May 14, 2024
c01f254
Log initialization technique
ayushbaid May 14, 2024
4d6c3a5
Add unit test comparing initializations
ayushbaid May 14, 2024
42b3e70
Log shonan optimality
ayushbaid May 14, 2024
8848d76
Remove optimality logging
ayushbaid May 14, 2024
980b94d
Remove unused import
ayushbaid May 14, 2024
f55f2b3
Merge master
ayushbaid May 21, 2024
337c6db
Add unit test for initialization on larger scene
ayushbaid May 26, 2024
dde24e8
env v2
ayushbaid May 13, 2024
725aedf
Merge branch 'master' into feature/shonan_mst_init
ayushbaid May 26, 2024
35b3a7b
Remove duplicate args plus update process meta
ayushbaid May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions environment_v2_linux_cpuonly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: gtsfm-v2
channels:
# for priority order, we prefer pytorch as the highest priority as it supplies
# latest stable packages for numerous deep learning based methods. conda-forge
# supplies higher versions of packages like opencv compared to the defaults
# channel.
- pytorch
- conda-forge
dependencies:
# python essentials
- python
- pip
# formatting and dev environment
- black
- coverage
- mypy
- pylint
- pytest
- flake8
- isort
# dask and related
- dask # same as dask[complete] pip distribution
- asyncssh
- python-graphviz
# core functionality and APIs
- matplotlib
- networkx
- numpy
- nodejs
- pandas
- pillow
- scikit-learn
- seaborn
- scipy
- hydra-core
- gtsam
# 3rd party algorithms for different modules
- cpuonly # replacement of cudatoolkit for cpu only machines
- pytorch
- torchvision
- kornia
- pycolmap
- opencv
# io
- h5py
- plotly
- tabulate
- simplejson
- open3d
- colour
- pydot
- trimesh
# testing
- parameterized
# - pip:
# - pydegensac
3 changes: 2 additions & 1 deletion gtsfm/averaging/rotation/rotation_averaging_base.py
Original file line number Diff line number Diff line change
@@ -26,12 +26,13 @@ class RotationAveragingBase(GTSFMProcess):
rotations.
"""

@staticmethod
def get_ui_metadata() -> UiMetadata:
"""Returns data needed to display node and edge info for this process in the process graph."""

return UiMetadata(
display_name="Rotation Averaging",
input_products=("View-Graph Relative Rotations", "Relative Pose Priors"),
input_products=("View-Graph Relative Rotations", "Relative Pose Priors", "Verified Correspondences"),
output_products=("Global Rotations",),
parent_plate="Sparse Reconstruction",
)
33 changes: 28 additions & 5 deletions gtsfm/averaging/rotation/shonan.py
Original file line number Diff line number Diff line change
@@ -20,9 +20,11 @@
Rot3,
ShonanAveraging3,
ShonanAveragingParameters3,
Values,
)

import gtsfm.utils.logger as logger_utils
import gtsfm.utils.rotation as rotation_util
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.common.pose_prior import PosePrior

@@ -38,7 +40,10 @@ class ShonanRotationAveraging(RotationAveragingBase):
"""Performs Shonan rotation averaging."""

def __init__(
self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, weight_by_inliers: bool = True
self,
two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA,
weight_by_inliers: bool = True,
use_mst_init: bool = False,
) -> None:
"""Initializes module.

@@ -50,10 +55,11 @@ def __init__(
of inlier correspondences per edge.
"""
super().__init__()
self._two_view_rotation_sigma = two_view_rotation_sigma
self._p_min = 3
self._p_max = 64
self._two_view_rotation_sigma = two_view_rotation_sigma
self._weight_by_inliers = weight_by_inliers
self._use_mst_init = use_mst_init

def __get_shonan_params(self) -> ShonanAveragingParameters3:
lm_params = LevenbergMarquardtParams.CeresDefaults()
@@ -108,7 +114,7 @@ def get_isotropic_noise_model_sigma(covariance: np.ndarray) -> float:
return measurements

def _run_with_consecutive_ordering(
self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3
self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3, initial: Optional[Values]
) -> List[Optional[Rot3]]:
"""Run the rotation averaging on a connected graph w/ N keys ordered consecutively [0,...,N-1].

@@ -134,7 +140,9 @@ def _run_with_consecutive_ordering(
)
shonan = ShonanAveraging3(measurements, self.__get_shonan_params())

initial = shonan.initializeRandomly()
if initial is None:
logger.info("Using random initialization for Shonan")
initial = shonan.initializeRandomly()
logger.info("Initial cost: %.5f", shonan.cost(initial))
result, _ = shonan.run(initial, self._p_min, self._p_max)
logger.info("Final cost: %.5f", shonan.cost(result))
@@ -203,13 +211,28 @@ def run_rotation_averaging(
if (i1, i2) in i2Ri1_dict
}

# Use negative of the number of correspondences as the edge weight.
initial_values: Optional[Values] = None
if self._use_mst_init:
logger.info("Using MST initialization for Shonan")
wRi_initial_ = rotation_util.initialize_global_rotations_using_mst(
len(nodes_with_edges),
i2Ri1_dict_remapped,
edge_weights={
(i1, i2): -num_correspondences_dict.get((i1, i2), 0) for i1, i2 in i2Ri1_dict_remapped.keys()
},
)
initial_values = Values()
for i, wRi in enumerate(wRi_initial_):
initial_values.insert(i, wRi)

def _create_factors_and_run() -> List[Rot3]:
measurements: gtsam.BinaryMeasurementsRot3 = self.__measurements_from_2view_relative_rotations(
i2Ri1_dict=i2Ri1_dict_remapped, num_correspondences_dict=num_correspondences_dict
)
measurements.extend(self._measurements_from_pose_priors(i1Ti2_priors, old_to_new_idxs))
wRi_list_subset = self._run_with_consecutive_ordering(
num_connected_nodes=len(nodes_with_edges), measurements=measurements
num_connected_nodes=len(nodes_with_edges), measurements=measurements, initial=initial_values
)
return wRi_list_subset

4 changes: 3 additions & 1 deletion gtsfm/utils/geometry_comparisons.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,8 @@ def compare_rotations(
Args:
aTi_list: 1st list of rotations.
bTi_list: 2nd list of rotations.
angular_error_threshold_degrees: the threshold for angular error between two rotations.
angular_error_threshold_degrees: Threshold for angular error between two rotations.

Returns:
Result of the comparison.
"""
@@ -55,6 +56,7 @@ def compare_rotations(
relative_rotations_angles = np.array(
[compute_relative_rotation_angle(aRi, aRi_) for (aRi, aRi_) in zip(aRi_list, aRi_list_)], dtype=np.float32
)
print(relative_rotations_angles)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be removed

return np.all(relative_rotations_angles < angular_error_threshold_degrees)


76 changes: 76 additions & 0 deletions gtsfm/utils/rotation.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would spanning_tree be a better name for this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I indented to keep all rotations related util functions here. I feel this is not as generic to be named a spanning tree right now because the args are rotations and not a generic type.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Utility functions for rotations.

Authors: Ayush Baid
"""

from typing import Dict, List, Tuple

import networkx as nx
import numpy as np
from gtsam import Rot3


def random_rotation(angle_scale_factor: float = 0.1) -> Rot3:
"""Sample a random rotation by generating a sample from the 4d unit sphere."""
q = np.random.rand(4)
# make unit-length quaternion
q /= np.linalg.norm(q)
qw, qx, qy, qz = q
R = Rot3(qw, qx, qy, qz)
axis, angle = R.axisAngle()
angle = angle * angle_scale_factor
return Rot3.AxisAngle(axis.point3(), angle)


def initialize_global_rotations_using_mst(
num_images: int, i2Ri1_dict: Dict[Tuple[int, int], Rot3], edge_weights: Dict[Tuple[int, int], int]
) -> List[Rot3]:
"""Initializes rotations using minimum spanning tree (weighted by number of correspondences).

Args:
num_images: Number of images in the scene.
i2Ri1_dict: Dictionary of relative rotations (i1, i2): i2Ri1.
edge_weights: Weight of the edges (i1, i2). All edges in i2Ri1 must have an edge weight.

Returns:
Global rotations wRi initialized using an MST. Randomly initialized if we have a forest.
"""
# Create a graph from the relative rotations dictionary.
graph = nx.Graph()
for i1, i2 in i2Ri1_dict.keys():
graph.add_edge(i1, i2, weight=edge_weights[(i1, i2)])

if not nx.is_connected(graph):
raise ValueError("Relative rotation graph is not connected")

# Compute the Minimum Spanning Tree (MST)
mst = nx.minimum_spanning_tree(graph)

# MST graph.
G = nx.Graph()
G.add_edges_from(mst.edges)

wRi_list: List[Rot3] = [Rot3()] * num_images
# Choose origin node.
origin_node = list(G.nodes)[0]
wRi_list[origin_node] = Rot3()

# Ignore 0th node, as we already set its global pose as the origin
for dst_node in list(G.nodes)[1:]:
# Determine the path to this node from the origin. ordered from [origin_node,...,dst_node]
path = nx.shortest_path(G, source=origin_node, target=dst_node)

# Chain relative rotations w.r.t. origin node. Initialize as identity Rot3 w.r.t origin node `i1`.
wRi1 = Rot3()
for i1, i2 in zip(path[:-1], path[1:]):
# NOTE: i1, i2 may not be in sorted order here. May need to reverse ordering.
if i1 < i2:
i1Ri2 = i2Ri1_dict[(i1, i2)].inverse()
else:
i1Ri2 = i2Ri1_dict[(i2, i1)]
Comment on lines +67 to +70
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not guaranteed, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why not?

# Path order is (origin -> ... -> i1 -> i2 -> ... -> dst_node). Set `i2` to be new `i1`.
wRi1 = wRi1 * i1Ri2

wRi_list[dst_node] = wRi1

return wRi_list
124 changes: 105 additions & 19 deletions tests/averaging/rotation/test_shonan.py
Original file line number Diff line number Diff line change
@@ -2,21 +2,27 @@
Authors: Ayush Baid, John Lambert
"""

import pickle
import random
import unittest
from typing import Dict, List, Tuple
from pathlib import Path

import dask
import numpy as np
from gtsam import Pose3, Rot3

import gtsfm.utils.geometry_comparisons as geometry_comparisons
import gtsfm.utils.io as io_utils
import gtsfm.utils.rotation as rotation_util
import tests.data.sample_poses as sample_poses
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.averaging.rotation.shonan import ShonanRotationAveraging
from gtsfm.common.pose_prior import PosePrior, PosePriorType

ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2
TEST_DATA_ROOT = Path(__file__).resolve().parent.parent.parent / "data"
LARGE_PROBLEM_BAL_FILE = TEST_DATA_ROOT / "problem-394-100368-pre.txt"


class TestShonanRotationAveraging(unittest.TestCase):
@@ -25,17 +31,17 @@ class TestShonanRotationAveraging(unittest.TestCase):
All unit test functions defined in TestRotationAveragingBase are run automatically.
"""

def setUp(self):
def setUp(self) -> None:
super().setUp()

self.obj: RotationAveragingBase = ShonanRotationAveraging()
self.obj = ShonanRotationAveraging()

def __execute_test(self, i2Ri1_input: Dict[Tuple[int, int], Rot3], wRi_expected: List[Rot3]) -> None:
"""Helper function to run the averagaing and assert w/ expected.
Args:
i2Ri1_input: relative rotations, which are input to the algorithm.
wRi_expected: expected global rotations.
i2Ri1_input: Relative rotations, which are input to the algorithm.
wRi_expected: Expected global rotations.
"""
i1Ti2_priors: Dict[Tuple[int, int], PosePrior] = {}
v_corr_idxs = _create_dummy_correspondences(i2Ri1_input)
@@ -72,29 +78,24 @@ def test_panorama(self):
)
self.__execute_test(i2Ri1_dict, wRi_expected)

def test_simple(self):
def test_simple_three_nodes_two_measurements(self):
"""Test a simple case with three relative rotations."""

i2Ri1_dict = {
(1, 0): Rot3.RzRyRx(0, np.deg2rad(30), 0),
(2, 1): Rot3.RzRyRx(0, 0, np.deg2rad(20)),
}
i0Ri1 = Rot3.RzRyRx(0, np.deg2rad(30), 0)
i1Ri2 = Rot3.RzRyRx(0, 0, np.deg2rad(20))
i0Ri2 = i0Ri1.compose(i1Ri2)

expected_wRi_list = [
Rot3.RzRyRx(0, 0, 0),
Rot3.RzRyRx(0, np.deg2rad(30), 0),
i2Ri1_dict[(1, 0)].compose(i2Ri1_dict[(2, 1)]),
]
i2Ri1_dict = {(0, 1): i0Ri1.inverse(), (1, 2): i1Ri2.inverse()}

expected_wRi_list = [Rot3(), i0Ri1, i0Ri2]

self.__execute_test(i2Ri1_dict, expected_wRi_list)

def test_simple_with_prior(self):
"""Test a simple case with 1 measurement and a single pose prior."""
expected_wRi_list = [Rot3.RzRyRx(0, 0, 0), Rot3.RzRyRx(0, np.deg2rad(30), 0), Rot3.RzRyRx(np.deg2rad(30), 0, 0)]

i2Ri1_dict = {
(1, 0): Rot3.RzRyRx(0, np.deg2rad(30), 0),
}
i2Ri1_dict = {(0, 1): expected_wRi_list[1].between(expected_wRi_list[0])}

expected_0R2 = expected_wRi_list[0].between(expected_wRi_list[2])
i1Ti2_priors = {
@@ -157,7 +158,7 @@ def test_nonconsecutive_indices(self):
"""
num_images = 4

# assume pose 0 is orphaned in the visibility graph
# Assume pose 0 is orphaned in the visibility graph
# Let wTi0's (R,t) be parameterized as identity Rot3(), and t = [1,1,0]
wTi1 = Pose3(Rot3(), np.array([3, 1, 0]))
wTi2 = Pose3(Rot3(), np.array([3, 3, 0]))
@@ -171,6 +172,13 @@ def test_nonconsecutive_indices(self):
(1, 3): wTi3.between(wTi1).rotation(),
}

# Keys do not overlap with i2Ri1_dict.
v_corr_idxs = {
(1, 2): _generate_corr_idxs(200),
(1, 3): _generate_corr_idxs(500),
(0, 2): _generate_corr_idxs(0),
}

relative_pose_priors: Dict[Tuple[int, int], PosePrior] = {}
v_corr_idxs = _create_dummy_correspondences(i2Ri1_input)
wRi_computed = self.obj.run_rotation_averaging(num_images, i2Ri1_input, relative_pose_priors, v_corr_idxs)
@@ -179,6 +187,84 @@ def test_nonconsecutive_indices(self):
geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, angular_error_threshold_degrees=0.1)
)

def test_initialization(self) -> None:
"""Test that the result of Shonan is not dependent on the initialization."""
i2Ri1_dict_noisefree, wRi_expected = sample_poses.convert_data_for_rotation_averaging(
sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES
)
v_corr_idxs = {pair: _generate_corr_idxs(random.randint(1, 10)) for pair in i2Ri1_dict_noisefree.keys()}

# Add noise to the relative rotations
i2Ri1_dict_noisy = {
pair: i2Ri1 * rotation_util.random_rotation() for pair, i2Ri1 in i2Ri1_dict_noisefree.items()
}

wRi_computed_with_random_init = self.obj.run_rotation_averaging(
num_images=len(wRi_expected),
i2Ri1_dict=i2Ri1_dict_noisy,
i1Ti2_priors={},
v_corr_idxs=v_corr_idxs,
)

shonan_mst_init = ShonanRotationAveraging(use_mst_init=True)
wRi_computed_with_mst_init = shonan_mst_init.run_rotation_averaging(
num_images=len(wRi_expected),
i2Ri1_dict=i2Ri1_dict_noisy,
i1Ti2_priors={},
v_corr_idxs=v_corr_idxs,
)

self.assertTrue(
geometry_comparisons.compare_rotations(
wRi_computed_with_random_init, wRi_computed_with_mst_init, angular_error_threshold_degrees=0.1
)
)

def test_initialization_big(self):
"""Test that the result of Shonan is not dependent on the initialization on a bigger dataset."""
gt_data = io_utils.read_bal(str(LARGE_PROBLEM_BAL_FILE))
poses = gt_data.get_camera_poses()[:15]
pairs: List[Tuple[int, int]] = []
for i in range(len(poses)):
for j in range(i + 1, min(i + 5, len(poses))):
pairs.append((i, j))

i2Ri1_dict_noisefree, _ = sample_poses.convert_data_for_rotation_averaging(
poses, sample_poses.generate_relative_from_global(poses, pairs)
)
v_corr_idxs = {pair: _generate_corr_idxs(random.randint(1, 10)) for pair in i2Ri1_dict_noisefree.keys()}

# Add noise to the relative rotations
i2Ri1_dict_noisy = {
pair: i2Ri1 * rotation_util.random_rotation(angle_scale_factor=0.5)
for pair, i2Ri1 in i2Ri1_dict_noisefree.items()
}

wRi_computed_with_random_init = self.obj.run_rotation_averaging(
num_images=len(poses),
i2Ri1_dict=i2Ri1_dict_noisy,
i1Ti2_priors={},
v_corr_idxs=v_corr_idxs,
)

shonan_mst_init = ShonanRotationAveraging(use_mst_init=True)
wRi_computed_with_mst_init = shonan_mst_init.run_rotation_averaging(
num_images=len(poses),
i2Ri1_dict=i2Ri1_dict_noisy,
i1Ti2_priors={},
v_corr_idxs=v_corr_idxs,
)

self.assertTrue(
geometry_comparisons.compare_rotations(
wRi_computed_with_random_init, wRi_computed_with_mst_init, angular_error_threshold_degrees=0.1
)
)


def _generate_corr_idxs(num_corrs: int) -> np.ndarray:
return np.random.randint(low=0, high=10000, size=(num_corrs, 2))


def _create_dummy_correspondences(i2Ri1_dict: Dict[Tuple[int, int], Rot3]) -> Dict[Tuple[int, int], np.ndarray]:
"""Create dummy verified correspondences for each edge in view graph."""
6 changes: 3 additions & 3 deletions tests/data/sample_poses.py
Original file line number Diff line number Diff line change
@@ -21,8 +21,8 @@ def generate_relative_from_global(
"""Generate relative poses from global poses.
Args:
wTi_list: global poses.
pair_indices: pairs (i1, i2) to construct relative poses for.
wTi_list: Global poses.
pair_indices: Pairs (i1, i2) to construct relative poses for.
Returns:
Dictionary (i1, i2) -> i2Ti1 for all requested pairs.
@@ -37,7 +37,7 @@ def generate_relative_from_global(
CIRCLE_TWO_EDGES_GLOBAL_POSES = SFMdata.createPoses(Cal3_S2(fx=1, fy=1, s=0, u0=0, v0=0))[::2]

CIRCLE_TWO_EDGES_RELATIVE_POSES = generate_relative_from_global(
CIRCLE_TWO_EDGES_GLOBAL_POSES, [(1, 0), (2, 1), (3, 2), (0, 3)]
CIRCLE_TWO_EDGES_GLOBAL_POSES, [(0, 1), (1, 2), (2, 3), (0, 3)]
)

"""4 poses in the circle of radius 5m, all looking at the center of the circle.
170 changes: 170 additions & 0 deletions tests/utils/test_rotation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Unit tests for rotation utils.
Authors: Ayush Baid
"""

import unittest
from typing import Dict, List, Tuple

import numpy as np
from gtsam import Rot3

import gtsfm.utils.geometry_comparisons as geometry_comparisons
import gtsfm.utils.rotation as rotation_util
import tests.data.sample_poses as sample_poses

ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2


RELATIVE_ROTATION_DICT = Dict[Tuple[int, int], Rot3]


def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]:
"""Return data for a scenario with 5 camera poses, with ordering that follows their connectivity.
Accordingly, we specify i1 < i2 for all edges (i1,i2).
Graph topology:
| 2 | 3
o-- ... o--
. .
. .
| | |
o-- ... --o --o
0 1 4
Returns:
Tuple of mapping from image index pair to relative rotations, and expected global rotation angles.
"""
# Expected angles.
wRi_list_euler_deg_expected = np.array([0, 90, 0, 0, 90])

# Ground truth 3d rotations for 5 ordered poses (0,1,2,3,4)
wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected]

edges = [(0, 1), (1, 2), (2, 3), (3, 4)]
i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt, edges=edges)

return i2Ri1_dict, wRi_list_euler_deg_expected


def _get_mixed_order_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]:
"""Return data for a scenario with 5 camera poses, with ordering that does NOT follow their connectivity.
Below, we do NOT specify i1 < i2 for all edges (i1,i2).
Graph topology:
| 3 | 0
o-- ... o--
. .
. .
| | |
o-- ... --o --o
4 1 2
"""
# Expected angles.
wRi_list_euler_deg_expected = np.array([0, 90, 90, 0, 0])

# Ground truth 2d rotations for 5 ordered poses (0,1,2,3,4)
wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected]

edges = [(1, 4), (1, 3), (0, 3), (0, 2)]
i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt=wRi_list_gt, edges=edges)

return i2Ri1_dict, wRi_list_euler_deg_expected


def _create_synthetic_relative_pose_measurements(
wRi_list_gt: List[Rot3], edges: List[Tuple[int, int]]
) -> Dict[Tuple[int, int], Rot3]:
"""Generate synthetic relative rotation measurements, from ground truth global rotations.
Args:
wRi_list_gt: List of (3,3) rotation matrices.
edges: Edges as pairs of image indices.
Returns:
Relative rotation measurements.
"""
i2Ri1_dict = {}
for i1, i2 in edges:
wRi2 = wRi_list_gt[i2]
wRi1 = wRi_list_gt[i1]
i2Ri1_dict[(i1, i2)] = wRi2.inverse() * wRi1

return i2Ri1_dict


def _wrap_angles(angles: np.ndarray) -> np.ndarray:
r"""Map angle (in degrees) from domain [-∞, ∞] to [0, 360).
Args:
angles: Array of shape (N,) representing angles (in degrees) in any interval.
Returns:
Array of shape (N,) representing the angles (in degrees) mapped to the interval [0, 360].
"""
# Reduce the angle
angles = angles % 360

# Force it to be the positive remainder, so that 0 <= angle < 360
angles = (angles + 360) % 360
return angles


class TestRotationUtil(unittest.TestCase):
def test_mst_initialization(self):
"""Test for 4 poses in a circle, with a pose connected to all others."""
i2Ri1_dict, wRi_expected = sample_poses.convert_data_for_rotation_averaging(
sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES
)

wRi_computed = rotation_util.initialize_global_rotations_using_mst(
len(wRi_expected),
i2Ri1_dict,
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()},
)
self.assertTrue(
geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, ROTATION_ANGLE_ERROR_THRESHOLD_DEG)
)

def test_greedily_construct_st_ordered_chain(self) -> None:
"""Ensures that we can greedily construct a Spanning Tree for an ordered chain."""

i2Ri1_dict, wRi_list_euler_deg_expected = _get_ordered_chain_pose_data()

num_images = 5
wRi_list_computed = rotation_util.initialize_global_rotations_using_mst(
num_images,
i2Ri1_dict,
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()},
)

wRi_list_euler_deg_est = [np.rad2deg(wRi.roll()) for wRi in wRi_list_computed]
assert np.allclose(wRi_list_euler_deg_est, wRi_list_euler_deg_expected)

def test_greedily_construct_st_mixed_order_chain(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ayushbaid I added a test that fails with the current implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to change the lookup logic? Or do you have a fix in mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a fix -- traversing the shortest path in the MST between the origin and destination, and chaining the poses together

"""Ensures that we can greedily construct a Spanning Tree for an unordered chain."""
i2Ri1_dict, wRi_list_euler_deg_expected = _get_mixed_order_chain_pose_data()

num_images = 5
wRi_list_computed = rotation_util.initialize_global_rotations_using_mst(
num_images,
i2Ri1_dict,
edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()},
)

wRi_list_euler_deg_est = np.array([np.rad2deg(wRi.roll()) for wRi in wRi_list_computed])

# Make sure both lists of angles start at 0 deg.
wRi_list_euler_deg_est -= wRi_list_euler_deg_est[0]
wRi_list_euler_deg_expected -= wRi_list_euler_deg_expected[0]

assert np.allclose(_wrap_angles(wRi_list_euler_deg_est), _wrap_angles(wRi_list_euler_deg_expected))


if __name__ == "__main__":
unittest.main()