|
| 1 | +/* |
| 2 | + * Copyright (c) 2021-2022 NVIDIA Corporation |
| 3 | + * |
| 4 | + * Licensed under the Apache License Version 2.0 with LLVM Exceptions |
| 5 | + * (the "License"); you may not use this file except in compliance with |
| 6 | + * the License. You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * https://llvm.org/LICENSE.txt |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#pragma once |
| 17 | + |
| 18 | +#include "../stdexec/execution.hpp" |
| 19 | +#include "env.hpp" |
| 20 | + |
| 21 | +#include <variant> |
| 22 | + |
| 23 | +namespace exec { |
| 24 | + using namespace stdexec; |
| 25 | + |
| 26 | + template <typename... _TailSenderN> |
| 27 | + struct __variant_tail_sender : private std::variant<_TailSenderN...> { |
| 28 | + static_assert(sizeof...(_TailSenderN) >= 1, "variant_tail_sender requires at least one sender"); |
| 29 | + static_assert((tail_sender<_TailSenderN> && ...), "variant_tail_sender requires all senders to be tail_sender"); |
| 30 | + |
| 31 | + using __senders_t = std::variant<_TailSenderN...>; |
| 32 | + |
| 33 | + using __senders_t::__senders_t; |
| 34 | + using __senders_t::operator=; |
| 35 | + using __senders_t::index; |
| 36 | + using __senders_t::emplace; |
| 37 | + using __senders_t::swap; |
| 38 | + |
| 39 | + // __variant_tail_sender() = default; |
| 40 | + // template<class... _OtherTailSenders> |
| 41 | + // __variant_tail_sender(const __variant_tail_sender<_OtherTailSenders...>& __o) |
| 42 | + // : __senders_t(static_cast<const std::variant<_OtherTailSenders...>&>(__o)) {} |
| 43 | + // template<class... _OtherTailSenders> |
| 44 | + // __variant_tail_sender(__variant_tail_sender<_OtherTailSenders...>&& __o) |
| 45 | + // : __senders_t(static_cast<std::variant<_OtherTailSenders...>&&>(__o)) {} |
| 46 | + // template<class... _OtherTailSenders> |
| 47 | + // __variant_tail_sender& operator=(const __variant_tail_sender<_OtherTailSenders...>& __o) { |
| 48 | + // __senders_t::operator=(__o); |
| 49 | + // return *this; |
| 50 | + // } |
| 51 | + // template<class... _OtherTailSenders> |
| 52 | + // __variant_tail_sender& operator=(__variant_tail_sender<_OtherTailSenders...>&& __o) { |
| 53 | + // __senders_t::operator=(std::move(__o)); |
| 54 | + // return *this; |
| 55 | + // } |
| 56 | + |
| 57 | + // template<class... _An> |
| 58 | + // requires (__is_not_instance_of<_An, __variant_tail_sender> && ...) |
| 59 | + // __variant_tail_sender(_An&&... __an) : __senders_t((_An&&)__an...) {} |
| 60 | + // template<class _A> |
| 61 | + // requires __is_not_instance_of<_A, __variant_tail_sender> |
| 62 | + // __variant_tail_sender& operator=(_A&& __a) { |
| 63 | + // __senders_t::operator=((_A&&)__a); |
| 64 | + // return *this; |
| 65 | + // } |
| 66 | + |
| 67 | + template <class _TailReceiver> |
| 68 | + struct op { |
| 69 | + using __opn_t = __variant<std::monostate, std::optional<connect_result_t<_TailSenderN, _TailReceiver>>...>; |
| 70 | + using __start_result_t = __variant_tail_sender<next_tail_from_sender_to_t<_TailSenderN, _TailReceiver>...>; |
| 71 | + |
| 72 | + op(const op&) = delete; |
| 73 | + op(op&&) = delete; |
| 74 | + op& operator=(const op&) = delete; |
| 75 | + op& operator=(op&&) = delete; |
| 76 | + |
| 77 | + explicit op(__senders_t&& __t, _TailReceiver __r) { |
| 78 | + std::visit([&, this](auto&& __t) -> void { |
| 79 | + using _T = std::remove_cvref_t<decltype(__t)>; |
| 80 | + if constexpr (tail_sender<_T>) { |
| 81 | + static_assert(tail_sender_to<_T, _TailReceiver>, "variant-tail-sender member cannot connect"); |
| 82 | + using op_t = connect_result_t<_T, _TailReceiver>; |
| 83 | + using opt_t = std::optional<op_t>; |
| 84 | + __opn_.template emplace<opt_t>(); |
| 85 | + opt_t& opt = std::get<opt_t>(__opn_); |
| 86 | + opt.~opt_t(); |
| 87 | + new (&opt) opt_t{stdexec::__conv{ |
| 88 | + [&] () -> op_t { |
| 89 | + return stdexec::connect((decltype(__t)&&)__t, __r); |
| 90 | + } |
| 91 | + }}; |
| 92 | + } else { |
| 93 | + std::terminate(); |
| 94 | + } |
| 95 | + }, (__senders_t&&)__t); |
| 96 | + } |
| 97 | + |
| 98 | + operator bool() const noexcept { |
| 99 | + return std::visit([&](auto&& __op) -> bool { |
| 100 | + using _Opt = std::decay_t<decltype(__op)>; |
| 101 | + if constexpr (__is_instance_of_<_Opt, std::optional>) { |
| 102 | + auto& op = *__op; |
| 103 | + using _Op = std::decay_t<decltype(op)>; |
| 104 | + if constexpr (__nullable_tail_operation_state<_Op>) { |
| 105 | + return !!op; |
| 106 | + } |
| 107 | + return true; |
| 108 | + } else { |
| 109 | + std::terminate(); |
| 110 | + } |
| 111 | + }, __opn_); |
| 112 | + } |
| 113 | + |
| 114 | + [[nodiscard]] |
| 115 | + friend __start_result_t tag_invoke(start_t, op& __self) noexcept { |
| 116 | + return std::visit([&](auto&& __op) -> __start_result_t { |
| 117 | + using _Opt = std::decay_t<decltype(__op)>; |
| 118 | + if constexpr (__is_instance_of_<_Opt, std::optional>) { |
| 119 | + auto& op = *__op; |
| 120 | + using _Op = std::decay_t<decltype(op)>; |
| 121 | + if constexpr (__nullable_tail_operation_state<_Op>) { |
| 122 | + if (!op) { |
| 123 | + return __start_result_t{}; |
| 124 | + } |
| 125 | + } |
| 126 | + if constexpr (__terminal_tail_operation_state<_Op>) { |
| 127 | + stdexec::start(op); |
| 128 | + return __start_result_t{}; |
| 129 | + } else { |
| 130 | + return result_from<__start_result_t>(stdexec::start(op)); |
| 131 | + } |
| 132 | + } else { |
| 133 | + std::terminate(); |
| 134 | + } |
| 135 | + }, __self.__opn_); |
| 136 | + } |
| 137 | + |
| 138 | + friend void tag_invoke(unwind_t, op& __self) noexcept { |
| 139 | + return std::visit([&](auto&& __op) -> void { |
| 140 | + using _Opt = std::decay_t<decltype(__op)>; |
| 141 | + if constexpr (__is_instance_of_<_Opt, std::optional>) { |
| 142 | + exec::unwind(*__op); |
| 143 | + } else { |
| 144 | + std::terminate(); |
| 145 | + } |
| 146 | + }, __self.__opn_); |
| 147 | + } |
| 148 | + __opn_t __opn_; |
| 149 | + }; |
| 150 | + |
| 151 | + using completion_signatures = completion_signatures<set_value_t(), set_stopped_t()>; |
| 152 | + |
| 153 | + template <class _TailReceiver> |
| 154 | + [[nodiscard]] |
| 155 | + friend auto tag_invoke(connect_t, __variant_tail_sender&& __self, _TailReceiver&& __r) noexcept |
| 156 | + -> op<std::decay_t<_TailReceiver>> { |
| 157 | + return op<std::decay_t<_TailReceiver>>{(__variant_tail_sender&&)__self, (_TailReceiver&&)__r}; |
| 158 | + } |
| 159 | + |
| 160 | + template<class _Env> |
| 161 | + friend constexpr bool tag_invoke( |
| 162 | + exec::always_completes_inline_t, exec::c_t<__variant_tail_sender>, exec::c_t<_Env>) noexcept { |
| 163 | + return true; |
| 164 | + } |
| 165 | + |
| 166 | + private: |
| 167 | + template <typename... _OtherTailSenderN> |
| 168 | + friend struct __variant_tail_sender; |
| 169 | + |
| 170 | + template <typename _To> |
| 171 | + friend constexpr _To variant_cast(__variant_tail_sender __f) noexcept { |
| 172 | + return std::visit([]<class _U>(_U&& __u) -> _To { |
| 173 | + if constexpr (stdexec::__v<__mapply<__contains<_U>, _To>>) { |
| 174 | + return _To{(_U&&) __u}; |
| 175 | + } else { |
| 176 | + printf("variant_cast\n"); fflush(stdout); |
| 177 | + std::terminate(); |
| 178 | + } |
| 179 | + }, std::move(static_cast<__senders_t&>(__f))); |
| 180 | + } |
| 181 | + |
| 182 | + template<tail_sender _To> |
| 183 | + friend constexpr std::decay_t<_To> get(__variant_tail_sender __f) noexcept { |
| 184 | + static_assert(stdexec::__v<__mapply<__contains<std::decay_t<_To>>, __variant_tail_sender>>, "get does not have _To as an alternative"); |
| 185 | + if (!holds_alternative<std::decay_t<_To>>(__f)) { |
| 186 | + printf("get\n"); fflush(stdout); |
| 187 | + std::terminate(); |
| 188 | + } |
| 189 | + return std::get<std::decay_t<_To>>(std::move(static_cast<__senders_t&>(__f))); |
| 190 | + } |
| 191 | + |
| 192 | + template< class T> |
| 193 | + friend inline constexpr bool holds_alternative( const __variant_tail_sender& v ) noexcept { |
| 194 | + return std::holds_alternative<T>(v); |
| 195 | + } |
| 196 | + }; |
| 197 | + |
| 198 | + template <template<class...> class _T> |
| 199 | + struct __mflattener_of { |
| 200 | + |
| 201 | + template <class _Continuation> |
| 202 | + struct __push_back_flatten; |
| 203 | + |
| 204 | + template <class _Continuation = __q<__types>> |
| 205 | + struct __mflatten { |
| 206 | + template <class... _Ts> |
| 207 | + using __f = |
| 208 | + __mapply< |
| 209 | + _Continuation, |
| 210 | + __minvoke<__fold_right<__types<>, __push_back_flatten<__q<__types>>>, _Ts...>>; |
| 211 | + }; |
| 212 | + |
| 213 | + template <class _Continuation> |
| 214 | + struct __push_back_flatten { |
| 215 | + |
| 216 | + template <bool _IsInstance, class _List, class _Item> |
| 217 | + struct __f_; |
| 218 | + template <template<class...> class _List, class... _ListItems, template<class...> class _Instance, class... _InstanceItems> |
| 219 | + struct __f_<true, _List<_ListItems...>, _Instance<_InstanceItems...>> { |
| 220 | + using __t = __minvoke<__mflatten<_Continuation>, _ListItems..., _InstanceItems...>; |
| 221 | + }; |
| 222 | + template <template<class...> class _List, class... _ListItems, class _Item> |
| 223 | + struct __f_<false, _List<_ListItems...>, _Item> { |
| 224 | + using __t = __minvoke<_Continuation, _ListItems..., _Item>; |
| 225 | + }; |
| 226 | + template <class _List, class _Item> |
| 227 | + using __f = __t<__f_<__is_instance_of<_Item, _T>, _List, _Item>>; |
| 228 | + }; |
| 229 | + }; |
| 230 | + |
| 231 | + template <typename... _TailSenderN> |
| 232 | + using variant_tail_sender = |
| 233 | + __minvoke< |
| 234 | + __if_c< |
| 235 | + sizeof...(_TailSenderN) != 0, |
| 236 | + __transform<__q<decay_t>, |
| 237 | + __mflattener_of<__variant_tail_sender>::__mflatten< |
| 238 | + __munique<__q<__variant_tail_sender>>>>, |
| 239 | + __mconst<__not_a_variant>>, |
| 240 | + _TailSenderN...>; |
| 241 | + |
| 242 | +} // namespace exec |
0 commit comments