diff --git a/environment_v2_linux_cpuonly.yml b/environment_v2_linux_cpuonly.yml new file mode 100644 index 000000000..2e1e18020 --- /dev/null +++ b/environment_v2_linux_cpuonly.yml @@ -0,0 +1,56 @@ +name: gtsfm-v2 +channels: + # for priority order, we prefer pytorch as the highest priority as it supplies + # latest stable packages for numerous deep learning based methods. conda-forge + # supplies higher versions of packages like opencv compared to the defaults + # channel. + - pytorch + - conda-forge +dependencies: + # python essentials + - python + - pip + # formatting and dev environment + - black + - coverage + - mypy + - pylint + - pytest + - flake8 + - isort + # dask and related + - dask # same as dask[complete] pip distribution + - asyncssh + - python-graphviz + # core functionality and APIs + - matplotlib + - networkx + - numpy + - nodejs + - pandas + - pillow + - scikit-learn + - seaborn + - scipy + - hydra-core + - gtsam + # 3rd party algorithms for different modules + - cpuonly # replacement of cudatoolkit for cpu only machines + - pytorch + - torchvision + - kornia + - pycolmap + - opencv + # io + - h5py + - plotly + - tabulate + - simplejson + - open3d + - colour + - pydot + - trimesh + # testing + - parameterized + # - pip: + # - pydegensac diff --git a/gtsfm/averaging/rotation/rotation_averaging_base.py b/gtsfm/averaging/rotation/rotation_averaging_base.py index a98912557..51d982a2b 100644 --- a/gtsfm/averaging/rotation/rotation_averaging_base.py +++ b/gtsfm/averaging/rotation/rotation_averaging_base.py @@ -26,12 +26,13 @@ class RotationAveragingBase(GTSFMProcess): rotations. """ + @staticmethod def get_ui_metadata() -> UiMetadata: """Returns data needed to display node and edge info for this process in the process graph.""" return UiMetadata( display_name="Rotation Averaging", - input_products=("View-Graph Relative Rotations", "Relative Pose Priors"), + input_products=("View-Graph Relative Rotations", "Relative Pose Priors", "Verified Correspondences"), output_products=("Global Rotations",), parent_plate="Sparse Reconstruction", ) diff --git a/gtsfm/averaging/rotation/shonan.py b/gtsfm/averaging/rotation/shonan.py index ad8643a43..b0de424aa 100644 --- a/gtsfm/averaging/rotation/shonan.py +++ b/gtsfm/averaging/rotation/shonan.py @@ -20,9 +20,11 @@ Rot3, ShonanAveraging3, ShonanAveragingParameters3, + Values, ) import gtsfm.utils.logger as logger_utils +import gtsfm.utils.rotation as rotation_util from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase from gtsfm.common.pose_prior import PosePrior @@ -38,7 +40,10 @@ class ShonanRotationAveraging(RotationAveragingBase): """Performs Shonan rotation averaging.""" def __init__( - self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, weight_by_inliers: bool = True + self, + two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, + weight_by_inliers: bool = True, + use_mst_init: bool = False, ) -> None: """Initializes module. @@ -50,10 +55,11 @@ def __init__( of inlier correspondences per edge. """ super().__init__() - self._two_view_rotation_sigma = two_view_rotation_sigma self._p_min = 3 self._p_max = 64 + self._two_view_rotation_sigma = two_view_rotation_sigma self._weight_by_inliers = weight_by_inliers + self._use_mst_init = use_mst_init def __get_shonan_params(self) -> ShonanAveragingParameters3: lm_params = LevenbergMarquardtParams.CeresDefaults() @@ -108,7 +114,7 @@ def get_isotropic_noise_model_sigma(covariance: np.ndarray) -> float: return measurements def _run_with_consecutive_ordering( - self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3 + self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3, initial: Optional[Values] ) -> List[Optional[Rot3]]: """Run the rotation averaging on a connected graph w/ N keys ordered consecutively [0,...,N-1]. @@ -134,7 +140,9 @@ def _run_with_consecutive_ordering( ) shonan = ShonanAveraging3(measurements, self.__get_shonan_params()) - initial = shonan.initializeRandomly() + if initial is None: + logger.info("Using random initialization for Shonan") + initial = shonan.initializeRandomly() logger.info("Initial cost: %.5f", shonan.cost(initial)) result, _ = shonan.run(initial, self._p_min, self._p_max) logger.info("Final cost: %.5f", shonan.cost(result)) @@ -203,13 +211,28 @@ def run_rotation_averaging( if (i1, i2) in i2Ri1_dict } + # Use negative of the number of correspondences as the edge weight. + initial_values: Optional[Values] = None + if self._use_mst_init: + logger.info("Using MST initialization for Shonan") + wRi_initial_ = rotation_util.initialize_global_rotations_using_mst( + len(nodes_with_edges), + i2Ri1_dict_remapped, + edge_weights={ + (i1, i2): -num_correspondences_dict.get((i1, i2), 0) for i1, i2 in i2Ri1_dict_remapped.keys() + }, + ) + initial_values = Values() + for i, wRi in enumerate(wRi_initial_): + initial_values.insert(i, wRi) + def _create_factors_and_run() -> List[Rot3]: measurements: gtsam.BinaryMeasurementsRot3 = self.__measurements_from_2view_relative_rotations( i2Ri1_dict=i2Ri1_dict_remapped, num_correspondences_dict=num_correspondences_dict ) measurements.extend(self._measurements_from_pose_priors(i1Ti2_priors, old_to_new_idxs)) wRi_list_subset = self._run_with_consecutive_ordering( - num_connected_nodes=len(nodes_with_edges), measurements=measurements + num_connected_nodes=len(nodes_with_edges), measurements=measurements, initial=initial_values ) return wRi_list_subset diff --git a/gtsfm/utils/geometry_comparisons.py b/gtsfm/utils/geometry_comparisons.py index bdc72cea2..87289ed0c 100644 --- a/gtsfm/utils/geometry_comparisons.py +++ b/gtsfm/utils/geometry_comparisons.py @@ -30,7 +30,8 @@ def compare_rotations( Args: aTi_list: 1st list of rotations. bTi_list: 2nd list of rotations. - angular_error_threshold_degrees: the threshold for angular error between two rotations. + angular_error_threshold_degrees: Threshold for angular error between two rotations. + Returns: Result of the comparison. """ @@ -55,6 +56,7 @@ def compare_rotations( relative_rotations_angles = np.array( [compute_relative_rotation_angle(aRi, aRi_) for (aRi, aRi_) in zip(aRi_list, aRi_list_)], dtype=np.float32 ) + print(relative_rotations_angles) return np.all(relative_rotations_angles < angular_error_threshold_degrees) diff --git a/gtsfm/utils/rotation.py b/gtsfm/utils/rotation.py new file mode 100644 index 000000000..cde9b51b9 --- /dev/null +++ b/gtsfm/utils/rotation.py @@ -0,0 +1,76 @@ +"""Utility functions for rotations. + +Authors: Ayush Baid +""" + +from typing import Dict, List, Tuple + +import networkx as nx +import numpy as np +from gtsam import Rot3 + + +def random_rotation(angle_scale_factor: float = 0.1) -> Rot3: + """Sample a random rotation by generating a sample from the 4d unit sphere.""" + q = np.random.rand(4) + # make unit-length quaternion + q /= np.linalg.norm(q) + qw, qx, qy, qz = q + R = Rot3(qw, qx, qy, qz) + axis, angle = R.axisAngle() + angle = angle * angle_scale_factor + return Rot3.AxisAngle(axis.point3(), angle) + + +def initialize_global_rotations_using_mst( + num_images: int, i2Ri1_dict: Dict[Tuple[int, int], Rot3], edge_weights: Dict[Tuple[int, int], int] +) -> List[Rot3]: + """Initializes rotations using minimum spanning tree (weighted by number of correspondences). + + Args: + num_images: Number of images in the scene. + i2Ri1_dict: Dictionary of relative rotations (i1, i2): i2Ri1. + edge_weights: Weight of the edges (i1, i2). All edges in i2Ri1 must have an edge weight. + + Returns: + Global rotations wRi initialized using an MST. Randomly initialized if we have a forest. + """ + # Create a graph from the relative rotations dictionary. + graph = nx.Graph() + for i1, i2 in i2Ri1_dict.keys(): + graph.add_edge(i1, i2, weight=edge_weights[(i1, i2)]) + + if not nx.is_connected(graph): + raise ValueError("Relative rotation graph is not connected") + + # Compute the Minimum Spanning Tree (MST) + mst = nx.minimum_spanning_tree(graph) + + # MST graph. + G = nx.Graph() + G.add_edges_from(mst.edges) + + wRi_list: List[Rot3] = [Rot3()] * num_images + # Choose origin node. + origin_node = list(G.nodes)[0] + wRi_list[origin_node] = Rot3() + + # Ignore 0th node, as we already set its global pose as the origin + for dst_node in list(G.nodes)[1:]: + # Determine the path to this node from the origin. ordered from [origin_node,...,dst_node] + path = nx.shortest_path(G, source=origin_node, target=dst_node) + + # Chain relative rotations w.r.t. origin node. Initialize as identity Rot3 w.r.t origin node `i1`. + wRi1 = Rot3() + for i1, i2 in zip(path[:-1], path[1:]): + # NOTE: i1, i2 may not be in sorted order here. May need to reverse ordering. + if i1 < i2: + i1Ri2 = i2Ri1_dict[(i1, i2)].inverse() + else: + i1Ri2 = i2Ri1_dict[(i2, i1)] + # Path order is (origin -> ... -> i1 -> i2 -> ... -> dst_node). Set `i2` to be new `i1`. + wRi1 = wRi1 * i1Ri2 + + wRi_list[dst_node] = wRi1 + + return wRi_list diff --git a/tests/averaging/rotation/test_shonan.py b/tests/averaging/rotation/test_shonan.py index 20f40be64..0556b0ede 100644 --- a/tests/averaging/rotation/test_shonan.py +++ b/tests/averaging/rotation/test_shonan.py @@ -2,21 +2,27 @@ Authors: Ayush Baid, John Lambert """ + import pickle +import random import unittest from typing import Dict, List, Tuple +from pathlib import Path import dask import numpy as np from gtsam import Pose3, Rot3 import gtsfm.utils.geometry_comparisons as geometry_comparisons +import gtsfm.utils.io as io_utils +import gtsfm.utils.rotation as rotation_util import tests.data.sample_poses as sample_poses -from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase from gtsfm.averaging.rotation.shonan import ShonanRotationAveraging from gtsfm.common.pose_prior import PosePrior, PosePriorType ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2 +TEST_DATA_ROOT = Path(__file__).resolve().parent.parent.parent / "data" +LARGE_PROBLEM_BAL_FILE = TEST_DATA_ROOT / "problem-394-100368-pre.txt" class TestShonanRotationAveraging(unittest.TestCase): @@ -25,17 +31,17 @@ class TestShonanRotationAveraging(unittest.TestCase): All unit test functions defined in TestRotationAveragingBase are run automatically. """ - def setUp(self): + def setUp(self) -> None: super().setUp() - self.obj: RotationAveragingBase = ShonanRotationAveraging() + self.obj = ShonanRotationAveraging() def __execute_test(self, i2Ri1_input: Dict[Tuple[int, int], Rot3], wRi_expected: List[Rot3]) -> None: """Helper function to run the averagaing and assert w/ expected. Args: - i2Ri1_input: relative rotations, which are input to the algorithm. - wRi_expected: expected global rotations. + i2Ri1_input: Relative rotations, which are input to the algorithm. + wRi_expected: Expected global rotations. """ i1Ti2_priors: Dict[Tuple[int, int], PosePrior] = {} v_corr_idxs = _create_dummy_correspondences(i2Ri1_input) @@ -72,19 +78,16 @@ def test_panorama(self): ) self.__execute_test(i2Ri1_dict, wRi_expected) - def test_simple(self): + def test_simple_three_nodes_two_measurements(self): """Test a simple case with three relative rotations.""" - i2Ri1_dict = { - (1, 0): Rot3.RzRyRx(0, np.deg2rad(30), 0), - (2, 1): Rot3.RzRyRx(0, 0, np.deg2rad(20)), - } + i0Ri1 = Rot3.RzRyRx(0, np.deg2rad(30), 0) + i1Ri2 = Rot3.RzRyRx(0, 0, np.deg2rad(20)) + i0Ri2 = i0Ri1.compose(i1Ri2) - expected_wRi_list = [ - Rot3.RzRyRx(0, 0, 0), - Rot3.RzRyRx(0, np.deg2rad(30), 0), - i2Ri1_dict[(1, 0)].compose(i2Ri1_dict[(2, 1)]), - ] + i2Ri1_dict = {(0, 1): i0Ri1.inverse(), (1, 2): i1Ri2.inverse()} + + expected_wRi_list = [Rot3(), i0Ri1, i0Ri2] self.__execute_test(i2Ri1_dict, expected_wRi_list) @@ -92,9 +95,7 @@ def test_simple_with_prior(self): """Test a simple case with 1 measurement and a single pose prior.""" expected_wRi_list = [Rot3.RzRyRx(0, 0, 0), Rot3.RzRyRx(0, np.deg2rad(30), 0), Rot3.RzRyRx(np.deg2rad(30), 0, 0)] - i2Ri1_dict = { - (1, 0): Rot3.RzRyRx(0, np.deg2rad(30), 0), - } + i2Ri1_dict = {(0, 1): expected_wRi_list[1].between(expected_wRi_list[0])} expected_0R2 = expected_wRi_list[0].between(expected_wRi_list[2]) i1Ti2_priors = { @@ -157,7 +158,7 @@ def test_nonconsecutive_indices(self): """ num_images = 4 - # assume pose 0 is orphaned in the visibility graph + # Assume pose 0 is orphaned in the visibility graph # Let wTi0's (R,t) be parameterized as identity Rot3(), and t = [1,1,0] wTi1 = Pose3(Rot3(), np.array([3, 1, 0])) wTi2 = Pose3(Rot3(), np.array([3, 3, 0])) @@ -171,6 +172,13 @@ def test_nonconsecutive_indices(self): (1, 3): wTi3.between(wTi1).rotation(), } + # Keys do not overlap with i2Ri1_dict. + v_corr_idxs = { + (1, 2): _generate_corr_idxs(200), + (1, 3): _generate_corr_idxs(500), + (0, 2): _generate_corr_idxs(0), + } + relative_pose_priors: Dict[Tuple[int, int], PosePrior] = {} v_corr_idxs = _create_dummy_correspondences(i2Ri1_input) wRi_computed = self.obj.run_rotation_averaging(num_images, i2Ri1_input, relative_pose_priors, v_corr_idxs) @@ -179,6 +187,84 @@ def test_nonconsecutive_indices(self): geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, angular_error_threshold_degrees=0.1) ) + def test_initialization(self) -> None: + """Test that the result of Shonan is not dependent on the initialization.""" + i2Ri1_dict_noisefree, wRi_expected = sample_poses.convert_data_for_rotation_averaging( + sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES + ) + v_corr_idxs = {pair: _generate_corr_idxs(random.randint(1, 10)) for pair in i2Ri1_dict_noisefree.keys()} + + # Add noise to the relative rotations + i2Ri1_dict_noisy = { + pair: i2Ri1 * rotation_util.random_rotation() for pair, i2Ri1 in i2Ri1_dict_noisefree.items() + } + + wRi_computed_with_random_init = self.obj.run_rotation_averaging( + num_images=len(wRi_expected), + i2Ri1_dict=i2Ri1_dict_noisy, + i1Ti2_priors={}, + v_corr_idxs=v_corr_idxs, + ) + + shonan_mst_init = ShonanRotationAveraging(use_mst_init=True) + wRi_computed_with_mst_init = shonan_mst_init.run_rotation_averaging( + num_images=len(wRi_expected), + i2Ri1_dict=i2Ri1_dict_noisy, + i1Ti2_priors={}, + v_corr_idxs=v_corr_idxs, + ) + + self.assertTrue( + geometry_comparisons.compare_rotations( + wRi_computed_with_random_init, wRi_computed_with_mst_init, angular_error_threshold_degrees=0.1 + ) + ) + + def test_initialization_big(self): + """Test that the result of Shonan is not dependent on the initialization on a bigger dataset.""" + gt_data = io_utils.read_bal(str(LARGE_PROBLEM_BAL_FILE)) + poses = gt_data.get_camera_poses()[:15] + pairs: List[Tuple[int, int]] = [] + for i in range(len(poses)): + for j in range(i + 1, min(i + 5, len(poses))): + pairs.append((i, j)) + + i2Ri1_dict_noisefree, _ = sample_poses.convert_data_for_rotation_averaging( + poses, sample_poses.generate_relative_from_global(poses, pairs) + ) + v_corr_idxs = {pair: _generate_corr_idxs(random.randint(1, 10)) for pair in i2Ri1_dict_noisefree.keys()} + + # Add noise to the relative rotations + i2Ri1_dict_noisy = { + pair: i2Ri1 * rotation_util.random_rotation(angle_scale_factor=0.5) + for pair, i2Ri1 in i2Ri1_dict_noisefree.items() + } + + wRi_computed_with_random_init = self.obj.run_rotation_averaging( + num_images=len(poses), + i2Ri1_dict=i2Ri1_dict_noisy, + i1Ti2_priors={}, + v_corr_idxs=v_corr_idxs, + ) + + shonan_mst_init = ShonanRotationAveraging(use_mst_init=True) + wRi_computed_with_mst_init = shonan_mst_init.run_rotation_averaging( + num_images=len(poses), + i2Ri1_dict=i2Ri1_dict_noisy, + i1Ti2_priors={}, + v_corr_idxs=v_corr_idxs, + ) + + self.assertTrue( + geometry_comparisons.compare_rotations( + wRi_computed_with_random_init, wRi_computed_with_mst_init, angular_error_threshold_degrees=0.1 + ) + ) + + +def _generate_corr_idxs(num_corrs: int) -> np.ndarray: + return np.random.randint(low=0, high=10000, size=(num_corrs, 2)) + def _create_dummy_correspondences(i2Ri1_dict: Dict[Tuple[int, int], Rot3]) -> Dict[Tuple[int, int], np.ndarray]: """Create dummy verified correspondences for each edge in view graph.""" diff --git a/tests/data/sample_poses.py b/tests/data/sample_poses.py index 0f484a0e3..1aacd7e5f 100644 --- a/tests/data/sample_poses.py +++ b/tests/data/sample_poses.py @@ -21,8 +21,8 @@ def generate_relative_from_global( """Generate relative poses from global poses. Args: - wTi_list: global poses. - pair_indices: pairs (i1, i2) to construct relative poses for. + wTi_list: Global poses. + pair_indices: Pairs (i1, i2) to construct relative poses for. Returns: Dictionary (i1, i2) -> i2Ti1 for all requested pairs. @@ -37,7 +37,7 @@ def generate_relative_from_global( CIRCLE_TWO_EDGES_GLOBAL_POSES = SFMdata.createPoses(Cal3_S2(fx=1, fy=1, s=0, u0=0, v0=0))[::2] CIRCLE_TWO_EDGES_RELATIVE_POSES = generate_relative_from_global( - CIRCLE_TWO_EDGES_GLOBAL_POSES, [(1, 0), (2, 1), (3, 2), (0, 3)] + CIRCLE_TWO_EDGES_GLOBAL_POSES, [(0, 1), (1, 2), (2, 3), (0, 3)] ) """4 poses in the circle of radius 5m, all looking at the center of the circle. diff --git a/tests/utils/test_rotation_utils.py b/tests/utils/test_rotation_utils.py new file mode 100644 index 000000000..f7086e6ef --- /dev/null +++ b/tests/utils/test_rotation_utils.py @@ -0,0 +1,170 @@ +"""Unit tests for rotation utils. + +Authors: Ayush Baid +""" + +import unittest +from typing import Dict, List, Tuple + +import numpy as np +from gtsam import Rot3 + +import gtsfm.utils.geometry_comparisons as geometry_comparisons +import gtsfm.utils.rotation as rotation_util +import tests.data.sample_poses as sample_poses + +ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2 + + +RELATIVE_ROTATION_DICT = Dict[Tuple[int, int], Rot3] + + +def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]: + """Return data for a scenario with 5 camera poses, with ordering that follows their connectivity. + + Accordingly, we specify i1 < i2 for all edges (i1,i2). + + Graph topology: + + | 2 | 3 + o-- ... o-- + . . + . . + | | | + o-- ... --o --o + 0 1 4 + + Returns: + Tuple of mapping from image index pair to relative rotations, and expected global rotation angles. + """ + # Expected angles. + wRi_list_euler_deg_expected = np.array([0, 90, 0, 0, 90]) + + # Ground truth 3d rotations for 5 ordered poses (0,1,2,3,4) + wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected] + + edges = [(0, 1), (1, 2), (2, 3), (3, 4)] + i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt, edges=edges) + + return i2Ri1_dict, wRi_list_euler_deg_expected + + +def _get_mixed_order_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]: + """Return data for a scenario with 5 camera poses, with ordering that does NOT follow their connectivity. + + Below, we do NOT specify i1 < i2 for all edges (i1,i2). + + Graph topology: + + | 3 | 0 + o-- ... o-- + . . + . . + | | | + o-- ... --o --o + 4 1 2 + + """ + # Expected angles. + wRi_list_euler_deg_expected = np.array([0, 90, 90, 0, 0]) + + # Ground truth 2d rotations for 5 ordered poses (0,1,2,3,4) + wRi_list_gt = [Rot3.RzRyRx(np.deg2rad(Rz_deg), 0, 0) for Rz_deg in wRi_list_euler_deg_expected] + + edges = [(1, 4), (1, 3), (0, 3), (0, 2)] + i2Ri1_dict = _create_synthetic_relative_pose_measurements(wRi_list_gt=wRi_list_gt, edges=edges) + + return i2Ri1_dict, wRi_list_euler_deg_expected + + +def _create_synthetic_relative_pose_measurements( + wRi_list_gt: List[Rot3], edges: List[Tuple[int, int]] +) -> Dict[Tuple[int, int], Rot3]: + """Generate synthetic relative rotation measurements, from ground truth global rotations. + + Args: + wRi_list_gt: List of (3,3) rotation matrices. + edges: Edges as pairs of image indices. + + Returns: + Relative rotation measurements. + """ + i2Ri1_dict = {} + for i1, i2 in edges: + wRi2 = wRi_list_gt[i2] + wRi1 = wRi_list_gt[i1] + i2Ri1_dict[(i1, i2)] = wRi2.inverse() * wRi1 + + return i2Ri1_dict + + +def _wrap_angles(angles: np.ndarray) -> np.ndarray: + r"""Map angle (in degrees) from domain [-∞, ∞] to [0, 360). + + Args: + angles: Array of shape (N,) representing angles (in degrees) in any interval. + + Returns: + Array of shape (N,) representing the angles (in degrees) mapped to the interval [0, 360]. + """ + # Reduce the angle + angles = angles % 360 + + # Force it to be the positive remainder, so that 0 <= angle < 360 + angles = (angles + 360) % 360 + return angles + + +class TestRotationUtil(unittest.TestCase): + def test_mst_initialization(self): + """Test for 4 poses in a circle, with a pose connected to all others.""" + i2Ri1_dict, wRi_expected = sample_poses.convert_data_for_rotation_averaging( + sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES + ) + + wRi_computed = rotation_util.initialize_global_rotations_using_mst( + len(wRi_expected), + i2Ri1_dict, + edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, + ) + self.assertTrue( + geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, ROTATION_ANGLE_ERROR_THRESHOLD_DEG) + ) + + def test_greedily_construct_st_ordered_chain(self) -> None: + """Ensures that we can greedily construct a Spanning Tree for an ordered chain.""" + + i2Ri1_dict, wRi_list_euler_deg_expected = _get_ordered_chain_pose_data() + + num_images = 5 + wRi_list_computed = rotation_util.initialize_global_rotations_using_mst( + num_images, + i2Ri1_dict, + edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, + ) + + wRi_list_euler_deg_est = [np.rad2deg(wRi.roll()) for wRi in wRi_list_computed] + assert np.allclose(wRi_list_euler_deg_est, wRi_list_euler_deg_expected) + + def test_greedily_construct_st_mixed_order_chain(self) -> None: + """Ensures that we can greedily construct a Spanning Tree for an unordered chain.""" + i2Ri1_dict, wRi_list_euler_deg_expected = _get_mixed_order_chain_pose_data() + + num_images = 5 + wRi_list_computed = rotation_util.initialize_global_rotations_using_mst( + num_images, + i2Ri1_dict, + edge_weights={(i1, i2): (i1 + i2) * 100 for i1, i2 in i2Ri1_dict.keys()}, + ) + + wRi_list_euler_deg_est = np.array([np.rad2deg(wRi.roll()) for wRi in wRi_list_computed]) + + # Make sure both lists of angles start at 0 deg. + wRi_list_euler_deg_est -= wRi_list_euler_deg_est[0] + wRi_list_euler_deg_expected -= wRi_list_euler_deg_expected[0] + + assert np.allclose(_wrap_angles(wRi_list_euler_deg_est), _wrap_angles(wRi_list_euler_deg_expected)) + + +if __name__ == "__main__": + unittest.main()