Skip to content

Commit 6af616a

Browse files
committed
add variant_tail_sender
1 parent e80f3a6 commit 6af616a

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

include/exec/variant_tail_sender.hpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
/*
2+
* Copyright (c) 2021-2022 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "../stdexec/execution.hpp"
19+
#include "env.hpp"
20+
21+
#include <variant>
22+
23+
namespace exec {
24+
using namespace stdexec;
25+
26+
template <typename... _TailSenderN>
27+
struct __variant_tail_sender : private std::variant<_TailSenderN...> {
28+
static_assert(sizeof...(_TailSenderN) >= 1, "variant_tail_sender requires at least one sender");
29+
static_assert((tail_sender<_TailSenderN> && ...), "variant_tail_sender requires all senders to be tail_sender");
30+
31+
using __senders_t = std::variant<_TailSenderN...>;
32+
33+
using __senders_t::__senders_t;
34+
using __senders_t::operator=;
35+
using __senders_t::index;
36+
using __senders_t::emplace;
37+
using __senders_t::swap;
38+
39+
// __variant_tail_sender() = default;
40+
// template<class... _OtherTailSenders>
41+
// __variant_tail_sender(const __variant_tail_sender<_OtherTailSenders...>& __o)
42+
// : __senders_t(static_cast<const std::variant<_OtherTailSenders...>&>(__o)) {}
43+
// template<class... _OtherTailSenders>
44+
// __variant_tail_sender(__variant_tail_sender<_OtherTailSenders...>&& __o)
45+
// : __senders_t(static_cast<std::variant<_OtherTailSenders...>&&>(__o)) {}
46+
// template<class... _OtherTailSenders>
47+
// __variant_tail_sender& operator=(const __variant_tail_sender<_OtherTailSenders...>& __o) {
48+
// __senders_t::operator=(__o);
49+
// return *this;
50+
// }
51+
// template<class... _OtherTailSenders>
52+
// __variant_tail_sender& operator=(__variant_tail_sender<_OtherTailSenders...>&& __o) {
53+
// __senders_t::operator=(std::move(__o));
54+
// return *this;
55+
// }
56+
57+
// template<class... _An>
58+
// requires (__is_not_instance_of<_An, __variant_tail_sender> && ...)
59+
// __variant_tail_sender(_An&&... __an) : __senders_t((_An&&)__an...) {}
60+
// template<class _A>
61+
// requires __is_not_instance_of<_A, __variant_tail_sender>
62+
// __variant_tail_sender& operator=(_A&& __a) {
63+
// __senders_t::operator=((_A&&)__a);
64+
// return *this;
65+
// }
66+
67+
template <class _TailReceiver>
68+
struct op {
69+
using __opn_t = __variant<std::monostate, std::optional<connect_result_t<_TailSenderN, _TailReceiver>>...>;
70+
using __start_result_t = __variant_tail_sender<next_tail_from_sender_to_t<_TailSenderN, _TailReceiver>...>;
71+
72+
op(const op&) = delete;
73+
op(op&&) = delete;
74+
op& operator=(const op&) = delete;
75+
op& operator=(op&&) = delete;
76+
77+
explicit op(__senders_t&& __t, _TailReceiver __r) {
78+
std::visit([&, this](auto&& __t) -> void {
79+
using _T = std::remove_cvref_t<decltype(__t)>;
80+
if constexpr (tail_sender<_T>) {
81+
static_assert(tail_sender_to<_T, _TailReceiver>, "variant-tail-sender member cannot connect");
82+
using op_t = connect_result_t<_T, _TailReceiver>;
83+
using opt_t = std::optional<op_t>;
84+
__opn_.template emplace<opt_t>();
85+
opt_t& opt = std::get<opt_t>(__opn_);
86+
opt.~opt_t();
87+
new (&opt) opt_t{stdexec::__conv{
88+
[&] () -> op_t {
89+
return stdexec::connect((decltype(__t)&&)__t, __r);
90+
}
91+
}};
92+
} else {
93+
std::terminate();
94+
}
95+
}, (__senders_t&&)__t);
96+
}
97+
98+
operator bool() const noexcept {
99+
return std::visit([&](auto&& __op) -> bool {
100+
using _Opt = std::decay_t<decltype(__op)>;
101+
if constexpr (__is_instance_of_<_Opt, std::optional>) {
102+
auto& op = *__op;
103+
using _Op = std::decay_t<decltype(op)>;
104+
if constexpr (__nullable_tail_operation_state<_Op>) {
105+
return !!op;
106+
}
107+
return true;
108+
} else {
109+
std::terminate();
110+
}
111+
}, __opn_);
112+
}
113+
114+
[[nodiscard]]
115+
friend __start_result_t tag_invoke(start_t, op& __self) noexcept {
116+
return std::visit([&](auto&& __op) -> __start_result_t {
117+
using _Opt = std::decay_t<decltype(__op)>;
118+
if constexpr (__is_instance_of_<_Opt, std::optional>) {
119+
auto& op = *__op;
120+
using _Op = std::decay_t<decltype(op)>;
121+
if constexpr (__nullable_tail_operation_state<_Op>) {
122+
if (!op) {
123+
return __start_result_t{};
124+
}
125+
}
126+
if constexpr (__terminal_tail_operation_state<_Op>) {
127+
stdexec::start(op);
128+
return __start_result_t{};
129+
} else {
130+
return result_from<__start_result_t>(stdexec::start(op));
131+
}
132+
} else {
133+
std::terminate();
134+
}
135+
}, __self.__opn_);
136+
}
137+
138+
friend void tag_invoke(unwind_t, op& __self) noexcept {
139+
return std::visit([&](auto&& __op) -> void {
140+
using _Opt = std::decay_t<decltype(__op)>;
141+
if constexpr (__is_instance_of_<_Opt, std::optional>) {
142+
exec::unwind(*__op);
143+
} else {
144+
std::terminate();
145+
}
146+
}, __self.__opn_);
147+
}
148+
__opn_t __opn_;
149+
};
150+
151+
using completion_signatures = completion_signatures<set_value_t(), set_stopped_t()>;
152+
153+
template <class _TailReceiver>
154+
[[nodiscard]]
155+
friend auto tag_invoke(connect_t, __variant_tail_sender&& __self, _TailReceiver&& __r) noexcept
156+
-> op<std::decay_t<_TailReceiver>> {
157+
return op<std::decay_t<_TailReceiver>>{(__variant_tail_sender&&)__self, (_TailReceiver&&)__r};
158+
}
159+
160+
template<class _Env>
161+
friend constexpr bool tag_invoke(
162+
exec::always_completes_inline_t, exec::c_t<__variant_tail_sender>, exec::c_t<_Env>) noexcept {
163+
return true;
164+
}
165+
166+
private:
167+
template <typename... _OtherTailSenderN>
168+
friend struct __variant_tail_sender;
169+
170+
template <typename _To>
171+
friend constexpr _To variant_cast(__variant_tail_sender __f) noexcept {
172+
return std::visit([]<class _U>(_U&& __u) -> _To {
173+
if constexpr (stdexec::__v<__mapply<__contains<_U>, _To>>) {
174+
return _To{(_U&&) __u};
175+
} else {
176+
printf("variant_cast\n"); fflush(stdout);
177+
std::terminate();
178+
}
179+
}, std::move(static_cast<__senders_t&>(__f)));
180+
}
181+
182+
template<tail_sender _To>
183+
friend constexpr std::decay_t<_To> get(__variant_tail_sender __f) noexcept {
184+
static_assert(stdexec::__v<__mapply<__contains<std::decay_t<_To>>, __variant_tail_sender>>, "get does not have _To as an alternative");
185+
if (!holds_alternative<std::decay_t<_To>>(__f)) {
186+
printf("get\n"); fflush(stdout);
187+
std::terminate();
188+
}
189+
return std::get<std::decay_t<_To>>(std::move(static_cast<__senders_t&>(__f)));
190+
}
191+
192+
template< class T>
193+
friend inline constexpr bool holds_alternative( const __variant_tail_sender& v ) noexcept {
194+
return std::holds_alternative<T>(v);
195+
}
196+
};
197+
198+
template <template<class...> class _T>
199+
struct __mflattener_of {
200+
201+
template <class _Continuation>
202+
struct __push_back_flatten;
203+
204+
template <class _Continuation = __q<__types>>
205+
struct __mflatten {
206+
template <class... _Ts>
207+
using __f =
208+
__mapply<
209+
_Continuation,
210+
__minvoke<__fold_right<__types<>, __push_back_flatten<__q<__types>>>, _Ts...>>;
211+
};
212+
213+
template <class _Continuation>
214+
struct __push_back_flatten {
215+
216+
template <bool _IsInstance, class _List, class _Item>
217+
struct __f_;
218+
template <template<class...> class _List, class... _ListItems, template<class...> class _Instance, class... _InstanceItems>
219+
struct __f_<true, _List<_ListItems...>, _Instance<_InstanceItems...>> {
220+
using __t = __minvoke<__mflatten<_Continuation>, _ListItems..., _InstanceItems...>;
221+
};
222+
template <template<class...> class _List, class... _ListItems, class _Item>
223+
struct __f_<false, _List<_ListItems...>, _Item> {
224+
using __t = __minvoke<_Continuation, _ListItems..., _Item>;
225+
};
226+
template <class _List, class _Item>
227+
using __f = __t<__f_<__is_instance_of<_Item, _T>, _List, _Item>>;
228+
};
229+
};
230+
231+
template <typename... _TailSenderN>
232+
using variant_tail_sender =
233+
__minvoke<
234+
__if_c<
235+
sizeof...(_TailSenderN) != 0,
236+
__transform<__q<decay_t>,
237+
__mflattener_of<__variant_tail_sender>::__mflatten<
238+
__munique<__q<__variant_tail_sender>>>>,
239+
__mconst<__not_a_variant>>,
240+
_TailSenderN...>;
241+
242+
} // namespace exec

0 commit comments

Comments
 (0)