From 1b79e8800ff4a8e06635f6d3a4872e38c23845af Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 20:08:09 -0500 Subject: [PATCH] add deadModeThreshold argument to HybridBayesNet::prune --- gtsam/hybrid/HybridBayesNet.cpp | 4 ++-- gtsam/hybrid/HybridBayesNet.h | 6 +++++- gtsam/hybrid/HybridSmoother.cpp | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index d27b1026e0..66661e845d 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, - bool removeDeadModes) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, + double deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 5d3270f4cd..0546a74222 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -219,9 +219,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param maxNrLeaves Continuous values at which to compute the error. * @param removeDeadModes Flag to enable removal of modes which only have a * single possible assignment. + * @param deadModeThreshold The threshold to check the mode marginals against. + * If greater than this threshold, the mode gets assigned that value and is + * considered "dead" for hybrid elimination. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; + HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, + double deadModeThreshold = 0.99) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index a50be28baa..34f28ff803 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -74,6 +74,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, ordering = *given_ordering; } + // graph.print("Original GRAPH"); + // GTSAM_PRINT(updatedGraph); + // GTSAM_PRINT(hybridBayesNet_); + // GTSAM_PRINT(ordering); + // Eliminate. HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); @@ -81,7 +86,8 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_, + deadModeThreshold_); } // Add the partial bayes net to the posterior bayes net.