@@ -1736,13 +1736,21 @@ class PullbackCloner::Implementation final
1736
1736
// / move_value, begin_borrow.
1737
1737
// / Original: y = copy_value x
1738
1738
// / Adjoint: adj[x] += adj[y]
1739
- void visitValueOwnershipInst (SingleValueInstruction *svi) {
1739
+ void visitValueOwnershipInst (SingleValueInstruction *svi,
1740
+ bool needZeroResAdj = false ) {
1740
1741
assert (svi->getNumOperands () == 1 );
1741
1742
auto *bb = svi->getParent ();
1742
1743
switch (getTangentValueCategory (svi)) {
1743
1744
case SILValueCategory::Object: {
1744
1745
auto adj = getAdjointValue (bb, svi);
1745
1746
addAdjointValue (bb, svi->getOperand (0 ), adj, svi->getLoc ());
1747
+ if (needZeroResAdj) {
1748
+ assert (svi->getNumResults () == 1 );
1749
+ SILValue val = svi->getResult (0 );
1750
+ setAdjointValue (
1751
+ bb, val,
1752
+ makeZeroAdjointValue (getRemappedTangentType (val->getType ())));
1753
+ }
1746
1754
break ;
1747
1755
}
1748
1756
case SILValueCategory::Address: {
@@ -1768,8 +1776,16 @@ class PullbackCloner::Implementation final
1768
1776
1769
1777
// / Handle `move_value` instruction.
1770
1778
// / Original: y = move_value x
1771
- // / Adjoint: adj[x] += adj[y]
1772
- void visitMoveValueInst (MoveValueInst *mvi) { visitValueOwnershipInst (mvi); }
1779
+ // / Adjoint: adj[x] += adj[y]; adj[y] = 0
1780
+ void visitMoveValueInst (MoveValueInst *mvi) {
1781
+ switch (getTangentValueCategory (mvi)) {
1782
+ case SILValueCategory::Address:
1783
+ llvm::report_fatal_error (" AutoDiff does not support move_value with "
1784
+ " SILValueCategory::Address" );
1785
+ case SILValueCategory::Object:
1786
+ visitValueOwnershipInst (mvi, /* needZeroResAdj=*/ true );
1787
+ }
1788
+ }
1773
1789
1774
1790
void visitEndInitLetRefInst (EndInitLetRefInst *eir) { visitValueOwnershipInst (eir); }
1775
1791
0 commit comments