Description
How should one write "proper" rules for methods that work for generic AbstractArray objects?
As an example, take this function:
function _setindex(a::AbstractArray,v,args...)
b::typeof(a) = copy(a);
b[args...] = v
b
end
This method seems pretty tame, and I think should be generically correct for any abstractarray object. The backward rule looks simple:
function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...)
t = _setindex(a,tv,args...);
function toret(v)
backwards_tv = v[args...];
backwards_a = copy(v);
backwards_a[args...] = zero.(backwards_a[args...])
(NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
end
t,toret
end
This doesn't work of course, v can be a zerotangent! Let's correct for this case:
function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...)
t = _setindex(a,tv,args...);
function toret(v)
if iszero(v)
backwards_tv = ZeroTangent();
backwards_a = ZeroTangent();
else
backwards_tv = v[args...];
backwards_a = copy(v);
backwards_a[args...] = zero.(backwards_a[args...])
end
(NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
end
t,toret
end
But this rule is still incorrect! When working with arrays, the tangent type can sometimes be a FillArray. FillArrays don't define setindex!, but they can be converted.
function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...)
t = _setindex(a,tv,args...);
function toret(v)
if iszero(v)
backwards_tv = ZeroTangent();
backwards_a = ZeroTangent();
else
v = convert(typeof(a),v);
backwards_tv = v[args...];
backwards_a = copy(v);
backwards_a[args...] = zero.(backwards_a[args...])
end
(NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
end
t,toret
end
Still wrong of course, as it can also be a Tangent, which cannot be copied or converted, but they can be constructed!
function ChainRulesCore.rrule(::typeof(_setindex),a::AbstractArray,tv,args...)
t = _setindex(a,tv,args...);
function toret(v)
if iszero(v)
backwards_tv = ZeroTangent();
backwards_a = ZeroTangent();
else
v = v isa Tangent ? construct(typeof(a),v) : v;
v = convert(typeof(a),v);
backwards_tv = v[args...];
backwards_a = copy(v);
backwards_a[args...] = zero.(backwards_a[args...])
end
(NoTangent(),backwards_a,backwards_tv,fill(ZeroTangent(),length(args))...)
end
t,toret
end
In short, my rrule essentially has to be a spaghetti of if statements, and at the end I will have no way of knowing whether my implementation will work in practice. There is no list of possible tangent types - or a formal interface that they should al satisfy, and so whatever operations I do may end up being undefined.
I have read the documentation, and I just don't understand how I am to write this backward rule. I also don't understand how I am to hook up my own types so that they play nice with chainrules.
This year old PR seems like a step in the right direction #446 but even that wouldn't solve the issue completely. ProjectTo is defined in such a way that - when faced with a type it doesn't know - it falls back to just returning the same Tangent type.