Description
I almost named this Getting rid of * and One
, in line with most of our other open issues but, alas, that wasn't an appropriate title. This issue is more of a commentary on the value of properly deciding what operations make sense for Differential
s all of the time, what operations only make sense some of the time, and what operations never make sense.
Operations on Differentials
It is currently the case that *
is defined between any two Differential
s. It's not actually clear that this should always be the case. Quite possibly we should consider only defining *
in certain cases e.g. for scalars. Similarly, the role of One
is unclear in general, and there's some confusion between +
and accumulate
. The aim of this issue is to resolve these issues, and improve our collective understanding of what the things in the package actually (should) do.
First consider the things that you need to be able to do with Differential
s to be able to perform AD:
- Add
Differential
s during the accumulation phase of reverse-mode. So we definitely need addition (accumulation in Zygote-speak) to work, always. - Apply linear maps to differentials. AD essentially takes in your original programme and returns a linear programme that accepts
Differential
s and returns some otherDifferential
s.
These are the only two things necessary for ChainRules(Core)
to provide for AD systems. With this in mind, we plough ahead.
+(differential, differential)
This should always be defined, and it's clear how to do it with all of our currently implemented Differential
types, with the exception of One
. More on that later.
Not necessary to define +(primal, differential) all of the time
Notably, it's not necessary to be able to add Differential
s to their primal type. Although it's often possible to do this (and you need to be able to do it for objects that you want to gradient-descend on, for example), automatic differentiation does not require that you can.
For example, the differential of a Vector{Float64}
will typically be represented as a Vector{Float64}
, which we can clearly add. It could also be a Zero
, or a Thunk
. Now, it happens to be the case that we know how to add these to their primals, but what if we have the differential w.r.t. to the following struct
:
struct Foo
x
y
Foo(x) = new(x, x)
end
The differential will be represented as a Composite
(see @oxinabox 's PR) with two fields, x
and y
. It's entirely unclear how to usefully add these two objects. Clearly the result needs to be of type Foo
, but it's not possible to increment both fields of Foo
.
To summarise,
- addition between a given differential and some primal type needn't be defined all of the time
- addition will be defined between differentials and primals most of the time. e.g. you should expect that you can add the differential of a
Vector
orFloat64
to saidVector
orFloat64
. Similarly, additional between primal andZero
is always defined (just don't modify the primal). But there exist some primal-differential pairs that don't admit a useful / sensible definition of+
, and that is okay.
One
See here
In short, it doesn't make sense as a differential, but it might not be a totally useless construct.
Multiplication
Multiplication doesn't make sense between all differential types, but we are currently required to define it for all differential types. For example, what does it mean to multiply two composites? You could define it to be some notion of elementwise multiplication, but it's not clear what we would gain by doing that.
In the context of ChainRules(Core)
, multiplication is best thought of as a data-parametrised linear map that is well-defined for certain types of differentials (scalars, vectors, matrices, etc) but not others (e.g. Composite
s). This observation resolves the issue with Wirtinger
s whereby it's not at all clear how to multiply them together. We could remove the method with the massive error message entirely since there would no longer be any expectation on the part of the user that one should be able to multiply them together in an unambiguous manner.
As such, it clearly has value and I believe that we should
- define it for
Zero
, because anything*
Zero
isZero
(currently implemented) - define it recursively for
AbstractThunk
(currently implemented) - not require it to be defined for
Composite
, because how would it be usefully defined? - remove the definition of
*
forWirtinger
s entirely, as there's no longer any expectation that it should work.
Accumulate + friends
Given the centrality of +
, it makes sense to re-visit accumulate
at this point.
accumulate
literally just implements+
. We have+
, so there's no need foraccumulate
.accumulate!
is the in-place version of+
. This method lets you doA = A + B
without allocating, whereA
andB
are differentials. It might make sense for us to rename thisadd!
for consistency with+
. It's important to note, however, that it's not always possible to in-place add differentials, and this will only be well-defined some of the time. A more useful way of thinking about this functionality is asmaybe_add_inplace
.- Similar comment for
store!
. It only really makes sense some of the time, specifically when dealing with dense arrays, so it's not clear to me how much of a help this really is. Probably requires further though.
To my mind, the conclusion here is that we just remove accumulate
, and think further about what we actually want to get out of accumulate!
and store!
in a separate issue. They're not top-priority, but there are definitely use-cases for them in big neural-network-y applications where you typically work handle a lot of StridedArray
types.