Skip to content

Commit

Permalink
Merge branch 'working-hybrid' into model-selection-bayestree
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Aug 25, 2024
2 parents 1725def + 88c2ad5 commit f06777f
Show file tree
Hide file tree
Showing 11 changed files with 662 additions and 31 deletions.
39 changes: 35 additions & 4 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {
Expand Down Expand Up @@ -318,19 +319,49 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}

// The discrete assignment is not valid so we throw an error.
if (!valid_assignment) {
throw std::runtime_error(
"Invalid discrete values in values. Not all discrete keys specified.");
}

// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
if (conditional) {
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into
* @brief Convert a GaussianMixture of conditionals into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
Expand Down
5 changes: 3 additions & 2 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << (s.empty() ? "" : s + "\n");
std::cout << "GaussianMixtureFactor" << std::endl;
HybridFactor::print("", formatter);
std::cout << "{\n";
if (factors_.empty()) {
std::cout << " empty" << std::endl;
Expand Down Expand Up @@ -117,6 +119,5 @@ double GaussianMixtureFactor::error(const HybridValues &values) const {
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/

} // namespace gtsam
9 changes: 4 additions & 5 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
* @param factors The decision tree of Gaussian factors stored
* as the mixture density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
Expand All @@ -107,9 +107,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {

bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

void print(
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
void print(const std::string &s = "", const KeyFormatter &formatter =
DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,16 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn;
DiscreteFactorGraph discrete_fg;

for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete());
discrete_fg.push_back(conditional->asDiscrete());
}
}

// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
DiscreteValues mpe = discrete_fg.optimize();

// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#pragma once
Expand Down
40 changes: 34 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Case where we have a GaussianMixtureFactor with no continuous keys.
// In this case, compute discrete probabilities.
auto probability =
[&](const GaussianFactor::shared_ptr &factor) -> double {
if (!factor) return 0.0;
return exp(-factor->error(VectorValues()));
};
dfg.emplace_shared<DecisionTreeFactor>(
gmf->discreteKeys(),
DecisionTree<Key, double>(gmf->factors(), probability));

} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here.
Expand Down Expand Up @@ -279,21 +291,37 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
using Result = std::pair<std::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>;

// Integrate the probability mass in the last continuous conditional using
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean.
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m))
/**
* Compute the probability q(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
* from the residual error at the mean μ.
* The residual error contains no keys, and only
* depends on the discrete separator if present.
*/
static std::shared_ptr<Factor> createDiscreteFactor(
const DecisionTree<Key, Result> &eliminationResults,
const DiscreteKeys &discreteSeparator) {
auto probability = [&](const Result &pair) -> double {
auto logProbability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair;
static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual.
if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();

// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
};

DecisionTree<Key, double> probabilities(eliminationResults, probability);
AlgebraicDecisionTree<Key> logProbabilities(
DecisionTree<Key, double>(eliminationResults, logProbability));

// Perform normalization
double max_log = logProbabilities.max();
AlgebraicDecisionTree probabilities = DecisionTree<Key, double>(
logProbabilities,
[&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());

return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
}
Expand Down
Loading

0 comments on commit f06777f

Please sign in to comment.