Skip to content

Commit 0ba886d

Browse files
authored
[AutoDiff] Fix adjoint for move_value (swiftlang#78286)
Since `move_value` is a destroying operation, the adjoint of `y = move_value x` should be `adj[x] += adj[y]; adj[y] = 0` instead of just `adj[x] += adj[y]`.
1 parent 26b2f7b commit 0ba886d

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,13 +1736,21 @@ class PullbackCloner::Implementation final
17361736
/// move_value, begin_borrow.
17371737
/// Original: y = copy_value x
17381738
/// Adjoint: adj[x] += adj[y]
1739-
void visitValueOwnershipInst(SingleValueInstruction *svi) {
1739+
void visitValueOwnershipInst(SingleValueInstruction *svi,
1740+
bool needZeroResAdj = false) {
17401741
assert(svi->getNumOperands() == 1);
17411742
auto *bb = svi->getParent();
17421743
switch (getTangentValueCategory(svi)) {
17431744
case SILValueCategory::Object: {
17441745
auto adj = getAdjointValue(bb, svi);
17451746
addAdjointValue(bb, svi->getOperand(0), adj, svi->getLoc());
1747+
if (needZeroResAdj) {
1748+
assert(svi->getNumResults() == 1);
1749+
SILValue val = svi->getResult(0);
1750+
setAdjointValue(
1751+
bb, val,
1752+
makeZeroAdjointValue(getRemappedTangentType(val->getType())));
1753+
}
17461754
break;
17471755
}
17481756
case SILValueCategory::Address: {
@@ -1768,8 +1776,16 @@ class PullbackCloner::Implementation final
17681776

17691777
/// Handle `move_value` instruction.
17701778
/// Original: y = move_value x
1771-
/// Adjoint: adj[x] += adj[y]
1772-
void visitMoveValueInst(MoveValueInst *mvi) { visitValueOwnershipInst(mvi); }
1779+
/// Adjoint: adj[x] += adj[y]; adj[y] = 0
1780+
void visitMoveValueInst(MoveValueInst *mvi) {
1781+
switch (getTangentValueCategory(mvi)) {
1782+
case SILValueCategory::Address:
1783+
llvm::report_fatal_error("AutoDiff does not support move_value with "
1784+
"SILValueCategory::Address");
1785+
case SILValueCategory::Object:
1786+
visitValueOwnershipInst(mvi, /*needZeroResAdj=*/true);
1787+
}
1788+
}
17731789

17741790
void visitEndInitLetRefInst(EndInitLetRefInst *eir) { visitValueOwnershipInst(eir); }
17751791

test/AutoDiff/SILOptimizer/pullback_generation.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,25 @@ func f4(a: NonTrivial) -> Float {
198198
// CHECK: %95 = metatype $@thick Float.Type
199199
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200200
// CHECK: destroy_value %67 : $NonTrivial
201+
202+
@differentiable(reverse)
203+
func move_value(x: Float) -> Float {
204+
var result = x
205+
repeat {
206+
let temp = result
207+
result = temp
208+
} while 0 == 1
209+
return result
210+
}
211+
212+
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation10move_value1xS2f_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> Float {
213+
// CHECK: bb3(%[[#]] : $Float, %[[#]] : $Float, %[[#]] : $Float, %[[#]] : $(predecessor: _AD__$s19pullback_generation10move_value1xS2f_tF_bb1__Pred__src_0_wrt_0)):
214+
// CHECK: %[[#]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
215+
// CHECK: %[[#T1:]] = alloc_stack $Float
216+
// CHECK: %[[#T2:]] = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
217+
// CHECK: %[[#T3:]] = metatype $@thick Float.Type
218+
// CHECK: %[[#]] = apply %[[#T2]]<Float>(%[[#T1]], %[[#T3]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
219+
// CHECK: %[[#T4:]] = load [trivial] %[[#T1]]
220+
// CHECK: dealloc_stack %[[#T1]]
221+
// CHECK: bb4(%113 : $Builtin.RawPointer):
222+
// CHECK: br bb5(%[[#]] : $Float, %[[#]] : $Float, %[[#T4]] : $Float, %[[#]] : $(predecessor: _AD__$s19pullback_generation10move_value1xS2f_tF_bb2__Pred__src_0_wrt_0))

0 commit comments

Comments
 (0)