Open
Description
Description
For some functions with ternary operators, derivative is calculated incorrectly (ternary
function in reproducer). When changing ternary operators to if statements, the calculation becomes correct (ifelse
function in reproducer). The function results themselves are calculated correctly for both cases.
Particularly, for the reproducer, the expected gradient value with x0<=x1
is (0,0)
, while we get (1/x0,0)
.
Reproduction
import _Differentiation
// x0 > x1 && x1 < 0: 1 - x1 / x0
// x0 > x1 && x1 >= 0: 1 + x1 / x0
// x0 <= x1: 1
@differentiable(reverse)
func ternary(_ x0: Float, _ x1: Float) -> Float {
let t1 = x1 + x0;
let t2 = x0 - x1;
let t4 = x0 > x1 ? t1 : x0;
let t6 = 1 / x0;
let t5 = x0 > t4 ? t2 : t4;
let t7 = t5 * t6;
return t7;
}
func actual_gradient(_ x0: Float, _ x1: Float) -> (Float, Float) {
if x0 > x1 {
if x1 < 0 {
return (x1 / (x0 * x0), -1 / x0)
}
return (-x1 / (x0 * x0), 1 / x0)
}
// FIXME: this should be (0, 0)
return (1 / x0, 0)
}
@differentiable(reverse)
func ifelse(_ x0: Float, _ x1: Float) -> Float {
let t1 = x1 + x0;
let t2 = x0 - x1;
let t4 : Float = if x0 > x1 { t1 } else { x0 }
let t6 = 1 / x0;
let t5 : Float = if x0 > t4 { t2 } else { t4 }
let t7 = t5 * t6;
return t7;
}
func true_gradient(_ x0: Float, _ x1: Float) -> (Float, Float) {
if x0 > x1 {
if x1 < 0 {
return (x1 / (x0 * x0), -1 / x0)
}
return (-x1 / (x0 * x0), 1 / x0)
}
return (0, 0)
}
for (x0, x1) in [(Float(5), Float(9)), (Float(-2), Float(1))] {
print(" ternary'(\(x0), \(x1)) = \(gradient(at: x0, x1, of: ternary))")
print("actual_gradient(\(x0), \(x1)) = \(actual_gradient(x0, x1))")
print(" ifelse'(\(x0), \(x1)) = \(gradient(at: x0, x1, of: ifelse))")
print(" true_gradient(\(x0), \(x1)) = \(true_gradient(x0, x1))")
print("")
}
Output:
ternary'(5.0, 9.0) = (0.2, 0.0)
actual_gradient(5.0, 9.0) = (0.2, 0.0)
ifelse'(5.0, 9.0) = (1.4901161e-08, 0.0)
true_gradient(5.0, 9.0) = (0.0, 0.0)
ternary'(-2.0, 1.0) = (-0.5, 0.0)
actual_gradient(-2.0, 1.0) = (-0.5, 0.0)
ifelse'(-2.0, 1.0) = (0.0, 0.0)
true_gradient(-2.0, 1.0) = (0.0, 0.0)
Expected behavior
Gradient value for ternary
and ifelse
should be equal (particularly, for x0<=x1
, we should have (0,0) gradient value for both).
Environment
Swift version 6.2-dev (LLVM dc6a0c133fea15e, Swift efecad888e1731c)
Target: x86_64-unknown-linux-gnu
Build config: +assertions
Additional information
No response