Skip to content

Commit

Permalink
Merge pull request #1963 from borglab/discrete-multiply
Browse files Browse the repository at this point in the history
DiscreteFactor multiply method
  • Loading branch information
varunagrawal authored Jan 6, 2025
2 parents ffd04fd + f043ac4 commit 47074bd
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 5 deletions.
27 changes: 26 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>

#include <utility>

Expand Down Expand Up @@ -62,6 +63,30 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If f is a TableFactor, we convert `this` to a TableFactor since this
// conversion is cheaper than converting `f` to a DecisionTreeFactor. We
// then return a TableFactor.
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));

} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, simply call operator*.
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));

} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated.
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
}
return result;
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
17 changes: 17 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;

/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not. Hence we specialize the code to return a TableFactor if
* `f` is a TableFactor, and DecisionTreeFactor otherwise.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
Expand Down
15 changes: 11 additions & 4 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,18 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
}
}
return result;
return result->toDecisionTreeFactor();
}

/* ************************************************************************ */
Expand Down
26 changes: 26 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,32 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
// If `f` is a TableFactor, we can simply call `operator*`.
result = std::make_shared<TableFactor>(this->operator*(*tf));

} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If `f` is a DecisionTreeFactor, we convert to a TableFactor which is
// cheaper than converting `this` to a DecisionTreeFactor.
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));

} else {
// Simulate double dispatch in C++
// Useful for other classes which inherit from DiscreteFactor and have
// only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't
// need to be updated to know about TableFactor.
// Those classes can be specialized to use TableFactor
// if efficiency is a problem.
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
}
return result;
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
Expand Down
17 changes: 17 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/**
* @brief Multiply factors, DiscreteFactor::shared_ptr edition.
*
* This method accepts `DiscreteFactor::shared_ptr` and uses dynamic
* dispatch and specializations to perform the most efficient
* multiplication.
*
* While converting a DecisionTreeFactor to a TableFactor is efficient, the
* reverse is not.
* Hence we specialize the code to return a TableFactor always.
*
* @param f The factor to multiply with.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
Expand Down
8 changes: 8 additions & 0 deletions gtsam_unstable/discrete/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {

/// Partially apply known values, domain version
virtual shared_ptr partiallyApply(const Domains&) const = 0;

/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}

/// @}
/// @name Wrapper support
/// @{
Expand Down

0 comments on commit 47074bd

Please sign in to comment.