diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 118e7f841..e0a134385 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,8 +4,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad export frule, rrule export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk -export extern, cast, store! -export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk +export extern, chain, cast, store! +export Wirtinger, ComplexGradient, Zero, One, Casted, DNE, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index e65748d34..29206f75f 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -7,14 +7,15 @@ subtypes, as we know the full set that might be encountered. Thus we can avoid any ambiguities. Notice: - The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) + The precidence goes: (:AbstractWirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) Thus each of the @eval loops creating definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. ==# -function Base.:*(a::Wirtinger, b::Wirtinger) +function Base.:*(a::Union{Complex,AbstractWirtinger}, + b::Union{Complex,AbstractWirtinger}) error(""" Cannot multiply two Wirtinger objects; this error likely means a `WirtingerRule` was inappropriately defined somewhere. Multiplication @@ -32,18 +33,33 @@ function Base.:*(a::Wirtinger, b::Wirtinger) """) end -function Base.:+(a::Wirtinger, b::Wirtinger) - return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) +function Base.:+(a::AbstractWirtinger, b::AbstractWirtinger) + return Wirtinger(wirtinger_primal(a) + wirtinger_primal(b), + wirtinger_conjugate(a) + wirtinger_conjugate(b)) end -for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any) - @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) - @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b +Base.:+(a::ComplexGradient, b::ComplexGradient) = ComplexGradient(a.val + b.val) + +for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk) + @eval Base.:+(a::AbstractWirtinger, b::$T) = a + Wirtinger(b, Zero()) + @eval Base.:+(a::$T, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b @eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b) @eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate) + + @eval Base.:*(a::ComplexGradient, b::$T) = ComplexGradient(a.val * b) + @eval Base.:*(a::$T, b::ComplexGradient) = ComplexGradient(a * b.val) end +Base.:+(a::AbstractWirtinger, b) = a + Wirtinger(b, Zero()) +Base.:+(a, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b + +Base.:*(a::Wirtinger, b::Real) = Wirtinger(a.primal * b, a.conjugate * b) +Base.:*(a::Real, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate) + +Base.:*(a::ComplexGradient, b::Real) = ComplexGradient(a.val * b) +Base.:*(a::Real, b::ComplexGradient) = ComplexGradient(a * b.val) + Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value)) Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value)) @@ -98,3 +114,51 @@ for T in (:Any,) @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) end + +@inline chain(outer, inner, swap_order=false) = + _chain(unthunk(outer), unthunk(inner), swap_order) + +@inline function _chain(outer, inner, swap_order) + if swap_order + return Wirtinger( + wirtinger_primal(inner) * wirtinger_primal(outer) + + conj(wirtinger_conjugate(inner)) * wirtinger_conjugate(outer), + wirtinger_conjugate(inner) * wirtinger_primal(outer) + + conj(wirtinger_primal(inner) * wirtinger_conjugate(outer)) + ) |> refine_differential + end + return Wirtinger( + wirtinger_primal(outer) * wirtinger_primal(inner) + + wirtinger_conjugate(outer) * conj(wirtinger_conjugate(inner)), + wirtinger_primal(outer) * wirtinger_conjugate(inner) + + wirtinger_conjugate(outer) * conj(wirtinger_primal(inner)) + ) |> refine_differential +end + +@inline function _chain(outer::ComplexGradient, inner, swap_order) + if swap_order + return ComplexGradient( + (wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) * + outer.val + ) + end + return ComplexGradient( + outer.val * + (wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) + ) +end + +@inline function _chain(outer::Real, inner::ComplexGradient, swap_order) + if swap_order + return ComplexGradient(inner.val * outer) + end + return ComplexGradient(outer * inner.val) +end + +# don't know if we actually need this, shouldn't really occur in actual code +@inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order) + if swap_order + return ComplexGradient(conj(inner.val) * outer.val) + end + return ComplexGradient(outer.val * conj(inner.val)) +end diff --git a/src/differentials.jl b/src/differentials.jl index 5ad2f8818..1b69ee191 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -41,13 +41,41 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. @inline Base.conj(x::AbstractDifferential) = x +##### +##### `AbstractWirtinger` +##### + +""" + AbstractWirtinger <: AbstractDifferential + +Represents the differential of a non-holomorphic function taking complex input. + +All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref). + +All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper. +E.g., a typical differential would look like this: +``` +Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number})) +``` +`@thunk` and `AbstractArray` are, of course, optional. +""" +abstract type AbstractWirtinger <: AbstractDifferential end + +wirtinger_primal(x) = x +wirtinger_conjugate(::Any) = Zero() + +extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) + +# `conj` is not defined for `AbstractWirtinger`. +# Need this method to override the definition of `conj` for `AbstractDifferential`. +Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) + ##### ##### `Wirtinger` ##### """ - Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) + Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref) Returns a `Wirtinger` instance representing the complex differential: @@ -60,32 +88,40 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to ` The two fields of the returned instance can be accessed generically via the [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods. """ -struct Wirtinger{P,C} <: AbstractDifferential +struct Wirtinger{P,C} <: AbstractWirtinger primal::P conjugate::C - function Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) - return new{typeof(primal),typeof(conjugate)}(primal, conjugate) - end end wirtinger_primal(x::Wirtinger) = x.primal -wirtinger_primal(x) = x - wirtinger_conjugate(x::Wirtinger) = x.conjugate -wirtinger_conjugate(::Any) = Zero() - -extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate)) -Base.iterate(x::Wirtinger) = (x, nothing) -Base.iterate(::Wirtinger, ::Any) = nothing +##### +##### `ComplexGradient` +##### + +""" + ComplexGradient(val) <: [`AbstractWirtinger`](@ref) + +Returns a `ComplexGradient` instance representing the complex differential: + +``` +df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z) +``` + +where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`. +""" +struct ComplexGradient{T} <: AbstractWirtinger + val::T +end -# TODO: define `conj` for` `Wirtinger` -Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) +wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x)) +wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val +Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val)) ##### ##### `Casted` @@ -131,6 +167,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing +Base.real(::Zero) = Zero() ##### ##### `DNE` @@ -172,6 +209,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One()) Base.iterate(x::One) = (x, nothing) Base.iterate(::One, ::Any) = nothing +Base.real(::One) = One() ##### ##### `AbstractThunk @@ -191,6 +229,16 @@ end return element, (externed, new_state) end +unthunk(x) = x +unthunk(x::AbstractThunk) = unthunk(x()) + +wirtinger_primal(::AbstractThunk) = + throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) +wirtinger_conjugate(::AbstractThunk) = + throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) + +Base.real(x::AbstractThunk) = real(x()) + ##### ##### `Thunk` ##### @@ -239,6 +287,11 @@ struct Thunk{F} <: AbstractThunk f::F end +""" + @thunk body + +Returns `Thunk(() -> body)` +""" macro thunk(body) return :(Thunk(() -> $(esc(body)))) end @@ -291,14 +344,24 @@ function itself, when that function is not a closure. const NO_FIELDS = DNE() """ - refine_differential(𝒟::Type, der) + refine_differential([𝒟::Type, ]der) Converts, if required, a differential object `der` (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), to another differential that is more suited for the domain given by the type 𝒟. Often this will behave as the identity function on `der`. """ +function refine_differential end + function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) + w = refine_differential(w) return wirtinger_primal(w) + wirtinger_conjugate(w) end -refine_differential(::Any, der) = der # most of the time leave it alone. +function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient) + g = refine_differential(g.val) + return real(g) +end +refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone. + +refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal +refine_differential(der::Any) = der diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a06820e64..a9f55eee9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials) Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] pushforward_returns = map(1:n_outputs) do output_i ∂s = partials[output_i].args - propagation_expr(𝒟, Δs, ∂s) + frule_propagation_expr(𝒟, Δs, ∂s) end if n_outputs > 1 # For forward-mode we only return a tuple if output actually a tuple. @@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials) # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i ∂s = [partial.args[input_i] for partial in partials] - propagation_expr(𝒟, Δs, ∂s) + rrule_propagation_expr(𝒟, Δs, ∂s) end pullback = quote @@ -222,56 +222,46 @@ end if it is taken at `1+1im` it returns `Complex{Int}`. At present it is ignored for non-Wirtinger derivatives. """ -function propagation_expr(𝒟, Δs, ∂s) - wirtinger_indices = findall(∂s) do ex - Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger - end +function frule_propagation_expr(𝒟, Δs, ∂s) ∂s = map(esc, ∂s) - if isempty(wirtinger_indices) - return standard_propagation_expr(Δs, ∂s) - else - return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) - end + ∂_mul_Δs = [:(chain($(_thunk(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)] + return :(refine_differential($𝒟, +($(∂_mul_Δs...)))) end -function standard_propagation_expr(Δs, ∂s) - # This is basically Δs ⋅ ∂s - - # Notice: the thunking of `∂s[i] (potentially) saves us some computation - # if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon - # as the pullback is evaluated - ∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)] - return :(+($(∂_mul_Δs...))) +function rrule_propagation_expr(𝒟, Δs, ∂s) + ∂s = map(esc, ∂s) + ∂_mul_Δs = [:(chain($(Δs[i]), $(_thunk(∂s[i])))) for i in 1:length(∂s)] + return :(refine_differential($𝒟, +($(∂_mul_Δs...)))) end -function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s) - ∂_mul_Δs_primal = Any[] - ∂_mul_Δs_conjugate = Any[] - ∂_wirtinger_defs = Any[] - for i in 1:length(∂s) - if i in wirtinger_indices - Δi = Δs[i] - ∂i = Symbol(string(:∂, i)) - push!(∂_wirtinger_defs, :($∂i = $(∂s[i]))) - ∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi)) - ∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi)) - ∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi)) - ∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi)) - push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄)) - push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄)) - else - ∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i])) - push!(∂_mul_Δs_primal, ∂_mul_Δ) - push!(∂_mul_Δs_conjugate, ∂_mul_Δ) +""" + _thunk(body) + +Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). +In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. +""" +function _thunk(body) + if Meta.isexpr(body, :call) + fname = body.args[1] + if fname in (:Wirtinger, :ComplexGradient) + return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...))) end + elseif Meta.isexpr(body, :escape) + return Expr(:escape, _thunk(body.args[1])) end - primal_sum = :(+($(∂_mul_Δs_primal...))) - conjugate_sum = :(+($(∂_mul_Δs_conjugate...))) - return quote # This will be a block, so will have value equal to last statement - $(∂_wirtinger_defs...) - w = Wirtinger($primal_sum, $conjugate_sum) - refine_differential($𝒟, w) - end + return thunk_assert_no_wirtinger(body) +end + +thunk_assert_no_wirtinger(body) = quote + Thunk( + function() + res = $body + res isa ChainRulesCore.AbstractWirtinger && error(""" + Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule. + Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""") + return res + end + ) end """ diff --git a/test/differentials.jl b/test/differentials.jl index 570b09d88..389e26cdc 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -12,9 +12,6 @@ # TODO: other + methods stack overflow @test_throws ErrorException w*w @test_throws ArgumentError extern(w) - for x in w - @test x === w - end @test broadcastable(w) == w @test_throws MethodError conj(w) end @@ -82,14 +79,33 @@ @testset "Refine Differential" begin - @test refine_differential(typeof(1.0 + 1im), Wirtinger(2,2)) == Wirtinger(2,2) - @test refine_differential(typeof([1.0 + 1im]), Wirtinger(2,2)) == Wirtinger(2,2) + for (p, c) in ( + (2, -3), + (2.0 + im, 5.0 - 3.0im), + ([1+im, 2-im], [-3+im, 4+im]), + (@thunk(1+2), @thunk(4-3)), + ) + w = Wirtinger(p, c) + @testset "$w" begin + @test refine_differential(typeof(1.0 + 1im), w) === w + @test refine_differential(typeof([1.0 + 1im]), w) === w + + @test refine_differential(typeof(1.2), w) == p + c + @test refine_differential(typeof([1.2]), w) == p + c + end - @test refine_differential(typeof(1.2), Wirtinger(2,2)) == 4 - @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 + g = ComplexGradient(c) + @testset "$g" begin + @test refine_differential(typeof(1.0 + 1im), g) === g + @test refine_differential(typeof([1.0 + 1im]), g) === g + + @test refine_differential(typeof(1.2), g) == real(c) + @test refine_differential(typeof([1.2]), g) == real(c) + end + end # For most differentials, in most domains, this does nothing - for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) + for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0) for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) @test refine_differential(𝒟, der) === der end diff --git a/test/rules.jl b/test/rules.jl index e23680326..9cab840c0 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -46,9 +46,9 @@ end @testset "real input" begin # even though our rule was define in terms of Wirtinger, - # pushforward result will be real as real (even if seed is Compex) + # pushforward result will be real as real (even if seed is Complex) - x = rand(Float64) + x = 5.0 f, myabs2_pushforward = frule(myabs2, x) @test f === x^2 @@ -56,22 +56,22 @@ end df = @inferred myabs2_pushforward(NamedTuple(), Δ) @test df === x + x - Δ = rand(Complex{Int64}) + Δ = 2.0 + 3.0im df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === Δ * (x + x) + @test df === (Δ + conj(Δ)) * x end @testset "complex input" begin - z = rand(Complex{Float64}) + z = 5.0 + 7.0im f, myabs2_pushforward = frule(myabs2, z) @test f === abs2(z) df = @inferred myabs2_pushforward(NamedTuple(), One()) @test df === Wirtinger(z', z) - Δ = rand(Complex{Int64}) + Δ = 2.0 + 3.0im df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === Wirtinger(Δ * z', Δ * z) + @test df === Wirtinger(Δ * conj(z), conj(Δ) * z) end end @@ -97,7 +97,9 @@ end abs_to_pow(x::Complex, p), @setup(u = abs(x)), ( - p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), + p == 0 ? Zero() : let v = p * u^(p-1) / 2u + Wirtinger(x' * v, x * v) + end, Ω * log(abs(x)) ) ) @@ -132,11 +134,11 @@ end fx, f_pushforward = res df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) - df_dx::Thunk = df(One(), Zero()) - df_dp::Thunk = df(Zero(), One()) + df_dx = df(One(), Zero()) + df_dp = df(Zero(), One()) @test fx == f(x, p) # Check we still get the normal value, right - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp + @test df_dx isa expected_type_df_dx + @test df_dp isa expected_type_df_dp res = rrule(f, x, p) @@ -145,7 +147,7 @@ end dself, df_dx, df_dp = f_pullback(One()) @test fx == f(x, p) # Check we still get the normal value, right @test dself == NO_FIELDS - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp + @test df_dx isa expected_type_df_dx + @test df_dp isa expected_type_df_dp end end