Skip to content

Commit

Permalink
Cleaner version of eliminate
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Dec 22, 2023
1 parent 9d70605 commit fcda153
Showing 1 changed file with 56 additions and 46 deletions.
102 changes: 56 additions & 46 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {

GaussianFactorGraphTree result;

for (auto &f : factors_) {
Expand Down Expand Up @@ -198,6 +197,51 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
}

/* ************************************************************************ */
using Result = std::pair<std::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;

// Integrate the probability mass in the last continuous conditional using
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
static std::shared_ptr<Factor> createDiscreteFactor(
const DecisionTree<Key, Result> &eliminationResults,
const DiscreteKeys &discreteSeparator) {
auto probability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair;
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
};

DecisionTree<Key, double> probabilities(eliminationResults, probability);

return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
}

// Create GaussianMixtureFactor on the separator, taking care to correct
// for conditional constants.
static std::shared_ptr<Factor> createGaussianMixtureFactor(
const DecisionTree<Key, Result> &eliminationResults,
const KeyVector &continuousSeparator,
const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
const auto &[conditional, factor] = pair;
if (factor) {
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
}
return factor;
};
DecisionTree<Key, GaussianFactor::shared_ptr> newFactors(eliminationResults,
correct);

return std::make_shared<GaussianMixtureFactor>(continuousSeparator,
discreteSeparator, newFactors);
}

static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys,
Expand All @@ -217,9 +261,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// FG has a nullptr as we're looping over the factors.
factorGraphTree = removeEmpty(factorGraphTree);

using Result = std::pair<std::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;

// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
if (graph.empty()) {
Expand All @@ -234,53 +275,22 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);

// Separate out decision tree into conditionals and remaining factors.
const auto [conditionals, newFactors] = unzip(eliminationResults);
// If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a GaussianMixtureFactor
// on the separator, taking care to correct for conditional constants.
auto newFactor =
continuousSeparator.empty()
? createDiscreteFactor(eliminationResults, discreteSeparator)
: createGaussianMixtureFactor(eliminationResults, continuousSeparator,
discreteSeparator);

// Create the GaussianMixture from the conditionals
GaussianMixture::Conditionals conditionals(
eliminationResults, [](const Result &pair) { return pair.first; });
auto gaussianMixture = std::make_shared<GaussianMixture>(
frontalKeys, continuousSeparator, discreteSeparator, conditionals);

if (continuousSeparator.empty()) {
// If there are no more continuous parents, then we create a
// DiscreteFactor here, with the error for each discrete choice.

// Integrate the probability mass in the last continuous conditional using
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
auto probability = [&](const Result &pair) -> double {
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
const auto &factor = pair.second;
if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / pair.first->normalizationConstant();
};

DecisionTree<Key, double> probabilities(eliminationResults, probability);

return {
std::make_shared<HybridConditional>(gaussianMixture),
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
} else {
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
// taking care to correct for conditional constant.

// Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) {
const auto &factor = pair.second;
if (!factor) return;
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
};
eliminationResults.visit(correct);

const auto mixtureFactor = std::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors);

return {std::make_shared<HybridConditional>(gaussianMixture),
mixtureFactor};
}
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
}

/* ************************************************************************
Expand Down

0 comments on commit fcda153

Please sign in to comment.