Skip to content

Commit

Permalink
Allow cancellation to be propagated to child coroutines (#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
oldnewthing authored Sep 18, 2020
1 parent e504a0e commit c20a75b
Show file tree
Hide file tree
Showing 6 changed files with 346 additions and 15 deletions.
41 changes: 39 additions & 2 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,21 @@ namespace winrt::impl
};

template <typename Async>
struct await_adapter
struct await_adapter : enable_await_cancellation
{
await_adapter(Async const& async) : async(async) { }

Async const& async;
Windows::Foundation::AsyncStatus status = Windows::Foundation::AsyncStatus::Started;

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* parameter)
{
cancel_asynchronously(reinterpret_cast<await_adapter*>(parameter)->async);
}, this);
}

bool await_ready() const noexcept
{
return false;
Expand All @@ -153,6 +163,19 @@ namespace winrt::impl
check_status_canceled(status);
return async.GetResults();
}

private:
static fire_and_forget cancel_asynchronously(Async async)
{
co_await winrt::resume_background();
try
{
async.Cancel();
}
catch (hresult_error const&)
{
}
}
};

template <typename D>
Expand Down Expand Up @@ -278,6 +301,11 @@ namespace winrt::impl
m_promise->cancellation_callback(std::move(cancel));
}

bool enable_propagation(bool value = true) const noexcept
{
return m_promise->enable_cancellation_propagation(value);
}

private:

Promise* m_promise;
Expand Down Expand Up @@ -414,6 +442,8 @@ namespace winrt::impl
{
cancel();
}

m_cancellable.cancel();
}

void Close() const noexcept
Expand Down Expand Up @@ -536,7 +566,7 @@ namespace winrt::impl
throw winrt::hresult_canceled();
}

return notify_awaiter<Expression>{ static_cast<Expression&&>(expression) };
return notify_awaiter<Expression>{ static_cast<Expression&&>(expression), m_propagate_cancellation ? &m_cancellable : nullptr };
}

cancellation_token<Derived> await_transform(get_cancellation_token_t) noexcept
Expand Down Expand Up @@ -567,6 +597,11 @@ namespace winrt::impl
}
}

bool enable_cancellation_propagation(bool value) noexcept
{
return std::exchange(m_propagate_cancellation, value);
}

#if defined(_DEBUG) && !defined(WINRT_NO_MAKE_DETECTION)
void use_make_function_to_create_this_object() final
{
Expand All @@ -587,8 +622,10 @@ namespace winrt::impl
slim_mutex m_lock;
async_completed_handler_t<AsyncInterface> m_completed;
winrt::delegate<> m_cancel;
cancellable_promise m_cancellable;
std::atomic<AsyncStatus> m_status;
bool m_completed_assigned{ false };
bool m_propagate_cancellation{ false };
};
}

Expand Down
171 changes: 162 additions & 9 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,81 @@ namespace winrt::impl
static constexpr bool has_co_await_member = find_co_await_member<T&&>(0);
static constexpr bool has_co_await_free = find_co_await_free<T&&>(0);
};
}

WINRT_EXPORT namespace winrt
{
struct cancellable_promise
{
using canceller_t = void(*)(void*);

void set_canceller(canceller_t canceller, void* context)
{
m_context = context;
m_canceller.store(canceller, std::memory_order_release);
}

void revoke_canceller()
{
while (m_canceller.exchange(nullptr, std::memory_order_acquire) == cancelling_ptr)
{
std::this_thread::yield();
}
}

void cancel()
{
auto canceller = m_canceller.exchange(cancelling_ptr, std::memory_order_acquire);
struct unique_cancellation_lock
{
cancellable_promise* promise;
~unique_cancellation_lock()
{
promise->m_canceller.store(nullptr, std::memory_order_release);
}
} lock{ this };

if ((canceller != nullptr) && (canceller != cancelling_ptr))
{
canceller(m_context);
}
}

private:
static inline auto const cancelling_ptr = reinterpret_cast<canceller_t>(1);

std::atomic<canceller_t> m_canceller{ nullptr };
void* m_context{ nullptr };
};

struct enable_await_cancellation
{
enable_await_cancellation() noexcept = default;
enable_await_cancellation(enable_await_cancellation const&) = delete;

~enable_await_cancellation()
{
if (m_promise)
{
m_promise->revoke_canceller();
}
}

void operator=(enable_await_cancellation const&) = delete;

void set_cancellable_promise(cancellable_promise* promise) noexcept
{
m_promise = promise;
}

private:

cancellable_promise* m_promise = nullptr;
};
}

namespace winrt::impl
{
template <typename T>
decltype(auto) get_awaiter(T&& value) noexcept
{
Expand All @@ -149,8 +223,16 @@ namespace winrt::impl
{
decltype(get_awaiter(std::declval<T&&>())) awaitable;

notify_awaiter(T&& awaitable) : awaitable(get_awaiter(static_cast<T&&>(awaitable)))
notify_awaiter(T&& awaitable_arg, cancellable_promise* promise = nullptr) : awaitable(get_awaiter(static_cast<T&&>(awaitable_arg)))
{
if constexpr (std::is_convertible_v<std::remove_reference_t<decltype(awaitable)>&, enable_await_cancellation&>)
{
if (promise)
{
static_cast<enable_await_cancellation&>(awaitable).set_cancellable_promise(promise);
awaitable.enable_cancellation(promise);
}
}
}

bool await_ready()
Expand Down Expand Up @@ -271,34 +353,67 @@ WINRT_EXPORT namespace winrt

[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
{
struct awaitable
struct awaitable : enable_await_cancellation
{
explicit awaitable(Windows::Foundation::TimeSpan duration) noexcept :
m_duration(duration)
{
}

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* context)
{
auto that = static_cast<awaitable*>(context);
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
{
that->fire_immediately();
}
}, this);
}

bool await_ready() const noexcept
{
return m_duration.count() <= 0;
}

void await_suspend(std::experimental::coroutine_handle<> handle)
{
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, handle.address(), nullptr)));
m_handle = handle;
m_timer.attach(check_pointer(WINRT_IMPL_CreateThreadpoolTimer(callback, this, nullptr)));
int64_t relative_count = -m_duration.count();
WINRT_IMPL_SetThreadpoolTimer(m_timer.get(), &relative_count, 0, 0);
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &relative_count, 0, 0);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

void await_resume() const noexcept
void await_resume()
{
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
{
throw hresult_canceled();
}
}

private:

void fire_immediately() noexcept
{
if (WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), nullptr, 0, 0))
{
int64_t now = 0;
WINRT_IMPL_SetThreadpoolTimerEx(m_timer.get(), &now, 0, 0);
}
}

static void __stdcall callback(void*, void* context, void*) noexcept
{
std::experimental::coroutine_handle<>::from_address(context)();
auto that = reinterpret_cast<awaitable*>(context);
that->m_handle();
}

struct timer_traits
Expand All @@ -316,8 +431,12 @@ WINRT_EXPORT namespace winrt
}
};

enum class state { idle, pending, canceled };

handle_type<timer_traits> m_timer;
Windows::Foundation::TimeSpan m_duration;
std::experimental::coroutine_handle<> m_handle;
std::atomic<state> m_state{ state::idle };
};

return awaitable{ duration };
Expand All @@ -332,13 +451,25 @@ WINRT_EXPORT namespace winrt

[[nodiscard]] inline auto resume_on_signal(void* handle, Windows::Foundation::TimeSpan timeout = {}) noexcept
{
struct awaitable
struct awaitable : enable_await_cancellation
{
awaitable(void* handle, Windows::Foundation::TimeSpan timeout) noexcept :
m_timeout(timeout),
m_handle(handle)
{}

void enable_cancellation(cancellable_promise* promise)
{
promise->set_canceller([](void* context)
{
auto that = static_cast<awaitable*>(context);
if (that->m_state.exchange(state::canceled, std::memory_order_acquire) == state::pending)
{
that->fire_immediately();
}
}, this);
}

bool await_ready() const noexcept
{
return WINRT_IMPL_WaitForSingleObject(m_handle, 0) == 0;
Expand All @@ -350,16 +481,35 @@ WINRT_EXPORT namespace winrt
m_wait.attach(check_pointer(WINRT_IMPL_CreateThreadpoolWait(callback, this, nullptr)));
int64_t relative_count = -m_timeout.count();
int64_t* file_time = relative_count != 0 ? &relative_count : nullptr;
WINRT_IMPL_SetThreadpoolWait(m_wait.get(), m_handle, file_time);
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), m_handle, file_time, nullptr);

state expected = state::idle;
if (!m_state.compare_exchange_strong(expected, state::pending, std::memory_order_release))
{
fire_immediately();
}
}

bool await_resume() const noexcept
bool await_resume()
{
if (m_state.exchange(state::idle, std::memory_order_relaxed) == state::canceled)
{
throw hresult_canceled();
}
return m_result == 0;
}

private:

void fire_immediately() noexcept
{
if (WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), nullptr, nullptr, nullptr))
{
int64_t now = 0;
WINRT_IMPL_SetThreadpoolWaitEx(m_wait.get(), WINRT_IMPL_GetCurrentProcess(), &now, nullptr);
}
}

static void __stdcall callback(void*, void* context, void*, uint32_t result) noexcept
{
auto that = static_cast<awaitable*>(context);
Expand All @@ -382,11 +532,14 @@ WINRT_EXPORT namespace winrt
}
};

enum class state { idle, pending, canceled };

handle_type<wait_traits> m_wait;
Windows::Foundation::TimeSpan m_timeout;
void* m_handle;
uint32_t m_result{};
std::experimental::coroutine_handle<> m_resume{ nullptr };
std::atomic<state> m_state{ state::idle };
};

return awaitable{ handle, timeout };
Expand Down
8 changes: 4 additions & 4 deletions strings/base_extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ extern "C"

int32_t __stdcall WINRT_IMPL_TrySubmitThreadpoolCallback(void(__stdcall *callback)(void*, void* context), void* context, void*) noexcept;
winrt::impl::ptp_timer __stdcall WINRT_IMPL_CreateThreadpoolTimer(void(__stdcall *callback)(void*, void* context, void*), void* context, void*) noexcept;
void __stdcall WINRT_IMPL_SetThreadpoolTimer(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer timer, void* time, uint32_t period, uint32_t window) noexcept;
void __stdcall WINRT_IMPL_CloseThreadpoolTimer(winrt::impl::ptp_timer timer) noexcept;
winrt::impl::ptp_wait __stdcall WINRT_IMPL_CreateThreadpoolWait(void(__stdcall *callback)(void*, void* context, void*, uint32_t result), void* context, void*) noexcept;
void __stdcall WINRT_IMPL_SetThreadpoolWait(winrt::impl::ptp_wait wait, void* handle, void* timeout) noexcept;
int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait wait, void* handle, void* timeout, void* reserved) noexcept;
void __stdcall WINRT_IMPL_CloseThreadpoolWait(winrt::impl::ptp_wait wait) noexcept;
winrt::impl::ptp_io __stdcall WINRT_IMPL_CreateThreadpoolIo(void* object, void(__stdcall *callback)(void*, void* context, void* overlapped, uint32_t result, std::size_t bytes, void*) noexcept, void* context, void*) noexcept;
void __stdcall WINRT_IMPL_StartThreadpoolIo(winrt::impl::ptp_io io) noexcept;
Expand Down Expand Up @@ -147,10 +147,10 @@ WINRT_IMPL_LINK(WaitForSingleObject, 8)

WINRT_IMPL_LINK(TrySubmitThreadpoolCallback, 12)
WINRT_IMPL_LINK(CreateThreadpoolTimer, 12)
WINRT_IMPL_LINK(SetThreadpoolTimer, 16)
WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16)
WINRT_IMPL_LINK(CloseThreadpoolTimer, 4)
WINRT_IMPL_LINK(CreateThreadpoolWait, 12)
WINRT_IMPL_LINK(SetThreadpoolWait, 12)
WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16)
WINRT_IMPL_LINK(CloseThreadpoolWait, 4)
WINRT_IMPL_LINK(CreateThreadpoolIo, 16)
WINRT_IMPL_LINK(StartThreadpoolIo, 4)
Expand Down
1 change: 1 addition & 0 deletions strings/base_includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stdexcept>
#include <string_view>
#include <string>
#include <thread>
#include <tuple>
#include <type_traits>
#include <unordered_map>
Expand Down
Loading

0 comments on commit c20a75b

Please sign in to comment.