47
47
#include " swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
48
48
#include " llvm/ADT/APSInt.h"
49
49
#include " llvm/ADT/BreadthFirstIterator.h"
50
+ #include " llvm/ADT/DenseMap.h"
50
51
#include " llvm/ADT/DenseSet.h"
51
52
#include " llvm/ADT/SmallSet.h"
52
53
#include " llvm/Support/CommandLine.h"
@@ -84,6 +85,9 @@ class DifferentiationTransformer {
84
85
// / Context necessary for performing the transformations.
85
86
ADContext context;
86
87
88
+ // / Cache used in getUnwrappedCurryThunkFunction.
89
+ llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
90
+
87
91
// / Promotes the given `differentiable_function` instruction to a valid
88
92
// / `@differentiable` function-typed value.
89
93
SILValue promoteToDifferentiableFunction (DifferentiableFunctionInst *inst,
@@ -96,6 +100,25 @@ class DifferentiationTransformer {
96
100
SILBuilder &builder, SILLocation loc,
97
101
DifferentiationInvoker invoker);
98
102
103
+ // / Emits a reference to a derivative function of `original`, differentiated
104
+ // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
105
+ // / the derivative function and the actual indices that the derivative
106
+ // / function is with respect to.
107
+ // /
108
+ // / Returns `None` on failure, signifying that a diagnostic has been emitted
109
+ // / using `invoker`.
110
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
111
+ emitDerivativeFunctionReference (
112
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
113
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
114
+ DifferentiationInvoker invoker,
115
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
116
+
117
+ // / If the given function corresponds to AutoClosureExpr with either
118
+ // / SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
119
+ // / corresponding to the function being wrapped in the thunk.
120
+ SILFunction *getUnwrappedCurryThunkFunction (SILFunction *originalFn);
121
+
99
122
public:
100
123
// / Construct an `DifferentiationTransformer` for the given module.
101
124
explicit DifferentiationTransformer (SILModuleTransform &transform)
@@ -453,21 +476,63 @@ static SILValue reapplyFunctionConversion(
453
476
llvm_unreachable (" Unhandled function conversion instruction" );
454
477
}
455
478
456
- // / Emits a reference to a derivative function of `original`, differentiated
457
- // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458
- // / the derivative function and the actual indices that the derivative function
459
- // / is with respect to.
460
- // /
461
- // / Returns `None` on failure, signifying that a diagnostic has been emitted
462
- // / using `invoker`.
463
- static std::optional<std::pair<SILValue, AutoDiffConfig>>
464
- emitDerivativeFunctionReference (
465
- DifferentiationTransformer &transformer, SILBuilder &builder,
466
- const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467
- SILValue original, DifferentiationInvoker invoker,
468
- SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469
- ADContext &context = transformer.getContext ();
479
+ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction (
480
+ SILFunction *originalFn) {
481
+ auto *autoCE = dyn_cast_or_null<AutoClosureExpr>(
482
+ originalFn->getDeclRef ().getAbstractClosureExpr ());
483
+ if (autoCE == nullptr )
484
+ return nullptr ;
485
+
486
+ auto *ae = dyn_cast_or_null<ApplyExpr>(autoCE->getUnwrappedCurryThunkExpr ());
487
+ if (ae == nullptr )
488
+ return nullptr ;
470
489
490
+ AbstractFunctionDecl *afd = cast<AbstractFunctionDecl>(ae->getCalledValue (
491
+ /* skipFunctionConversions=*/ true ));
492
+ auto silFnIt = afdToSILFn.find (afd);
493
+ if (silFnIt == afdToSILFn.end ()) {
494
+ assert (afdToSILFn.empty () && " Expect all 'afdToSILFn' cache entries to be "
495
+ " filled at once on the first access attempt" );
496
+
497
+ SILModule *module = getTransform ().getModule ();
498
+ for (SILFunction ¤tFunc : module ->getFunctions ()) {
499
+ if (auto *currentAFD =
500
+ currentFunc.getDeclRef ().getAbstractFunctionDecl ()) {
501
+ // Update cache only with AFDs which might be potentially wrapped by a
502
+ // curry thunk. This includes member function references and references
503
+ // to functions having external property wrapper parameters (see
504
+ // ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
505
+ // in future, the assertion after the loop will be a trigger for such
506
+ // cases being unhandled here.
507
+ //
508
+ // FIXME: References to functions having external property wrapper
509
+ // parameters are not handled since we can't now construct a test case
510
+ // for that due to the crash
511
+ // https://github.com/swiftlang/swift/issues/77613
512
+ if (currentAFD->hasCurriedSelf ()) {
513
+ auto [_, wasEmplace] =
514
+ afdToSILFn.try_emplace (currentAFD, ¤tFunc);
515
+ assert (wasEmplace && " Expect all 'afdToSILFn' cache entries to be "
516
+ " filled at once on the first access attempt" );
517
+ }
518
+ }
519
+ }
520
+
521
+ silFnIt = afdToSILFn.find (afd);
522
+ assert (silFnIt != afdToSILFn.end () &&
523
+ " Expect present curry thunk to SIL function mapping after "
524
+ " 'afdToSILFn' cache fill" );
525
+ }
526
+
527
+ return silFnIt->second ;
528
+ }
529
+
530
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
531
+ DifferentiationTransformer::emitDerivativeFunctionReference (
532
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
533
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
534
+ DifferentiationInvoker invoker,
535
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471
536
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472
537
// matches the given kind and desired differentiation parameter indices,
473
538
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +675,36 @@ emitDerivativeFunctionReference(
610
675
DifferentiabilityKind::Reverse, desiredParameterIndices,
611
676
desiredResultIndices, derivativeConstrainedGenSig, /* jvp*/ nullptr ,
612
677
/* vjp*/ nullptr , /* isSerialized*/ false );
613
- if (transformer. canonicalizeDifferentiabilityWitness (
614
- minimalWitness, invoker, IsNotSerialized))
678
+ if (canonicalizeDifferentiabilityWitness (minimalWitness, invoker,
679
+ IsNotSerialized))
615
680
return std::nullopt;
616
681
}
617
682
assert (minimalWitness);
618
- if (original->getFunction ()->isSerialized () &&
619
- !hasPublicVisibility (minimalWitness->getLinkage ())) {
620
- enum { Inlinable = 0 , DefaultArgument = 1 };
621
- unsigned fragileKind = Inlinable;
622
- // FIXME: This is not a very robust way of determining if the function is
623
- // a default argument. Also, we have not exhaustively listed all the kinds
624
- // of fragility.
625
- if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
626
- fragileKind = DefaultArgument;
627
- context.emitNondifferentiabilityError (
628
- original, invoker, diag::autodiff_private_derivative_from_fragile,
629
- fragileKind,
630
- isa_and_nonnull<AbstractClosureExpr>(
631
- originalFRI->getLoc ().getAsASTNode <Expr>()));
632
- return std::nullopt;
683
+ if (original->getFunction ()->isSerialized ()) {
684
+ // When dealing with curry thunk, look at the function being wrapped
685
+ // inside implicit closure. If it has public visibility, the corresponding
686
+ // differentiability witness also has public visibility. It should be OK
687
+ // for implicit wrapper closure and its witness to have private linkage.
688
+ SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction (originalFn);
689
+ bool isWitnessPublic =
690
+ unwrappedFn == nullptr
691
+ ? hasPublicVisibility (minimalWitness->getLinkage ())
692
+ : hasPublicVisibility (unwrappedFn->getLinkage ());
693
+ if (!isWitnessPublic) {
694
+ enum { Inlinable = 0 , DefaultArgument = 1 };
695
+ unsigned fragileKind = Inlinable;
696
+ // FIXME: This is not a very robust way of determining if the function
697
+ // is a default argument. Also, we have not exhaustively listed all the
698
+ // kinds of fragility.
699
+ if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
700
+ fragileKind = DefaultArgument;
701
+ context.emitNondifferentiabilityError (
702
+ original, invoker, diag::autodiff_private_derivative_from_fragile,
703
+ fragileKind,
704
+ isa_and_nonnull<AbstractClosureExpr>(
705
+ originalFRI->getLoc ().getAsASTNode <Expr>()));
706
+ return std::nullopt;
707
+ }
633
708
}
634
709
// TODO(TF-482): Move generic requirement checking logic to
635
710
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1196,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1121
1196
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
1122
1197
AutoDiffDerivativeFunctionKind::VJP}) {
1123
1198
auto derivativeFnAndIndices = emitDerivativeFunctionReference (
1124
- * this , builder, desiredConfig, derivativeFnKind, origFnOperand,
1125
- invoker, newBuffersToDealloc);
1199
+ builder, desiredConfig, derivativeFnKind, origFnOperand, invoker ,
1200
+ newBuffersToDealloc);
1126
1201
// Show an error at the operator, highlight the argument, and show a note
1127
1202
// at the definition site of the argument.
1128
1203
if (!derivativeFnAndIndices)
0 commit comments