Skip to content

Commit

Permalink
Merge pull request #1961 from borglab/serialize-tablefactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 3, 2025
2 parents 05d8030 + 3718cb1 commit e9e52ad
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
40 changes: 40 additions & 0 deletions gtsam/base/MatrixSerialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <gtsam/base/Matrix.h>

#include <Eigen/Sparse>
#include <boost/serialization/array.hpp>
#include <boost/serialization/nvp.hpp>
#include <boost/serialization/split_free.hpp>
Expand Down Expand Up @@ -87,6 +88,45 @@ void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
split_free(ar, m, version);
}

/******************************************************************************/
/// Customized functions for serializing Eigen::SparseVector
template <class Archive, typename _Scalar, int _Options, typename _Index>
void save(Archive& ar, const Eigen::SparseVector<_Scalar, _Options, _Index>& m,
const unsigned int /*version*/) {
_Index size = m.size();

std::vector<std::pair<Eigen::Index, _Scalar>> data;
for (typename Eigen::SparseVector<_Scalar, _Options, _Index>::InnerIterator
it(m);
it; ++it)
data.push_back({it.index(), it.value()});

ar << BOOST_SERIALIZATION_NVP(size);
ar << BOOST_SERIALIZATION_NVP(data);
}

template <class Archive, typename _Scalar, int _Options, typename _Index>
void load(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
const unsigned int /*version*/) {
_Index size;
ar >> BOOST_SERIALIZATION_NVP(size);
m.resize(size);

std::vector<std::pair<Eigen::Index, _Scalar>> data;
ar >> BOOST_SERIALIZATION_NVP(data);

for (auto&& d : data) {
m.coeffRef(d.first) = d.second;
}
}

template <class Archive, typename _Scalar, int _Options, typename _Index>
void serialize(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
const unsigned int version) {
split_free(ar, m, version);
}
/******************************************************************************/

} // namespace serialization
} // namespace boost
#endif
19 changes: 19 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
#include <utility>
#include <vector>

#if GTSAM_ENABLE_BOOST_SERIALIZATION
#include <gtsam/base/MatrixSerialization.h>

#include <boost/serialization/nvp.hpp>
#endif

namespace gtsam {

class DiscreteConditional;
Expand Down Expand Up @@ -342,6 +348,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
double error(const HybridValues& values) const override;

/// @}

private:
#if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(sparse_table_);
ar& BOOST_SERIALIZATION_NVP(denominators_);
ar& BOOST_SERIALIZATION_NVP(sorted_dkeys_);
}
#endif
};

// traits
Expand Down
15 changes: 15 additions & 0 deletions gtsam/discrete/tests/testSerializationDiscrete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Symbol.h>

using namespace std;
Expand All @@ -32,6 +33,7 @@ BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")

BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
BOOST_CLASS_EXPORT_GUID(TableFactor, "gtsam_TableFactor");

using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
Expand Down Expand Up @@ -79,6 +81,19 @@ TEST(DiscreteSerialization, DecisionTreeFactor) {
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}

/* ************************************************************************* */
// Check serialization for TableFactor
TEST(DiscreteSerialization, TableFactor) {
using namespace serializationTestHelpers;

DiscreteKey A(Symbol('x', 1), 3);
TableFactor tf(A, "1 2 2");

EXPECT(equalsObj<TableFactor>(tf));
EXPECT(equalsXML<TableFactor>(tf));
EXPECT(equalsBinary<TableFactor>(tf));
}

/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
Expand Down

0 comments on commit e9e52ad

Please sign in to comment.