Skip to content

*, One, extern, and accumulate + friends #62

Closed
@willtebbutt

Description

@willtebbutt

I almost named this Getting rid of * and One, in line with most of our other open issues but, alas, that wasn't an appropriate title. This issue is more of a commentary on the value of properly deciding what operations make sense for Differentials all of the time, what operations only make sense some of the time, and what operations never make sense.

Operations on Differentials

It is currently the case that * is defined between any two Differentials. It's not actually clear that this should always be the case. Quite possibly we should consider only defining * in certain cases e.g. for scalars. Similarly, the role of One is unclear in general, and there's some confusion between + and accumulate. The aim of this issue is to resolve these issues, and improve our collective understanding of what the things in the package actually (should) do.

First consider the things that you need to be able to do with Differentials to be able to perform AD:

  1. Add Differentials during the accumulation phase of reverse-mode. So we definitely need addition (accumulation in Zygote-speak) to work, always.
  2. Apply linear maps to differentials. AD essentially takes in your original programme and returns a linear programme that accepts Differentials and returns some other Differentials.

These are the only two things necessary for ChainRules(Core) to provide for AD systems. With this in mind, we plough ahead.

+(differential, differential)

This should always be defined, and it's clear how to do it with all of our currently implemented Differential types, with the exception of One. More on that later.

Not necessary to define +(primal, differential) all of the time

Notably, it's not necessary to be able to add Differentials to their primal type. Although it's often possible to do this (and you need to be able to do it for objects that you want to gradient-descend on, for example), automatic differentiation does not require that you can.

For example, the differential of a Vector{Float64} will typically be represented as a Vector{Float64}, which we can clearly add. It could also be a Zero, or a Thunk. Now, it happens to be the case that we know how to add these to their primals, but what if we have the differential w.r.t. to the following struct:

struct Foo
    x
    y
    Foo(x) = new(x, x)
end

The differential will be represented as a Composite (see @oxinabox 's PR) with two fields, x and y. It's entirely unclear how to usefully add these two objects. Clearly the result needs to be of type Foo, but it's not possible to increment both fields of Foo.

To summarise,

  1. addition between a given differential and some primal type needn't be defined all of the time
  2. addition will be defined between differentials and primals most of the time. e.g. you should expect that you can add the differential of a Vector or Float64 to said Vector or Float64. Similarly, additional between primal and Zero is always defined (just don't modify the primal). But there exist some primal-differential pairs that don't admit a useful / sensible definition of +, and that is okay.

One

See here

In short, it doesn't make sense as a differential, but it might not be a totally useless construct.

Multiplication

Multiplication doesn't make sense between all differential types, but we are currently required to define it for all differential types. For example, what does it mean to multiply two composites? You could define it to be some notion of elementwise multiplication, but it's not clear what we would gain by doing that.

In the context of ChainRules(Core), multiplication is best thought of as a data-parametrised linear map that is well-defined for certain types of differentials (scalars, vectors, matrices, etc) but not others (e.g. Composites). This observation resolves the issue with Wirtingers whereby it's not at all clear how to multiply them together. We could remove the method with the massive error message entirely since there would no longer be any expectation on the part of the user that one should be able to multiply them together in an unambiguous manner.

As such, it clearly has value and I believe that we should

  • define it for Zero, because anything * Zero is Zero (currently implemented)
  • define it recursively for AbstractThunk (currently implemented)
  • not require it to be defined for Composite, because how would it be usefully defined?
  • remove the definition of * for Wirtingers entirely, as there's no longer any expectation that it should work.

Accumulate + friends

Given the centrality of +, it makes sense to re-visit accumulate at this point.

  • accumulate literally just implements +. We have +, so there's no need for accumulate.
  • accumulate! is the in-place version of +. This method lets you do A = A + B without allocating, where A and B are differentials. It might make sense for us to rename this add! for consistency with +. It's important to note, however, that it's not always possible to in-place add differentials, and this will only be well-defined some of the time. A more useful way of thinking about this functionality is as maybe_add_inplace.
  • Similar comment for store!. It only really makes sense some of the time, specifically when dealing with dense arrays, so it's not clear to me how much of a help this really is. Probably requires further though.

To my mind, the conclusion here is that we just remove accumulate, and think further about what we actually want to get out of accumulate! and store! in a separate issue. They're not top-priority, but there are definitely use-cases for them in big neural-network-y applications where you typically work handle a lot of StridedArray types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions