Skip to content

Commit

Permalink
Merge pull request #1955 from borglab/hybrid-custom-discrete
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 3, 2025
2 parents 49b74af + 7440c19 commit 73f98d8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>

using namespace std;
using std::pair;
Expand Down
18 changes: 0 additions & 18 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,7 @@ namespace gtsam {
static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
Expand Down Expand Up @@ -229,13 +223,7 @@ namespace gtsam {
DecisionTreeFactor product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
Expand All @@ -245,14 +233,8 @@ namespace gtsam {
sum->keys().end());

// now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif

return {conditional, sum};
}
Expand Down
86 changes: 79 additions & 7 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/**
* @brief Multiply all the `factors` using the machinery of the TableFactor.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = product * TableFactor(*dtf);
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
static DiscreteFactorGraph CollectDiscreteFactors(
const HybridGaussianFactorGraph &factors) {
DiscreteFactorGraph dfg;

for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df);

} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute a discrete factor from the remaining error.
Expand Down Expand Up @@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
}
}

return dfg;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
// Check if separator is empty.
// This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph.
// If the separator is empty, we have a clique of all the discrete variables
// so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor
TableFactor product = TableProduct(dfg);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto conditional = std::make_shared<DiscreteConditional>(
frontalKeys.size(), product.toDecisionTreeFactor());
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif

return {std::make_shared<HybridConditional>(result.first), result.second};
TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif

return {std::make_shared<HybridConditional>(conditional), sum};

} else {
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
}
}

/* ************************************************************************ */
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down

0 comments on commit 73f98d8

Please sign in to comment.