Closed
Description
To account for #644 in Yota, I change all calls like this:
rrule(cfg, broadcasted, f, args...)
to this:
rrule(cfg, broadcasted, bcast_style, f, args...)
However, this way many rules are not triggered anymore. A particular example I encountered is activation functions from NNlib, e.g.:
julia> @which rrule(YotaRuleConfig(), broadcasted, leakyrelu, x, 0.2f0)
rrule(::RuleConfig, args...) in ChainRulesCore at /home/azbs/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134
# ^ just invokes the same without the config
julia> @which rrule(broadcasted, leakyrelu, x, 0.2f0)
rrule(::typeof(Base.Broadcast.broadcasted), ::typeof(leakyrelu), x1::Union{AbstractArray{<:T}, T} where T<:Number, x2::Number) in NNlib at /home/azbs/.julia/packages/NNlib/0QnJJ/src/activations.jl:909
# ^ correctly points to rrule defined in NNlib
@which rrule(YotaRuleConfig(), broadcasted, Base.Broadcast.DefaultArrayStyle{2}(), leakyrelu, x, 0.2f0)
rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N} in ChainRules at /home/azbs/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:29
# ^ hits generic broadcasting, bypassing NNlib rrules
I think I can check whether rrule()
without the BroadcastStyle
exists before adding it to the call, but I was wondering if somebody else encountered this problem and how they solved it.