Skip to content

Commit c971665

Browse files
committed
fixup
1 parent 84989e6 commit c971665

File tree

3 files changed

+108
-94
lines changed

3 files changed

+108
-94
lines changed

src/extra_rules.jl

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -138,27 +138,17 @@ struct NonDiffOdd{N, O, P}; end
138138
(::NonDiffEven{N, O, O})(Δ...) where {N, O} = error()
139139

140140
# WARNING: Method definition rrule(typeof(Core.apply_type), Any, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Core/core.jl:10 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:140.
141-
# @Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
142-
# Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
143-
# end
141+
@Base.pure function ChainRulesCore.rrule(::typeof(Core.apply_type), head, args...)
142+
Core.apply_type(head, args...), NonDiffOdd{plus1(plus1(length(args))), 1, 1}()
143+
end
144144

145145
function ChainRulesCore.rrule(::typeof(Core.tuple), args...)
146146
Core.tuple(args...), Δ->Core.tuple(NoTangent(), Δ...)
147147
end
148148

149-
# TODO: What to do about these integer rules
150-
# @ChainRulesCore.non_differentiable Base.rem(a::Integer, b::Type) # now in CR 1.18
151-
152149
ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent()
153150
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
154151

155-
# # Skip AD'ing through the axis computation
156-
# function ChainRules.rrule(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
157-
# return Base.Broadcast.instantiate(bc), Δ->begin
158-
# Core.tuple(NoTangent(), Δ)
159-
# end
160-
# end
161-
162152

163153
using StaticArrays
164154

@@ -201,22 +191,6 @@ function ChainRules.rrule(::typeof(map), ::typeof(+), A::AbstractVector, B::Abst
201191
map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ)
202192
end
203193

204-
# https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/array.jl#L7
205-
# function ChainRules.rrule(AT::Type{<:Array{T,N}}, x::AbstractArray{S,N}) where {T,S,N}
206-
# # We're leaving these in the eltype that the cotangent vector already has.
207-
# # There isn't really a good reason to believe we should convert to the
208-
# # original array type, so don't unless explicitly requested.
209-
# AT(x), Δ->(NoTangent(), Δ)
210-
# end
211-
212-
# WARNING: Method definition rrule(Type{var"#s260"} where var"#s260"<:(Array{T, N} where N where T), UndefInitializer, Any...) in module ChainRules at /Users/me/.julia/packages/ChainRules/kkDLd/src/rulesets/Base/array.jl:5 overwritten in module Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:209.
213-
# function ChainRules.rrule(AT::Type{<:Array}, undef::UndefInitializer, args...)
214-
# # We're leaving these in the eltype that the cotangent vector already has.
215-
# # There isn't really a good reason to believe we should convert to the
216-
# # original array type, so don't unless explicitly requested.
217-
# AT(undef, args...), Δ->(NoTangent(), NoTangent(), ntuple(_->NoTangent(), length(args))...)
218-
# end
219-
220194
function unzip_tuple(t::Tuple)
221195
map(x->x[1], t), map(x->x[2], t)
222196
end
@@ -256,7 +230,6 @@ function ChainRules.frule(_, ::Type{Vector{T}}, undef::UndefInitializer, dims::I
256230
Vector{T}(undef, dims...), zeros(T, dims...)
257231
end
258232

259-
# @ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer) CR#558
260233
@ChainRules.non_differentiable Base.throw(err)
261234
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
262235

@@ -273,3 +246,13 @@ end
273246
# ERROR: ArgumentError: Tangent for the primal Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}} should be backed by a AbstractDict type, not by NamedTuple{(:data,), Tuple{ChainRulesCore.ZeroTangent}}.
274247
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing # solves that!
275248

249+
# Rather than have a rule for broadcasted 3-arg *, just send it to the efficient path:
250+
ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = ((y*z, x*z, x*y),)
251+
function ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number, w::Number)
252+
xy = x*y
253+
zw = z*w
254+
((y*zw, x*zw, xy*w, xy*z),)
255+
end
256+
257+
# Fixes @test (x->sum(isa_control_flow(Matrix{Float64}, x)))'(Float32[1 2;]) == [1.0 1.0;]
258+
(project::ProjectTo{<:AbstractArray})(th::InplaceableThunk) = project(unthunk(th))

src/stage1/broadcast.jl

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ end
3434
using ChainRulesCore: derivatives_given_output
3535

3636
# Broadcast over one element is just map
37-
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
38-
∂⃖ₙ(map, f, a)
39-
end
37+
# function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
38+
# ∂⃖ₙ(map, f, a)
39+
# end
40+
41+
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
4042

41-
(::∂⃖{1})(::typeof(broadcasted), f, args...) = split_bc_rule(f, args...)
42-
(::∂⃖{1})(::typeof(broadcasted), f, arg::Array) = split_bc_rule(f, arg) # ambiguity
43+
(::∂⃖{1})(::typeof(broadcasted), f::F, args...) where {F} = split_bc_rule(f, args...)
44+
# (::∂⃖{1})(::typeof(broadcasted), f::F, arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity
4345
function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4446
T = Broadcast.combine_eltypes(f, args)
4547
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
@@ -48,17 +50,17 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4850
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
4951
return f.(args...), back_1
5052
elseif T <: Number && isconcretetype(TΔ)
51-
# Fast path: just broadcast, and use x & y to find derivative.
53+
# Fast path: just broadcast, and use arguments & result to find derivatives.
5254
ys = f.(args...)
5355
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
5456
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
5557
das = only(derivatives_given_output(y, f, a))
56-
dy * conj(only(das))
58+
dy * conj(only(das)) # possibly this * should be made nan-safe.
5759
end
5860
(NoTangent(), NoTangent(), unbroadcast(only(args), delta))
5961
end
6062
function back_2_many(dys)
61-
deltas = splitcast(unthunk(dys), ys, args...) do dy, y, as...
63+
deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as...
6264
das = only(derivatives_given_output(y, f, as...))
6365
map(da -> dy * conj(da), das)
6466
end
@@ -70,62 +72,76 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
7072
# Slow path: collect all the pullbacks & apply them later.
7173
# (Since broadcast makes no guarantee about order of calls, and un-fusing
7274
# can change the number of calls, this does not bother to try to reverse.)
73-
ys, backs = splitcast(∂⃖{1}(), f, args...)
75+
ys3, backs = tuplecast(∂⃖{1}(), f, args...)
7476
function back_3(dys)
75-
deltas = splitmap(backs, unthunk(dys)) do back, dy
77+
deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match
7678
map(unthunk, back(dy))
7779
end
78-
dargs = map(unbroadcast, args, Base.tail(deltas)) # no real need to close over args here
80+
dargs = map(unbroadcast, args, Base.tail(deltas))
7981
(NoTangent(), sum(first(deltas)), dargs...)
8082
end
8183
back_3(::AbstractZero) = (NoTangent(), map(Returns(ZeroTangent()), args)...)
82-
return ys, back_3
84+
return ys3, back_3
8385
end
8486
end
8587

88+
# Don't run broadcasting on scalars
89+
function split_bc_rule(f::F, args::Number...) where {F}
90+
z, back = ∂⃖{1}()(f, args...)
91+
z, dz -> (NoTangent(), back(dz)...)
92+
end
93+
94+
split_bc_rule(::typeof(identity), x) = x, Δ -> (NoTangent(), NoTangent(), Δ)
95+
split_bc_rule(::typeof(identity), x::Number) = x, Δ -> (NoTangent(), NoTangent(), Δ)
96+
8697
# Skip AD'ing through the axis computation
8798
function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted)
8899
uninstantiate(Δ) = Core.tuple(NoTangent(), Δ)
89100
return Base.Broadcast.instantiate(bc), uninstantiate
90101
end
91102

92-
# This uses "multimap"-like constructs:
93103
using StructArrays
94-
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...)))
95-
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
104+
105+
function tuplecast(f::F, args...) where {F}
106+
T = Broadcast.combine_eltypes(f, args)
107+
if isconcretetype(T)
108+
T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple."))
109+
end
110+
bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
111+
StructArrays.components(StructArray(bc))
112+
end
96113

97114
# For certain cheap operations we can easily allow fused broadcast:
115+
const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}
98116

99-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...)
100-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity
117+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::NumericOrBroadcast...) = lazy_bc_plus(args...)
118+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args::Number) = split_bc_rule(+, args...)
101119
function lazy_bc_plus(xs...) where {F}
102120
broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw)
103121
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...)
104122
end
105123
end
106124

107-
(::∂⃖{1})(::typeof(copy), bc::Broadcast.Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ)
108-
109-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
125+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = split_bc_rule(-, x, y)
126+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast)
110127
broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw)
111128
(NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ))
112-
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
113129
end
114130
end
115131

116132
using LinearAlgebra: dot
117-
const Numeric{T<:Number} = Union{T, AbstractArray{T}}
118133

119-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric)
134+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = split_bc_rule(*, x, y)
135+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast)
120136
broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw)
121-
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y))
122-
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x))
123-
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
124-
(NoTangent(), NoTangent(), dx, dy)
137+
(NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ))
125138
end
126139
end
140+
_back_star(x, y, Δ) = unbroadcast(x, Δ .* conj.(y))
141+
_back_star(x::Number, y, Δ) = dot(y, Δ)
142+
_back_star(x::Bool, y, Δ) = NoTangent()
127143

128-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2})
144+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2})
129145
broadcasted(*, x, x), Δ -> begin
130146
dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x))
131147
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
@@ -135,41 +151,40 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type
135151
x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent())
136152
end
137153

138-
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number)
139-
z, back = ∂⃖{1}()(/, x, y)
140-
z, dz -> begin
141-
_, dx, dy = back(dz)
154+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = split_bc_rule(/, x, y)
155+
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number)
156+
z = broadcast(/, x, y)
157+
z, Δth -> let Δ = unthunk(Δth)
158+
dx = unbroadcast(x, Δ ./ conj.(y))
159+
dy = -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here
142160
(NoTangent(), NoTangent(), dx, dy)
143161
end
144162
end
145163

146-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = x, identity_pullback
147-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = x, identity_pullback # ambiguity
148-
identity_pullback(Δ) = (NoTangent(), NoTangent(), Δ)
164+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x) = split_bc_rule(identity, x)
165+
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity
149166

150-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = x, identity_pullback
151-
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = x, identity_pullback
167+
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::AbstractArray{Real}) = split_bc_rule(identity, x)
168+
# (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity
152169
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x) =
153170
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
154171
(::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array) =
155172
broadcasted(conj, x), Δ -> (NoTangent(), conj(unthunk(Δ)))
156173

157-
# All broadcasts use `unbroadcast` to reduce to correct shape:
158-
174+
# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
159175
function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx)
160176
N = ndims(dx)
161177
if length(x) == length(dx)
162178
ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
163179
else
164-
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # awful hack to get type-stable `dims`
165-
ProjectTo(x)(sum(dx; dims))
180+
dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims`
181+
ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked?
166182
end
167183
end
168184
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx
169185

170186
unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx)))
171187
function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
172-
_print("unbroadcast tuple")
173188
val = if length(x) == length(dx)
174189
dx
175190
else

0 commit comments

Comments
 (0)