-
Notifications
You must be signed in to change notification settings - Fork 64
WIP: Wirtinger support #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This, together with adding `@inline` makes constant-propagation possible. Also fix a bug from before.
Thinking about making |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like how this simplifies the code... but I admit i've not fully understood everything here yet. I think docs for ComplexGradient
(and it's relation to Wirtinger
) would be a big help.
_chain(unthunk(outer), unthunk(inner), swap_order) | ||
|
||
@inline function _chain(outer, inner, swap_order) | ||
if swap_order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to write things this way, rather than as e.g. if swap_order; a, b = b, a; end
or swap_order && _chain(inner, outer, false)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is that there's two orders here to consider. The first one, why we need the chain
function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. For AbstractWirtinger
, this difference does matter, which is the purpose of chain
. The second one is the order of multiplication, which matters if inner
and outer
are non-commutative objects like matrices. They might still be of type Wirtinger
, only wirtinger_primal
and wirtinger_conjugate
are matrices. In general, both orderings are relevant, in Base.:*
for example, we want to multiply the outer differential from the right, but this is not equivalent to chain(inner, outer)
. I will definitely explain this in a docstring.
Co-Authored-By: Nick Robinson <[email protected]>
I've written down some thoughts about this in FluxML/Zygote.jl#328 also. I still need to finish that though and will then try to adapt some of that in the docs here. |
Co-Authored-By: Nick Robinson <[email protected]>
I have not forgotten about this and started reading it again yesterday. |
b63f6ee
to
06ccb14
Compare
|
||
# For most differentials, in most domains, this does nothing | ||
for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) | ||
for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if this test really makes a lot of sense...
All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper. | ||
E.g., a typical differential would look like this: | ||
``` | ||
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After thinking about this quite a bit, I think this order makes the most sense, since this way, we can avoid allocations completely for derivatives like Wirtinger(Zero(), One())
.
src/differentials.jl
Outdated
|
||
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) | ||
|
||
Base.iterate(x::AbstractWirtinger) = (x, nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, I am going to get rid of this, since e.g. Wirtinger
can now wrap arrays as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about making this behave like Wirtinger(i, j) for (i,j) in zip(primal, conjugate)
, but I think this would encourage people to collect this into an array, and we don't ever want to have AbstractArray{<:AbstractWirtinger}
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think the iterate Overloads can probably be removed from everything
src/differentials.jl
Outdated
|
||
extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type.")) | ||
|
||
Base.iterate(x::AbstractWirtinger) = (x, nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think the iterate Overloads can probably be removed from everything
wirtinger_primal(::Union{AbstractThunk}) = | ||
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first.")) | ||
wirtinger_conjugate(::Union{AbstractThunk}) = | ||
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first.")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird Union
Also not sure about this function but will wait and see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the Union
is just a relic. Probably don't need this anymore, since Thunk
s should now never wrap AbstractWirtinger
.
src/rule_definition_tools.jl
Outdated
Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref). | ||
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`. | ||
""" | ||
macro _thunk(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be a macro anymore as it is only called from with in a function of ASTs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you take a look, whether everything is escaped correctly now? I'm not 100% sure, I know how escaping works.
Tests will break for now, and that's good.
8774ada
to
e3bf56d
Compare
Why disallowing |
Most of the time, users shouldn't interact with Wirtinger objects at all. One of the main use case I see for them is as intermediary representations in mixed-mode AD implementations. If a library wants to expose that functionality to users, the authors can add their own abstraction on top to best fit their particular needs. |
@simeonschaub what's the status of this? Has it been completely superceded by the other complex numbers stuff? |
What has been superseded is |
As discussed in #40, Wirtinger support is going to be moved out of master for now. I'm going to start working on it in this branch, but might eventually decide to move this into a package.
My current plans for this so far are:
makeadd a functionwirtinger_[primal|conjugate]
recursive, to work better for things likeThunk
sunthunk
for this insteadComplexGradient
, which works like Zygote's complex derivatives to address Special case derivative of non-holomorphic functions of type ℂ(^n)→ℝ #23 and make porting Zygote to ChainRules easier (Still needs docstrings)AbstractWirtinger
, whichWirtinger
as well asComplexGradient
are a subtype of (Still needs docstrings)chain
, which works mostly like*
, but respects which function is the derivative of the outer/inner function, which is important forAbstractWirtinger
differentials (Still needs docstrings)@scalar_rule
with thischain
functionAbstractWirtinger
and require users to usechain
, since the order of chaining matters here, too (Still needs a better error)I always appreciate any feedback.