Skip to content

Commit 97ba96b

Browse files
authored
[AutoDiff] Implement active Optional differentiation (swiftlang#74977)
Fixes swiftlang#74972
1 parent f8141e2 commit 97ba96b

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,6 +1803,57 @@ class PullbackCloner::Implementation final
18031803
builder.emitZeroIntoBuffer(uccai->getLoc(), adjDest, IsInitialization);
18041804
}
18051805

1806+
/// Handle `enum` instruction.
1807+
/// Original: y = enum $Enum, #Enum.some!enumelt, x
1808+
/// Adjoint: adj[x] += adj[y]
1809+
void visitEnumInst(EnumInst *ei) {
1810+
SILBasicBlock *bb = ei->getParent();
1811+
SILLocation loc = ei->getLoc();
1812+
auto *optionalEnumDecl = getASTContext().getOptionalDecl();
1813+
1814+
// Only `Optional`-typed operands are supported for now. Diagnose all other
1815+
// enum operand types.
1816+
if (ei->getType().getEnumOrBoundGenericEnum() != optionalEnumDecl) {
1817+
LLVM_DEBUG(getADDebugStream()
1818+
<< "Unsupported enum type in PullbackCloner: " << *ei);
1819+
getContext().emitNondifferentiabilityError(
1820+
ei, getInvoker(),
1821+
diag::autodiff_expression_not_differentiable_note);
1822+
errorOccurred = true;
1823+
return;
1824+
}
1825+
1826+
auto adjOpt = getAdjointValue(bb, ei);
1827+
auto adjStruct = materializeAdjointDirect(adjOpt, loc);
1828+
StructDecl *adjStructDecl =
1829+
adjStruct->getType().getStructOrBoundGenericStruct();
1830+
1831+
VarDecl *adjOptVar = nullptr;
1832+
if (adjStructDecl) {
1833+
ArrayRef<VarDecl *> properties = adjStructDecl->getStoredProperties();
1834+
adjOptVar = properties.size() == 1 ? properties[0] : nullptr;
1835+
}
1836+
1837+
EnumDecl *adjOptDecl =
1838+
adjOptVar ? adjOptVar->getTypeInContext()->getEnumOrBoundGenericEnum()
1839+
: nullptr;
1840+
1841+
// Optional<T>.TangentVector should be a struct with a single
1842+
// Optional<T.TangentVector> property. This is an implementation detail of
1843+
// OptionalDifferentiation.swift
1844+
// TODO: Maybe it would be better to have getters / setters here that we
1845+
// can call and hide this implementation detail?
1846+
if (!adjOptDecl || adjOptDecl != optionalEnumDecl)
1847+
llvm_unreachable("Unexpected type of Optional.TangentVector");
1848+
1849+
auto *adjVal = builder.createStructExtract(loc, adjStruct, adjOptVar);
1850+
1851+
EnumElementDecl *someElemDecl = getASTContext().getOptionalSomeDecl();
1852+
auto *adjData = builder.createUncheckedEnumData(loc, adjVal, someElemDecl);
1853+
1854+
addAdjointValue(bb, ei->getOperand(), makeConcreteAdjointValue(adjData), loc);
1855+
}
1856+
18061857
/// Handle a sequence of `init_enum_data_addr` and `inject_enum_addr`
18071858
/// instructions.
18081859
///

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,8 @@ func checked_cast_addr_nonactive_result<T: Differentiable>(_ x: T) -> T {
197197
// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2
198198
// CHECK: }
199199

200-
// expected-error @+1 {{function is not differentiable}}
201200
@differentiable(reverse)
202-
// expected-note @+1 {{when differentiating this function definition}}
203201
func checked_cast_addr_active_result<T: Differentiable>(x: T) -> T {
204-
// expected-note @+1 {{expression is not differentiable}}
205202
if let y = x as? Float {
206203
// Use `y: Float?` value in an active way.
207204
return y as! T
@@ -777,12 +774,9 @@ func testClassModifyAccessor(_ c: inout C) {
777774
// Enum differentiation
778775
//===----------------------------------------------------------------------===//
779776

780-
// expected-error @+1 {{function is not differentiable}}
781777
@differentiable(reverse)
782-
// expected-note @+1 {{when differentiating this function definition}}
783778
func testActiveOptional(_ x: Float) -> Float {
784779
var maybe: Float? = 10
785-
// expected-note @+1 {{expression is not differentiable}}
786780
maybe = x
787781
return maybe!
788782
}

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,9 @@ class C<T: Differentiable>: Differentiable {
147147
// Enum differentiation
148148
//===----------------------------------------------------------------------===//
149149

150-
// expected-error @+1 {{function is not differentiable}}
151150
@differentiable(reverse)
152-
// expected-note @+1 {{when differentiating this function definition}}
153151
func usesOptionals(_ x: Float) -> Float {
154152
var maybe: Float? = 10
155-
// expected-note @+1 {{expression is not differentiable}}
156153
maybe = x
157154
return maybe!
158155
}

test/AutoDiff/validation-test/optional.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ func optional_nil_coalescing(_ maybeX: Float?) -> Float {
2121
}
2222
*/
2323

24+
OptionalTests.test("Active") {
25+
@differentiable(reverse)
26+
func square(y: Float) -> Float? {
27+
return y * y
28+
}
29+
30+
expectEqual(gradient(at: 10, of: {y in square(y:y)!}), .init(20.0))
31+
}
32+
2433
OptionalTests.test("Let") {
2534
@differentiable(reverse)
2635
func optional_let(_ maybeX: Float?) -> Float {

0 commit comments

Comments
 (0)