Skip to content

Commit d0ae100

Browse files
author
Vakho Tsulaia
committed
Patched unit tests to adapt to the changes in detray core (passing geometry context to single store)
1 parent 416baea commit d0ae100

File tree

8 files changed

+29
-16
lines changed

8 files changed

+29
-16
lines changed

tests/include/detray/test/validation/detector_scanner.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct brute_force_scan {
5656
using trajectory_type = trajectory_t;
5757

5858
template <typename detector_t>
59-
inline auto operator()(const typename detector_t::geometry_context,
59+
inline auto operator()(const typename detector_t::geometry_context ctx,
6060
const detector_t &detector, const trajectory_t &traj,
6161
const std::array<typename detector_t::scalar_type, 2>
6262
mask_tolerance = {0.f, 0.f},
@@ -89,7 +89,7 @@ struct brute_force_scan {
8989
// Retrieve candidate(s) from the surface
9090
const auto sf = tracking_surface{detector, sf_desc};
9191
sf.template visit_mask<intersection_kernel_t>(
92-
intersections, traj, sf_desc, trf_store,
92+
intersections, traj, sf_desc, trf_store, ctx,
9393
sf.is_portal() ? std::array<scalar_t, 2>{0.f, 0.f}
9494
: mask_tolerance);
9595

tests/unit_tests/cpu/detectors/telescope_detector.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,17 @@ using algebra_t = test::algebra;
4646
// dummy propagator state
4747
template <typename stepping_t, typename navigation_t>
4848
struct prop_state {
49+
using context_t = typename navigation_t::detector_type::geometry_context;
4950

5051
stepping_t _stepping;
5152
navigation_t _navigation;
53+
context_t _context;
5254

5355
template <typename track_t, typename field_type>
5456
prop_state(const track_t &t_in, const field_type &field,
55-
const typename navigation_t::detector_type &det)
56-
: _stepping(t_in, field), _navigation(det) {}
57+
const typename navigation_t::detector_type &det,
58+
const context_t &ctx = {})
59+
: _stepping(t_in, field), _navigation(det), _context(ctx) {}
5760
};
5861

5962
inline constexpr bool verbose_check = true;
@@ -193,9 +196,9 @@ GTEST_TEST(detray_detectors, telescope_detector) {
193196
navigation_state_t &navigation_x = propgation_x._navigation;
194197

195198
// propagate all telescopes
196-
navigator_z1.init(stepping_z1(), navigation_z1, prop_cfg.navigation);
197-
navigator_z2.init(stepping_z2(), navigation_z2, prop_cfg.navigation);
198-
navigator_x.init(stepping_x(), navigation_x, prop_cfg.navigation);
199+
navigator_z1.init(stepping_z1(), navigation_z1, prop_cfg.navigation, prop_cfg.context);
200+
navigator_z2.init(stepping_z2(), navigation_z2, prop_cfg.navigation, prop_cfg.context);
201+
navigator_x.init(stepping_x(), navigation_x, prop_cfg.navigation, prop_cfg.context);
199202

200203
bool heartbeat_z1 = navigation_z1.is_alive();
201204
bool heartbeat_z2 = navigation_z2.is_alive();
@@ -292,7 +295,7 @@ GTEST_TEST(detray_detectors, telescope_detector) {
292295
stepping_state_t &tel_stepping = tel_propagation._stepping;
293296

294297
// run propagation
295-
tel_navigator.init(tel_stepping(), tel_navigation, prop_cfg.navigation);
298+
tel_navigator.init(tel_stepping(), tel_navigation, prop_cfg.navigation, prop_cfg.context);
296299
bool heartbeat_tel = tel_navigation.is_alive();
297300

298301
bool do_reset_tel{true};

tests/unit_tests/cpu/navigation/brute_force_finder.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ GTEST_TEST(detray_navigation, brute_force_search) {
103103
const auto [det, names] = build_toy_detector(host_mr);
104104

105105
using detector_t = decltype(det);
106-
106+
using context_t = detector_t::geometry_context;
107+
context_t ctx{};
108+
107109
struct navigation_cfg {
108110
std::array<dindex, 2> search_window;
109111
};
@@ -116,6 +118,6 @@ GTEST_TEST(detray_navigation, brute_force_search) {
116118
detail::ray<typename detector_t::algebra_type> trk({0.f, 0.f, 0.f}, 0.f,
117119
{1.f, 0.f, 0.f}, -1.f);
118120

119-
vol.template visit_neighborhood<neighbor_visit_test>(trk, navigation_cfg{},
121+
vol.template visit_neighborhood<neighbor_visit_test>(trk, navigation_cfg{}, ctx,
120122
test_vol_idx);
121123
}

tests/unit_tests/cpu/navigation/intersection/intersection_kernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ GTEST_TEST(detray_intersection, intersection_kernel_ray) {
161161
for (const auto &surface : surfaces) {
162162
mask_store.visit<intersection_initialize<ray_intersector>>(
163163
surface.mask(), sfi_init, detail::ray(track), surface,
164-
transform_store, std::array<scalar_t, 2>{tol, tol});
164+
transform_store, static_context, std::array<scalar_t, 2>{tol, tol});
165165
}
166166

167167
ASSERT_TRUE(expected_points.size() == sfi_init.size());
@@ -291,7 +291,7 @@ GTEST_TEST(detray_intersection, intersection_kernel_helix) {
291291
// Try the intersections - with automated dispatching via the kernel
292292
for (const auto [sf_idx, surface] : detray::views::enumerate(surfaces)) {
293293
mask_store.visit<intersection_initialize<helix_intersector>>(
294-
surface.mask(), sfi_helix, h, surface, transform_store,
294+
surface.mask(), sfi_helix, h, surface, transform_store, static_context,
295295
std::array<scalar_t, 2>{0.f, 0.f}, scalar_t{0.f}, scalar_t{0.f});
296296

297297
vector3 global;

tests/unit_tests/cpu/navigation/navigator.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ constexpr std::size_t cache_size{navigation::default_cache_size};
3939
// dummy propagator state
4040
template <typename stepping_t, typename navigation_t>
4141
struct prop_state {
42+
using context_t = typename navigation_t::detector_type::geometry_context;
4243
stepping_t _stepping;
4344
navigation_t _navigation;
45+
context_t _context{};
4446
};
4547

4648
/// Checks for a correct 'towards_surface' state
@@ -163,6 +165,7 @@ GTEST_TEST(detray_navigation, navigator_toy_geometry) {
163165
stepper_t::state{traj}, navigator_t::state(toy_det)};
164166
navigator_t::state &navigation = propagation._navigation;
165167
stepper_t::state &stepping = propagation._stepping;
168+
const auto& ctx = propagation._context;
166169

167170
// Check that the state is unitialized
168171
// Default volume is zero
@@ -180,7 +183,7 @@ GTEST_TEST(detray_navigation, navigator_toy_geometry) {
180183

181184
// Initialize navigation
182185
// Test that the navigator has a heartbeat
183-
nav.init(stepping(), navigation, nav_cfg);
186+
nav.init(stepping(), navigation, nav_cfg, ctx);
184187
ASSERT_TRUE(navigation.is_alive());
185188
// The status is towards beampipe
186189
// Two candidates: beampipe and portal
@@ -358,6 +361,7 @@ GTEST_TEST(detray_navigation, navigator_wire_chamber) {
358361
stepper_t::state{traj}, navigator_t::state(wire_det)};
359362
navigator_t::state &navigation = propagation._navigation;
360363
stepper_t::state &stepping = propagation._stepping;
364+
const auto& ctx = propagation._context;
361365

362366
// Check that the state is unitialized
363367
// Default volume is zero
@@ -375,7 +379,7 @@ GTEST_TEST(detray_navigation, navigator_wire_chamber) {
375379

376380
// Initialize navigation
377381
// Test that the navigator has a heartbeat
378-
nav.init(stepping(), navigation, nav_cfg);
382+
nav.init(stepping(), navigation, nav_cfg, ctx);
379383
ASSERT_TRUE(navigation.is_alive());
380384
// The status is towards portal
381385
// One candidates: barrel cylinder portal

tests/unit_tests/device/cuda/navigator_cuda.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ TEST(navigator_cuda, navigator) {
7373

7474
navigator_host_t::state& navigation = propagation._navigation;
7575
stepper_t::state& stepping = propagation._stepping;
76+
const auto& ctx = propagation._context;
7677

7778
// Start propagation and record volume IDs
78-
nav.init(stepping(), navigation, nav_cfg);
79+
nav.init(stepping(), navigation, nav_cfg, ctx);
7980
bool heartbeat = navigation.is_alive();
8081
bool do_reset{true};
8182

tests/unit_tests/device/cuda/navigator_cuda_kernel.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,13 @@ __global__ void navigator_test_kernel(
4141

4242
navigator_device_t::state& navigation = propagation._navigation;
4343
stepper_t::state& stepping = propagation._stepping;
44+
const auto& ctx = propagation._context;
4445

4546
// Set initial volume
4647
navigation.set_volume(0u);
4748

4849
// Start propagation and record volume IDs
49-
nav.init(stepping(), navigation, nav_cfg);
50+
nav.init(stepping(), navigation, nav_cfg, ctx);
5051
bool heartbeat = navigation.is_alive();
5152
bool do_reset{true};
5253

tests/unit_tests/device/cuda/navigator_cuda_kernel.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ constexpr dscalar<algebra_t> pos_diff_tolerance{1e-3f};
4747
// dummy propagator state
4848
template <typename navigation_t>
4949
struct prop_state {
50+
using context_t = typename navigation_t::detector_type::geometry_context;
5051
stepper_t::state _stepping;
5152
navigation_t _navigation;
53+
context_t _context{};
5254
};
5355

5456
/// test function for navigator with single state

0 commit comments

Comments
 (0)