Skip to content

Commit 527e52e

Browse files
committed
Extend type checker hack for @_unsafeInheritExecutor functions to methods
With the re-introduction of `@_unsafeInheritExecutor` for `TaskLocal.withValue`, we need to extend the type checker trick with `_unsafeInheritExecutor_`-prefixed functions to work with methods. Do so to make `TaskLocal.withValue` actually work this way.
1 parent 0232182 commit 527e52e

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

lib/Sema/ConstraintSystem.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,13 @@ LookupResult &ConstraintSystem::lookupMember(Type base, DeclNameRef name,
292292
result = TypeChecker::lookupMember(DC, base, name, loc,
293293
defaultMemberLookupOptions);
294294

295+
// If we are in an @_unsafeInheritExecutor context, swap out
296+
// declarations for their _unsafeInheritExecutor_ counterparts if they
297+
// exist.
298+
if (enclosingUnsafeInheritsExecutor(DC)) {
299+
introduceUnsafeInheritExecutorReplacements(DC, base, loc, *result);
300+
}
301+
295302
// If we aren't performing dynamic lookup, we're done.
296303
if (!*result || !base->isAnyObject())
297304
return *result;

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,17 +2115,18 @@ void swift::introduceUnsafeInheritExecutorReplacements(
21152115
// Make sure at least some of the entries are functions in the _Concurrency
21162116
// module.
21172117
ModuleDecl *concurrencyModule = nullptr;
2118+
DeclBaseName baseName;
21182119
for (auto decl: decls) {
21192120
if (isReplaceable(decl)) {
21202121
concurrencyModule = decl->getDeclContext()->getParentModule();
2122+
baseName = decl->getName().getBaseName();
21212123
break;
21222124
}
21232125
}
21242126
if (!concurrencyModule)
21252127
return;
21262128

2127-
// Dig out the name.
2128-
auto baseName = decls.front()->getName().getBaseName();
2129+
// Ignore anything with a special name.
21292130
if (baseName.isSpecial())
21302131
return;
21312132

@@ -2153,6 +2154,60 @@ void swift::introduceUnsafeInheritExecutorReplacements(
21532154
}
21542155
}
21552156

2157+
void swift::introduceUnsafeInheritExecutorReplacements(
2158+
const DeclContext *dc, Type base, SourceLoc loc, LookupResult &lookup) {
2159+
if (lookup.empty())
2160+
return;
2161+
2162+
auto baseNominal = base->getAnyNominal();
2163+
if (!baseNominal || !inConcurrencyModule(baseNominal))
2164+
return;
2165+
2166+
auto isReplaceable = [&](ValueDecl *decl) {
2167+
return isa<FuncDecl>(decl) && inConcurrencyModule(decl->getDeclContext());
2168+
};
2169+
2170+
// Make sure at least some of the entries are functions in the _Concurrency
2171+
// module.
2172+
ModuleDecl *concurrencyModule = nullptr;
2173+
DeclBaseName baseName;
2174+
for (auto &result: lookup) {
2175+
auto decl = result.getValueDecl();
2176+
if (isReplaceable(decl)) {
2177+
concurrencyModule = decl->getDeclContext()->getParentModule();
2178+
baseName = decl->getBaseName();
2179+
break;
2180+
}
2181+
}
2182+
if (!concurrencyModule)
2183+
return;
2184+
2185+
// Ignore anything with a special name.
2186+
if (baseName.isSpecial())
2187+
return;
2188+
2189+
// Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2190+
ASTContext &ctx = base->getASTContext();
2191+
Identifier newIdentifier = ctx.getIdentifier(
2192+
("_unsafeInheritExecutor_" + baseName.getIdentifier().str()).str());
2193+
2194+
LookupResult replacementLookup = TypeChecker::lookupMember(
2195+
const_cast<DeclContext *>(dc), base, DeclNameRef(newIdentifier), loc,
2196+
defaultMemberLookupOptions);
2197+
if (replacementLookup.innerResults().empty())
2198+
return;
2199+
2200+
// Drop all of the _Concurrency entries in favor of the ones found by this
2201+
// lookup.
2202+
lookup.filter([&](const LookupResultEntry &entry, bool) {
2203+
return !isReplaceable(entry.getValueDecl());
2204+
});
2205+
2206+
for (const auto &entry: replacementLookup.innerResults()) {
2207+
lookup.add(entry, /*isOuter=*/false);
2208+
}
2209+
}
2210+
21562211
/// Check if it is safe for the \c globalActor qualifier to be removed from
21572212
/// \c ty, when the function value of that type is isolated to that actor.
21582213
///

lib/Sema/TypeCheckConcurrency.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class EnumElementDecl;
4444
class Expr;
4545
class FuncDecl;
4646
class Initializer;
47+
class LookupResult;
4748
class PatternBindingDecl;
4849
class ProtocolConformance;
4950
class TopLevelCodeDecl;
@@ -668,6 +669,18 @@ void replaceUnsafeInheritExecutorWithDefaultedIsolationParam(
668669
void introduceUnsafeInheritExecutorReplacements(
669670
const DeclContext *dc, SourceLoc loc, SmallVectorImpl<ValueDecl *> &decls);
670671

672+
/// Replace any functions in this list that were found in the _Concurrency
673+
/// module as a member on "base" and have _unsafeInheritExecutor_-prefixed
674+
/// versions with those _unsafeInheritExecutor_-prefixed versions.
675+
///
676+
/// This function is an egregious hack that allows us to introduce the
677+
/// #isolation-based versions of functions into the concurrency library
678+
/// without breaking clients that use @_unsafeInheritExecutor. Since those
679+
/// clients can't use #isolation (it doesn't work with @_unsafeInheritExecutor),
680+
/// we route them to the @_unsafeInheritExecutor versions implicitly.
681+
void introduceUnsafeInheritExecutorReplacements(
682+
const DeclContext *dc, Type base, SourceLoc loc, LookupResult &result);
683+
671684
} // end namespace swift
672685

673686
namespace llvm {

test/Concurrency/unsafe_inherit_executor.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ func unsafeCallerAvoidsNewLoop() async throws {
113113
} onCancel: {
114114
}
115115

116-
TL.$string.withValue("hello") {
116+
await TL.$string.withValue("hello") {
117117
print(TL.string)
118118
}
119+
120+
func operation() async throws -> Int { 7 }
121+
try await TL.$string.withValue("hello", operation: operation)
119122
}

0 commit comments

Comments
 (0)