Skip to content

Commit 9dd7419

Browse files
authored
fix: To make traccc update work (#725)
Some fixes to make the current traccc update work: widening the angle space for the track generators and fixing a resulting bug in the quadratic equation. Additionally fixed a bug in the navigation that could lead to a navigation abort. Also harmonizing the track generator config interface with traccc.
1 parent b14f854 commit 9dd7419

File tree

13 files changed

+185
-97
lines changed

13 files changed

+185
-97
lines changed

core/include/detray/grids/grid2.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ class grid2 {
209209
*
210210
* @return the const reference to the value in this bin
211211
**/
212-
template <typename point2_t>
212+
template <typename point2_t,
213+
std::enable_if_t<!std::is_scalar_v<point2_t>, bool> = true>
213214
DETRAY_HOST_DEVICE typename serialized_storage::const_reference bin(
214215
const point2_t &p2) const {
215216
return _data_serialized[_serializer.template serialize<axis_p0_type,
@@ -223,7 +224,8 @@ class grid2 {
223224
*
224225
* @return the const reference to the value in this bin
225226
**/
226-
template <typename point2_t>
227+
template <typename point2_t,
228+
std::enable_if_t<!std::is_scalar_v<point2_t>, bool> = true>
227229
DETRAY_HOST_DEVICE typename serialized_storage::reference bin(
228230
const point2_t &p2) {
229231
return _data_serialized[_serializer.template serialize<axis_p0_type,

core/include/detray/navigation/detail/helix.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class helix {
5454
///
5555
/// @param pos the the origin of the helix
5656
/// @param time the time parameter
57-
/// @param dir the initial direction for the helix
57+
/// @param dir the initial normalized direction for the helix
5858
/// @param q the charge of the particle
5959
/// @param mag_field the magnetic field vector
6060
DETRAY_HOST_DEVICE
@@ -67,11 +67,14 @@ class helix {
6767
_h0 = vector::normalize(*_mag_field);
6868

6969
// Normalized tangent vector
70-
_t0 = vector::normalize(dir);
70+
_t0 = dir;
71+
72+
assert((math::abs(getter::norm(_t0) - 1.f) < 1e-5f) &&
73+
"The helix direction must be normalized");
7174

7275
// Momentum
7376
const vector3_type mom =
74-
1.f / static_cast<scalar_type>(math::abs(qop)) * dir;
77+
1.f / static_cast<scalar_type>(math::abs(qop)) * _t0;
7578

7679
// Normalized _h0 X _t0
7780
_n0 = vector::normalize(vector::cross(_h0, _t0));

core/include/detray/navigation/intersection/helix_cylinder_intersector.hpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,15 @@ struct helix_intersector_impl<cylindrical2D<algebra_t>, algebra_t>
117117

118118
// Obtain both possible solutions by looping over the (different)
119119
// starting positions
120-
unsigned int n_runs =
121-
math::abs(paths[0] - paths[1]) < convergence_tolerance ? 1u : 2u;
120+
unsigned int n_runs = static_cast<unsigned int>(qe.solutions());
121+
122+
// Even if the ray is parallel to the cylinder, the helix might still
123+
// hit it
124+
if (qe.solutions() == 0) {
125+
n_runs = 2u;
126+
paths[0] = r;
127+
paths[1] = -r;
128+
}
122129
for (unsigned int i = 0u; i < n_runs; ++i) {
123130

124131
scalar_type &s = paths[i];

core/include/detray/navigation/intersection/intersection.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct intersection2D {
101101
const intersection2D &is) {
102102
out_stream << "dist:" << is.path
103103
<< "\tsurface: " << is.sf_desc.barcode()
104+
<< ", type: " << static_cast<int>(is.sf_desc.mask().id())
104105
<< ", links to vol:" << is.volume_link << ")"
105106
<< ", loc [" << is.local[0] << ", " << is.local[1] << ", "
106107
<< is.local[2] << "], ";

core/include/detray/navigation/navigator.hpp

+26-23
Original file line numberDiff line numberDiff line change
@@ -610,33 +610,34 @@ class navigator {
610610
// Update next candidate: If not reachable, 'high trust' is broken
611611
if (not update_candidate(*navigation.next(), track, det, cfg)) {
612612
navigation.m_status = navigation::status::e_unknown;
613-
navigation.set_no_trust();
614-
return;
615-
}
613+
navigation.set_fair_trust();
614+
} else {
616615

617-
// Update navigation flow on the new candidate information
618-
update_navigation_state(cfg, propagation);
616+
// Update navigation flow on the new candidate information
617+
update_navigation_state(cfg, propagation);
619618

620-
navigation.run_inspector(cfg, "Update complete: high trust: ");
619+
navigation.run_inspector(cfg, "Update complete: high trust: ");
621620

622-
// The work is done if: the track has not reached a surface yet or
623-
// trust is gone (portal was reached or the cache is broken).
624-
if (navigation.status() == navigation::status::e_towards_object or
625-
navigation.trust_level() ==
626-
navigation::trust_level::e_no_trust) {
627-
return;
628-
}
621+
// The work is done if: the track has not reached a surface yet
622+
// or trust is gone (portal was reached or the cache is broken).
623+
if (navigation.status() ==
624+
navigation::status::e_towards_object or
625+
navigation.trust_level() ==
626+
navigation::trust_level::e_no_trust) {
627+
return;
628+
}
629629

630-
// Else: Track is on module.
631-
// Ready the next candidate after the current module
632-
if (update_candidate(*navigation.next(), track, det, cfg)) {
633-
return;
634-
}
630+
// Else: Track is on module.
631+
// Ready the next candidate after the current module
632+
if (update_candidate(*navigation.next(), track, det, cfg)) {
633+
return;
634+
}
635635

636-
// If next candidate is not reachable, don't 'return', but
637-
// escalate the trust level.
638-
// This will run into the fair trust case below.
639-
navigation.set_fair_trust();
636+
// If next candidate is not reachable, don't 'return', but
637+
// escalate the trust level.
638+
// This will run into the fair trust case below.
639+
navigation.set_fair_trust();
640+
}
640641
}
641642

642643
// Re-evaluate all currently available candidates and sort again
@@ -660,7 +661,9 @@ class navigator {
660661

661662
navigation.run_inspector(cfg, "Update complete: fair trust: ");
662663

663-
return;
664+
if (!navigation.is_exhausted()) {
665+
return;
666+
}
664667
}
665668

666669
// Actor flagged cache as broken (other cases of 'no trust' are

core/include/detray/utils/quadratic_equation.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ class quadratic_equation {
3939
const scalar_t tolerance = std::numeric_limits<scalar_t>::epsilon()) {
4040
// linear case
4141
if (math::abs(a) <= tolerance) {
42-
m_solutions = 1;
43-
m_values[0] = -c / b;
42+
if (math::abs(b) <= tolerance) {
43+
m_solutions = 0;
44+
} else {
45+
m_solutions = 1;
46+
m_values[0] = -c / b;
47+
}
4448
} else {
4549
const scalar_t discriminant{b * b - 4.f * a * c};
4650
// If there is more than one solution, then a != 0 and q != 0

tests/integration_tests/cpu/detectors/toy_detector_navigation.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ int main(int argc, char **argv) {
6262
cfg_hel_scan.name("toy_detector_helix_scan");
6363
cfg_hel_scan.whiteboard(white_board);
6464
cfg_hel_scan.track_generator().n_tracks(10000u);
65+
cfg_hel_scan.track_generator().eta_range(-4.f, 4.f);
6566
cfg_hel_scan.track_generator().p_T(1.f * unit<scalar_t>::GeV);
6667

6768
detail::register_checks<test::helix_scan>(toy_det, toy_names, cfg_hel_scan);
@@ -83,8 +84,6 @@ int main(int argc, char **argv) {
8384
cfg_hel_nav.name("toy_detector_helix_navigation");
8485
cfg_hel_nav.whiteboard(white_board);
8586
cfg_hel_nav.propagation().navigation.search_window = {3u, 3u};
86-
// For one surface the toy detector seems to need a stricter tolerance
87-
cfg_hel_nav.propagation().navigation.min_mask_tolerance = 1e-5f;
8887

8988
detail::register_checks<test::helix_navigation>(toy_det, toy_names,
9089
cfg_hel_nav);

tests/integration_tests/cpu/detectors/wire_chamber_navigation.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,17 @@ int main(int argc, char **argv) {
6262
cfg_hel_scan.name("wire_chamber_helix_scan");
6363
cfg_hel_scan.whiteboard(white_board);
6464
cfg_hel_scan.track_generator().n_tracks(10000u);
65+
cfg_hel_scan.track_generator().eta_range(-1.f, 1.f);
6566
// TODO: Fails for smaller momenta
66-
cfg_hel_scan.track_generator().p_T(3.f * unit<scalar_t>::GeV);
67+
cfg_hel_scan.track_generator().p_T(5.f * unit<scalar_t>::GeV);
6768

6869
detail::register_checks<test::helix_scan>(det, names, cfg_hel_scan);
6970

7071
// Comparison of straight line navigation with ray scan
7172
test::straight_line_navigation<wire_chamber_t>::config cfg_str_nav{};
7273
cfg_str_nav.name("wire_chamber_straight_line_navigation");
7374
cfg_str_nav.whiteboard(white_board);
74-
cfg_str_nav.propagation().navigation.search_window = {2u, 2u};
75+
cfg_str_nav.propagation().navigation.search_window = {3u, 3u};
7576
auto mask_tolerance = cfg_ray_scan.mask_tolerance();
7677
cfg_str_nav.propagation().navigation.min_mask_tolerance = mask_tolerance[0];
7778
cfg_str_nav.propagation().navigation.max_mask_tolerance = mask_tolerance[1];

tests/integration_tests/device/cuda/propagator_cuda.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ TEST_P(CudaPropConstBFieldMng, propagator) {
3535
propagator_test_config cfg{};
3636
cfg.track_generator.phi_steps(20).theta_steps(20);
3737
cfg.track_generator.p_tot(10.f * unit<scalar_t>::GeV);
38+
cfg.track_generator.eta_range(-3.f, 3.f);
3839
cfg.propagation.navigation.search_window = {3u, 3u};
3940
// Configuration for non-z-aligned B-fields
4041
cfg.propagation.navigation.overstep_tolerance = std::get<0>(GetParam());
@@ -69,6 +70,7 @@ TEST_P(CudaPropConstBFieldCpy, propagator) {
6970
propagator_test_config cfg{};
7071
cfg.track_generator.phi_steps(20u).theta_steps(20u);
7172
cfg.track_generator.p_tot(10.f * unit<scalar_t>::GeV);
73+
cfg.track_generator.eta_range(-3.f, 3.f);
7274
cfg.propagation.navigation.search_window = {3u, 3u};
7375
// Configuration for non-z-aligned B-fields
7476
cfg.propagation.navigation.overstep_tolerance = std::get<0>(GetParam());

tests/unit_tests/cpu/simulation/detector_scanner.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using namespace detray;
2626
using algebra_t = test::algebra;
2727
using vector3 = test::vector3;
2828

29-
constexpr const scalar tol{1e-3f};
29+
constexpr const scalar tol{1e-7f};
3030

3131
/// Brute force test: Intersect toy geometry and compare between ray and helix
3232
/// without B-field
@@ -74,10 +74,13 @@ GTEST_TEST(detray_simulation, detector_scanner) {
7474

7575
// Should have encountered the same number of tracks (vulnerable to
7676
// floating point errors)
77-
EXPECT_EQ(expected[n_tracks].size(), intersection_trace.size());
77+
EXPECT_EQ(expected[n_tracks].size(), intersection_trace.size())
78+
<< test_helix;
7879

7980
// Check every single recorded intersection
80-
for (std::size_t i = 0u; i < intersection_trace.size(); ++i) {
81+
for (std::size_t i = 0u;
82+
i < std::min(expected[n_tracks].size(), intersection_trace.size());
83+
++i) {
8184
if (expected[n_tracks][i].vol_idx !=
8285
intersection_trace[i].vol_idx) {
8386
// Intersection record at portal bound might be flipped

tests/unit_tests/cpu/simulation/track_generators.cpp

+17-12
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,24 @@ GTEST_TEST(detray_simulation, uniform_track_generator) {
2929
uniform_track_generator<free_track_parameters<algebra_t>>;
3030

3131
constexpr const scalar_t tol{1e-5f};
32-
constexpr const scalar_t epsilon{generator_t::configuration::epsilon};
32+
constexpr const scalar_t max_pi{generator_t::configuration::k_max_pi};
3333

3434
constexpr std::size_t phi_steps{50u};
3535
constexpr std::size_t theta_steps{50u};
3636

3737
std::array<vector3, phi_steps * theta_steps> momenta{};
3838

39-
// Loop over theta values ]0,pi[
39+
// Loop over theta values [0,pi)
4040
for (std::size_t itheta{0u}; itheta < theta_steps; ++itheta) {
41-
const scalar_t theta{epsilon +
42-
static_cast<scalar_t>(itheta) *
43-
(constant<scalar_t>::pi - 2.f * epsilon) /
44-
static_cast<scalar_t>(theta_steps - 1u)};
41+
const scalar_t theta{static_cast<scalar_t>(itheta) * max_pi /
42+
static_cast<scalar_t>(theta_steps - 1u)};
4543

46-
// Loop over phi values [-pi, pi]
44+
// Loop over phi values [-pi, pi)
4745
for (std::size_t iphi{0u}; iphi < phi_steps; ++iphi) {
4846
// The direction
4947
const scalar_t phi{-constant<scalar_t>::pi +
5048
static_cast<scalar_t>(iphi) *
51-
(2.f * constant<scalar_t>::pi) /
49+
(constant<scalar_t>::pi + max_pi) /
5250
static_cast<scalar_t>(phi_steps)};
5351

5452
// intialize a track
@@ -125,6 +123,7 @@ GTEST_TEST(detray_simulation, uniform_track_generator_eta) {
125123
uniform_track_generator<free_track_parameters<algebra_t>>;
126124

127125
constexpr const scalar_t tol{1e-5f};
126+
constexpr const scalar_t max_pi{generator_t::configuration::k_max_pi};
128127

129128
constexpr std::size_t phi_steps{50u};
130129
constexpr std::size_t eta_steps{50u};
@@ -137,12 +136,12 @@ GTEST_TEST(detray_simulation, uniform_track_generator_eta) {
137136
static_cast<scalar_t>(eta_steps - 1u)};
138137
const scalar_t theta{2.f * std::atan(std::exp(-eta))};
139138

140-
// Loop over phi values [-pi, pi]
139+
// Loop over phi values [-pi, pi)
141140
for (std::size_t iphi{0u}; iphi < phi_steps; ++iphi) {
142141
// The direction
143142
const scalar_t phi{-constant<scalar_t>::pi +
144143
static_cast<scalar_t>(iphi) *
145-
(2.f * constant<scalar_t>::pi) /
144+
(constant<scalar_t>::pi + max_pi) /
146145
static_cast<scalar_t>(phi_steps)};
147146

148147
// intialize a track
@@ -224,7 +223,7 @@ GTEST_TEST(detray_simulation, random_track_generator_uniform) {
224223
random_track_generator<free_track_parameters<algebra_t>, uniform_gen_t>;
225224

226225
// Tolerance depends on sample size
227-
constexpr scalar_t tol{0.05f};
226+
constexpr scalar_t tol{0.02f};
228227

229228
// Track counter
230229
std::size_t n_tracks{0u};
@@ -236,6 +235,7 @@ GTEST_TEST(detray_simulation, random_track_generator_uniform) {
236235
trk_gen_cfg.seed(42u);
237236
trk_gen_cfg.phi_range(-0.9f * constant<scalar_t>::pi,
238237
0.8f * constant<scalar_t>::pi);
238+
trk_gen_cfg.eta_range(-4.f, 4.f);
239239
trk_gen_cfg.mom_range(1.f * unit<scalar_t>::GeV, 2.f * unit<scalar_t>::GeV);
240240
trk_gen_cfg.origin_stddev({0.1f * unit<scalar_t>::mm,
241241
0.f * unit<scalar_t>::mm,
@@ -268,6 +268,8 @@ GTEST_TEST(detray_simulation, random_track_generator_uniform) {
268268
const auto& ori_stddev = trk_gen_cfg.origin_stddev();
269269
const auto& phi_range = trk_gen_cfg.phi_range();
270270
const auto& theta_range = trk_gen_cfg.theta_range();
271+
ASSERT_NEAR(theta_range[0], 0.0366f, tol);
272+
ASSERT_NEAR(theta_range[1], 3.105f, tol);
271273
const auto& mom_range = trk_gen_cfg.mom_range();
272274

273275
// Mean
@@ -304,7 +306,7 @@ GTEST_TEST(detray_simulation, random_track_generator_normal) {
304306
random_track_generator<free_track_parameters<algebra_t>, normal_gen_t>;
305307

306308
// Tolerance depends on sample size
307-
constexpr scalar_t tol{0.05f};
309+
constexpr scalar_t tol{0.02f};
308310

309311
// Track counter
310312
std::size_t n_tracks{0u};
@@ -316,6 +318,7 @@ GTEST_TEST(detray_simulation, random_track_generator_normal) {
316318
trk_gen_cfg.seed(42u);
317319
trk_gen_cfg.phi_range(-0.9f * constant<scalar_t>::pi,
318320
0.8f * constant<scalar_t>::pi);
321+
trk_gen_cfg.eta_range(-4.f, 4.f);
319322
trk_gen_cfg.mom_range(1.f * unit<scalar_t>::GeV, 2.f * unit<scalar_t>::GeV);
320323
trk_gen_cfg.origin({0.f, 0.f, 0.f});
321324
trk_gen_cfg.origin_stddev({0.1f * unit<scalar_t>::mm,
@@ -349,6 +352,8 @@ GTEST_TEST(detray_simulation, random_track_generator_normal) {
349352
const auto& ori_stddev = trk_gen_cfg.origin_stddev();
350353
const auto& phi_range = trk_gen_cfg.phi_range();
351354
const auto& theta_range = trk_gen_cfg.theta_range();
355+
ASSERT_NEAR(theta_range[0], 0.0366f, tol);
356+
ASSERT_NEAR(theta_range[1], 3.105f, tol);
352357
const auto& mom_range = trk_gen_cfg.mom_range();
353358

354359
// Mean

0 commit comments

Comments
 (0)