33
33
34
34
using ChainRulesCore: derivatives_given_output
35
35
36
- _print (s) = nothing
37
- # _print(s) = printstyled(s, "\n"; color=:magenta)
38
-
39
36
# Broadcast over one element is just map
40
37
function (∂⃖ₙ: :∂⃖ {N})(:: typeof (broadcasted), f, a:: Array ) where {N}
41
- _print (" path 0, order $N " )
42
38
∂⃖ₙ (map, f, a)
43
39
end
44
40
@@ -49,13 +45,11 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
49
45
TΔ = Core. Compiler. _return_type (derivatives_given_output, Tuple{T, F, map (eltype, args)... })
50
46
if T === Bool
51
47
# Trivial case: non-differentiable output, e.g. `x .> 0`
52
- _print (" path 1" )
53
48
back_1 (_) = ntuple (Returns (ZeroTangent ()), length (args)+ 2 )
54
49
return f .(args... ), back_1
55
50
elseif T <: Number && isconcretetype (TΔ)
56
51
# Fast path: just broadcast, and use x & y to find derivative.
57
52
ys = f .(args... )
58
- _print (" path 2" )
59
53
function back_2_one (dys) # For f.(x) we do not need StructArrays / unzip at all
60
54
delta = broadcast (unthunk (dys), ys, args... ) do dy, y, a
61
55
das = only (derivatives_given_output (y, f, a))
@@ -76,7 +70,6 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
76
70
# Slow path: collect all the pullbacks & apply them later.
77
71
# (Since broadcast makes no guarantee about order of calls, and un-fusing
78
72
# can change the number of calls, this does not bother to try to reverse.)
79
- _print (" path 3" )
80
73
ys, backs = splitcast (∂⃖ {1} (), f, args... )
81
74
function back_3 (dys)
82
75
deltas = splitmap (backs, unthunk (dys)) do back, dy
@@ -97,74 +90,16 @@ function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.
97
90
end
98
91
99
92
# This uses "multimap"-like constructs:
100
-
101
93
using StructArrays
102
94
splitmap (f, args... ) = StructArrays. components (StructArray (Iterators. map (f, args... )))
103
95
splitcast (f, args... ) = StructArrays. components (StructArray (Broadcast. instantiate (Broadcast. broadcasted (f, args... ))))
104
96
105
- #=
106
- # This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
107
- # Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
108
-
109
- function Diffractor.splitmap(f, args...)
110
- if any(a -> a isa CuArray, args)
111
- Diffractor._print("unzip splitmap")
112
- unzip(map(f, args...))
113
- else
114
- StructArrays.components(StructArray(Iterators.map(f, args...)))
115
- end
116
- end
117
- function Diffractor.splitcast(f, args...)
118
- if any(a -> a isa CuArray, args)
119
- Diffractor._print("unzip splitcast")
120
- unzip(broadcast(f, args...))
121
- else
122
- StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
123
- end
124
- end
125
-
126
- gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
127
- gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
128
-
129
- =#
130
-
131
- function unzip (xs:: AbstractArray )
132
- x1 = first (xs)
133
- x1 isa Tuple || throw (ArgumentError (" unzip only accepts arrays of tuples" ))
134
- N = length (x1)
135
- unzip (xs, Val (N)) # like Zygote's unzip
136
- end
137
- @generated function unzip (xs, :: Val{N} ) where {N}
138
- each = [:(map ($ (Get (i)), xs)) for i in 1 : N]
139
- Expr (:tuple , each... )
140
- end
141
- unzip (xs:: AbstractArray{Tuple{T}} ) where {T} = (reinterpret (T, xs),) # best case, no copy
142
- @generated function unzip (xs:: AbstractArray{Ts} ) where {Ts<: Tuple }
143
- each = if count (! Base. issingletontype, Ts. parameters) < 2
144
- # good case, no copy of data, some trivial arrays
145
- [Base. issingletontype (T) ? :(similar (xs, $ T)) : :(reinterpret ($ T, xs)) for T in Ts. parameters]
146
- else
147
- [:(map ($ (Get (i)), xs)) for i in 1 : length (fieldnames (Ts))]
148
- end
149
- Expr (:tuple , each... )
150
- end
151
-
152
- struct Get{i} end
153
- Get (i) = Get {Int(i)} ()
154
- (:: Get{i} )(x) where {i} = x[i]
155
-
156
- function ChainRulesCore. rrule (:: typeof (unzip), xs:: AbstractArray )
157
- rezip (dy) = (NoTangent (), tuple .(unthunk (dy)... ))
158
- return unzip (xs), rezip
159
- end
160
-
161
97
# For certain cheap operations we can easily allow fused broadcast:
162
98
163
99
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), args... ) = lazy_bc_plus (args... )
164
100
(:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (+ ), arg:: Array ) = lazy_bc_plus (arg) # ambiguity
165
101
function lazy_bc_plus (xs... ) where {F}
166
102
broadcasted (+ , xs... ), Δraw -> let Δ = unthunk (Δraw)
167
- _print (" broadcast +" )
168
103
(NoTangent (), NoTangent (), map (x -> unbroadcast (x, Δ), xs)... )
169
104
end
170
105
end
173
108
174
109
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (- ), x, y)
175
110
broadcasted (- , x, y), Δraw -> let Δ = unthunk (Δraw)
176
- _print (" broadcast -" )
177
111
(NoTangent (), NoTangent (), unbroadcast (x, Δ), - unbroadcast (y, Δ))
178
112
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
179
113
end
@@ -184,7 +118,6 @@ const Numeric{T<:Number} = Union{T, AbstractArray{T}}
184
118
185
119
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (* ), x:: Numeric , y:: Numeric )
186
120
broadcasted (* , x, y), Δraw -> let Δ = unthunk (Δraw)
187
- _print (" broadcast *" )
188
121
dx = eltype (x)== Bool ? NoTangent () : x isa Number ? dot (y, Δ) : unbroadcast (x, Δ .* conj .(y))
189
122
dy = eltype (y)== Bool ? NoTangent () : y isa Number ? dot (x, Δ) : unbroadcast (y, Δ .* conj .(x))
190
123
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
@@ -193,19 +126,16 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeri
193
126
end
194
127
195
128
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x, :: Val{2} )
196
- _print (" broadcast ^2" )
197
129
broadcasted (* , x, x), Δ -> begin
198
130
dx = unbroadcast (x, 2 .* unthunk (Δ) .* conj .(x))
199
131
(NoTangent (), NoTangent (), NoTangent (), dx, NoTangent ())
200
132
end
201
133
end
202
134
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (Base. literal_pow), :: typeof (^ ), x:: Number , :: Val{2} )
203
- _print (" simple ^2" )
204
135
x^ 2 , Δ -> (NoTangent (), NoTangent (), NoTangent (), 2 * Δ * conj (x), NoTangent ())
205
136
end
206
137
207
138
function (:: ∂⃖{1 })(:: typeof (broadcasted), :: typeof (/ ), x:: Numeric , y:: Number )
208
- _print (" simple /" )
209
139
z, back = ∂⃖ {1} ()(/ , x, y)
210
140
z, dz -> begin
211
141
_, dx, dy = back (dz)
0 commit comments