Skip to content

Commit 0a58178

Browse files
author
Vakho Tsulaia
committed
Patched unit tests and benchmarks to adapt them to the changes in detray core
The changes required for passing geometry context to the transform store
1 parent 416baea commit 0a58178

File tree

9 files changed

+30
-17
lines changed

9 files changed

+30
-17
lines changed

tests/benchmarks/cpu/intersect_all.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void BM_INTERSECT_ALL(benchmark::State &state) {
8080
const auto sf = tracking_surface{d, sf_desc};
8181
sf.template visit_mask<
8282
intersection_initialize<ray_intersector>>(
83-
intersections, detail::ray(track), sf_desc, transforms,
83+
intersections, detail::ray(track), sf_desc, transforms, geo_context,
8484
std::array<scalar_t, 2>{1.f * unit<scalar_t>::um,
8585
1.f * unit<scalar_t>::mm},
8686
scalar_t{0.f});

tests/include/detray/test/validation/detector_scanner.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct brute_force_scan {
5656
using trajectory_type = trajectory_t;
5757

5858
template <typename detector_t>
59-
inline auto operator()(const typename detector_t::geometry_context,
59+
inline auto operator()(const typename detector_t::geometry_context ctx,
6060
const detector_t &detector, const trajectory_t &traj,
6161
const std::array<typename detector_t::scalar_type, 2>
6262
mask_tolerance = {0.f, 0.f},
@@ -89,7 +89,7 @@ struct brute_force_scan {
8989
// Retrieve candidate(s) from the surface
9090
const auto sf = tracking_surface{detector, sf_desc};
9191
sf.template visit_mask<intersection_kernel_t>(
92-
intersections, traj, sf_desc, trf_store,
92+
intersections, traj, sf_desc, trf_store, ctx,
9393
sf.is_portal() ? std::array<scalar_t, 2>{0.f, 0.f}
9494
: mask_tolerance);
9595

tests/unit_tests/cpu/detectors/telescope_detector.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,17 @@ using algebra_t = test::algebra;
4646
// dummy propagator state
4747
template <typename stepping_t, typename navigation_t>
4848
struct prop_state {
49+
using context_t = typename navigation_t::detector_type::geometry_context;
4950

5051
stepping_t _stepping;
5152
navigation_t _navigation;
53+
context_t _context;
5254

5355
template <typename track_t, typename field_type>
5456
prop_state(const track_t &t_in, const field_type &field,
55-
const typename navigation_t::detector_type &det)
56-
: _stepping(t_in, field), _navigation(det) {}
57+
const typename navigation_t::detector_type &det,
58+
const context_t &ctx = {})
59+
: _stepping(t_in, field), _navigation(det), _context(ctx) {}
5760
};
5861

5962
inline constexpr bool verbose_check = true;
@@ -193,9 +196,9 @@ GTEST_TEST(detray_detectors, telescope_detector) {
193196
navigation_state_t &navigation_x = propgation_x._navigation;
194197

195198
// propagate all telescopes
196-
navigator_z1.init(stepping_z1(), navigation_z1, prop_cfg.navigation);
197-
navigator_z2.init(stepping_z2(), navigation_z2, prop_cfg.navigation);
198-
navigator_x.init(stepping_x(), navigation_x, prop_cfg.navigation);
199+
navigator_z1.init(stepping_z1(), navigation_z1, prop_cfg.navigation, prop_cfg.context);
200+
navigator_z2.init(stepping_z2(), navigation_z2, prop_cfg.navigation, prop_cfg.context);
201+
navigator_x.init(stepping_x(), navigation_x, prop_cfg.navigation, prop_cfg.context);
199202

200203
bool heartbeat_z1 = navigation_z1.is_alive();
201204
bool heartbeat_z2 = navigation_z2.is_alive();
@@ -292,7 +295,7 @@ GTEST_TEST(detray_detectors, telescope_detector) {
292295
stepping_state_t &tel_stepping = tel_propagation._stepping;
293296

294297
// run propagation
295-
tel_navigator.init(tel_stepping(), tel_navigation, prop_cfg.navigation);
298+
tel_navigator.init(tel_stepping(), tel_navigation, prop_cfg.navigation, prop_cfg.context);
296299
bool heartbeat_tel = tel_navigation.is_alive();
297300

298301
bool do_reset_tel{true};

tests/unit_tests/cpu/navigation/brute_force_finder.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ GTEST_TEST(detray_navigation, brute_force_search) {
103103
const auto [det, names] = build_toy_detector(host_mr);
104104

105105
using detector_t = decltype(det);
106-
106+
using context_t = detector_t::geometry_context;
107+
context_t ctx{};
108+
107109
struct navigation_cfg {
108110
std::array<dindex, 2> search_window;
109111
};
@@ -116,6 +118,6 @@ GTEST_TEST(detray_navigation, brute_force_search) {
116118
detail::ray<typename detector_t::algebra_type> trk({0.f, 0.f, 0.f}, 0.f,
117119
{1.f, 0.f, 0.f}, -1.f);
118120

119-
vol.template visit_neighborhood<neighbor_visit_test>(trk, navigation_cfg{},
121+
vol.template visit_neighborhood<neighbor_visit_test>(trk, navigation_cfg{}, ctx,
120122
test_vol_idx);
121123
}

tests/unit_tests/cpu/navigation/intersection/intersection_kernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ GTEST_TEST(detray_intersection, intersection_kernel_ray) {
161161
for (const auto &surface : surfaces) {
162162
mask_store.visit<intersection_initialize<ray_intersector>>(
163163
surface.mask(), sfi_init, detail::ray(track), surface,
164-
transform_store, std::array<scalar_t, 2>{tol, tol});
164+
transform_store, static_context, std::array<scalar_t, 2>{tol, tol});
165165
}
166166

167167
ASSERT_TRUE(expected_points.size() == sfi_init.size());
@@ -291,7 +291,7 @@ GTEST_TEST(detray_intersection, intersection_kernel_helix) {
291291
// Try the intersections - with automated dispatching via the kernel
292292
for (const auto [sf_idx, surface] : detray::views::enumerate(surfaces)) {
293293
mask_store.visit<intersection_initialize<helix_intersector>>(
294-
surface.mask(), sfi_helix, h, surface, transform_store,
294+
surface.mask(), sfi_helix, h, surface, transform_store, static_context,
295295
std::array<scalar_t, 2>{0.f, 0.f}, scalar_t{0.f}, scalar_t{0.f});
296296

297297
vector3 global;

tests/unit_tests/cpu/navigation/navigator.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ constexpr std::size_t cache_size{navigation::default_cache_size};
3939
// dummy propagator state
4040
template <typename stepping_t, typename navigation_t>
4141
struct prop_state {
42+
using context_t = typename navigation_t::detector_type::geometry_context;
4243
stepping_t _stepping;
4344
navigation_t _navigation;
45+
context_t _context{};
4446
};
4547

4648
/// Checks for a correct 'towards_surface' state
@@ -163,6 +165,7 @@ GTEST_TEST(detray_navigation, navigator_toy_geometry) {
163165
stepper_t::state{traj}, navigator_t::state(toy_det)};
164166
navigator_t::state &navigation = propagation._navigation;
165167
stepper_t::state &stepping = propagation._stepping;
168+
const auto& ctx = propagation._context;
166169

167170
// Check that the state is unitialized
168171
// Default volume is zero
@@ -180,7 +183,7 @@ GTEST_TEST(detray_navigation, navigator_toy_geometry) {
180183

181184
// Initialize navigation
182185
// Test that the navigator has a heartbeat
183-
nav.init(stepping(), navigation, nav_cfg);
186+
nav.init(stepping(), navigation, nav_cfg, ctx);
184187
ASSERT_TRUE(navigation.is_alive());
185188
// The status is towards beampipe
186189
// Two candidates: beampipe and portal
@@ -358,6 +361,7 @@ GTEST_TEST(detray_navigation, navigator_wire_chamber) {
358361
stepper_t::state{traj}, navigator_t::state(wire_det)};
359362
navigator_t::state &navigation = propagation._navigation;
360363
stepper_t::state &stepping = propagation._stepping;
364+
const auto& ctx = propagation._context;
361365

362366
// Check that the state is unitialized
363367
// Default volume is zero
@@ -375,7 +379,7 @@ GTEST_TEST(detray_navigation, navigator_wire_chamber) {
375379

376380
// Initialize navigation
377381
// Test that the navigator has a heartbeat
378-
nav.init(stepping(), navigation, nav_cfg);
382+
nav.init(stepping(), navigation, nav_cfg, ctx);
379383
ASSERT_TRUE(navigation.is_alive());
380384
// The status is towards portal
381385
// One candidates: barrel cylinder portal

tests/unit_tests/device/cuda/navigator_cuda.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ TEST(navigator_cuda, navigator) {
7373

7474
navigator_host_t::state& navigation = propagation._navigation;
7575
stepper_t::state& stepping = propagation._stepping;
76+
const auto& ctx = propagation._context;
7677

7778
// Start propagation and record volume IDs
78-
nav.init(stepping(), navigation, nav_cfg);
79+
nav.init(stepping(), navigation, nav_cfg, ctx);
7980
bool heartbeat = navigation.is_alive();
8081
bool do_reset{true};
8182

tests/unit_tests/device/cuda/navigator_cuda_kernel.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ __global__ void navigator_test_kernel(
4141

4242
navigator_device_t::state& navigation = propagation._navigation;
4343
stepper_t::state& stepping = propagation._stepping;
44+
const auto& ctx = propagation._context;
4445

4546
// Set initial volume
4647
navigation.set_volume(0u);
4748

4849
// Start propagation and record volume IDs
49-
nav.init(stepping(), navigation, nav_cfg);
50+
nav.init(stepping(), navigation, nav_cfg, ctx);
5051
bool heartbeat = navigation.is_alive();
5152
bool do_reset{true};
5253

tests/unit_tests/device/cuda/navigator_cuda_kernel.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ constexpr dscalar<algebra_t> pos_diff_tolerance{1e-3f};
4747
// dummy propagator state
4848
template <typename navigation_t>
4949
struct prop_state {
50+
using context_t = typename navigation_t::detector_type::geometry_context;
5051
stepper_t::state _stepping;
5152
navigation_t _navigation;
53+
context_t _context{};
5254
};
5355

5456
/// test function for navigator with single state

0 commit comments

Comments
 (0)