Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/refactor cartesian product #1998

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions gtsam/discrete/Assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,31 @@ class Assignment : public std::map<L, size_t> {
* variables with each having cardinalities 4, we get 4096 possible
* configurations!!
*/
template <typename Derived = Assignment<L>>
static std::vector<Derived> CartesianProduct(
template <typename AssignmentType = Assignment<L>>
static std::vector<AssignmentType> CartesianProduct(
const std::vector<std::pair<L, size_t>>& keys) {
std::vector<Derived> allPossValues;
Derived values;
typedef std::pair<L, size_t> DiscreteKey;
for (const DiscreteKey& key : keys)
values[key.first] = 0; // Initialize from 0
while (1) {
allPossValues.push_back(values);
std::vector<AssignmentType> allPossValues;
AssignmentType assignment;
for (const auto [idx, _] : keys) assignment[idx] = 0; // Initialize from 0

const size_t nrKeys = keys.size();
while (true) {
allPossValues.push_back(assignment);

// Increment the assignment. This generalizes incrementing a binary number
size_t j = 0;
for (j = 0; j < keys.size(); j++) {
L idx = keys[j].first;
values[idx]++;
if (values[idx] < keys[j].second) break;
// Wrap condition
values[idx] = 0;
for (j = 0; j < nrKeys; j++) {
auto [idx, cardinality] = keys[j];
// Most of the time, we just increment the value for the first key, j=0:
assignment[idx]++;
// But if this key is done, we increment next key.
const bool carry = (assignment[idx] == cardinality);
if (!carry) break;
assignment[idx] = 0; // wrap on carry, and continue to next variable
}
if (j == keys.size()) break;

// If we propagated carry past the last key, exit:
if (j == nrKeys) break;
}
return allPossValues;
}
Expand Down
3 changes: 2 additions & 1 deletion gtsam/inference/BayesTree-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ namespace gtsam {

/* ************************************************************************* */
// Find the lowest common ancestor of two cliques
// TODO(Varun): consider implementing this as a Range Minimum Query
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::shared_ptr<CLIQUE>& C1, const std::shared_ptr<CLIQUE>& C2) {
Expand All @@ -360,7 +361,7 @@ namespace gtsam {

/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B)
// Return the Bayes tree P(S\B | S \cap B), where \cap is intersection
template <class CLIQUE>
static auto factorInto(
const std::shared_ptr<CLIQUE>& p_F_S, const std::shared_ptr<CLIQUE>& B,
Expand Down
5 changes: 3 additions & 2 deletions gtsam/inference/BayesTreeCliqueBase-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace gtsam {
// The shortcut density is a conditional P(S|B) of the separator of this
// clique on the root or common ancestor B. We can compute it recursively from
// the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the
// frontal nodes in p
// frontal nodes in the parent p, and Sp the separator of the parent.
/* *************************************************************************
*/
template <class DERIVED, class FACTORGRAPH>
Expand Down Expand Up @@ -141,7 +141,8 @@ namespace gtsam {
/* *********************************************************************** */
// Separator marginal, uses separator marginal of parent recursively
// Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// if P(Sp) is not cached, it will call separatorMarginal on the parent
// if P(Sp) is not cached, it will call separatorMarginal on the parent.
// Here again, Fp and Sp are the frontal nodes and separator in the parent p.
/* *********************************************************************** */
template <class DERIVED, class FACTORGRAPH>
typename BayesTreeCliqueBase<DERIVED, FACTORGRAPH>::FactorGraphType
Expand Down
Loading