Description
Natural tangents are helpful because they're sometimes more intelligible to humans than structural tangents in a number of situations, and can play nicely with generic linear operations written in rrule
s in a number of situations, making it more straightforward to write generic rrules. However, this is not always the case, and they can cause complications.
The purpose of this issue is to document situations in which structural tangents are either preferable or a necessity, and the complications that can occur when structural and natural tangents interact. Some of it is obvious, and I'm intentionally not proposing to do anything, I just needed to get this stuff out of my head so that I can focus on other work for the next couple of weeks without this bothering me.
Maybe we should link this in the docs that @oxinabox recently merged?
In the following, please assume that the primal is an AbstractArray
, the structural tangent a Tangent
, and the natural tangent another AbstractArray
.
Interactions between structurals and naturals
Suppose we have code in which a hand-written pullback produces a natural tangent for a primal, and an automatically generated pullback produces a structural tangent for the same primal (say, because it hit getfield
). At some point, these will need to be added together. This means that code must exist to convert one to the other, and it must be hit. To my knowledge, we don't really deal with this at the minute.
Sometimes structurals are more natural
Clarity is in the eye of the beholder, but to my eye it's not uncommon for the structural tangent to be much more straightforwardly interpretable than the natural tangent. Moreover, it's always obvious what the structural tangent is, whereas one has to think (sometimes for an extended period of time, and consult with others) to determine an appropriate natural tangent. This problem is compounded by our current lack of a rigorous definition for the natural tangent.
One example of this is Fill
arrays. If we think about a Fill
as a struct
in the context of AD, it's incredibly clear what's going on and what its structural tangent represents. Conversely, if we think about it as an array, its natural tangent requires a good deal of thought. Again, the lack of a proper definition of a natural tangent is possibly the culprit.
Another good example is the WoodburyPDMat
. I haven't the foggiest idea what an appropriate natural tangent would be, nor do I particularly want to. The structural doesn't suffer this problem. In this sense, it's much more natural to think about the structural tangent. Later in this issue, I'm going to use this as an example of a situation in which using the structural tangent is necessary, and the natural is undefined. If anyone wants to argue that I must derive a natural tangent for this matrix, then we will need to have a separate conversation.
Sometimes the natural is simply redundant
On the WoodburyPDMat
example above, we intentionally only implement a few high-level linear algebra operations. There's simply never any need for a natural tangent because all of the code has been written in an AD-friendly manner (non-mutating, small number of differentiable high-level operations).
Other examples are ColVecs
and RowVecs
in the JuliaGPs ecosystem. They're thin wrappers around an AbstractMatrix
which are really only designed to make its interpretation in a particular context clear. While getindex
is defined, it's considered a bug if it's hit inside AD. Instead, the use of getfield
is central to ColVecs
and RowVecs
usage. Consequently, AD ought always to be able to derive pullbacks automatically in practice -- certainly we don't want to have to write any rules, or define ProjectTo
.
Structural tangents are sometimes a necessity
Symmetric
matrices are a good example of something that can use a natural differential a decent chunk of the time, but sometimes simply cannot.
If a Symmetric
wraps is a primitive, then we'll often want to use the natural differential. For example, this rule for svd
. This is fine.
In other situations, it's easier to think about rules producing structural tangents. For example, Iain Murray's derivation of the rrule for the Cholesky factorisation uses only the upper triangle of a Symmetric
, so it's easier to think in terms of the structural when writing that rrule.
A more extreme example is if one were to put a WoodburyPDMat
inside a Symmetric and do AD. The tangent of a WoodburyPDMat
will generally be a Tangent
, which cannot itself be stored inside of a Symmetric.
Therefore, the tangent of a Symmetric{<:Number, <:WoodburyPDMat}
must be a Tangent
.
A condition under which a given AbstractArray can be treated as a primitive
By calling a type T
primitive, I mean that the tangents of T
are always of type T
themselves.
Assuming that we define the pullback
of getfield
on a given struct
to return a Tangent
, the above yields the following condition for the possibility of treating a particular AbstractArray
(which is a struct
) as a primitive:
Each field must be able to take the type of any of its tangents.
Interestingly, this seems to preclude treating a number of very common AbstractArrays as primitives:
Symmetric
(wraps anAbstractMatrix
, whose tangent may be aTangent
)Diagonal
(for the same reason)Transpose
/Adjoint
(for the same reason)SArray
(wraps aTuple
, whose tangent is always aTangent
)
This is, of course, not to say that natural tangents ought not to be used for these types, it is simply to say that it's not possible to preclude the need for the use of a structural tangent some of the time.
Wait, we can't treat SArray
as a primitive?
The above claim was made under the assumption that we make the pullback for the rrule for getfield
return a Tangent
. This is crucial. As @oxinabox pointed out the other day, we don't always have to define getfield
this way.
If for an SArray
we make said pullback return another SArray
, the problem disappears. Indeed, in this case, we can safely treat SArray
s as primitives and never have to worry about the structural derivative. For example, the rrule
for getfield
might be something like
function rrule(::typeof(getfield), X::SArray, ::Any) # there's only one field, so nothing else to access
getfield_pullback(TY::Tangent) = SArray(TY...)
return S.data, getfield_pullback # S.data is a tuple, so its tangent must be a Tangent
end
We could also ensure that all other operations on SArray
s return SArray
s.
It's not obvious that the same can be done for Symmetric
/ Diagonal
/ Transpose
/ Adjoint
etc. The thing that made it work for SArray
was the ability to meaningfully convert the structural tangent of its field (a Tangent
) into the type of its field (a Tuple
).
Ultimately, this seems to come down to how we define getfield
's rrule. If we are willing to put the work in to ensure that any tangent for a field can be converted into a thing that can be put back into the primitive (and still have sensible semantics, like +
and Real *
working) then we might be able to treat a given type as primitive. This simply isn't possible for lots of arrays because, as discussed above, we can't generally rule out the need for a structural tangent, which rules out Symmetric
etc.
This doesn't seem to have tied down exactly what we can / can't treat as a primitive, but I thnk it gets us closer, and it at least suggests that we can ensure the non-existance of structural tangents for SArray
s if someone is willing to put the work in (and figure out where the code should live!)