Skip to content

WIP: Wirtinger support #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft

WIP: Wirtinger support #54

wants to merge 22 commits into from

Conversation

simeonschaub
Copy link
Member

@simeonschaub simeonschaub commented Sep 22, 2019

As discussed in #40, Wirtinger support is going to be moved out of master for now. I'm going to start working on it in this branch, but might eventually decide to move this into a package.

My current plans for this so far are:

  • remove type constraints for Wirtinger (since pretty much anything can be a differential, these don't make much sense to me anymore)
  • make wirtinger_[primal|conjugate] recursive, to work better for things like Thunks add a function unthunk for this instead
  • introduce a type ComplexGradient, which works like Zygote's complex derivatives to address Special case derivative of non-holomorphic functions of type ℂ(^n)→ℝ #23 and make porting Zygote to ChainRules easier (Still needs docstrings)
  • introduce an abstract type AbstractWirtinger, which Wirtinger as well as ComplexGradient are a subtype of (Still needs docstrings)
  • a function chain, which works mostly like *, but respects which function is the derivative of the outer/inner function, which is important for AbstractWirtinger differentials (Still needs docstrings)
  • simplify @scalar_rule with this chain function
  • disallow multiplication of complex numbers with AbstractWirtinger and require users to use chain, since the order of chaining matters here, too (Still needs a better error)
  • tests
  • docs

I always appreciate any feedback.

@simeonschaub
Copy link
Member Author

Thinking about making ComplexGradient two real parameters instead of one complex one, so we could make use of Zero in some places. Would complicate some arithmetic, though.

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like how this simplifies the code... but I admit i've not fully understood everything here yet. I think docs for ComplexGradient (and it's relation to Wirtinger) would be a big help.

_chain(unthunk(outer), unthunk(inner), swap_order)

@inline function _chain(outer, inner, swap_order)
if swap_order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to write things this way, rather than as e.g. if swap_order; a, b = b, a; end or swap_order && _chain(inner, outer, false)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is that there's two orders here to consider. The first one, why we need the chain function at all, is concerning, which of the differentials is the partial of the outer and which one is the partial of the inner function. For AbstractWirtinger, this difference does matter, which is the purpose of chain. The second one is the order of multiplication, which matters if inner and outer are non-commutative objects like matrices. They might still be of type Wirtinger, only wirtinger_primal and wirtinger_conjugate are matrices. In general, both orderings are relevant, in Base.:* for example, we want to multiply the outer differential from the right, but this is not equivalent to chain(inner, outer). I will definitely explain this in a docstring.

Co-Authored-By: Nick Robinson <[email protected]>
@simeonschaub
Copy link
Member Author

I've written down some thoughts about this in FluxML/Zygote.jl#328 also. I still need to finish that though and will then try to adapt some of that in the docs here.

Co-Authored-By: Nick Robinson <[email protected]>
@oxinabox
Copy link
Member

I have not forgotten about this and started reading it again yesterday.


# For most differentials, in most domains, this does nothing
for der in (DNE(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
for der in (DNE(), @thunk(23), [1 2], One(), Zero(), 0.0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this test really makes a lot of sense...

All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper.
E.g., a typical differential would look like this:
```
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number}))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking about this quite a bit, I think this order makes the most sense, since this way, we can avoid allocations completely for derivatives like Wirtinger(Zero(), One()).


extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))

Base.iterate(x::AbstractWirtinger) = (x, nothing)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, I am going to get rid of this, since e.g. Wirtinger can now wrap arrays as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about making this behave like Wirtinger(i, j) for (i,j) in zip(primal, conjugate), but I think this would encourage people to collect this into an array, and we don't ever want to have AbstractArray{<:AbstractWirtinger}.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the iterate Overloads can probably be removed from everything


extern(x::AbstractWirtinger) = throw(ArgumentError("`AbstractWirtinger` cannot be converted to an external type."))

Base.iterate(x::AbstractWirtinger) = (x, nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think the iterate Overloads can probably be removed from everything

wirtinger_primal(::Union{AbstractThunk}) =
throw(ArgumentError("`wirtinger_primal` is not defined for `AbstractThunk`. Call `unthunk` first."))
wirtinger_conjugate(::Union{AbstractThunk}) =
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird Union

Also not sure about this function but will wait and see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the Union is just a relic. Probably don't need this anymore, since Thunks should now never wrap AbstractWirtinger.

Returns `@thunk body`, except for when `body` is a call to [`Wirtinger`](@ref) or [`ComplexGradient`](@ref).
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
"""
macro _thunk(body)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a macro anymore as it is only called from with in a function of ASTs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you take a look, whether everything is escaped correctly now? I'm not 100% sure, I know how escaping works.

@YingboMa
Copy link
Member

YingboMa commented Jan 7, 2020

disallow multiplication of complex numbers with AbstractWirtinger and require users to use chain, since the order of chaining matters here, too (Still needs a better error)

Why disallowing ::Wirtinger * ::Complex? How would a user interact with the Wirtinger object if there is no generic function defined?

@simeonschaub
Copy link
Member Author

simeonschaub commented Jan 7, 2020

Most of the time, users shouldn't interact with Wirtinger objects at all. One of the main use case I see for them is as intermediary representations in mixed-mode AD implementations. If a library wants to expose that functionality to users, the authors can add their own abstraction on top to best fit their particular needs.
To understand, why we don't want to multiply complex numbers and Wirtinger objects, we can think of Wirtinger objects as 2x2 real Jacobians in a different basis, where two real numbers just happen to be represented by one complex number. In this context then, it is quite clear, that these form only a real vector space and why multiplication with a complex number is not well defined. We can still define an injective homomorphism from the complex numbers to Wirtinger objects by mapping z to Wirtinger(z, Zero()), but these don't commute with all Wirtinger objects anymore, which would be required for a complex vector space. JuliaDiff/ChainRules.jl#133 and JuliaDiff/ChainRules.jl#135 make it quite clear that subtypes of AbstractDifferential should form a vector space, so I believe disallowing ::Wirtinger * ::Complex is the only reasonable thing to do here.

@nickrobinson251 nickrobinson251 added the Complex Differentiation Relating to any form of complex differentiation label Jun 17, 2020
@willtebbutt
Copy link
Member

@simeonschaub what's the status of this? Has it been completely superceded by the other complex numbers stuff?

@simeonschaub
Copy link
Member Author

What has been superseded is ComplexGradient, now that we just use Adjoint for that. It seems like people still expressed interest in Wirtinger derivatives, so this might be something to consider for v2. I currently don't have much use for this anymore, but if someone wanted to push this forward, I would certainly certainly offer to help wherever I can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Complex Differentiation Relating to any form of complex differentiation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants