Skip to content

Commit

Permalink
Resuming neutral context from STA should force background thread (#662)
Browse files Browse the repository at this point in the history
  • Loading branch information
oldnewthing authored Jun 16, 2020
1 parent 8c0832f commit a0b1889
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 36 deletions.
4 changes: 2 additions & 2 deletions strings/base_coroutine_foundation.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace winrt::impl
{
// Note: A blocking wait on the UI thread for an asynchronous operation can cause a deadlock.
// See https://docs.microsoft.com/windows/uwp/cpp-and-winrt-apis/concurrency#block-the-calling-thread
WINRT_ASSERT(!is_sta());
WINRT_ASSERT(!is_sta_thread());
}

template <typename T, typename H>
Expand Down Expand Up @@ -119,7 +119,7 @@ namespace winrt::impl

private:
std::experimental::coroutine_handle<> m_handle;
com_ptr<IContextCallback> m_context = apartment_context();
resume_apartment_context m_context;

void Complete()
{
Expand Down
99 changes: 72 additions & 27 deletions strings/base_coroutine_threadpool.h
Original file line number Diff line number Diff line change
@@ -1,63 +1,108 @@

namespace winrt::impl
{
inline auto submit_threadpool_callback(void(__stdcall* callback)(void*, void* context), void* context)
{
if (!WINRT_IMPL_TrySubmitThreadpoolCallback(callback, context, nullptr))
{
throw_last_error();
}
}

inline void __stdcall resume_background_callback(void*, void* context) noexcept
{
std::experimental::coroutine_handle<>::from_address(context)();
};

inline auto resume_background(std::experimental::coroutine_handle<> handle)
{
if (!WINRT_IMPL_TrySubmitThreadpoolCallback(resume_background_callback, handle.address(), nullptr))
{
throw_last_error();
}
submit_threadpool_callback(resume_background_callback, handle.address());
}

inline bool is_sta() noexcept
inline std::pair<int32_t, int32_t> get_apartment_type() noexcept
{
int32_t aptType;
int32_t aptTypeQualifier;
return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/));
if (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier))
{
return { aptType, aptTypeQualifier };
}
else
{
return { 1 /* APTTYPE_MTA */, 1 /* APTTYPEQUALIFIER_IMPLICIT_MTA */ };
}
}

inline bool requires_apartment_context() noexcept
inline bool is_sta_thread() noexcept
{
int32_t aptType;
int32_t aptTypeQualifier;
return (0 == WINRT_IMPL_CoGetApartmentType(&aptType, &aptTypeQualifier)) && ((aptType == 0 /*APTTYPE_STA*/) || (aptType == 2 /*APTTYPE_NA*/) || (aptType == 3 /*APTTYPE_MAINSTA*/));
auto type = get_apartment_type();
switch (type.first)
{
case 0: /* APTTYPE_STA */
case 3: /* APTTYPE_MAINSTA */
return true;
case 2: /* APTTYPE_NA */
return type.second == 3 /* APTTYPEQUALIFIER_NA_ON_STA */ ||
type.second == 5 /* APTTYPEQUALIFIER_NA_ON_MAINSTA */;
}
return false;
}

inline auto apartment_context()
struct resume_apartment_context
{
return requires_apartment_context() ? capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext) : nullptr;
}
com_ptr<IContextCallback> m_context = try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext);
int32_t m_context_type = get_apartment_type().first;
};

inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept
{
std::experimental::coroutine_handle<>::from_address(args->data)();
return 0;
};

inline auto resume_apartment(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
inline void resume_apartment_sync(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
{
com_callback_args args{};
args.data = handle.address();

check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr));
}

inline void resume_apartment_on_threadpool(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle)
{
if (context)
struct threadpool_resume
{
com_callback_args args{};
args.data = handle.address();
threadpool_resume(com_ptr<IContextCallback> const& context, std::experimental::coroutine_handle<> handle) :
m_context(context), m_handle(handle) { }
com_ptr<IContextCallback> m_context;
std::experimental::coroutine_handle<> m_handle;
};
auto state = std::make_unique<threadpool_resume>(context, handle);
submit_threadpool_callback([](void*, void* p)
{
std::unique_ptr<threadpool_resume> state{ static_cast<threadpool_resume*>(p) };
resume_apartment_sync(state->m_context, state->m_handle);
}, state.get());
state.release();
}

check_hresult(context->ContextCallback(resume_apartment_callback, &args, guid_of<ICallbackWithNoReentrancyToApplicationSTA>(), 5, nullptr));
inline auto resume_apartment(resume_apartment_context const& context, std::experimental::coroutine_handle<> handle)
{
if ((context.m_context == nullptr) || (context.m_context == try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext)))
{
handle();
}
else if (context.m_context_type == 1 /* APTTYPE_MTA */)
{
resume_background(handle);
}
else if ((context.m_context_type == 2 /* APTTYPE_NTA */) && is_sta_thread())
{
resume_apartment_on_threadpool(context.m_context, handle);
}
else
{
if (requires_apartment_context())
{
resume_background(handle);
}
else
{
handle();
}
resume_apartment_sync(context.m_context, handle);
}
}

Expand Down Expand Up @@ -294,7 +339,7 @@ WINRT_EXPORT namespace winrt
impl::resume_apartment(context, handle);
}

com_ptr<impl::IContextCallback> context = impl::apartment_context();
impl::resume_apartment_context context;
};

[[nodiscard]] inline auto resume_after(Windows::Foundation::TimeSpan duration) noexcept
Expand Down
49 changes: 49 additions & 0 deletions test/old_tests/UnitTests/apartment_context.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "pch.h"
#include "catch.hpp"
#include <ctxtcall.h>

using namespace winrt;
using namespace Windows::Foundation;
using namespace Windows::System;

namespace
{
Expand All @@ -12,9 +14,56 @@ namespace

co_await context;
}

template<typename TLambda>
void InvokeInContext(IContextCallback* context, TLambda&& lambda)
{
ComCallData data;
data.pUserDefined = &lambda;
check_hresult(context->ContextCallback([](ComCallData* data) -> HRESULT
{
auto& lambda = *reinterpret_cast<TLambda*>(data->pUserDefined);
lambda();
return S_OK;
}, &data, IID_ICallbackWithNoReentrancyToApplicationSTA, 5, nullptr));
}

auto get_winrt_apartment_context_for_com_context(com_ptr<::IContextCallback> const& com_context)
{
std::optional<decltype(apartment_context())> context;
InvokeInContext(com_context.get(), [&] {
context = apartment_context();
});
return context.value();
}

bool is_nta_on_mta()
{
APTTYPE type;
APTTYPEQUALIFIER qualifier;
check_hresult(CoGetApartmentType(&type, &qualifier));
return (type == APTTYPE_NA) && (qualifier == APTTYPEQUALIFIER_NA_ON_MTA || qualifier == APTTYPEQUALIFIER_NA_ON_IMPLICIT_MTA);
}

IAsyncAction TestNeutralApartmentContext()
{
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
co_await resume_foreground(controller.DispatcherQueue());

// Entering neutral apartment from STA should resume on explicit background thread.
auto nta = get_winrt_apartment_context_for_com_context(capture<::IContextCallback>(CoGetDefaultContext, APTTYPE_NA));
co_await nta;

REQUIRE(is_nta_on_mta());
}
}

TEST_CASE("apartment_context coverage")
{
Async().get();
}

TEST_CASE("apartment_context nta")
{
TestNeutralApartmentContext().get();
}
22 changes: 15 additions & 7 deletions test/test/await_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ using namespace Windows::System;

namespace
{
bool is_sta()
{
APTTYPE type;
APTTYPEQUALIFIER qualifier;
check_hresult(CoGetApartmentType(&type, &qualifier));
return (type == APTTYPE_STA) || (type == APTTYPE_MAINSTA);
}

static handle signal{ CreateEventW(nullptr, false, false, nullptr) };

IAsyncAction OtherForegroundAsync()
Expand All @@ -29,9 +37,9 @@ namespace

IAsyncAction ForegroundAsync(DispatcherQueue dispatcher)
{
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
REQUIRE(impl::is_sta());
REQUIRE(is_sta());

// This exercises one STA thread waiting on another thus one context callback
// completing on another.
Expand All @@ -48,9 +56,9 @@ namespace

fire_and_forget SignalFromForeground(DispatcherQueue dispatcher)
{
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());
co_await resume_foreground(dispatcher);
REQUIRE(impl::is_sta());
REQUIRE(is_sta());

// Previously, this signal was never raised because the foreground thread
// was always blocked waiting for ContextCallback to return.
Expand All @@ -61,19 +69,19 @@ namespace
{
// Switch to a background (MTA) thread.
co_await resume_background();
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// This exercises one MTA thread waiting on another and just completing
// directly without the overhead of a context switch.
co_await OtherBackgroundAsync();
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// Wait for a coroutine that completes on a foreground (STA) thread.
co_await ForegroundAsync(dispatcher);

// Resumption should automatically switch to a background (MTA) thread
// without blocking the Completed handler (which would in turn block the foreground thread).
REQUIRE(!impl::is_sta());
REQUIRE(!is_sta());

// Attempt to signal from the foreground thread under the assumption
// that the foreground thread is not blocked.
Expand Down

0 comments on commit a0b1889

Please sign in to comment.