Skip to content

Commit 041054e

Browse files
authored
Merge pull request #895 from beomki-yeo/test-for-backward-propagation
Enable backward propagation
2 parents b4fd39a + b205db6 commit 041054e

File tree

4 files changed

+174
-18
lines changed

4 files changed

+174
-18
lines changed

core/include/detray/navigation/navigator.hpp

+21-10
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ class navigator {
231231
/// Scalar representation of the navigation state,
232232
/// @returns distance to next
233233
DETRAY_HOST_DEVICE
234-
scalar_type operator()() const { return target().path; }
234+
scalar_type operator()() const {
235+
return static_cast<scalar_type>(direction()) * target().path;
236+
}
235237

236238
/// @returns current volume (index) - const
237239
DETRAY_HOST_DEVICE
@@ -612,8 +614,12 @@ class navigator {
612614
const auto sf = tracking_surface{det, sf_descr};
613615

614616
sf.template visit_mask<intersection_initialize<ray_intersector>>(
615-
nav_state, detail::ray(track), sf_descr, det.transform_store(),
616-
ctx,
617+
nav_state,
618+
detail::ray<algebra_type>(
619+
track.pos(),
620+
static_cast<scalar_type>(nav_state.direction()) *
621+
track.dir()),
622+
sf_descr, det.transform_store(), ctx,
617623
sf.is_portal() ? std::array<scalar_type, 2>{0.f, 0.f}
618624
: mask_tol,
619625
mask_tol_scalor, overstep_tol);
@@ -775,7 +781,8 @@ class navigator {
775781
// - do this only when the navigation state is still coherent
776782
if (navigation.trust_level() == navigation::trust_level::e_high) {
777783
// Update next candidate: If not reachable, 'high trust' is broken
778-
if (!update_candidate(navigation.target(), track, det, cfg, ctx)) {
784+
if (!update_candidate(navigation.direction(), navigation.target(),
785+
track, det, cfg, ctx)) {
779786
navigation.m_status = navigation::status::e_unknown;
780787
navigation.set_fair_trust();
781788
} else {
@@ -797,7 +804,8 @@ class navigator {
797804

798805
// Else: Track is on module.
799806
// Ready the next candidate after the current module
800-
if (update_candidate(navigation.target(), track, det, cfg,
807+
if (update_candidate(navigation.direction(),
808+
navigation.target(), track, det, cfg,
801809
ctx)) {
802810
return false;
803811
}
@@ -815,7 +823,8 @@ class navigator {
815823

816824
for (auto &candidate : navigation) {
817825
// Disregard this candidate if it is not reachable
818-
if (!update_candidate(candidate, track, det, cfg, ctx)) {
826+
if (!update_candidate(navigation.direction(), candidate, track,
827+
det, cfg, ctx)) {
819828
// Forcefully set dist to numeric max for sorting
820829
candidate.path = std::numeric_limits<scalar_type>::max();
821830
}
@@ -897,9 +906,9 @@ class navigator {
897906
/// @returns whether the track can reach this candidate.
898907
template <typename track_t>
899908
DETRAY_HOST_DEVICE inline bool update_candidate(
900-
intersection_type &candidate, const track_t &track,
901-
const detector_type &det, const navigation::config &cfg,
902-
const context_type &ctx) const {
909+
const navigation::direction nav_dir, intersection_type &candidate,
910+
const track_t &track, const detector_type &det,
911+
const navigation::config &cfg, const context_type &ctx) const {
903912

904913
if (candidate.sf_desc.barcode().is_invalid()) {
905914
return false;
@@ -909,7 +918,9 @@ class navigator {
909918

910919
// Check whether this candidate is reachable by the track
911920
return sf.template visit_mask<intersection_update<ray_intersector>>(
912-
detail::ray(track), candidate, det.transform_store(), ctx,
921+
detail::ray<algebra_type>(
922+
track.pos(), static_cast<scalar_type>(nav_dir) * track.dir()),
923+
candidate, det.transform_store(), ctx,
913924
sf.is_portal() ? std::array<scalar_type, 2>{0.f, 0.f}
914925
: std::array<scalar_type, 2>{cfg.min_mask_tolerance,
915926
cfg.max_mask_tolerance},

tests/integration_tests/cpu/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ macro(detray_add_cpu_test algebra)
1313
"builders/material_map_builder.cpp"
1414
"builders/volume_builder.cpp"
1515
"material/material_interaction.cpp"
16+
"propagator/backward_propagation.cpp"
1617
"propagator/covariance_transport.cpp"
1718
"propagator/guided_navigator.cpp"
1819
"propagator/propagator.cpp"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/** Detray library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2022-2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
// Project include(s).
9+
#include "detray/definitions/units.hpp"
10+
#include "detray/detectors/bfield.hpp"
11+
#include "detray/geometry/barcode.hpp"
12+
#include "detray/geometry/shapes/rectangle2D.hpp"
13+
#include "detray/navigation/detail/ray.hpp"
14+
#include "detray/navigation/navigator.hpp"
15+
#include "detray/propagator/actor_chain.hpp"
16+
#include "detray/propagator/actors/parameter_resetter.hpp"
17+
#include "detray/propagator/actors/parameter_transporter.hpp"
18+
#include "detray/propagator/propagator.hpp"
19+
#include "detray/propagator/rk_stepper.hpp"
20+
#include "detray/tracks/tracks.hpp"
21+
22+
// Detray test include(s)
23+
#include "detray/test/utils/detectors/build_telescope_detector.hpp"
24+
#include "detray/test/utils/types.hpp"
25+
26+
// Vecmem include(s)
27+
#include <vecmem/memory/host_memory_resource.hpp>
28+
29+
// google-test include(s).
30+
#include <gtest/gtest.h>
31+
32+
using namespace detray;
33+
34+
// Algebra types
35+
using algebra_t = test::algebra;
36+
using point2 = test::point2;
37+
using vector3 = test::vector3;
38+
using matrix_operator = test::matrix_operator;
39+
40+
constexpr test::scalar tol{5e-3f};
41+
42+
GTEST_TEST(detray_propagator, backward_propagation) {
43+
44+
vecmem::host_memory_resource host_mr;
45+
46+
// Build in x-direction from given module positions
47+
detail::ray<algebra_t> traj{{0.f, 0.f, 0.f}, 0.f, {1.f, 0.f, 0.f}, -1.f};
48+
std::vector<test::scalar> positions = {0.f, 10.f, 20.f, 30.f, 40.f, 50.f,
49+
60.f, 70.f, 80.f, 90.f, 100.f};
50+
51+
tel_det_config<rectangle2D> tel_cfg{200.f * unit<test::scalar>::mm,
52+
200.f * unit<test::scalar>::mm};
53+
tel_cfg.positions(positions).pilot_track(traj);
54+
55+
// Build telescope detector with rectangular planes
56+
const auto [det, names] = build_telescope_detector(host_mr, tel_cfg);
57+
58+
// Create b field
59+
using bfield_t = bfield::const_field_t;
60+
vector3 B{1.f * unit<test::scalar>::T, 1.f * unit<test::scalar>::T,
61+
1.f * unit<test::scalar>::T};
62+
const bfield_t hom_bfield = bfield::create_const_field(B);
63+
64+
using navigator_t = navigator<decltype(det)>;
65+
using rk_stepper_t = rk_stepper<bfield_t::view_t, algebra_t>;
66+
using actor_chain_t = actor_chain<dtuple, parameter_transporter<algebra_t>,
67+
parameter_resetter<algebra_t>>;
68+
using propagator_t = propagator<rk_stepper_t, navigator_t, actor_chain_t>;
69+
70+
// Bound vector
71+
bound_parameters_vector<algebra_t> bound_vector{};
72+
bound_vector.set_theta(constant<test::scalar>::pi_2);
73+
bound_vector.set_qop(-1.f);
74+
75+
// Bound covariance
76+
typename bound_track_parameters<algebra_t>::covariance_type bound_cov =
77+
matrix_operator().template identity<e_bound_size, e_bound_size>();
78+
79+
// Bound track parameter
80+
const bound_track_parameters<algebra_t> bound_param0(
81+
geometry::barcode{}.set_index(0u), bound_vector, bound_cov);
82+
83+
// Actors
84+
parameter_transporter<algebra_t>::state bound_updater{};
85+
parameter_resetter<algebra_t>::state rst{};
86+
87+
propagation::config prop_cfg{};
88+
prop_cfg.stepping.rk_error_tol = 1e-12f * unit<float>::mm;
89+
prop_cfg.navigation.overstep_tolerance = -100.f * unit<float>::um;
90+
propagator_t p{prop_cfg};
91+
92+
// Forward state
93+
propagator_t::state fw_state(bound_param0, hom_bfield, det,
94+
prop_cfg.context);
95+
fw_state.do_debug = true;
96+
97+
// Run propagator
98+
p.propagate(fw_state, detray::tie(bound_updater, rst));
99+
100+
// Print the debug stream
101+
// std::cout << fw_state.debug_stream.str() << std::endl;
102+
103+
// Bound state after propagation
104+
const auto& bound_param1 = fw_state._stepping.bound_params();
105+
106+
// Check if the track reaches the final surface
107+
EXPECT_EQ(bound_param0.surface_link().volume(), 4095u);
108+
EXPECT_EQ(bound_param0.surface_link().index(), 0u);
109+
EXPECT_EQ(bound_param1.surface_link().volume(), 0u);
110+
EXPECT_EQ(bound_param1.surface_link().id(), surface_id::e_sensitive);
111+
EXPECT_EQ(bound_param1.surface_link().index(), 10u);
112+
113+
// Backward state
114+
propagator_t::state bw_state(bound_param1, hom_bfield, det,
115+
prop_cfg.context);
116+
bw_state.do_debug = true;
117+
bw_state._navigation.set_direction(navigation::direction::e_backward);
118+
119+
// Run propagator
120+
p.propagate(bw_state, detray::tie(bound_updater, rst));
121+
122+
// Print the debug stream
123+
// std::cout << bw_state.debug_stream.str() << std::endl;
124+
125+
// Bound state after propagation
126+
const auto& bound_param2 = bw_state._stepping.bound_params();
127+
128+
// Check if the track reaches the initial surface
129+
EXPECT_EQ(bound_param2.surface_link().volume(), 0u);
130+
EXPECT_EQ(bound_param2.surface_link().id(), surface_id::e_sensitive);
131+
EXPECT_EQ(bound_param2.surface_link().index(), 0u);
132+
133+
const auto bound_vec0 = bound_param0.vector();
134+
const auto bound_vec2 = bound_param2.vector();
135+
136+
// Check vector
137+
for (unsigned int i = 0u; i < e_bound_size; i++) {
138+
EXPECT_NEAR(matrix_operator().element(bound_vec0, i, 0),
139+
matrix_operator().element(bound_vec2, i, 0), tol);
140+
}
141+
142+
const auto bound_cov0 = bound_param0.covariance();
143+
const auto bound_cov2 = bound_param2.covariance();
144+
145+
// Check covaraince
146+
for (unsigned int i = 0u; i < e_bound_size; i++) {
147+
for (unsigned int j = 0u; j < e_bound_size; j++) {
148+
EXPECT_NEAR(matrix_operator().element(bound_cov0, i, j),
149+
matrix_operator().element(bound_cov2, i, j), tol);
150+
}
151+
}
152+
}

tests/unit_tests/cpu/propagator/covariance_transport.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,4 @@ GTEST_TEST(detray_propagator, covariance_transport) {
122122
matrix_operator().element(bound_cov1, i, j), tol);
123123
}
124124
}
125-
126-
// Check covaraince
127-
for (unsigned int i = 0u; i < e_bound_size; i++) {
128-
for (unsigned int j = 0u; j < e_bound_size; j++) {
129-
EXPECT_NEAR(matrix_operator().element(bound_cov0, i, j),
130-
matrix_operator().element(bound_cov1, i, j), tol);
131-
}
132-
}
133125
}

0 commit comments

Comments
 (0)