34
34
using ChainRulesCore: derivatives_given_output
35
35
36
36
# Broadcast over one element is just map
37
- function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
38
- ∂⃖ₙ (map, f, a)
39
- end
37
+ # function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
38
+ # ∂⃖ₙ(map, f, a)
39
+ # end
40
+
41
+ (:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
40
42
41
- (:: ∂⃖{1 })(:: typeof (broadcasted), f, args... ) = split_bc_rule (f, args... )
42
- (:: ∂⃖{1 })(:: typeof (broadcasted), f, arg:: Array ) = split_bc_rule (f, arg) # ambiguity
43
+ (:: ∂⃖{1 })(:: typeof (broadcasted), f:: F , args... ) where {F} = split_bc_rule (f, args... )
44
+ # (::∂⃖{1})(::typeof(broadcasted), f::F , arg::Array) where {F} = split_bc_rule(f, arg) # ambiguity
43
45
function split_bc_rule (f:: F , args:: Vararg{Any,N} ) where {F,N}
44
46
T = Broadcast. combine_eltypes (f, args)
45
47
TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
@@ -48,17 +50,17 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
48
50
back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
49
51
return f .(args... ), back_1
50
52
elseif T <: Number && isconcretetype (TΔ)
51
- # Fast path: just broadcast, and use x & y to find derivative .
53
+ # Fast path: just broadcast, and use arguments & result to find derivatives .
52
54
ys = f .(args... )
53
55
function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
54
56
delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
55
57
das = only (derivatives_given_output (y, f, a))
56
- dy * conj (only (das))
58
+ dy * conj (only (das)) # possibly this * should be made nan-safe.
57
59
end
58
60
(NoTangent (), NoTangent (), unbroadcast (only (args), delta))
59
61
end
60
62
function back_2_many (dys)
61
- deltas = splitcast (unthunk (dys), ys, args... ) do dy, y, as...
63
+ deltas = tuplecast (unthunk (dys), ys, args... ) do dy, y, as...
62
64
das = only (derivatives_given_output (y, f, as... ))
63
65
map (da -> dy * conj (da), das)
64
66
end
@@ -70,62 +72,76 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
70
72
# Slow path: collect all the pullbacks & apply them later.
71
73
# (Since broadcast makes no guarantee about order of calls, and un-fusing
72
74
# can change the number of calls, this does not bother to try to reverse.)
73
- ys , backs = splitcast (∂⃖ {1} (), f, args... )
75
+ ys3 , backs = tuplecast (∂⃖ {1} (), f, args... )
74
76
function back_3 (dys)
75
- deltas = splitmap (backs, unthunk (dys)) do back, dy
77
+ deltas = tuplecast (backs, unthunk (dys)) do back, dy # could be map, sizes match
76
78
map (unthunk, back (dy))
77
79
end
78
- dargs = map (unbroadcast, args, Base. tail (deltas)) # no real need to close over args here
80
+ dargs = map (unbroadcast, args, Base. tail (deltas))
79
81
(NoTangent (), sum (first (deltas)), dargs... )
80
82
end
81
83
back_3 (:: AbstractZero ) = (NoTangent (), map (Returns (ZeroTangent ()), args)... )
82
- return ys , back_3
84
+ return ys3 , back_3
83
85
end
84
86
end
85
87
88
+ # Don't run broadcasting on scalars
89
+ function split_bc_rule (f:: F , args:: Number... ) where {F}
90
+ z, back = ∂⃖ {1} ()(f, args... )
91
+ z, dz -> (NoTangent (), back (dz)... )
92
+ end
93
+
94
+ split_bc_rule (:: typeof (identity), x) = x, Δ -> (NoTangent (), NoTangent (), Δ)
95
+ split_bc_rule (:: typeof (identity), x:: Number ) = x, Δ -> (NoTangent (), NoTangent (), Δ)
96
+
86
97
# Skip AD'ing through the axis computation
87
98
function (:: ∂⃖{1 })(:: typeof (Base. Broadcast. instantiate), bc:: Base.Broadcast.Broadcasted )
88
99
uninstantiate (Δ) = Core. tuple (NoTangent (), Δ)
89
100
return Base. Broadcast. instantiate (bc), uninstantiate
90
101
end
91
102
92
- # This uses "multimap"-like constructs:
93
103
using StructArrays
94
- splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
95
- splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
104
+
105
+ function tuplecast (f:: F , args... ) where {F}
106
+ T = Broadcast. combine_eltypes (f, args)
107
+ if isconcretetype (T)
108
+ T <: Tuple || throw (ArgumentError (" tuplecast(f, args) only works on functions returning a tuple." ))
109
+ end
110
+ bc = Broadcast. instantiate (Broadcast. broadcasted (f, args... ))
111
+ StructArrays. components (StructArray (bc))
112
+ end
96
113
97
114
# For certain cheap operations we can easily allow fused broadcast:
115
+ const NumericOrBroadcast = Union{Number, AbstractArray{<: Number }, NTuple{<: Any ,Number}, Broadcast. Broadcasted}
98
116
99
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = lazy_bc_plus (args... )
100
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg :: Array ) = lazy_bc_plus (arg) # ambiguity
117
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args:: NumericOrBroadcast ... ) = lazy_bc_plus (args... )
118
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args :: Number ) = split_bc_rule ( + , args ... )
101
119
function lazy_bc_plus (xs... ) where {F}
102
120
broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw)
103
121
(NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ), xs)... )
104
122
end
105
123
end
106
124
107
- (:: ∂⃖{1 })(:: typeof (copy), bc:: Broadcast.Broadcasted ) = copy (bc), Δ -> (NoTangent (), Δ)
108
-
109
- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
125
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x:: Number , y:: Number ) = split_bc_rule (- , x, y)
126
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
110
127
broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw)
111
128
(NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ))
112
- # Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
113
129
end
114
130
end
115
131
116
132
using LinearAlgebra: dot
117
- const Numeric{T<: Number } = Union{T, AbstractArray{T}}
118
133
119
- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric )
134
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Number , y:: Number ) = split_bc_rule (* , x, y)
135
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: NumericOrBroadcast , y:: NumericOrBroadcast )
120
136
broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw)
121
- dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ) : unbroadcast (x, Δ .* conj .(y))
122
- dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ) : unbroadcast (y, Δ .* conj .(x))
123
- # When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
124
- (NoTangent (), NoTangent (), dx, dy)
137
+ (NoTangent (), NoTangent (), _back_star (x, y, Δ), _back_star (y, x, Δ))
125
138
end
126
139
end
140
+ _back_star (x, y, Δ) = unbroadcast (x, Δ .* conj .(y))
141
+ _back_star (x:: Number , y, Δ) = dot (y, Δ)
142
+ _back_star (x:: Bool , y, Δ) = NoTangent ()
127
143
128
- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
144
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: NumericOrBroadcast , :: Val{2} )
129
145
broadcasted (* , x, x), Δ -> begin
130
146
dx = unbroadcast (x, 2 .* unthunk (Δ) .* conj .(x))
131
147
(NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
@@ -135,41 +151,40 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::type
135
151
x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
136
152
end
137
153
138
- function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Numeric , y:: Number )
139
- z, back = ∂⃖ {1} ()(/ , x, y)
140
- z, dz -> begin
141
- _, dx, dy = back (dz)
154
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Number , y:: Number ) = split_bc_rule (/ , x, y)
155
+ function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: NumericOrBroadcast , y:: Number )
156
+ z = broadcast (/ , x, y)
157
+ z, Δth -> let Δ = unthunk (Δth)
158
+ dx = unbroadcast (x, Δ ./ conj .(y))
159
+ dy = - dot (z, Δ) / (conj (y)) # the reason to be eager is to allow dot here
142
160
(NoTangent (), NoTangent (), dx, dy)
143
161
end
144
162
end
145
163
146
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) = x, identity_pullback
147
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x:: Array ) = x, identity_pullback # ambiguity
148
- identity_pullback (Δ) = (NoTangent (), NoTangent (), Δ)
164
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (identity), x) = split_bc_rule (identity, x)
165
+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(identity), x::Array) = split_bc_rule(identity, x) # ambiguity
149
166
150
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) = x, identity_pullback
151
- (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: Array{Real} ) = x, identity_pullback
167
+ (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: AbstractArray{Real} ) = split_bc_rule (identity, x)
168
+ # (::∂⃖{1})(::typeof(broadcasted), ::typeof(conj), x::Array{Real}) = split_bc_rule(identity, x) # ambiguity
152
169
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x) =
153
170
broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
154
171
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (conj), x:: Array ) =
155
172
broadcasted (conj, x), Δ -> (NoTangent (), conj (unthunk (Δ)))
156
173
157
- # All broadcasts use `unbroadcast` to reduce to correct shape:
158
-
174
+ # Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape:
159
175
function unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx)
160
176
N = ndims (dx)
161
177
if length (x) == length (dx)
162
178
ProjectTo (x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors
163
179
else
164
- dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # awful hack to get type-stable `dims`
165
- ProjectTo (x)(sum (dx; dims))
180
+ dims = ntuple (d -> get (size (x), d, 1 ) == 1 ? d : N+ 1 , N) # hack to get type-stable `dims`
181
+ ProjectTo (x)(sum (dx; dims)) # ideally this sum might be thunked?
166
182
end
167
183
end
168
184
unbroadcast (x:: Base.AbstractArrayOrBroadcasted , dx:: AbstractZero ) = dx
169
185
170
186
unbroadcast (x:: T , dx) where {T<: Tuple{Any} } = ProjectTo (x)(Tangent {T} (sum (dx)))
171
187
function unbroadcast (x:: T , dx) where {T<: Tuple{Vararg{Any,N}} } where {N}
172
- _print (" unbroadcast tuple" )
173
188
val = if length (x) == length (dx)
174
189
dx
175
190
else
0 commit comments