Description
Consider a case where we have a function f: ℝᵐ → ℂʳ → ℂˢ → ℝⁿ = ℝᵐ → ℝⁿ
, which we can write as f = f₃ ∘ f₂ ∘ f₁
.
Typically f₁
will produce a complex output by adding, subtracting, multiplying or dividing the real by a complex number
or by calling promote
, complex
, Complex
or cis
.
Typically f₃
will produce a real output by calling a non-holomorphic function like real
, imag
, abs
, abs2
, hypot
, or angle
.
From #167, the fact that there are complex intermediates to f
is just an implementation detail. We could have defined f: ℝᵐ → ℝ²ʳ → ℝ²ˢ → ℝⁿ
, and the pushforwards and pullbacks of this new f
should behave the same.
Since in general tangents are derivatives of a primal wrt a real, and co-tangents are derivatives of a real wrt a primal,
the pushforward through f₁: ℝᵐ → ℂʳ
should produce a complex tangent, while the pushforward through f₃: ℂˢ → ℝⁿ
should produce a real tangent.
Conversely, the pullback through f₃
should produce a complex cotangent, and the pullback through f₁
should produce a real cotangent.
The pushforward case is pretty easy to handle. We can 1) assume that a non-sensical tangent will not be passed and do nothing special (i.e. assume upstream AD did the right thing) or 2) define custom frule
s that ensure that the produced tangent of unary functions f₃(::Complex)::Real
is real.
The pullback case is more complicated. Right now e.g. in Zygote, unless you create a complex number from reals by calling complex
, you'll end up pulling back complex numbers through the initial real part of your program, which not only is wasteful but could break assumptions of the rrule
s of upstream functions. I propose for the binary functions f₁
adding custom rrule
s for f₁(::Real, ::Complex)::Complex
and f₁(::Complex, ::Real)::Complex
to ensure that the co-tangent pulled back to a real primal is actually real.
This came up a point of discussion in JuliaDiff/ChainRules.jl#196, and I would appreciate feedback so we can clarify our conventions here.