Skip to content

Commit 3914545

Browse files
committed
fix fwd ret
1 parent 01e3f39 commit 3914545

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ inline bool is_value_needed_in_reverse(
179179
}
180180
}
181181

182-
if (!TR.allFloat(const_cast<Value *>(inst)) &&
183-
gutils->mode != DerivativeMode::ForwardMode &&
184-
gutils->mode != DerivativeMode::ForwardModeError)
182+
if (!TR.allFloat(const_cast<Value *>(inst)))
185183
if (auto IVI = dyn_cast<Instruction>(user)) {
186184
bool inserted = false;
187185
if (auto II = dyn_cast<InsertValueInst>(IVI))
@@ -217,14 +215,32 @@ inline bool is_value_needed_in_reverse(
217215
}
218216

219217
bool partial = false;
220-
if (!gutils->isConstantValue(const_cast<Instruction *>(cur))) {
221-
partial = is_value_needed_in_reverse<QueryType::Shadow>(
222-
gutils, u, mode, seen, oldUnreachable);
223-
} else if (VT == QueryType::Shadow) {
224-
partial = is_value_needed_in_reverse<
225-
QueryType::ShadowByConstPrimal>(gutils, u, mode, seen,
226-
oldUnreachable);
218+
if (auto UI = dyn_cast<Instruction>(u)) {
219+
if (!gutils->isConstantValue(
220+
const_cast<Instruction *>(cur))) {
221+
bool recursiveUse = false;
222+
if (is_use_directly_needed_in_reverse(
223+
gutils, cur, mode, UI, oldUnreachable,
224+
QueryType::Shadow, &recursiveUse)) {
225+
partial = true;
226+
} else if (recursiveUse && !OneLevel) {
227+
partial = is_value_needed_in_reverse<QueryType::Shadow>(
228+
gutils, UI, mode, seen, oldUnreachable);
229+
}
230+
} else if (VT == QueryType::Shadow) {
231+
bool recursiveUse = false;
232+
if (is_use_directly_needed_in_reverse(
233+
gutils, cur, mode, UI, oldUnreachable,
234+
QueryType::ShadowByConstPrimal, &recursiveUse)) {
235+
partial = true;
236+
} else if (recursiveUse && !OneLevel) {
237+
partial = is_value_needed_in_reverse<
238+
QueryType::ShadowByConstPrimal>(gutils, UI, mode,
239+
seen, oldUnreachable);
240+
}
241+
}
227242
}
243+
228244
if (partial) {
229245

230246
if (EnzymePrintDiffUse)

0 commit comments

Comments
 (0)