Skip to content

rrule(broadcasted, ...) with BroadcastStyle skips many rules  #663

Closed
@dfdx

Description

@dfdx

To account for #644 in Yota, I change all calls like this:

rrule(cfg, broadcasted, f, args...)

to this:

rrule(cfg, broadcasted, bcast_style, f, args...)

However, this way many rules are not triggered anymore. A particular example I encountered is activation functions from NNlib, e.g.:

julia> @which rrule(YotaRuleConfig(), broadcasted, leakyrelu, x, 0.2f0)
rrule(::RuleConfig, args...) in ChainRulesCore at /home/azbs/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134
# ^ just invokes the same without the config

julia> @which rrule(broadcasted, leakyrelu, x, 0.2f0)
rrule(::typeof(Base.Broadcast.broadcasted), ::typeof(leakyrelu), x1::Union{AbstractArray{<:T}, T} where T<:Number, x2::Number) in NNlib at /home/azbs/.julia/packages/NNlib/0QnJJ/src/activations.jl:909
# ^ correctly points to rrule defined in NNlib

@which rrule(YotaRuleConfig(), broadcasted, Base.Broadcast.DefaultArrayStyle{2}(), leakyrelu, x, 0.2f0)
rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N} in ChainRules at /home/azbs/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:29
# ^ hits generic broadcasting, bypassing NNlib rrules

I think I can check whether rrule() without the BroadcastStyle exists before adding it to the call, but I was wondering if somebody else encountered this problem and how they solved it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions