Skip to content

Structural tangents are cool too, and SArrays can probably be primitivesΒ #441

Open
@willtebbutt

Description

@willtebbutt

Natural tangents are helpful because they're sometimes more intelligible to humans than structural tangents in a number of situations, and can play nicely with generic linear operations written in rrules in a number of situations, making it more straightforward to write generic rrules. However, this is not always the case, and they can cause complications.

The purpose of this issue is to document situations in which structural tangents are either preferable or a necessity, and the complications that can occur when structural and natural tangents interact. Some of it is obvious, and I'm intentionally not proposing to do anything, I just needed to get this stuff out of my head so that I can focus on other work for the next couple of weeks without this bothering me.

Maybe we should link this in the docs that @oxinabox recently merged?

In the following, please assume that the primal is an AbstractArray, the structural tangent a Tangent, and the natural tangent another AbstractArray.

Interactions between structurals and naturals

Suppose we have code in which a hand-written pullback produces a natural tangent for a primal, and an automatically generated pullback produces a structural tangent for the same primal (say, because it hit getfield). At some point, these will need to be added together. This means that code must exist to convert one to the other, and it must be hit. To my knowledge, we don't really deal with this at the minute.

Sometimes structurals are more natural

Clarity is in the eye of the beholder, but to my eye it's not uncommon for the structural tangent to be much more straightforwardly interpretable than the natural tangent. Moreover, it's always obvious what the structural tangent is, whereas one has to think (sometimes for an extended period of time, and consult with others) to determine an appropriate natural tangent. This problem is compounded by our current lack of a rigorous definition for the natural tangent.

One example of this is Fill arrays. If we think about a Fill as a struct in the context of AD, it's incredibly clear what's going on and what its structural tangent represents. Conversely, if we think about it as an array, its natural tangent requires a good deal of thought. Again, the lack of a proper definition of a natural tangent is possibly the culprit.

Another good example is the WoodburyPDMat. I haven't the foggiest idea what an appropriate natural tangent would be, nor do I particularly want to. The structural doesn't suffer this problem. In this sense, it's much more natural to think about the structural tangent. Later in this issue, I'm going to use this as an example of a situation in which using the structural tangent is necessary, and the natural is undefined. If anyone wants to argue that I must derive a natural tangent for this matrix, then we will need to have a separate conversation.

Sometimes the natural is simply redundant

On the WoodburyPDMat example above, we intentionally only implement a few high-level linear algebra operations. There's simply never any need for a natural tangent because all of the code has been written in an AD-friendly manner (non-mutating, small number of differentiable high-level operations).

Other examples are ColVecs and RowVecs in the JuliaGPs ecosystem. They're thin wrappers around an AbstractMatrix which are really only designed to make its interpretation in a particular context clear. While getindex is defined, it's considered a bug if it's hit inside AD. Instead, the use of getfield is central to ColVecs and RowVecs usage. Consequently, AD ought always to be able to derive pullbacks automatically in practice -- certainly we don't want to have to write any rules, or define ProjectTo.

Structural tangents are sometimes a necessity

Symmetric matrices are a good example of something that can use a natural differential a decent chunk of the time, but sometimes simply cannot.

If a Symmetric wraps is a primitive, then we'll often want to use the natural differential. For example, this rule for svd. This is fine.

In other situations, it's easier to think about rules producing structural tangents. For example, Iain Murray's derivation of the rrule for the Cholesky factorisation uses only the upper triangle of a Symmetric, so it's easier to think in terms of the structural when writing that rrule.

A more extreme example is if one were to put a WoodburyPDMat inside a Symmetric and do AD. The tangent of a WoodburyPDMat will generally be a Tangent, which cannot itself be stored inside of a Symmetric. Therefore, the tangent of a Symmetric{<:Number, <:WoodburyPDMat} must be a Tangent.

A condition under which a given AbstractArray can be treated as a primitive

By calling a type T primitive, I mean that the tangents of T are always of type T themselves.

Assuming that we define the pullback of getfield on a given struct to return a Tangent, the above yields the following condition for the possibility of treating a particular AbstractArray (which is a struct) as a primitive:

Each field must be able to take the type of any of its tangents.

Interestingly, this seems to preclude treating a number of very common AbstractArrays as primitives:

  • Symmetric (wraps an AbstractMatrix, whose tangent may be a Tangent)
  • Diagonal (for the same reason)
  • Transpose / Adjoint (for the same reason)
  • SArray (wraps a Tuple, whose tangent is always a Tangent)

This is, of course, not to say that natural tangents ought not to be used for these types, it is simply to say that it's not possible to preclude the need for the use of a structural tangent some of the time.

Wait, we can't treat SArray as a primitive?

The above claim was made under the assumption that we make the pullback for the rrule for getfield return a Tangent. This is crucial. As @oxinabox pointed out the other day, we don't always have to define getfield this way.

If for an SArray we make said pullback return another SArray, the problem disappears. Indeed, in this case, we can safely treat SArrays as primitives and never have to worry about the structural derivative. For example, the rrule for getfield might be something like

function rrule(::typeof(getfield), X::SArray, ::Any) # there's only one field, so nothing else to access
    getfield_pullback(TY::Tangent) = SArray(TY...)
    return S.data, getfield_pullback # S.data is a tuple, so its tangent must be a Tangent
end

We could also ensure that all other operations on SArrays return SArrays.

It's not obvious that the same can be done for Symmetric / Diagonal / Transpose / Adjoint etc. The thing that made it work for SArray was the ability to meaningfully convert the structural tangent of its field (a Tangent) into the type of its field (a Tuple).

Ultimately, this seems to come down to how we define getfield's rrule. If we are willing to put the work in to ensure that any tangent for a field can be converted into a thing that can be put back into the primitive (and still have sensible semantics, like + and Real * working) then we might be able to treat a given type as primitive. This simply isn't possible for lots of arrays because, as discussed above, we can't generally rule out the need for a structural tangent, which rules out Symmetric etc.

This doesn't seem to have tied down exactly what we can / can't treat as a primitive, but I thnk it gets us closer, and it at least suggests that we can ensure the non-existance of structural tangents for SArrays if someone is willing to put the work in (and figure out where the code should live!)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Structural TangentRelated to the `Tangent` type for structured (composite) values

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions