From fcda1536c6c97bac6c514717342ce0c7fbef3b98 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 22 Dec 2023 14:53:45 -0800 Subject: [PATCH] Cleaner version of eliminate --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 102 +++++++++++---------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2029b48e09..7eaefbf85b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian( // TODO(dellaert): it's probably more efficient to first collect the discrete // keys, and then loop over all assignments to populate a vector. GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { - GaussianFactorGraphTree result; for (auto &f : factors_) { @@ -198,6 +197,51 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { } /* ************************************************************************ */ +using Result = std::pair, + GaussianMixtureFactor::sharedFactor>; + +// Integrate the probability mass in the last continuous conditional using +// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. +// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) +static std::shared_ptr createDiscreteFactor( + const DecisionTree &eliminationResults, + const DiscreteKeys &discreteSeparator) { + auto probability = [&](const Result &pair) -> double { + const auto &[conditional, factor] = pair; + static const VectorValues kEmpty; + // If the factor is not null, it has no keys, just contains the residual. + if (!factor) return 1.0; // TODO(dellaert): not loving this. + return exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); + }; + + DecisionTree probabilities(eliminationResults, probability); + + return std::make_shared(discreteSeparator, probabilities); +} + +// Create GaussianMixtureFactor on the separator, taking care to correct +// for conditional constants. +static std::shared_ptr createGaussianMixtureFactor( + const DecisionTree &eliminationResults, + const KeyVector &continuousSeparator, + const DiscreteKeys &discreteSeparator) { + // Correct for the normalization constant used up by the conditional + auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { + const auto &[conditional, factor] = pair; + if (factor) { + auto hf = std::dynamic_pointer_cast(factor); + if (!hf) throw std::runtime_error("Expected HessianFactor!"); + hf->constantTerm() += 2.0 * conditional->logNormalizationConstant(); + } + return factor; + }; + DecisionTree newFactors(eliminationResults, + correct); + + return std::make_shared(continuousSeparator, + discreteSeparator, newFactors); +} + static std::pair> hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, @@ -217,9 +261,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // FG has a nullptr as we're looping over the factors. factorGraphTree = removeEmpty(factorGraphTree); - using Result = std::pair, - GaussianMixtureFactor::sharedFactor>; - // This is the elimination method on the leaf nodes auto eliminate = [&](const GaussianFactorGraph &graph) -> Result { if (graph.empty()) { @@ -234,53 +275,22 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Perform elimination! DecisionTree eliminationResults(factorGraphTree, eliminate); - // Separate out decision tree into conditionals and remaining factors. - const auto [conditionals, newFactors] = unzip(eliminationResults); + // If there are no more continuous parents we create a DiscreteFactor with the + // error for each discrete choice. Otherwise, create a GaussianMixtureFactor + // on the separator, taking care to correct for conditional constants. + auto newFactor = + continuousSeparator.empty() + ? createDiscreteFactor(eliminationResults, discreteSeparator) + : createGaussianMixtureFactor(eliminationResults, continuousSeparator, + discreteSeparator); // Create the GaussianMixture from the conditionals + GaussianMixture::Conditionals conditionals( + eliminationResults, [](const Result &pair) { return pair.first; }); auto gaussianMixture = std::make_shared( frontalKeys, continuousSeparator, discreteSeparator, conditionals); - if (continuousSeparator.empty()) { - // If there are no more continuous parents, then we create a - // DiscreteFactor here, with the error for each discrete choice. - - // Integrate the probability mass in the last continuous conditional using - // the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. - // discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) - auto probability = [&](const Result &pair) -> double { - static const VectorValues kEmpty; - // If the factor is not null, it has no keys, just contains the residual. - const auto &factor = pair.second; - if (!factor) return 1.0; // TODO(dellaert): not loving this. - return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant(); - }; - - DecisionTree probabilities(eliminationResults, probability); - - return { - std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; - } else { - // Otherwise, we create a resulting GaussianMixtureFactor on the separator, - // taking care to correct for conditional constant. - - // Correct for the normalization constant used up by the conditional - auto correct = [&](const Result &pair) { - const auto &factor = pair.second; - if (!factor) return; - auto hf = std::dynamic_pointer_cast(factor); - if (!hf) throw std::runtime_error("Expected HessianFactor!"); - hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); - }; - eliminationResults.visit(correct); - - const auto mixtureFactor = std::make_shared( - continuousSeparator, discreteSeparator, newFactors); - - return {std::make_shared(gaussianMixture), - mixtureFactor}; - } + return {std::make_shared(gaussianMixture), newFactor}; } /* ************************************************************************