Skip to content

WIP: Wirtinger support #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1548cbc
remove type constraints for Wirtinger
simeonschaub Sep 22, 2019
88bb756
introduce `AbstractWirtinger` and `ComplexGradient`
simeonschaub Sep 23, 2019
e3ce538
add `chain` function
simeonschaub Sep 23, 2019
186009a
make `swap_order` in `chain` a positional arg
simeonschaub Sep 23, 2019
2618146
introduce a function `unwrap_wirtinger`
simeonschaub Sep 23, 2019
aa7a84a
stop using types before they are defined
simeonschaub Sep 23, 2019
20e7134
fix tests
simeonschaub Sep 25, 2019
c42a953
rename `unwrap_wirtinger` -> `unthunk`
simeonschaub Sep 27, 2019
2ac8801
fix `chain` function
simeonschaub Sep 27, 2019
61a60ec
use the new `chain` function in `@scalar_rule`
simeonschaub Sep 27, 2019
41f1071
fix tests accordingly
simeonschaub Sep 27, 2019
93a7b14
Update src/differentials.jl
simeonschaub Oct 3, 2019
c904407
add chain(::Real, ::ComplexGradient)
simeonschaub Oct 5, 2019
a83da2f
Update src/differentials.jl
simeonschaub Oct 18, 2019
4d21e1a
special case `AbstractWirtinger` in `at_thunk`
simeonschaub Oct 19, 2019
71e0c7b
add `refine_differential` for `ComplexGradient`
simeonschaub Oct 19, 2019
491f781
overload `Base.real` for some differentials
simeonschaub Oct 19, 2019
06ccb14
add some docstrings
simeonschaub Oct 19, 2019
6ce400c
move `at_thunk`-magic into separate macro
simeonschaub Oct 19, 2019
b960a09
make at_scalar_rule detect wrong Wirtinger rules
simeonschaub Oct 19, 2019
3bc7e22
use `_thunk` as a function, not macro
simeonschaub Oct 19, 2019
e3bf56d
implement some of @oxinabox's suggestions
simeonschaub Oct 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export extern, chain, cast, store!
export Wirtinger, ComplexGradient, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
Expand Down
78 changes: 71 additions & 7 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.

Notice:
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
The precidence goes: (:AbstractWirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
==#


function Base.:*(a::Wirtinger, b::Wirtinger)
function Base.:*(a::Union{Complex,AbstractWirtinger},
b::Union{Complex,AbstractWirtinger})
error("""
Cannot multiply two Wirtinger objects; this error likely means a
`WirtingerRule` was inappropriately defined somewhere. Multiplication
Expand All @@ -32,18 +33,33 @@ function Base.:*(a::Wirtinger, b::Wirtinger)
""")
end

function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
function Base.:+(a::AbstractWirtinger, b::AbstractWirtinger)
return Wirtinger(wirtinger_primal(a) + wirtinger_primal(b),
wirtinger_conjugate(a) + wirtinger_conjugate(b))
end

for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b
Base.:+(a::ComplexGradient, b::ComplexGradient) = ComplexGradient(a.val + b.val)

for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk)
@eval Base.:+(a::AbstractWirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b

@eval Base.:*(a::Wirtinger, b::$T) = Wirtinger(a.primal * b, a.conjugate * b)
@eval Base.:*(a::$T, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)

@eval Base.:*(a::ComplexGradient, b::$T) = ComplexGradient(a.val * b)
@eval Base.:*(a::$T, b::ComplexGradient) = ComplexGradient(a * b.val)
end

Base.:+(a::AbstractWirtinger, b) = a + Wirtinger(b, Zero())
Base.:+(a, b::AbstractWirtinger) = Wirtinger(a, Zero()) + b

Base.:*(a::Wirtinger, b::Real) = Wirtinger(a.primal * b, a.conjugate * b)
Base.:*(a::Real, b::Wirtinger) = Wirtinger(a * b.primal, a * b.conjugate)

Base.:*(a::ComplexGradient, b::Real) = ComplexGradient(a.val * b)
Base.:*(a::Real, b::ComplexGradient) = ComplexGradient(a * b.val)


Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
Expand Down Expand Up @@ -98,3 +114,51 @@ for T in (:Any,)
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
end

@inline chain(outer, inner, swap_order=false) =
_chain(unthunk(outer), unthunk(inner), swap_order)

@inline function _chain(outer, inner, swap_order)
if swap_order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to write things this way, rather than as e.g. if swap_order; a, b = b, a; end or swap_order && _chain(inner, outer, false)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that there's two orders here to consider. The first one, why we need the chain function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. For AbstractWirtinger, this difference does matter, which is the purpose of chain. The second one is the order of multiplication, which matters if inner and outer are non-commutative objects like matrices. They might still be of type Wirtinger, only wirtinger_primal and wirtinger_conjugate are matrices. In general, both orderings are relevant, in Base.:* for example, we want to multiply the outer differential from the right, but this is not equivalent to chain(inner, outer). I will definitely explain this in a docstring.

return Wirtinger(
wirtinger_primal(inner) * wirtinger_primal(outer) +
conj(wirtinger_conjugate(inner)) * wirtinger_conjugate(outer),
wirtinger_conjugate(inner) * wirtinger_primal(outer) +
conj(wirtinger_primal(inner) * wirtinger_conjugate(outer))
) |> refine_differential
end
return Wirtinger(
wirtinger_primal(outer) * wirtinger_primal(inner) +
wirtinger_conjugate(outer) * conj(wirtinger_conjugate(inner)),
wirtinger_primal(outer) * wirtinger_conjugate(inner) +
wirtinger_conjugate(outer) * conj(wirtinger_primal(inner))
) |> refine_differential
end

@inline function _chain(outer::ComplexGradient, inner, swap_order)
if swap_order
return ComplexGradient(
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner))) *
outer.val
)
end
return ComplexGradient(
outer.val *
(wirtinger_conjugate(inner) + conj(wirtinger_primal(inner)))
)
end

@inline function _chain(outer::Real, inner::ComplexGradient, swap_order)
if swap_order
return ComplexGradient(inner.val * outer)
end
return ComplexGradient(outer * inner.val)
end

# don't know if we actually need this, shouldn't really occur in actual code
@inline function _chain(outer::ComplexGradient, inner::ComplexGradient, swap_order)
if swap_order
return ComplexGradient(conj(inner.val) * outer.val)
end
return ComplexGradient(outer.val * conj(inner.val))
end
99 changes: 81 additions & 18 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,41 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.

@inline Base.conj(x::AbstractDifferential) = x

#####
##### `AbstractWirtinger`
#####

"""
AbstractWirtinger <: AbstractDifferential

Represents the differential of a non-holomorphic function taking complex input.

All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref).

All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper.
E.g., a typical differential would look like this:
```
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number}))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this quite a bit, I think this order makes the most sense, since this way, we can avoid allocations completely for derivatives like Wirtinger(Zero(), One()).

```
`@thunk` and `AbstractArray` are, of course, optional.
"""
abstract type AbstractWirtinger <: AbstractDifferential end

wirtinger_primal(x) = x
wirtinger_conjugate(::Any) = Zero()

extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))

# `conj` is not defined for `AbstractWirtinger`.
# Need this method to override the definition of `conj` for `AbstractDifferential`.
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))

#####
##### `Wirtinger`
#####

"""
Wirtinger(primal::Union{Number,AbstractDifferential},
conjugate::Union{Number,AbstractDifferential})
Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref)

Returns a `Wirtinger` instance representing the complex differential:

Expand All @@ -60,32 +88,40 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to `
The two fields of the returned instance can be accessed generically via the
[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods.
"""
struct Wirtinger{P,C} <: AbstractDifferential
struct Wirtinger{P,C} <: AbstractWirtinger
primal::P
conjugate::C
function Wirtinger(primal::Union{Number,AbstractDifferential},
conjugate::Union{Number,AbstractDifferential})
return new{typeof(primal),typeof(conjugate)}(primal, conjugate)
end
end

wirtinger_primal(x::Wirtinger) = x.primal
wirtinger_primal(x) = x

wirtinger_conjugate(x::Wirtinger) = x.conjugate
wirtinger_conjugate(::Any) = Zero()

extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))

Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
broadcastable(w.conjugate))

Base.iterate(x::Wirtinger) = (x, nothing)
Base.iterate(::Wirtinger, ::Any) = nothing
#####
##### `ComplexGradient`
#####

"""
ComplexGradient(val) <: [`AbstractWirtinger`](@ref)

Returns a `ComplexGradient` instance representing the complex differential:

```
df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z)
```

where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`.
"""
struct ComplexGradient{T} <: AbstractWirtinger
val::T
end

# TODO: define `conj` for` `Wirtinger`
Base.conj(x::Wirtinger) = throw(MethodError(conj, x))
wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x))
wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val

Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val))

#####
##### `Casted`
Expand Down Expand Up @@ -131,6 +167,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
Base.iterate(x::Zero) = (x, nothing)
Base.iterate(::Zero, ::Any) = nothing

Base.real(::Zero) = Zero()

#####
##### `DNE`
Expand Down Expand Up @@ -172,6 +209,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One())
Base.iterate(x::One) = (x, nothing)
Base.iterate(::One, ::Any) = nothing

Base.real(::One) = One()

#####
##### `AbstractThunk
Expand All @@ -191,6 +229,16 @@ end
return element, (externed, new_state)
end

unthunk(x) = x
unthunk(x::AbstractThunk) = unthunk(x())

wirtinger_primal(::AbstractThunk) =
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first."))
wirtinger_conjugate(::AbstractThunk) =
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird Union

Also not sure about this function but will wait and see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the Union is just a relic. Probably don't need this anymore, since Thunks should now never wrap AbstractWirtinger.


Base.real(x::AbstractThunk) = real(x())

#####
##### `Thunk`
#####
Expand Down Expand Up @@ -239,6 +287,11 @@ struct Thunk{F} <: AbstractThunk
f::F
end

"""
@thunk body

Returns `Thunk(() -> body)`
"""
macro thunk(body)
return :(Thunk(() -> $(esc(body))))
end
Expand Down Expand Up @@ -291,14 +344,24 @@ function itself, when that function is not a closure.
const NO_FIELDS = DNE()

"""
refine_differential(𝒟::Type, der)
refine_differential([𝒟::Type, ]der)

Converts, if required, a differential object `der`
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
to another differential that is more suited for the domain given by the type 𝒟.
Often this will behave as the identity function on `der`.
"""
function refine_differential end

function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
w = refine_differential(w)
return wirtinger_primal(w) + wirtinger_conjugate(w)
end
refine_differential(::Any, der) = der # most of the time leave it alone.
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient)
g = refine_differential(g.val)
return real(g)
end
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone.

refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal
refine_differential(der::Any) = der
80 changes: 35 additions & 45 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function scalar_frule_expr(𝒟, f, call, setup_stmts, inputs, partials)
Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs]
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(𝒟, Δs, ∂s)
frule_propagation_expr(𝒟, Δs, ∂s)
end
if n_outputs > 1
# For forward-mode we only return a tuple if output actually a tuple.
Expand Down Expand Up @@ -193,7 +193,7 @@ function scalar_rrule_expr(𝒟, f, call, setup_stmts, inputs, partials)
# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
∂s = [partial.args[input_i] for partial in partials]
propagation_expr(𝒟, Δs, ∂s)
rrule_propagation_expr(𝒟, Δs, ∂s)
end

pullback = quote
Expand Down Expand Up @@ -222,56 +222,46 @@ end
if it is taken at `1+1im` it returns `Complex{Int}`.
At present it is ignored for non-Wirtinger derivatives.
"""
function propagation_expr(𝒟, Δs, ∂s)
wirtinger_indices = findall(∂s) do ex
Meta.isexpr(ex, :call) && ex.args[1] === :Wirtinger
end
function frule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
if isempty(wirtinger_indices)
return standard_propagation_expr(Δs, ∂s)
else
return wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
end
∂_mul_Δs = [:(chain($(_thunk(∂s[i])), $(Δs[i]))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

function standard_propagation_expr(Δs, ∂s)
# This is basically Δs ⋅ ∂s

# Notice: the thunking of `∂s[i] (potentially) saves us some computation
# if `Δs[i]` is a `AbstractDifferential` otherwise it is computed as soon
# as the pullback is evaluated
∂_mul_Δs = [:(@thunk($(∂s[i])) * $(Δs[i])) for i in 1:length(∂s)]
return :(+($(∂_mul_Δs...)))
function rrule_propagation_expr(𝒟, Δs, ∂s)
∂s = map(esc, ∂s)
∂_mul_Δs = [:(chain($(Δs[i]), $(_thunk(∂s[i])))) for i in 1:length(∂s)]
return :(refine_differential($𝒟, +($(∂_mul_Δs...))))
end

function wirtinger_propagation_expr(𝒟, wirtinger_indices, Δs, ∂s)
∂_mul_Δs_primal = Any[]
∂_mul_Δs_conjugate = Any[]
∂_wirtinger_defs = Any[]
for i in 1:length(∂s)
if i in wirtinger_indices
Δi = Δs[i]
∂i = Symbol(string(:∂, i))
push!(∂_wirtinger_defs, :($∂i = $(∂s[i])))
∂f∂i_mul_Δ = :(wirtinger_primal($∂i) * wirtinger_primal($Δi))
∂f∂ī_mul_Δ̄ = :(conj(wirtinger_conjugate($∂i)) * wirtinger_conjugate($Δi))
∂f̄∂i_mul_Δ = :(wirtinger_conjugate($∂i) * wirtinger_primal($Δi))
∂f̄∂ī_mul_Δ̄ = :(conj(wirtinger_primal($∂i)) * wirtinger_conjugate($Δi))
push!(∂_mul_Δs_primal, :($∂f∂i_mul_Δ + $∂f∂ī_mul_Δ̄))
push!(∂_mul_Δs_conjugate, :($∂f̄∂i_mul_Δ + $∂f̄∂ī_mul_Δ̄))
else
∂_mul_Δ = :(@thunk($(∂s[i])) * $(Δs[i]))
push!(∂_mul_Δs_primal, ∂_mul_Δ)
push!(∂_mul_Δs_conjugate, ∂_mul_Δ)
"""
_thunk(body)

Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
"""
function _thunk(body)
if Meta.isexpr(body, :call)
fname = body.args[1]
if fname in (:Wirtinger, :ComplexGradient)
return :($fname($(thunk_assert_no_wirtinger.(body.args[2:end])...)))
end
elseif Meta.isexpr(body, :escape)
return Expr(:escape, _thunk(body.args[1]))
end
primal_sum = :(+($(∂_mul_Δs_primal...)))
conjugate_sum = :(+($(∂_mul_Δs_conjugate...)))
return quote # This will be a block, so will have value equal to last statement
$(∂_wirtinger_defs...)
w = Wirtinger($primal_sum, $conjugate_sum)
refine_differential($𝒟, w)
end
return thunk_assert_no_wirtinger(body)
end

thunk_assert_no_wirtinger(body) = quote
Thunk(
function()
res = $body
res isa ChainRulesCore.AbstractWirtinger && error("""
Couldn't automatically handle `AbstractWirtinger` in `@scalar_rule.
Make sure `Wirtinger`/`ComplexGradient` is the outermost function call or write the rule manually.""")
return res
end
)
end

"""
Expand Down
Loading