-
Notifications
You must be signed in to change notification settings - Fork 64
WIP: Wirtinger support #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1548cbc
88bb756
e3ce538
186009a
2618146
aa7a84a
20e7134
c42a953
2ac8801
61a60ec
41f1071
93a7b14
c904407
a83da2f
4d21e1a
71e0c7b
491f781
06ccb14
6ce400c
b960a09
3bc7e22
e3bf56d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,13 +41,41 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself. | |
|
||
@inline Base.conj(x::AbstractDifferential) = x | ||
|
||
##### | ||
##### `AbstractWirtinger` | ||
##### | ||
|
||
""" | ||
AbstractWirtinger <: AbstractDifferential | ||
|
||
Represents the differential of a non-holomorphic function taking complex input. | ||
|
||
All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref). | ||
|
||
All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper. | ||
E.g., a typical differential would look like this: | ||
``` | ||
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number})) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After thinking about this quite a bit, I think this order makes the most sense, since this way, we can avoid allocations completely for derivatives like |
||
``` | ||
`@thunk` and `AbstractArray` are, of course, optional. | ||
""" | ||
abstract type AbstractWirtinger <: AbstractDifferential end | ||
|
||
wirtinger_primal(x) = x | ||
wirtinger_conjugate(::Any) = Zero() | ||
|
||
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) | ||
|
||
# `conj` is not defined for `AbstractWirtinger`. | ||
# Need this method to override the definition of `conj` for `AbstractDifferential`. | ||
Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x)) | ||
|
||
##### | ||
##### `Wirtinger` | ||
##### | ||
|
||
""" | ||
Wirtinger(primal::Union{Number,AbstractDifferential}, | ||
conjugate::Union{Number,AbstractDifferential}) | ||
Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref) | ||
|
||
Returns a `Wirtinger` instance representing the complex differential: | ||
|
||
|
@@ -60,32 +88,40 @@ where `primal` corresponds to `∂f/∂z * dz` and `conjugate` corresponds to ` | |
The two fields of the returned instance can be accessed generically via the | ||
[`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref) methods. | ||
""" | ||
struct Wirtinger{P,C} <: AbstractDifferential | ||
struct Wirtinger{P,C} <: AbstractWirtinger | ||
primal::P | ||
conjugate::C | ||
function Wirtinger(primal::Union{Number,AbstractDifferential}, | ||
conjugate::Union{Number,AbstractDifferential}) | ||
return new{typeof(primal),typeof(conjugate)}(primal, conjugate) | ||
end | ||
end | ||
|
||
wirtinger_primal(x::Wirtinger) = x.primal | ||
wirtinger_primal(x) = x | ||
|
||
wirtinger_conjugate(x::Wirtinger) = x.conjugate | ||
wirtinger_conjugate(::Any) = Zero() | ||
|
||
extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) | ||
|
||
Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), | ||
broadcastable(w.conjugate)) | ||
|
||
Base.iterate(x::Wirtinger) = (x, nothing) | ||
Base.iterate(::Wirtinger, ::Any) = nothing | ||
##### | ||
##### `ComplexGradient` | ||
##### | ||
|
||
""" | ||
ComplexGradient(val) <: [`AbstractWirtinger`](@ref) | ||
|
||
Returns a `ComplexGradient` instance representing the complex differential: | ||
|
||
``` | ||
df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z) | ||
``` | ||
|
||
where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`. | ||
""" | ||
struct ComplexGradient{T} <: AbstractWirtinger | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val::T | ||
end | ||
|
||
# TODO: define `conj` for` `Wirtinger` | ||
Base.conj(x::Wirtinger) = throw(MethodError(conj, x)) | ||
wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x)) | ||
wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val | ||
|
||
Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val)) | ||
|
||
##### | ||
##### `Casted` | ||
|
@@ -131,6 +167,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) | |
Base.iterate(x::Zero) = (x, nothing) | ||
Base.iterate(::Zero, ::Any) = nothing | ||
|
||
Base.real(::Zero) = Zero() | ||
|
||
##### | ||
##### `DNE` | ||
|
@@ -172,6 +209,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One()) | |
Base.iterate(x::One) = (x, nothing) | ||
Base.iterate(::One, ::Any) = nothing | ||
|
||
Base.real(::One) = One() | ||
|
||
##### | ||
##### `AbstractThunk | ||
|
@@ -191,6 +229,16 @@ end | |
return element, (externed, new_state) | ||
end | ||
|
||
unthunk(x) = x | ||
unthunk(x::AbstractThunk) = unthunk(x()) | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
wirtinger_primal(::AbstractThunk) = | ||
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
wirtinger_conjugate(::AbstractThunk) = | ||
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird Also not sure about this function but will wait and see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the |
||
|
||
Base.real(x::AbstractThunk) = real(x()) | ||
|
||
##### | ||
##### `Thunk` | ||
##### | ||
|
@@ -239,6 +287,11 @@ struct Thunk{F} <: AbstractThunk | |
f::F | ||
end | ||
|
||
""" | ||
@thunk body | ||
|
||
Returns `Thunk(() -> body)` | ||
""" | ||
macro thunk(body) | ||
return :(Thunk(() -> $(esc(body)))) | ||
end | ||
|
@@ -291,14 +344,24 @@ function itself, when that function is not a closure. | |
const NO_FIELDS = DNE() | ||
|
||
""" | ||
refine_differential(𝒟::Type, der) | ||
refine_differential([𝒟::Type, ]der) | ||
|
||
Converts, if required, a differential object `der` | ||
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), | ||
to another differential that is more suited for the domain given by the type 𝒟. | ||
Often this will behave as the identity function on `der`. | ||
""" | ||
function refine_differential end | ||
|
||
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) | ||
w = refine_differential(w) | ||
return wirtinger_primal(w) + wirtinger_conjugate(w) | ||
end | ||
refine_differential(::Any, der) = der # most of the time leave it alone. | ||
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, g::ComplexGradient) | ||
g = refine_differential(g.val) | ||
return real(g) | ||
end | ||
refine_differential(::Any, der) = refine_differential(der) # most of the time leave it alone. | ||
|
||
refine_differential(w::Wirtinger{<:Any,Zero}) = w.primal | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
refine_differential(der::Any) = der | ||
simeonschaub marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to write things this way, rather than as e.g.
if swap_order; a, b = b, a; end
orswap_order && _chain(inner, outer, false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that there's two orders here to consider. The first one, why we need the
chain
function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. ForAbstractWirtinger
, this difference does matter, which is the purpose ofchain
. The second one is the order of multiplication, which matters ifinner
andouter
are non-commutative objects like matrices. They might still be of typeWirtinger
, onlywirtinger_primal
andwirtinger_conjugate
are matrices. In general, both orderings are relevant, inBase.:*
for example, we want to multiply the outer differential from the right, but this is not equivalent tochain(inner, outer)
. I will definitely explain this in a docstring.