Skip to content

Commit d24f22e

Browse files
committed
[ConstraintSystem] Refactor spots that create ApplicationFunction directly
This is required to make it easier to add "use dc" to the application.
1 parent f4101c4 commit d24f22e

File tree

7 files changed

+83
-44
lines changed

7 files changed

+83
-44
lines changed

include/swift/Sema/Constraint.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
582582
static Constraint *createApplicableFunction(
583583
ConstraintSystem &cs, Type argumentFnType, Type calleeType,
584584
std::optional<TrailingClosureMatching> trailingClosureMatching,
585-
ConstraintLocator *locator);
585+
DeclContext *useDC, ConstraintLocator *locator);
586586

587587
static Constraint *createSyntacticElement(ConstraintSystem &cs,
588588
ASTNode node,

include/swift/Sema/ConstraintSystem.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3565,6 +3565,11 @@ class ConstraintSystem {
35653565
void addConstraint(Requirement req, ConstraintLocatorBuilder locator,
35663566
bool isFavored = false);
35673567

3568+
void addApplicationConstraint(
3569+
FunctionType *appliedFn, Type calleeType,
3570+
std::optional<TrailingClosureMatching> trailingClosureMatching,
3571+
DeclContext *useDC, ConstraintLocatorBuilder locator);
3572+
35683573
/// Add the appropriate constraint for a contextual conversion.
35693574
void addContextualConversionConstraint(Expr *expr, Type conversionType,
35703575
ContextualTypePurpose purpose,
@@ -4962,6 +4967,7 @@ class ConstraintSystem {
49624967
SolutionKind simplifyApplicableFnConstraint(
49634968
Type type1, Type type2,
49644969
std::optional<TrailingClosureMatching> trailingClosureMatching,
4970+
DeclContext *useDC,
49654971
TypeMatchOptions flags, ConstraintLocatorBuilder locator);
49664972

49674973
/// Attempt to simplify the DynamicCallableApplicableFunction constraint.

lib/Sema/CSGen.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,9 @@ namespace {
395395

396396
// Add the constraint that the index expression's type be convertible
397397
// to the input type of the subscript operator.
398-
CS.addConstraint(ConstraintKind::ApplicableFunction,
399-
FunctionType::get(params, outputTy),
400-
memberTy,
401-
fnLocator);
398+
CS.addApplicationConstraint(FunctionType::get(params, outputTy), memberTy,
399+
/*trailingClosureMatching=*/std::nullopt,
400+
CurDC, fnLocator);
402401

403402
Type fixedOutputType =
404403
CS.getFixedTypeRecursive(outputTy, /*wantRValue=*/false);
@@ -721,9 +720,9 @@ namespace {
721720
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult),
722721
TVO_CanBindToNoEscape);
723722

724-
CS.addConstraint(ConstraintKind::ApplicableFunction,
725-
FunctionType::get(params, resultType), memberType,
726-
fnLoc);
723+
CS.addApplicationConstraint(
724+
FunctionType::get(params, resultType), memberType,
725+
/*trailingClosureMatching=*/std::nullopt, CurDC, fnLoc);
727726

728727
if (constr->isFailable())
729728
return OptionalType::get(witnessType);
@@ -2545,10 +2544,10 @@ namespace {
25452544
SmallVector<AnyFunctionType::Param, 8> params;
25462545
getMatchingParams(expr->getArgs(), params);
25472546

2548-
CS.addConstraint(ConstraintKind::ApplicableFunction,
2549-
FunctionType::get(params, resultType, extInfo),
2550-
CS.getType(fnExpr),
2551-
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));
2547+
CS.addApplicationConstraint(
2548+
FunctionType::get(params, resultType, extInfo), CS.getType(fnExpr),
2549+
/*trailingClosureMatching=*/std::nullopt, CurDC,
2550+
CS.getConstraintLocator(expr, ConstraintLocator::ApplyFunction));
25522551

25532552
// If we ended up resolving the result type variable to a concrete type,
25542553
// set it as the favored type for this expression.
@@ -3296,10 +3295,11 @@ namespace {
32963295
CS.getConstraintLocator(expr, ConstraintLocator::FunctionResult),
32973296
TVO_CanBindToNoEscape);
32983297

3299-
CS.addConstraint(
3300-
ConstraintKind::ApplicableFunction,
3298+
CS.addApplicationConstraint(
33013299
FunctionType::get(params, resultType),
33023300
macroRefType,
3301+
/*trailingClosureMatching=*/std::nullopt,
3302+
CurDC,
33033303
CS.getConstraintLocator(
33043304
expr, ConstraintLocator::ApplyFunction));
33053305

lib/Sema/CSSimplify.cpp

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8364,8 +8364,9 @@ ConstraintSystem::simplifyConstructionConstraint(
83648364
paramTypeVar, locator);
83658365
}
83668366

8367-
addConstraint(ConstraintKind::ApplicableFunction, fnType, memberType,
8368-
fnLocator);
8367+
addApplicationConstraint(fnType, memberType,
8368+
/*trailingClosureMatching=*/std::nullopt, useDC,
8369+
fnLocator);
83698370

83708371
return SolutionKind::Solved;
83718372
}
@@ -13108,6 +13109,7 @@ createImplicitRootForCallAsFunction(ConstraintSystem &cs, Type refType,
1310813109
ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1310913110
Type type1, Type type2,
1311013111
std::optional<TrailingClosureMatching> trailingClosureMatching,
13112+
DeclContext *useDC,
1311113113
TypeMatchOptions flags, ConstraintLocatorBuilder locator) {
1311213114
auto &ctx = getASTContext();
1311313115

@@ -13171,7 +13173,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1317113173
auto formUnsolved = [&](bool activate = false) {
1317213174
if (flags.contains(TMF_GenerateConstraints)) {
1317313175
auto *application = Constraint::createApplicableFunction(
13174-
*this, type1, type2, trailingClosureMatching,
13176+
*this, type1, type2, trailingClosureMatching, useDC,
1317513177
getConstraintLocator(locator));
1317613178

1317713179
addUnsolvedConstraint(application);
@@ -13232,8 +13234,8 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1323213234
/*outerAlternatives*/ {}, memberLoc);
1323313235
// Add new applicable function constraint based on the member type
1323413236
// variable.
13235-
addConstraint(ConstraintKind::ApplicableFunction, func1, memberTy,
13236-
locator);
13237+
addApplicationConstraint(func1, memberTy, trailingClosureMatching, useDC,
13238+
locator);
1323713239
return SolutionKind::Solved;
1323813240
}
1323913241

@@ -13348,9 +13350,9 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1334813350
// Form an unsolved constraint to apply trailing closures to a
1334913351
// callable type produced by `.init`. This constraint would become
1335013352
// active when `callableType` is bound.
13351-
addUnsolvedConstraint(Constraint::create(
13352-
*this, ConstraintKind::ApplicableFunction, callAsFunctionArguments,
13353-
callableType,
13353+
addUnsolvedConstraint(Constraint::createApplicableFunction(
13354+
*this, callAsFunctionArguments, callableType,
13355+
trailingClosureMatching, useDC,
1335413356
getConstraintLocator(implicitRef,
1335513357
ConstraintLocator::ApplyFunction)));
1335613358
break;
@@ -13368,12 +13370,13 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1336813370

1336913371
auto applyLocator = getConstraintLocator(locator);
1337013372
auto forwardConstraint = Constraint::createApplicableFunction(
13371-
*this, type1, type2, TrailingClosureMatching::Forward, applyLocator);
13373+
*this, type1, type2, TrailingClosureMatching::Forward, useDC,
13374+
applyLocator);
1337213375
auto backwardConstraint = Constraint::createApplicableFunction(
13373-
*this, type1, type2, TrailingClosureMatching::Backward,
13376+
*this, type1, type2, TrailingClosureMatching::Backward, useDC,
1337413377
applyLocator);
13375-
addDisjunctionConstraint(
13376-
{ forwardConstraint, backwardConstraint}, applyLocator);
13378+
addDisjunctionConstraint({forwardConstraint, backwardConstraint},
13379+
applyLocator);
1337713380
break;
1337813381
}
1337913382

@@ -13450,7 +13453,7 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyApplicableFnConstraint(
1345013453
// Construct the instance from the input arguments.
1345113454
auto simplified = simplifyConstructionConstraint(
1345213455
instance2, func1, subflags,
13453-
/*FIXME?*/ DC, FunctionRefInfo::singleBaseNameApply(),
13456+
useDC, FunctionRefInfo::singleBaseNameApply(),
1345413457
getConstraintLocator(outerLocator));
1345513458

1345613459
// Record any fixes we attempted to get to the correct solution.
@@ -15673,15 +15676,6 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1567315676
case ConstraintKind::BridgingConversion:
1567415677
return simplifyBridgingConstraint(first, second, subflags, locator);
1567515678

15676-
case ConstraintKind::ApplicableFunction: {
15677-
// First try to simplify the overload set for the function being applied.
15678-
if (simplifyAppliedOverloads(second, first->castTo<FunctionType>(),
15679-
locator)) {
15680-
return SolutionKind::Error;
15681-
}
15682-
return simplifyApplicableFnConstraint(first, second, std::nullopt, subflags,
15683-
locator);
15684-
}
1568515679
case ConstraintKind::DynamicCallableApplicableFunction:
1568615680
return simplifyDynamicCallableApplicableFnConstraint(first, second,
1568715681
subflags, locator);
@@ -15762,6 +15756,7 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first,
1576215756
case ConstraintKind::KeyPathApplication:
1576315757
case ConstraintKind::FallbackType:
1576415758
case ConstraintKind::SyntacticElement:
15759+
case ConstraintKind::ApplicableFunction:
1576515760
llvm_unreachable("Use the correct addConstraint()");
1576615761
}
1576715762

@@ -15990,6 +15985,41 @@ void ConstraintSystem::addConstraint(ConstraintKind kind, Type first,
1599015985
}
1599115986
}
1599215987

15988+
void ConstraintSystem::addApplicationConstraint(
15989+
FunctionType *appliedFn, Type calleeType,
15990+
std::optional<TrailingClosureMatching> trailingClosureMatching,
15991+
DeclContext *useDC,
15992+
ConstraintLocatorBuilder locator) {
15993+
auto recordFailure = [&]() {
15994+
if (shouldRecordFailedConstraint()) {
15995+
auto *c = Constraint::createApplicableFunction(
15996+
*this, appliedFn, calleeType, trailingClosureMatching, useDC,
15997+
getConstraintLocator(locator));
15998+
recordFailedConstraint(c);
15999+
}
16000+
};
16001+
16002+
// First try to simplify the overload set for the function being applied.
16003+
if (simplifyAppliedOverloads(calleeType, appliedFn, locator)) {
16004+
recordFailure();
16005+
return;
16006+
}
16007+
16008+
switch (simplifyApplicableFnConstraint(appliedFn, calleeType,
16009+
trailingClosureMatching, useDC,
16010+
TMF_GenerateConstraints, locator)) {
16011+
case SolutionKind::Error:
16012+
recordFailure();
16013+
break;
16014+
16015+
case SolutionKind::Unsolved:
16016+
llvm_unreachable("should have generated constraints");
16017+
16018+
case SolutionKind::Solved:
16019+
return;
16020+
}
16021+
}
16022+
1599316023
void ConstraintSystem::addContextualConversionConstraint(
1599416024
Expr *expr, Type conversionType, ContextualTypePurpose purpose,
1599516025
ConstraintLocator *locator) {
@@ -16175,7 +16205,8 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
1617516205
case ConstraintKind::ApplicableFunction:
1617616206
return simplifyApplicableFnConstraint(
1617716207
constraint.getFirstType(), constraint.getSecondType(),
16178-
constraint.getTrailingClosureMatching(), std::nullopt,
16208+
constraint.getTrailingClosureMatching(),
16209+
/*FIXME*/DC, /*flags=*/std::nullopt,
1617916210
constraint.getLocator());
1618016211

1618116212
case ConstraintKind::DynamicCallableApplicableFunction:

lib/Sema/Constraint.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ Constraint *Constraint::createConjunction(
988988
Constraint *Constraint::createApplicableFunction(
989989
ConstraintSystem &cs, Type argumentFnType, Type calleeType,
990990
std::optional<TrailingClosureMatching> trailingClosureMatching,
991-
ConstraintLocator *locator) {
991+
DeclContext *useDC, ConstraintLocator *locator) {
992992
// Collect type variables.
993993
SmallPtrSet<TypeVariableType *, 4> typeVars;
994994
if (argumentFnType->hasTypeVariable())

lib/Sema/TypeOfReference.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,8 +1933,9 @@ void ConstraintSystem::bindOverloadType(
19331933
{FunctionType::Param(argTy, ctx.Id_dynamicMember)}, resultTy);
19341934

19351935
ConstraintLocatorBuilder builder(callLoc);
1936-
addConstraint(ConstraintKind::ApplicableFunction, callerTy, fnTy,
1937-
builder.withPathElement(ConstraintLocator::ApplyFunction));
1936+
addApplicationConstraint(
1937+
callerTy, fnTy, /*trailingClosureMatching=*/std::nullopt, useDC,
1938+
builder.withPathElement(ConstraintLocator::ApplyFunction));
19381939

19391940
if (isExpr<KeyPathExpr>(locator->getAnchor())) {
19401941
auto paramTy = fnTy->getParams()[0].getParameterType();
@@ -2088,8 +2089,9 @@ void ConstraintSystem::bindOverloadType(
20882089
// Add a constraint for the inner application that uses the args of the
20892090
// original call-site, and a fresh type var result equal to the leaf type.
20902091
ConstraintLocatorBuilder kpLocBuilder(keyPathLoc);
2091-
addConstraint(
2092-
ConstraintKind::ApplicableFunction, adjustedFnTy, memberTy,
2092+
addApplicationConstraint(
2093+
adjustedFnTy, memberTy, /*trailingClosureMatching=*/std::nullopt,
2094+
useDC,
20932095
kpLocBuilder.withPathElement(ConstraintLocator::ApplyFunction));
20942096

20952097
addConstraint(ConstraintKind::Equal, subscriptResultTy, leafTy,

unittests/Sema/ConstraintSimplificationTests.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ TEST_F(SemaTest, TestTrailingClosureMatchRecordingForIdenticalFunctions) {
3030
auto func = FunctionType::get(
3131
{FunctionType::Param(intType), FunctionType::Param(intType)}, floatType);
3232

33-
cs.addConstraint(
34-
ConstraintKind::ApplicableFunction, func, func,
33+
cs.addApplicationConstraint(
34+
func, func, /*trailingClosureMatching=*/std::nullopt, DC,
3535
cs.getConstraintLocator({}, ConstraintLocator::ApplyFunction));
3636

3737
SmallVector<Solution, 2> solutions;

0 commit comments

Comments
 (0)