Skip to content

Commit 68bf7a9

Browse files
Expose methods to sum shapes into map through Python API
1 parent 759669b commit 68bf7a9

11 files changed

Lines changed: 136 additions & 34 deletions

File tree

examples/python/edit/sum_map.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121
# Merge them together
2222
wave.edit.sum(your_map, your_map_translated)
2323

24+
# Set a box in the map to free
25+
box = wave.AABB(min=np.array([6.0, 6.0, -2.0]),
26+
max=np.array([10.0, 10.0, 2.0]))
27+
wave.edit.sum(your_map, box, -1.0)
28+
29+
# Set a sphere in the map to occupied
30+
sphere = wave.Sphere(center=np.array([8.0, 8.0, 0.0]), radius=1.5)
31+
wave.edit.sum(your_map, sphere, 2.0)
32+
2433
# Save the map
2534
output_map_path = os.path.join(user_home, "your_map_merged.wvmp")
2635
your_map.store(output_map_path)

library/cpp/include/wavemap/core/utils/geometry/aabb.h

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,56 +16,56 @@ struct AABB {
1616
static constexpr int kDim = dim_v<PointT>;
1717
static constexpr int kNumCorners = int_math::exp2(kDim);
1818
using PointType = PointT;
19-
using ScalarType = typename PointType::Scalar;
19+
using ScalarType = typename PointT::Scalar;
2020
using Corners = Eigen::Matrix<ScalarType, kDim, kNumCorners>;
2121

2222
static constexpr auto kInitialMin = std::numeric_limits<ScalarType>::max();
2323
static constexpr auto kInitialMax = std::numeric_limits<ScalarType>::lowest();
2424

25-
PointType min = PointType::Constant(kInitialMin);
26-
PointType max = PointType::Constant(kInitialMax);
25+
PointT min = PointT::Constant(kInitialMin);
26+
PointT max = PointT::Constant(kInitialMax);
2727

2828
AABB() = default;
2929
AABB(const PointT& min, const PointT& max) : min(min), max(max) {}
3030
AABB(PointT&& min, PointT&& max) : min(std::move(min)), max(std::move(max)) {}
3131

32-
void insert(const PointType& point) {
32+
void insert(const PointT& point) {
3333
min = min.cwiseMin(point);
3434
max = max.cwiseMax(point);
3535
}
36-
bool contains(const PointType& point) const {
36+
bool contains(const PointT& point) const {
3737
return (min.array() <= point.array() && point.array() <= max.array()).all();
3838
}
3939

40-
PointType closestPointTo(const PointType& point) const {
41-
PointType closest_point = point.cwiseMax(min).cwiseMin(max);
40+
PointT closestPointTo(const PointT& point) const {
41+
PointT closest_point = point.cwiseMax(min).cwiseMin(max);
4242
return closest_point;
4343
}
44-
PointType furthestPointFrom(const PointType& point) const {
45-
const PointType aabb_center = (min + max) / static_cast<ScalarType>(2);
46-
PointType furthest_point =
44+
PointT furthestPointFrom(const PointT& point) const {
45+
const PointT aabb_center = (min + max) / static_cast<ScalarType>(2);
46+
PointT furthest_point =
4747
(aabb_center.array() < point.array()).select(min, max);
4848
return furthest_point;
4949
}
5050

51-
PointType minOffsetTo(const PointType& point) const {
51+
PointT minOffsetTo(const PointT& point) const {
5252
return point - closestPointTo(point);
5353
}
54-
PointType maxOffsetTo(const PointType& point) const {
54+
PointT maxOffsetTo(const PointT& point) const {
5555
return point - furthestPointFrom(point);
5656
}
5757
// TODO(victorr): Check correctness with unit tests
58-
PointType minOffsetTo(const AABB& other) const {
59-
const PointType greatest_min = min.cwiseMax(other.min);
60-
const PointType smallest_max = max.cwiseMin(other.max);
58+
PointT minOffsetTo(const AABB& other) const {
59+
const PointT greatest_min = min.cwiseMax(other.min);
60+
const PointT smallest_max = max.cwiseMin(other.max);
6161
return (greatest_min - smallest_max).cwiseMax(0);
6262
}
6363
// TODO(victorr): Check correctness with unit tests. Pay particular
6464
// attention to whether the offset signs are correct.
65-
PointType maxOffsetTo(const AABB& other) const {
66-
const PointType diff_1 = min - other.max;
67-
const PointType diff_2 = max - other.min;
68-
PointType offset =
65+
PointT maxOffsetTo(const AABB& other) const {
66+
const PointT diff_1 = min - other.max;
67+
const PointT diff_2 = max - other.min;
68+
PointT offset =
6969
(diff_2.array().abs() < diff_1.array().abs()).select(diff_1, diff_2);
7070
return offset;
7171
}
@@ -92,7 +92,7 @@ struct AABB {
9292
ScalarType width() const {
9393
return max[dim] - min[dim];
9494
}
95-
PointType widths() const { return max - min; }
95+
PointT widths() const { return max - min; }
9696

9797
Corners corner_matrix() const {
9898
Eigen::Matrix<ScalarType, kDim, kNumCorners> corners;
@@ -104,8 +104,8 @@ struct AABB {
104104
return corners;
105105
}
106106

107-
PointType corner_point(int corner_idx) const {
108-
PointType corner;
107+
PointT corner_point(int corner_idx) const {
108+
PointT corner;
109109
for (int dim_idx = 0; dim_idx < kDim; ++dim_idx) {
110110
corner[dim_idx] = corner_coordinate(dim_idx, corner_idx);
111111
}

library/cpp/include/wavemap/core/utils/geometry/sphere.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ template <typename PointT>
1313
struct Sphere {
1414
static constexpr int kDim = dim_v<PointT>;
1515
using PointType = PointT;
16-
using ScalarType = typename PointType::Scalar;
16+
using ScalarType = typename PointT::Scalar;
1717

18-
PointType center;
19-
ScalarType radius;
18+
PointT center = PointT::Constant(kNaN);
19+
ScalarType radius = static_cast<ScalarType>(0);
2020

2121
Sphere() = default;
2222
Sphere(const PointT& center, ScalarType radius)
@@ -25,10 +25,14 @@ struct Sphere {
2525
: center(std::move(center)), radius(radius) {}
2626

2727
operator AABB<PointT>() const {
28-
return AABB<PointT>(center.array() - radius, center.array() + radius);
28+
if (std::isnan(center[0])) {
29+
return {};
30+
}
31+
return {center.array() - radius, center.array() + radius};
2932
}
3033

31-
bool contains(const PointType& point) const {
34+
// TODO(victorr): Add tests, incl. behavior after default construction
35+
bool contains(const PointT& point) const {
3236
return (point - center).squaredNorm() <= radius * radius;
3337
}
3438

library/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ nanobind_add_module(_pywavemap_bindings STABLE_ABI
5959
src/pywavemap.cc
6060
src/convert.cc
6161
src/edit.cc
62+
src/geometry.cc
6263
src/indices.cc
6364
src/logging.cc
6465
src/maps.cc
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef PYWAVEMAP_GEOMETRY_H_
2+
#define PYWAVEMAP_GEOMETRY_H_
3+
4+
#include <nanobind/nanobind.h>
5+
6+
namespace nb = nanobind;
7+
8+
namespace wavemap {
9+
void add_geometry_bindings(nb::module_& m);
10+
}
11+
12+
#endif // PYWAVEMAP_GEOMETRY_H_

library/python/src/edit.cc

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
#include <wavemap/core/utils/edit/crop.h>
1010
#include <wavemap/core/utils/edit/multiply.h>
1111
#include <wavemap/core/utils/edit/transform.h>
12+
#include <wavemap/core/utils/geometry/aabb.h>
13+
#include <wavemap/core/utils/geometry/sphere.h>
1214

1315
using namespace nb::literals; // NOLINT
1416

1517
namespace wavemap {
1618
void add_edit_module(nb::module_& m_edit) {
17-
// Map multiply methods
19+
// Multiply a map with a scalar
1820
// NOTE: Among others, this can be used to implement exponential forgetting,
1921
// by multiplying the map with a scalar between 0 and 1.
2022
m_edit.def(
@@ -30,7 +32,7 @@ void add_edit_module(nb::module_& m_edit) {
3032
},
3133
"map"_a, "multiplier"_a);
3234

33-
// Map sum methods
35+
// Sum two maps together
3436
m_edit.def(
3537
"sum",
3638
[](HashedWaveletOctree& map_A, const HashedWaveletOctree& map_B) {
@@ -45,7 +47,39 @@ void add_edit_module(nb::module_& m_edit) {
4547
},
4648
"map_A"_a, "map_B"_a);
4749

48-
// Map transformation methods
50+
// Add a scalar value to all cells within an axis aligned bounding box
51+
m_edit.def(
52+
"sum",
53+
[](HashedWaveletOctree& map, const AABB<Point3D>& aabb,
54+
FloatingPoint update) {
55+
edit::sum(map, aabb, update, std::make_shared<ThreadPool>());
56+
},
57+
"map"_a, "aabb"_a, "update"_a);
58+
m_edit.def(
59+
"sum",
60+
[](HashedChunkedWaveletOctree& map, const AABB<Point3D>& aabb,
61+
FloatingPoint update) {
62+
edit::sum(map, aabb, update, std::make_shared<ThreadPool>());
63+
},
64+
"map"_a, "aabb"_a, "update"_a);
65+
66+
// Add a scalar value to all cells within a sphere
67+
m_edit.def(
68+
"sum",
69+
[](HashedWaveletOctree& map, const Sphere<Point3D>& sphere,
70+
FloatingPoint update) {
71+
edit::sum(map, sphere, update, std::make_shared<ThreadPool>());
72+
},
73+
"map"_a, "sphere"_a, "update"_a);
74+
m_edit.def(
75+
"sum",
76+
[](HashedChunkedWaveletOctree& map, const Sphere<Point3D>& sphere,
77+
FloatingPoint update) {
78+
edit::sum(map, sphere, update, std::make_shared<ThreadPool>());
79+
},
80+
"map"_a, "sphere"_a, "update"_a);
81+
82+
// Transform a map into a different coordinate frame
4983
m_edit.def(
5084
"transform",
5185
[](HashedWaveletOctree& B_map, const Transformation3D& T_AB) {
@@ -59,7 +93,7 @@ void add_edit_module(nb::module_& m_edit) {
5993
},
6094
"map"_a, "transformation"_a);
6195

62-
// Map cropping methods
96+
// Crop a map
6397
m_edit.def(
6498
"crop_to_sphere",
6599
[](HashedWaveletOctree& map, const Point3D& t_W_center,

library/python/src/geometry.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "pywavemap/geometry.h"
2+
3+
#include <nanobind/eigen/dense.h>
4+
#include <wavemap/core/common.h>
5+
#include <wavemap/core/utils/geometry/aabb.h>
6+
#include <wavemap/core/utils/geometry/sphere.h>
7+
8+
using namespace nb::literals; // NOLINT
9+
10+
namespace wavemap {
11+
void add_geometry_bindings(nb::module_& m) {
12+
// Axis-Aligned Bounding Box
13+
nb::class_<AABB<Point3D>>(
14+
m, "AABB", "A class representing an Axis-Aligned Bounding Box.")
15+
.def(nb::init())
16+
.def(nb::init<Point3D, Point3D>(), "min"_a, "max"_a)
17+
.def_rw("min", &AABB<Point3D>::min)
18+
.def_rw("max", &AABB<Point3D>::max)
19+
.def("insert", &AABB<Point3D>::insert,
20+
"Expand the AABB to tightly fit the new point "
21+
"and its previous self.")
22+
.def("contains", &AABB<Point3D>::contains,
23+
"Test whether the AABB contains the given point.");
24+
25+
// Axis-Aligned Bounding Box
26+
nb::class_<Sphere<Point3D>>(m, "Sphere", "A class representing a sphere.")
27+
.def(nb::init())
28+
.def(nb::init<Point3D, FloatingPoint>(), "center"_a, "radius"_a)
29+
.def_rw("center", &Sphere<Point3D>::center)
30+
.def_rw("radius", &Sphere<Point3D>::radius)
31+
.def("contains", &Sphere<Point3D>::contains,
32+
"Test whether the sphere contains the given point.");
33+
}
34+
} // namespace wavemap

library/python/src/indices.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace wavemap {
99
void add_index_bindings(nb::module_& m) {
1010
nb::class_<OctreeIndex>(m, "OctreeIndex",
1111
"A class representing indices of octree nodes.")
12-
.def(nb::init<>())
12+
.def(nb::init())
1313
.def(nb::init<OctreeIndex::Element, OctreeIndex::Position>(), "height"_a,
1414
"position"_a)
1515
.def_rw("height", &OctreeIndex::height, "height"_a = 0,

library/python/src/measurements.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ void add_measurement_bindings(nb::module_& m) {
2727

2828
// Pointclouds
2929
nb::class_<Pointcloud<>>(m, "Pointcloud", "A class to store pointclouds.")
30-
.def(nb::init<Pointcloud<>::Data>(), "point_matrix"_a);
30+
.def(nb::init<Pointcloud<>::Data>(), "point_matrix"_a)
31+
.def_prop_ro("size", &Pointcloud<>::size);
3132
nb::class_<PosedPointcloud<>>(
3233
m, "PosedPointcloud",
3334
"A class to store pointclouds with an associated pose.")
@@ -36,7 +37,9 @@ void add_measurement_bindings(nb::module_& m) {
3637

3738
// Images
3839
nb::class_<Image<>>(m, "Image", "A class to store depth images.")
39-
.def(nb::init<Image<>::Data>(), "pixel_matrix"_a);
40+
.def(nb::init<Image<>::Data>(), "pixel_matrix"_a)
41+
.def_prop_ro("size", &Image<>::size)
42+
.def_prop_ro("dimensions", &Image<>::getDimensions);
4043
nb::class_<PosedImage<>>(
4144
m, "PosedImage", "A class to store depth images with an associated pose.")
4245
.def(nb::init<Transformation3D, Image<>>(), "pose"_a, "image"_a);

library/python/src/pywavemap.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "pywavemap/convert.h"
44
#include "pywavemap/edit.h"
5+
#include "pywavemap/geometry.h"
56
#include "pywavemap/indices.h"
67
#include "pywavemap/logging.h"
78
#include "pywavemap/maps.h"
@@ -57,6 +58,9 @@ NB_MODULE(_pywavemap_bindings, m) {
5758
// Bindings for measurement types
5859
add_measurement_bindings(m);
5960

61+
// Bindings for geometric types
62+
add_geometry_bindings(m);
63+
6064
// Bindings for map types
6165
add_map_bindings(m);
6266

0 commit comments

Comments
 (0)