Skip to content

Commit 5cf319f

Browse files
authored
Null Model (#157)
* Add a NullModel which always predicts zeros * rename en -> n
1 parent 9aa4cc7 commit 5cf319f

File tree

5 files changed

+130
-3
lines changed

5 files changed

+130
-3
lines changed

include/albatross/NullModel

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_NULL_MODEL_H
14+
#define ALBATROSS_NULL_MODEL_H
15+
16+
#include "Core"
17+
18+
#include <albatross/src/models/null_model.hpp>
19+
20+
#endif
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (C) 2019 Swift Navigation Inc.
3+
* Contact: Swift Navigation <[email protected]>
4+
*
5+
* This source is subject to the license found in the file 'LICENSE' which must
6+
* be distributed together with this source. All other rights reserved.
7+
*
8+
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
9+
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
10+
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
11+
*/
12+
13+
#ifndef ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_
14+
#define ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_
15+
16+
namespace albatross {
17+
18+
class NullModel;
19+
20+
template <> struct Fit<NullModel> {
21+
template <typename Archive>
22+
void serialize(Archive &archive, const std::uint32_t){};
23+
24+
bool operator==(const Fit<NullModel> &other) const { return true; }
25+
};
26+
27+
class NullModel : public ModelBase<NullModel> {
28+
29+
public:
30+
NullModel(){};
31+
NullModel(const ParameterStore &param_store) : params_(param_store){};
32+
33+
std::string get_name() const { return "null_model"; };
34+
35+
/*
36+
* The Gaussian Process Regression model derives its parameters from
37+
* the covariance functions.
38+
*/
39+
ParameterStore get_params() const override { return params_; }
40+
41+
void unchecked_set_param(const std::string &name,
42+
const Parameter &param) override {
43+
params_[name] = param;
44+
}
45+
46+
// If the implementing class doesn't have a fit method for this
47+
// FeatureType but the CovarianceFunction does.
48+
template <typename FeatureType>
49+
Fit<NullModel> _fit_impl(const std::vector<FeatureType> &features,
50+
const MarginalDistribution &targets) const {
51+
return {};
52+
}
53+
54+
template <typename FeatureType>
55+
auto fit_from_prediction(const std::vector<FeatureType> &features,
56+
const JointDistribution &prediction) const {
57+
const NullModel m(*this);
58+
FitModel<NullModel, Fit<NullModel>> fit_model(m, Fit<NullModel>());
59+
return fit_model;
60+
}
61+
62+
template <typename FeatureType>
63+
JointDistribution
64+
_predict_impl(const std::vector<FeatureType> &features,
65+
const Fit<NullModel> &fit,
66+
PredictTypeIdentity<JointDistribution> &&) const {
67+
const Eigen::Index n = static_cast<Eigen::Index>(features.size());
68+
const Eigen::VectorXd mean = Eigen::VectorXd::Zero(n);
69+
const Eigen::MatrixXd cov = 1.e4 * Eigen::MatrixXd::Identity(n, n);
70+
return JointDistribution(mean, cov);
71+
}
72+
73+
template <typename FeatureType>
74+
MarginalDistribution
75+
_predict_impl(const std::vector<FeatureType> &features,
76+
const Fit<NullModel> &fit,
77+
PredictTypeIdentity<MarginalDistribution> &&) const {
78+
const Eigen::Index en = static_cast<Eigen::Index>(features.size());
79+
const Eigen::VectorXd mean = Eigen::VectorXd::Zero(en);
80+
const Eigen::VectorXd diag = 1.e4 * Eigen::VectorXd::Ones(en);
81+
return MarginalDistribution(mean, diag.asDiagonal());
82+
}
83+
84+
private:
85+
ParameterStore params_;
86+
};
87+
88+
} // namespace albatross
89+
90+
#endif /* THIRD_PARTY_ALBATROSS_INCLUDE_ALBATROSS_SRC_MODELS_NULL_MODEL_HPP_ \
91+
*/

tests/test_cross_validation.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TEST(test_cross_validation, test_fold_creation) {
4242

4343
bool is_monotonic_increasing(const Eigen::VectorXd &x) {
4444
for (Eigen::Index i = 0; i < x.size() - 1; i++) {
45-
if (x[i + 1] - x[i] <= 0.) {
45+
if (x[i + 1] - x[i] < 0.) {
4646
return false;
4747
}
4848
}
@@ -110,7 +110,9 @@ TYPED_TEST_P(RegressionModelTester, test_score_variants) {
110110
// Here we make sure the cross validated mean absolute error is reasonable.
111111
// Note that because we are running leave one out cross validation, the
112112
// RMSE for each fold is just the absolute value of the error.
113-
EXPECT_LE(cv_scores.mean(), 0.1);
113+
if (!std::is_same<decltype(model), NullModel>::value) {
114+
EXPECT_LE(cv_scores.mean(), 0.1);
115+
}
114116
}
115117

116118
REGISTER_TYPED_TEST_CASE_P(RegressionModelTester, test_loo_predict_variants,

tests/test_models.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ TYPED_TEST_P(RegressionModelTester, test_performs_reasonably_on_linear_data) {
1818
auto dataset = this->test_case.get_dataset();
1919
auto model = this->test_case.get_model();
2020

21+
if (std::is_same<decltype(model), NullModel>::value) {
22+
return;
23+
}
24+
2125
const auto fit_model = model.fit(dataset.features, dataset.targets);
2226
const auto pred = fit_model.predict(dataset.features);
2327
const auto pred_mean = pred.mean();

tests/test_models.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <albatross/GP>
1414
#include <albatross/LeastSquares>
15+
#include <albatross/NullModel>
1516
#include <albatross/Ransac>
1617
#include <gtest/gtest.h>
1718

@@ -171,6 +172,15 @@ class MakeLinearRegression {
171172
}
172173
};
173174

175+
class MakeNullModel {
176+
public:
177+
NullModel get_model() const { return NullModel(); }
178+
179+
RegressionDataset<double> get_dataset() const {
180+
return make_toy_linear_data();
181+
}
182+
};
183+
174184
template <typename ModelTestCase>
175185
class RegressionModelTester : public ::testing::Test {
176186
public:
@@ -179,7 +189,7 @@ class RegressionModelTester : public ::testing::Test {
179189

180190
typedef ::testing::Types<MakeLinearRegression, MakeGaussianProcess,
181191
MakeAdaptedGaussianProcess, MakeRansacGaussianProcess,
182-
MakeRansacAdaptedGaussianProcess>
192+
MakeRansacAdaptedGaussianProcess, MakeNullModel>
183193
ExampleModels;
184194

185195
TYPED_TEST_CASE_P(RegressionModelTester);

0 commit comments

Comments
 (0)