Skip to content

Commit 94fb38c

Browse files
author
Vakho Tsulaia
committed
A bunch of patches and hacks to make sure single_store::at() does not get called with default context
1 parent 4abd90e commit 94fb38c

13 files changed

+37
-24
lines changed

core/include/detray/geometry/detail/volume_kernels.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ struct neighborhood_getter {
7171
DETRAY_HOST_DEVICE inline void operator()(
7272
const accel_group_t &group, const accel_index_t index,
7373
const detector_t &det, const typename detector_t::volume_type &volume,
74-
const track_t &track, const config_t &cfg, Args &&... args) const {
74+
const track_t &track, const config_t &cfg, const typename detector_t::geometry_context& ctx, Args &&... args) const {
7575

7676
decltype(auto) accel = group[index];
7777

7878
// Run over the surfaces in a single acceleration data structure
79-
for (const auto &sf : accel.search(det, volume, track, cfg)) {
79+
for (const auto &sf : accel.search(det, volume, track, cfg, ctx)) {
8080
functor_t{}(sf, std::forward<Args>(args)...);
8181
}
8282
}

core/include/detray/geometry/tracking_volume.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class tracking_volume {
4444

4545
/// Volume descriptor type
4646
using descr_t = typename detector_t::volume_type;
47+
using context_t = typename detector_t::geometry_context;
4748

4849
public:
4950
/// In case the geometry needs to be printed
@@ -154,9 +155,9 @@ class tracking_volume {
154155
int I = static_cast<int>(descr_t::object_id::e_size) - 1,
155156
typename track_t, typename config_t, typename... Args>
156157
DETRAY_HOST_DEVICE constexpr void visit_neighborhood(
157-
const track_t &track, const config_t &cfg, Args &&... args) const {
158+
const track_t &track, const config_t &cfg, const context_t &ctx, Args &&... args) const {
158159
visit_surfaces_impl<detail::neighborhood_getter<functor_t>>(
159-
m_detector, m_desc, track, cfg, std::forward<Args>(args)...);
160+
m_detector, m_desc, track, cfg, ctx, std::forward<Args>(args)...);
160161
}
161162

162163
/// Call a functor on the volume material with additional arguments.

core/include/detray/navigation/accelerators/brute_force_finder.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ class brute_force_collection {
6060
const detector_t& /*det*/,
6161
const typename detector_t::volume_type& /*volume*/,
6262
const track_t& /*track*/,
63-
const config_t& /*navigation_config*/) const {
63+
const config_t& /*navigation_config*/,
64+
const typename detector_t::geometry_context& /*ctx*/) const {
6465
return *this;
6566
}
6667

core/include/detray/navigation/intersection_kernel.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ struct intersection_initialize {
5050
{0.f, 1.f * unit<scalar_t>::mm},
5151
const scalar_t mask_tol_scalor = 0.f,
5252
const scalar_t overstep_tol = 0.f) const {
53-
53+
typename transform_container_t::context_type ctx{0}; // Hack!
5454
using mask_t = typename mask_group_t::value_type;
5555
using algebra_t = typename mask_t::algebra_type;
5656

57-
const auto &ctf = contextual_transforms.at(surface.transform());
57+
const auto &ctf = contextual_transforms.at(surface.transform(),ctx);
5858

5959
// Run over the masks that belong to the surface (only one can be hit)
6060
for (const auto &mask :
@@ -142,8 +142,8 @@ struct intersection_update {
142142

143143
using mask_t = typename mask_group_t::value_type;
144144
using algebra_t = typename mask_t::algebra_type;
145-
146-
const auto &ctf = contextual_transforms.at(sfi.sf_desc.transform());
145+
typename transform_container_t::context_type ctx{0}; // Hack!
146+
const auto &ctf = contextual_transforms.at(sfi.sf_desc.transform(),ctx);
147147

148148
// Run over the masks that belong to the surface
149149
for (const auto &mask :

core/include/detray/navigation/navigator.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class navigator {
122122
using vector3_type = dvector3D<algebra_type>;
123123

124124
using volume_type = typename detector_type::volume_type;
125+
using context_type = typename detector_t::geometry_context;
125126
using nav_link_type = typename detector_type::surface_type::navigation_link;
126127
using intersection_type = intersection_t;
127128
using inspector_type = inspector_t;
@@ -621,7 +622,8 @@ class navigator {
621622
/// @param propagation contains the stepper and navigator states
622623
template <typename propagator_state_t>
623624
DETRAY_HOST_DEVICE inline bool init(propagator_state_t &propagation,
624-
const navigation::config &cfg) const {
625+
const navigation::config &cfg,
626+
const context_type& ctx = {}) const {
625627

626628
state &navigation = propagation._navigation;
627629
const auto &det = navigation.detector();
@@ -634,7 +636,7 @@ class navigator {
634636

635637
// Search for neighboring surfaces and fill candidates into cache
636638
volume.template visit_neighborhood<candidate_search>(
637-
track, cfg, det, track, navigation,
639+
track, cfg, ctx, det, track, navigation,
638640
std::array<scalar_type, 2u>{cfg.min_mask_tolerance,
639641
cfg.max_mask_tolerance},
640642
static_cast<scalar_type>(cfg.mask_tolerance_scalor),

core/include/detray/propagator/actors/parameter_resetter.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct parameter_resetter : actor {
7070

7171
using geo_cxt_t =
7272
typename propagator_state_t::detector_type::geometry_context;
73-
const geo_cxt_t ctx{};
73+
const geo_cxt_t ctx{0}; // Hack!
7474

7575
// Surface
7676
const auto sf = navigation.get_surface();

core/include/detray/propagator/actors/parameter_transporter.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct parameter_transporter : actor {
9999

100100
using detector_type = typename propagator_state_t::detector_type;
101101
using geo_cxt_t = typename detector_type::geometry_context;
102-
const geo_cxt_t ctx{};
102+
const geo_cxt_t ctx{0}; // Hack!
103103

104104
// Current Surface
105105
const auto sf = navigation.get_surface();

core/include/detray/propagator/actors/pointwise_material_interactor.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ struct pointwise_material_interactor : actor {
142142

143143
auto &stepping = prop_state._stepping;
144144

145-
this->update(geo_context_type{}, stepping._ptc,
145+
this->update(geo_context_type{0}, stepping._ptc, // Hack!
146146
stepping._bound_params, interactor_state,
147147
static_cast<int>(navigation.direction()),
148148
navigation.get_surface());

core/include/detray/propagator/base_stepper.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ class base_stepper {
9191
template <typename detector_t>
9292
DETRAY_HOST_DEVICE state(
9393
const bound_track_parameters_type &bound_params,
94-
const detector_t &det)
94+
const detector_t &det,
95+
const typename detector_t::geometry_context& ctx = {})
9596
: _bound_params(bound_params) {
9697

9798
// Surface
9899
const auto sf = tracking_surface{det, bound_params.surface_link()};
99100

100-
const typename detector_t::geometry_context ctx{};
101101
sf.template visit_mask<
102102
typename parameter_resetter<algebra_t>::kernel>(
103103
sf.transform(ctx), sf.index(), *this);

core/include/detray/propagator/propagation_config.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "detray/definitions/detail/qualifiers.hpp"
1212
#include "detray/navigation/navigation_config.hpp"
1313
#include "detray/propagator/stepping_config.hpp"
14+
#include "detray/core/detail/data_context.hpp"
1415

1516
// System inlcudes
1617
#include <ostream>
@@ -21,6 +22,7 @@ namespace detray::propagation {
2122
struct config {
2223
navigation::config navigation{};
2324
stepping::config stepping{};
25+
geometry_context context{};
2426
};
2527

2628
/// Print the propagation configuration
@@ -31,7 +33,9 @@ inline std::ostream& operator<<(std::ostream& out,
3133
<< "----------------------------\n"
3234
<< cfg.navigation << "\nParameter Transport\n"
3335
<< "----------------------------\n"
34-
<< cfg.stepping << "\n";
36+
<< cfg.stepping << "\nGeometry Context\n"
37+
<< "----------------------------\n"
38+
<< cfg.context.get() << "\n";
3539

3640
return out;
3741
}

core/include/detray/propagator/propagator.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct propagator {
3434
using navigator_type = navigator_t;
3535
using intersection_type = typename navigator_type::intersection_type;
3636
using detector_type = typename navigator_type::detector_type;
37+
using context_type = typename detector_type::geometry_context;
3738
using actor_chain_type = actor_chain_t;
3839
using algebra_type = typename stepper_t::algebra_type;
3940
using scalar_type = dscalar<algebra_type>;
@@ -60,6 +61,7 @@ struct propagator {
6061
struct state {
6162

6263
using detector_type = typename navigator_t::detector_type;
64+
using context_type = typename detector_type::geometry_context;
6365
using navigator_state_type = typename navigator_t::state;
6466
using actor_chain_type = actor_chain_t;
6567
using scalar_type = typename navigator_t::scalar_type;
@@ -101,8 +103,9 @@ struct propagator {
101103
template <typename field_t>
102104
DETRAY_HOST_DEVICE state(const bound_track_parameters_type &param,
103105
const field_t &magnetic_field,
104-
const detector_type &det)
105-
: _stepping(param, magnetic_field, det), _navigation(det) {}
106+
const detector_type &det,
107+
const context_type &ctx = {})
108+
: _stepping(param, magnetic_field, det, ctx), _navigation(det) {}
106109

107110
/// Set the particle hypothesis
108111
DETRAY_HOST_DEVICE
@@ -135,7 +138,7 @@ struct propagator {
135138

136139
// Initialize the navigation
137140
propagation._heartbeat =
138-
m_navigator.init(propagation, m_cfg.navigation);
141+
m_navigator.init(propagation, m_cfg.navigation, m_cfg.context);
139142

140143
// Run all registered actors/aborters after init
141144
run_actors(actor_state_refs, propagation);

core/include/detray/propagator/rk_stepper.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class rk_stepper final
6666
template <typename detector_t>
6767
DETRAY_HOST_DEVICE state(
6868
const bound_track_parameters_type& bound_params,
69-
const magnetic_field_t& mag_field, const detector_t& det)
70-
: base_type::state(bound_params, det), _magnetic_field(mag_field) {}
69+
const magnetic_field_t& mag_field, const detector_t& det, const typename detector_t::geometry_context& ctx = {})
70+
: base_type::state(bound_params, det, ctx), _magnetic_field(mag_field) {}
7171

7272
/// stepping data required for RKN4
7373
struct {

core/include/detray/utils/grid/grid.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,12 @@ class grid_impl {
273273
template <typename detector_t, typename track_t, typename config_t>
274274
DETRAY_HOST_DEVICE auto search(
275275
const detector_t &det, const typename detector_t::volume_type &volume,
276-
const track_t &track, const config_t &cfg) const {
276+
const track_t &track, const config_t &cfg, const typename detector_t::geometry_context& /*ctx*/) const {
277277

278278
// Track position in grid coordinates
279-
const auto &trf = det.transform_store().at(volume.transform());
279+
// const auto &trf = det.transform_store().at(volume.transform(),ctx);
280+
typename detector_t::geometry_context cttx{0}; // Hack!
281+
const auto &trf = det.transform_store().at(volume.transform(),cttx);
280282
const auto loc_pos = project(trf, track.pos(), track.dir());
281283

282284
// Grid lookup

0 commit comments

Comments
 (0)