Skip to content

Commit ca800b0

Browse files
authored
Fix state tuple type and remove empty states (#940)
This PR ensures that the local state tuple type in the actor chain also contains the states of all observing actors in a flattened tuple. At the same time, it makes sure that empty default states are not included and therefore also do not need to be passed to the propagator. Consequently, the empty states were removed everywhere now. In the propagator the actor state tuple type is auto deduced, so that we do not depend on the order of actors in the chain anymore (this will simplify the KF in traccc).
1 parent db05347 commit ca800b0

22 files changed

+157
-123
lines changed

core/include/detray/propagator/actor_chain.hpp

+26-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "detray/definitions/containers.hpp"
1212
#include "detray/definitions/detail/qualifiers.hpp"
1313
#include "detray/propagator/base_actor.hpp"
14+
#include "detray/utils/tuple.hpp"
1415
#include "detray/utils/tuple_helpers.hpp"
1516

1617
// System include(s)
@@ -32,11 +33,14 @@ class actor_chain {
3233

3334
public:
3435
/// Types of the actors that are registered in the chain
35-
using actor_list_type = dtuple<actors_t...>;
36-
// Tuple of actor states
37-
using state_tuple = dtuple<typename actors_t::state...>;
38-
// Type of states tuple that is used in the propagator
39-
using state = dtuple<typename actors_t::state &...>;
36+
using actor_tuple = dtuple<actors_t...>;
37+
38+
// Tuple of actor states (including states of observing actors, if present)
39+
using state_tuple = detail::tuple_cat_t<detail::state_tuple_t<actors_t>...>;
40+
41+
// Tuple of state references that is used in the propagator
42+
using state_ref_tuple =
43+
detail::tuple_cat_t<detail::state_ref_tuple_t<actors_t>...>;
4044

4145
/// Call all actors in the chain.
4246
///
@@ -50,27 +54,26 @@ class actor_chain {
5054
}
5155

5256
/// @returns the actor list
53-
DETRAY_HOST_DEVICE const actor_list_type &actors() const {
57+
DETRAY_HOST_DEVICE constexpr const actor_tuple &actors() const {
5458
return m_actors;
5559
}
5660

5761
/// @returns a tuple of default constructible actor states
5862
DETRAY_HOST_DEVICE
59-
static constexpr auto make_actor_states() {
63+
static constexpr auto make_default_actor_states() {
6064
// Only possible if each state is default initializable
61-
if constexpr ((std::default_initializable<typename actors_t::state> &&
62-
...)) {
63-
return dtuple<typename actors_t::state...>{};
65+
if constexpr (std::default_initializable<state_tuple>) {
66+
return state_tuple{};
6467
} else {
6568
return std::nullopt;
6669
}
6770
}
6871

6972
/// @returns a tuple of reference for every state in the tuple @param t
70-
DETRAY_HOST_DEVICE static constexpr state setup_actor_states(
71-
dtuple<typename actors_t::state...> &t) {
73+
DETRAY_HOST_DEVICE static constexpr state_ref_tuple setup_actor_states(
74+
state_tuple &t) {
7275
return setup_actor_states(
73-
t, std::make_index_sequence<sizeof...(actors_t)>{});
76+
t, std::make_index_sequence<detail::tuple_size_v<state_tuple>>{});
7477
}
7578

7679
private:
@@ -99,7 +102,7 @@ class actor_chain {
99102

100103
/// Resolve the actor calls.
101104
///
102-
/// @param states states of all actors (only bare actors)
105+
/// @param states states of all actors
103106
/// @param p_state the state of the propagator (stepper and navigator)
104107
template <typename actor_states_t, typename propagator_state_t,
105108
std::size_t... indices>
@@ -111,22 +114,24 @@ class actor_chain {
111114

112115
/// @returns a tuple of reference for every state in the tuple @param t
113116
template <std::size_t... indices>
114-
DETRAY_HOST_DEVICE static constexpr state setup_actor_states(
115-
dtuple<typename actors_t::state...> &t,
116-
std::index_sequence<indices...> /*ids*/) {
117+
DETRAY_HOST_DEVICE static constexpr state_ref_tuple setup_actor_states(
118+
state_tuple &t, std::index_sequence<indices...> /*ids*/) {
117119
return detray::tie(detail::get<indices>(t)...);
118120
}
119121

120122
/// Tuple of actors
121-
actor_list_type m_actors = {};
123+
[[no_unique_address]] actor_tuple m_actors = {};
122124
};
123125

124126
/// Empty actor chain (placeholder)
125127
template <>
126128
class actor_chain<> {
127129

128130
public:
131+
using actor_tuple = dtuple<>;
129132
using state_tuple = dtuple<>;
133+
using state_ref_tuple = dtuple<>;
134+
130135
/// Empty states replaces a real actor states container
131136
struct state {};
132137

@@ -135,13 +140,13 @@ class actor_chain<> {
135140
/// @param states the states of the actors.
136141
/// @param p_state the propagation state.
137142
template <typename actor_states_t, typename propagator_state_t>
138-
DETRAY_HOST_DEVICE void operator()(actor_states_t & /*states*/,
139-
propagator_state_t & /*p_state*/) const {
143+
DETRAY_HOST_DEVICE constexpr void operator()(
144+
actor_states_t & /*states*/, propagator_state_t & /*p_state*/) const {
140145
/*Do nothing*/
141146
}
142147

143148
/// @returns an empty state
144-
DETRAY_HOST_DEVICE static constexpr state setup_actor_states(
149+
DETRAY_HOST_DEVICE static constexpr state_ref_tuple setup_actor_states(
145150
const state_tuple &) {
146151
return {};
147152
}

core/include/detray/propagator/base_actor.hpp

+63-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** Detray library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2022-2024 CERN for the benefit of the ACTS project
3+
* (c) 2022-2025 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -28,26 +28,81 @@ struct actor {
2828
struct state {};
2929
};
3030

31+
namespace detail {
32+
/// Extrac the tuple of actor states from an actor type
33+
/// @{
34+
// Simple actor: No observers
35+
template <typename actor_t>
36+
struct get_state_tuple {
37+
private:
38+
using state_t = typename actor_t::state;
39+
40+
// Remove empty default state of base actor type from tuple
41+
using principal = std::conditional_t<std::same_as<state_t, actor::state>,
42+
dtuple<>, dtuple<state_t>>;
43+
using principal_ref =
44+
std::conditional_t<std::same_as<state_t, actor::state>, dtuple<>,
45+
dtuple<state_t &>>;
46+
47+
public:
48+
using type = principal;
49+
using ref_type = principal_ref;
50+
};
51+
52+
// Composite actor: Has observers
53+
template <typename actor_t>
54+
requires(!std::same_as<typename std::remove_cvref_t<actor_t>::observer_states,
55+
void>) struct get_state_tuple<actor_t> {
56+
private:
57+
using principal_actor_t = typename actor_t::actor_type;
58+
59+
using principal = typename get_state_tuple<principal_actor_t>::type;
60+
using principal_ref = typename get_state_tuple<principal_actor_t>::ref_type;
61+
62+
using observers = typename actor_t::observer_states;
63+
using observer_refs = typename actor_t::observer_state_refs;
64+
65+
public:
66+
using type = detail::tuple_cat_t<principal, observers>;
67+
using ref_type = detail::tuple_cat_t<principal_ref, observer_refs>;
68+
};
69+
70+
/// Tuple of state types
71+
template <typename actor_t>
72+
using state_tuple_t = get_state_tuple<actor_t>::type;
73+
74+
/// Tuple of references
75+
template <typename actor_t>
76+
using state_ref_tuple_t = get_state_tuple<actor_t>::ref_type;
77+
/// @}
78+
79+
} // namespace detail
80+
3181
/// Composition of actors
3282
///
3383
/// The composition represents an actor together with its observers. In
3484
/// addition to running its own implementation, it notifies its observing actors
3585
///
36-
/// @tparam actor_impl_t the actor the compositions implements itself.
86+
/// @tparam principal_actor_t the actor the compositions implements itself.
3787
/// @tparam observers a pack of observing actors that get called on the updated
3888
/// actor state of the compositions actor implementation.
39-
template <class actor_impl_t = actor, typename... observers>
40-
class composite_actor final : public actor_impl_t {
89+
template <class principal_actor_t = actor, typename... observers>
90+
class composite_actor final : public principal_actor_t {
4191

4292
public:
4393
/// Tag whether this is a composite type (hides the def in the actor)
4494
struct is_comp_actor : public std::true_type {};
4595

46-
/// The composite is an actor in itself. For simplicity, it cannot be
47-
/// derived from another composition (final).
48-
using actor_type = actor_impl_t;
96+
/// The composite is an actor in itself.
97+
using actor_type = principal_actor_t;
4998
using state = typename actor_type::state;
5099

100+
/// Tuple of states of observing actors
101+
using observer_states =
102+
detail::tuple_cat_t<detail::state_tuple_t<observers>...>;
103+
using observer_state_refs =
104+
detail::tuple_cat_t<detail::state_ref_tuple_t<observers>...>;
105+
51106
/// Call to the implementation of the actor (the actor possibly being an
52107
/// observer itself)
53108
///
@@ -133,7 +188,7 @@ class composite_actor final : public actor_impl_t {
133188
}
134189

135190
/// Keep the observers (might be composites again)
136-
dtuple<observers...> m_observers = {};
191+
[[no_unique_address]] dtuple<observers...> m_observers = {};
137192
};
138193

139194
} // namespace detray

core/include/detray/propagator/propagator.hpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ struct propagator {
150150
///
151151
/// @note If the return value of this function is true, a propagation step
152152
/// can be taken afterwards.
153+
template <typename actor_states_t>
153154
DETRAY_HOST_DEVICE void propagate_init(
154-
state &propagation,
155-
typename actor_chain_t::state actor_state_refs) const {
155+
state &propagation, actor_states_t actor_state_refs) const {
156156
auto &navigation = propagation._navigation;
157157
auto &stepping = propagation._stepping;
158158
auto &context = propagation._context;
@@ -179,9 +179,10 @@ struct propagator {
179179
///
180180
/// @note If the return value of this function is true, another step can
181181
/// be taken afterwards.
182+
template <typename actor_states_t>
182183
DETRAY_HOST_DEVICE bool propagate_step(
183184
state &propagation, bool is_init,
184-
typename actor_chain_t::state actor_state_refs) const {
185+
actor_states_t actor_state_refs) const {
185186
auto &navigation = propagation._navigation;
186187
auto &stepping = propagation._stepping;
187188
auto &context = propagation._context;
@@ -242,9 +243,10 @@ struct propagator {
242243
/// @param actor_state_refs tuple containing refences to the actor states
243244
///
244245
/// @return propagation success.
246+
template <typename actor_states_t>
245247
DETRAY_HOST_DEVICE bool propagate(
246248
state &propagation,
247-
typename actor_chain_t::state actor_state_refs) const {
249+
actor_states_t actor_state_refs = dtuple<>{}) const {
248250

249251
propagate_init(propagation, actor_state_refs);
250252
bool is_init = true;
@@ -278,9 +280,9 @@ struct propagator {
278280
/// @param actor_states the actor state
279281
///
280282
/// @return propagation success.
283+
template <typename actor_states_t>
281284
DETRAY_HOST_DEVICE bool propagate_sync(
282-
state &propagation,
283-
typename actor_chain_t::state actor_state_refs) const {
285+
state &propagation, actor_states_t actor_state_refs) const {
284286

285287
propagate_init(propagation, actor_state_refs);
286288
bool is_init = true;

core/include/detray/utils/tuple_helpers.hpp

+31
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,35 @@ template <typename T, class tuple_t>
168168
constexpr bool has_type_v = has_type<T, tuple_t>::value;
169169
///@}
170170

171+
/// Concatenate tuple types
172+
/// @{
173+
template <typename... tuple_ts>
174+
struct tuple_cat_type {};
175+
176+
template <typename... Args>
177+
struct tuple_cat_type<std::tuple<Args...>> {
178+
using type = std::tuple<Args...>;
179+
};
180+
181+
template <typename... Args1, typename... Args2, typename... tuple_ts>
182+
struct tuple_cat_type<std::tuple<Args1...>, std::tuple<Args2...>, tuple_ts...> {
183+
using type = typename tuple_cat_type<std::tuple<Args1..., Args2...>,
184+
tuple_ts...>::type;
185+
};
186+
187+
template <typename... Args>
188+
struct tuple_cat_type<dtuple<Args...>> {
189+
using type = dtuple<Args...>;
190+
};
191+
192+
template <typename... Args1, typename... Args2, typename... tuple_ts>
193+
struct tuple_cat_type<dtuple<Args1...>, dtuple<Args2...>, tuple_ts...> {
194+
using type =
195+
typename tuple_cat_type<dtuple<Args1..., Args2...>, tuple_ts...>::type;
196+
};
197+
198+
template <typename... tuple_ts>
199+
using tuple_cat_t = typename tuple_cat_type<tuple_ts...>::type;
200+
/// @}
201+
171202
} // namespace detray::detail

tests/benchmarks/cpu/propagation.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,9 @@ int main(int argc, char** argv) {
112112

113113
dtuple<> empty_state{};
114114

115-
parameter_transporter<test_algebra>::state transporter_state{};
116115
pointwise_material_interactor<test_algebra>::state interactor_state{};
117-
parameter_resetter<test_algebra>::state resetter_state{};
118116

119-
auto actor_states = detail::make_tuple<dtuple>(
120-
transporter_state, interactor_state, resetter_state);
117+
auto actor_states = detail::make_tuple<dtuple>(interactor_state);
121118

122119
//
123120
// Register benchmarks

tests/benchmarks/cuda/propagation.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,9 @@ int main(int argc, char** argv) {
107107

108108
dtuple<> empty_state{};
109109

110-
parameter_transporter<test_algebra>::state transporter_state{};
111110
pointwise_material_interactor<test_algebra>::state interactor_state{};
112-
parameter_resetter<test_algebra>::state resetter_state{};
113111

114-
auto actor_states = detail::make_tuple<dtuple>(
115-
transporter_state, interactor_state, resetter_state);
112+
auto actor_states = detail::make_tuple<dtuple>(interactor_state);
116113

117114
//
118115
// Register benchmarks

tests/benchmarks/include/detray/benchmarks/cpu/propagation_benchmark.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct host_propagation_bm : public benchmark_base {
105105
// Fresh copy of actor states
106106
actor_states_t actor_states(*input_actor_states);
107107
// Tuple of references to pass to the propagator
108-
typename actor_chain_t::state actor_state_refs =
108+
typename actor_chain_t::state_ref_tuple actor_state_refs =
109109
actor_chain_t::setup_actor_states(actor_states);
110110

111111
typename propagator_t::state p_state(track, *bfield, *det);

tests/include/detray/test/device/cuda/material_validation.cu

+1-4
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,11 @@ __global__ void material_validation_kernel(
6262

6363
// Create the actor states
6464
typename pathlimit_aborter_t::state aborter_state{cfg.stepping.path_limit};
65-
typename parameter_transporter<algebra_t>::state transporter_state{};
66-
typename parameter_resetter<algebra_t>::state resetter_state{};
6765
typename pointwise_material_interactor<algebra_t>::state interactor_state{};
6866
typename material_tracer_t::state mat_tracer_state{mat_steps.at(trk_id)};
6967

7068
auto actor_states =
71-
::detray::tie(aborter_state, transporter_state, resetter_state,
72-
interactor_state, mat_tracer_state);
69+
::detray::tie(aborter_state, interactor_state, mat_tracer_state);
7370

7471
// Run propagation
7572
typename navigator_t::state::view_type nav_view{};

tests/include/detray/test/device/propagator_test.hpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,9 @@ inline auto run_propagation_host(vecmem::memory_resource *mr,
134134
tracer_state.collect_only_on_surface(true);
135135
typename pathlimit_aborter_t::state pathlimit_state{
136136
cfg.stepping.path_limit};
137-
parameter_transporter<test_algebra>::state transporter_state{};
138137
pointwise_material_interactor<test_algebra>::state interactor_state{};
139-
parameter_resetter<test_algebra>::state resetter_state{};
140138
auto actor_states =
141-
detray::tie(tracer_state, pathlimit_state, transporter_state,
142-
interactor_state, resetter_state);
139+
detray::tie(tracer_state, pathlimit_state, interactor_state);
143140

144141
typename propagator_host_t::state state(trk, field, det);
145142

tests/include/detray/test/utils/simulation/random_scatterer.hpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,11 @@ struct random_scatterer : actor {
6767
/// Material store visitor
6868
struct kernel {
6969

70-
using state = typename random_scatterer::state;
71-
7270
template <typename mat_group_t, typename index_t>
7371
DETRAY_HOST_DEVICE inline bool operator()(
7472
[[maybe_unused]] const mat_group_t& material_group,
7573
[[maybe_unused]] const index_t& mat_index,
76-
[[maybe_unused]] state& s,
74+
[[maybe_unused]] typename random_scatterer::state& s,
7775
[[maybe_unused]] const pdg_particle<scalar_type>& ptc,
7876
[[maybe_unused]] const bound_track_parameters<algebra_t>&
7977
bound_params,

0 commit comments

Comments
 (0)