Skip to content

writing rules for <:AbstractArray #582

Open
@maartenvd

Description

@maartenvd

How should one write "proper" rules for methods that work for generic AbstractArray objects?

As an example, take this function:

function _setindex(a::AbstractArray,v,args...)
    b::typeof(a) = copy(a);
    b[args...] = v
    b
end

This method seems pretty tame, and I think should be generically correct for any abstractarray object. The backward rule looks simple:

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        backwards_tv = v[args...];
        backwards_a = copy(v);
        backwards_a[args...] = zero.(backwards_a[args...])
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

This doesn't work of course, v can be a zerotangent! Let's correct for this case:

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

But this rule is still incorrect! When working with arrays, the tangent type can sometimes be a FillArray. FillArrays don't define setindex!, but they can be converted.

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            v = convert(typeof(a),v);
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

Still wrong of course, as it can also be a Tangent, which cannot be copied or converted, but they can be constructed!

function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...) 
    t = _setindex(a,tv,args...);
    
    function toret(v)
        if iszero(v)
            backwards_tv = ZeroTangent();
            backwards_a = ZeroTangent();
        else
            v = v isa Tangent ? construct(typeof(a),v) : v;
            v = convert(typeof(a),v);
            backwards_tv = v[args...];
            backwards_a = copy(v);
            backwards_a[args...] = zero.(backwards_a[args...])
        end
        (NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
    end
    t,toret
end

In short, my rrule essentially has to be a spaghetti of if statements, and at the end I will have no way of knowing whether my implementation will work in practice. There is no list of possible tangent types - or a formal interface that they should al satisfy, and so whatever operations I do may end up being undefined.

I have read the documentation, and I just don't understand how I am to write this backward rule. I also don't understand how I am to hook up my own types so that they play nice with chainrules.

This year old PR seems like a step in the right direction #446 but even that wouldn't solve the issue completely. ProjectTo is defined in such a way that - when faced with a type it doesn't know - it falls back to just returning the same Tangent type.

Metadata

Metadata

Assignees

No one assigned

    Labels

    ProjectTorelated to the projection functionalityStructural TangentRelated to the `Tangent` type for structured (composite) valuesdocumentationImprovements or additions to documentation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions