Skip to content

Commit 1a42a0c

Browse files
authored
[AutoDiff] Support curry thunks differentiation in fragile funcs (swiftlang#77615)
Inside fragile functions, we expect function derivatives to be public, which could be achieved by either explicitly marking the functions as differentiable or having a public explicit derivative defined for them. This is obviously not possible for single and double curry thunks which are a special case of `AutoClosureExpr`. Instead of looking at the thunk itself, we unwrap it and look at the function being wrapped. While the thunk itself and its differentiability witness will not have public visibility, it's not an issue for the case where the function being wrapped (and its witness) have public visibility. Fixes swiftlang#54819 Fixes swiftlang#75776
1 parent d0e4f84 commit 1a42a0c

File tree

5 files changed

+181
-81
lines changed

5 files changed

+181
-81
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 108 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
4848
#include "llvm/ADT/APSInt.h"
4949
#include "llvm/ADT/BreadthFirstIterator.h"
50+
#include "llvm/ADT/DenseMap.h"
5051
#include "llvm/ADT/DenseSet.h"
5152
#include "llvm/ADT/SmallSet.h"
5253
#include "llvm/Support/CommandLine.h"
@@ -84,6 +85,9 @@ class DifferentiationTransformer {
8485
/// Context necessary for performing the transformations.
8586
ADContext context;
8687

88+
/// Cache used in getUnwrappedCurryThunkFunction.
89+
llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
90+
8791
/// Promotes the given `differentiable_function` instruction to a valid
8892
/// `@differentiable` function-typed value.
8993
SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
@@ -96,6 +100,25 @@ class DifferentiationTransformer {
96100
SILBuilder &builder, SILLocation loc,
97101
DifferentiationInvoker invoker);
98102

103+
/// Emits a reference to a derivative function of `original`, differentiated
104+
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
105+
/// the derivative function and the actual indices that the derivative
106+
/// function is with respect to.
107+
///
108+
/// Returns `None` on failure, signifying that a diagnostic has been emitted
109+
/// using `invoker`.
110+
std::optional<std::pair<SILValue, AutoDiffConfig>>
111+
emitDerivativeFunctionReference(
112+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
113+
AutoDiffDerivativeFunctionKind kind, SILValue original,
114+
DifferentiationInvoker invoker,
115+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
116+
117+
/// If the given function corresponds to AutoClosureExpr with either
118+
/// SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
119+
/// corresponding to the function being wrapped in the thunk.
120+
SILFunction *getUnwrappedCurryThunkFunction(SILFunction *originalFn);
121+
99122
public:
100123
/// Construct an `DifferentiationTransformer` for the given module.
101124
explicit DifferentiationTransformer(SILModuleTransform &transform)
@@ -453,21 +476,63 @@ static SILValue reapplyFunctionConversion(
453476
llvm_unreachable("Unhandled function conversion instruction");
454477
}
455478

456-
/// Emits a reference to a derivative function of `original`, differentiated
457-
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458-
/// the derivative function and the actual indices that the derivative function
459-
/// is with respect to.
460-
///
461-
/// Returns `None` on failure, signifying that a diagnostic has been emitted
462-
/// using `invoker`.
463-
static std::optional<std::pair<SILValue, AutoDiffConfig>>
464-
emitDerivativeFunctionReference(
465-
DifferentiationTransformer &transformer, SILBuilder &builder,
466-
const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467-
SILValue original, DifferentiationInvoker invoker,
468-
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469-
ADContext &context = transformer.getContext();
479+
SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
480+
SILFunction *originalFn) {
481+
auto *autoCE = dyn_cast_or_null<AutoClosureExpr>(
482+
originalFn->getDeclRef().getAbstractClosureExpr());
483+
if (autoCE == nullptr)
484+
return nullptr;
485+
486+
auto *ae = dyn_cast_or_null<ApplyExpr>(autoCE->getUnwrappedCurryThunkExpr());
487+
if (ae == nullptr)
488+
return nullptr;
470489

490+
AbstractFunctionDecl *afd = cast<AbstractFunctionDecl>(ae->getCalledValue(
491+
/*skipFunctionConversions=*/true));
492+
auto silFnIt = afdToSILFn.find(afd);
493+
if (silFnIt == afdToSILFn.end()) {
494+
assert(afdToSILFn.empty() && "Expect all 'afdToSILFn' cache entries to be "
495+
"filled at once on the first access attempt");
496+
497+
SILModule *module = getTransform().getModule();
498+
for (SILFunction &currentFunc : module->getFunctions()) {
499+
if (auto *currentAFD =
500+
currentFunc.getDeclRef().getAbstractFunctionDecl()) {
501+
// Update cache only with AFDs which might be potentially wrapped by a
502+
// curry thunk. This includes member function references and references
503+
// to functions having external property wrapper parameters (see
504+
// ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
505+
// in future, the assertion after the loop will be a trigger for such
506+
// cases being unhandled here.
507+
//
508+
// FIXME: References to functions having external property wrapper
509+
// parameters are not handled since we can't now construct a test case
510+
// for that due to the crash
511+
// https://github.com/swiftlang/swift/issues/77613
512+
if (currentAFD->hasCurriedSelf()) {
513+
auto [_, wasEmplace] =
514+
afdToSILFn.try_emplace(currentAFD, &currentFunc);
515+
assert(wasEmplace && "Expect all 'afdToSILFn' cache entries to be "
516+
"filled at once on the first access attempt");
517+
}
518+
}
519+
}
520+
521+
silFnIt = afdToSILFn.find(afd);
522+
assert(silFnIt != afdToSILFn.end() &&
523+
"Expect present curry thunk to SIL function mapping after "
524+
"'afdToSILFn' cache fill");
525+
}
526+
527+
return silFnIt->second;
528+
}
529+
530+
std::optional<std::pair<SILValue, AutoDiffConfig>>
531+
DifferentiationTransformer::emitDerivativeFunctionReference(
532+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
533+
AutoDiffDerivativeFunctionKind kind, SILValue original,
534+
DifferentiationInvoker invoker,
535+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471536
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472537
// matches the given kind and desired differentiation parameter indices,
473538
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +675,36 @@ emitDerivativeFunctionReference(
610675
DifferentiabilityKind::Reverse, desiredParameterIndices,
611676
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
612677
/*vjp*/ nullptr, /*isSerialized*/ false);
613-
if (transformer.canonicalizeDifferentiabilityWitness(
614-
minimalWitness, invoker, IsNotSerialized))
678+
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
679+
IsNotSerialized))
615680
return std::nullopt;
616681
}
617682
assert(minimalWitness);
618-
if (original->getFunction()->isSerialized() &&
619-
!hasPublicVisibility(minimalWitness->getLinkage())) {
620-
enum { Inlinable = 0, DefaultArgument = 1 };
621-
unsigned fragileKind = Inlinable;
622-
// FIXME: This is not a very robust way of determining if the function is
623-
// a default argument. Also, we have not exhaustively listed all the kinds
624-
// of fragility.
625-
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
626-
fragileKind = DefaultArgument;
627-
context.emitNondifferentiabilityError(
628-
original, invoker, diag::autodiff_private_derivative_from_fragile,
629-
fragileKind,
630-
isa_and_nonnull<AbstractClosureExpr>(
631-
originalFRI->getLoc().getAsASTNode<Expr>()));
632-
return std::nullopt;
683+
if (original->getFunction()->isSerialized()) {
684+
// When dealing with curry thunk, look at the function being wrapped
685+
// inside implicit closure. If it has public visibility, the corresponding
686+
// differentiability witness also has public visibility. It should be OK
687+
// for implicit wrapper closure and its witness to have private linkage.
688+
SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction(originalFn);
689+
bool isWitnessPublic =
690+
unwrappedFn == nullptr
691+
? hasPublicVisibility(minimalWitness->getLinkage())
692+
: hasPublicVisibility(unwrappedFn->getLinkage());
693+
if (!isWitnessPublic) {
694+
enum { Inlinable = 0, DefaultArgument = 1 };
695+
unsigned fragileKind = Inlinable;
696+
// FIXME: This is not a very robust way of determining if the function
697+
// is a default argument. Also, we have not exhaustively listed all the
698+
// kinds of fragility.
699+
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
700+
fragileKind = DefaultArgument;
701+
context.emitNondifferentiabilityError(
702+
original, invoker, diag::autodiff_private_derivative_from_fragile,
703+
fragileKind,
704+
isa_and_nonnull<AbstractClosureExpr>(
705+
originalFRI->getLoc().getAsASTNode<Expr>()));
706+
return std::nullopt;
707+
}
633708
}
634709
// TODO(TF-482): Move generic requirement checking logic to
635710
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1196,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
11211196
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
11221197
AutoDiffDerivativeFunctionKind::VJP}) {
11231198
auto derivativeFnAndIndices = emitDerivativeFunctionReference(
1124-
*this, builder, desiredConfig, derivativeFnKind, origFnOperand,
1125-
invoker, newBuffersToDealloc);
1199+
builder, desiredConfig, derivativeFnKind, origFnOperand, invoker,
1200+
newBuffersToDealloc);
11261201
// Show an error at the operator, highlight the argument, and show a note
11271202
// at the definition site of the argument.
11281203
if (!derivativeFnAndIndices)

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -771,32 +771,6 @@ public func fragileDifferentiable(_ x: Float) -> Float {
771771
implicitlyDifferentiableFromFragile(x)
772772
}
773773

774-
775-
// FIXME: Differentiable curry thunk RequirementMachine error (rdar://87429620, https://github.com/apple/swift/issues/54819).
776-
#if false
777-
// TF-1208: Test curry thunk differentiation regression.
778-
public struct Struct_54819<Scalar> {
779-
var x: Scalar
780-
}
781-
extension Struct_54819: Differentiable where Scalar: Differentiable {
782-
@differentiable(reverse)
783-
public static func id(x: Self) -> Self {
784-
return x
785-
}
786-
}
787-
@differentiable(reverse, wrt: x)
788-
public func f_54819<Scalar: Differentiable>(
789-
_ x: Struct_54819<Scalar>,
790-
// NOTE(TF-1208): This diagnostic is unexpected because `Struct_54819.id` is marked `@differentiable`.
791-
// xpected-error @+3 2 {{function is not differentiable}}
792-
// xpected-note @+2 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}}
793-
// xpected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
794-
reduction: @differentiable(reverse) (Struct_54819<Scalar>) -> Struct_54819<Scalar> = Struct_54819.id
795-
) -> Struct_54819<Scalar> {
796-
reduction(x)
797-
}
798-
#endif
799-
800774
//===----------------------------------------------------------------------===//
801775
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
802776
//===----------------------------------------------------------------------===//
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null
2+
3+
import _Differentiation
4+
5+
/// Minimal reproducer for both single and double curry thunk
6+
7+
@inlinable
8+
func caller<Thing: Differentiable & FloatingPoint>(
9+
of f: @differentiable(reverse) (_: Thing) -> Thing
10+
) -> Int where Thing.TangentVector == Thing {
11+
return 42
12+
}
13+
14+
public struct Struct<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
15+
@inlinable
16+
static func foo_single() -> Int {
17+
return caller(of: callee_single) // No error expected
18+
}
19+
20+
@inlinable
21+
@differentiable(reverse)
22+
static func callee_single(input: Thing) -> Thing {
23+
return input
24+
}
25+
26+
@inlinable
27+
func foo_double() -> Int {
28+
return caller(of: callee_double) // No error expected
29+
}
30+
31+
@inlinable
32+
@differentiable(reverse)
33+
func callee_double(input: Thing) -> Thing {
34+
return input
35+
}
36+
}
37+
38+
/// Reproducer from https://github.com/swiftlang/swift/issues/75776
39+
40+
public struct Solution2<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
41+
@inlinable
42+
public static func optimization() -> Thing {
43+
var initial = Thing.zero
44+
let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected
45+
initial.move(by: delta)
46+
return initial
47+
}
48+
49+
@inlinable
50+
@differentiable(reverse)
51+
static func simulationWithLoss(input: Thing) -> Thing {
52+
return input // implementation
53+
}
54+
}
55+
56+
/// Reproducer from https://github.com/swiftlang/swift/issues/54819
57+
58+
public struct TF_688_Struct<Scalar> {
59+
var x: Scalar
60+
}
61+
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
62+
@differentiable(reverse)
63+
public static func id(x: Self) -> Self {
64+
return x
65+
}
66+
}
67+
@differentiable(reverse, wrt: x)
68+
public func TF_688<Scalar: Differentiable>(
69+
_ x: TF_688_Struct<Scalar>,
70+
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id // No error expected
71+
) -> TF_688_Struct<Scalar> {
72+
reduction(x)
73+
}

test/AutoDiff/SILOptimizer/generics.swift

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -250,27 +250,6 @@ extension TF_682_Proto where Self : Differentiable,
250250
}
251251
}
252252

253-
// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
254-
/*
255-
// TF-688: Test generic curry thunk cloning.
256-
public struct TF_688_Struct<Scalar> {
257-
var x: Scalar
258-
}
259-
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
260-
@differentiable(reverse)
261-
public static func id(x: Self) -> Self {
262-
return x
263-
}
264-
}
265-
@differentiable(reverse, wrt: x)
266-
public func TF_688<Scalar: Differentiable>(
267-
_ x: TF_688_Struct<Scalar>,
268-
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
269-
) -> TF_688_Struct<Scalar> {
270-
reduction(x)
271-
}
272-
*/
273-
274253
// TF-697: Test generic requirements of generated derivative function.
275254
protocol TF_697_Module: Differentiable {
276255
associatedtype Input

test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift renamed to test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
// XFAIL: *
32

43
// rdar://87429620
54
// https://github.com/apple/swift/issues/54819

0 commit comments

Comments
 (0)