Skip to content

Commit 9946ff3

Browse files
committed
rm printing etc.
1 parent 2eb6868 commit 9946ff3

File tree

1 file changed

+0
-70
lines changed

1 file changed

+0
-70
lines changed

src/stage1/broadcast.jl

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,8 @@ end
3333

3434
using ChainRulesCore: derivatives_given_output
3535

36-
_print(s) = nothing
37-
# _print(s) = printstyled(s, "\n"; color=:magenta)
38-
3936
# Broadcast over one element is just map
4037
function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N}
41-
_print("path 0, order $N")
4238
∂⃖ₙ(map, f, a)
4339
end
4440

@@ -49,13 +45,11 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
4945
= Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
5046
if T === Bool
5147
# Trivial case: non-differentiable output, e.g. `x .> 0`
52-
_print("path 1")
5348
back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2)
5449
return f.(args...), back_1
5550
elseif T <: Number && isconcretetype(TΔ)
5651
# Fast path: just broadcast, and use x & y to find derivative.
5752
ys = f.(args...)
58-
_print("path 2")
5953
function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all
6054
delta = broadcast(unthunk(dys), ys, args...) do dy, y, a
6155
das = only(derivatives_given_output(y, f, a))
@@ -76,7 +70,6 @@ function split_bc_rule(f::F, args::Vararg{Any,N}) where {F,N}
7670
# Slow path: collect all the pullbacks & apply them later.
7771
# (Since broadcast makes no guarantee about order of calls, and un-fusing
7872
# can change the number of calls, this does not bother to try to reverse.)
79-
_print("path 3")
8073
ys, backs = splitcast(∂⃖{1}(), f, args...)
8174
function back_3(dys)
8275
deltas = splitmap(backs, unthunk(dys)) do back, dy
@@ -97,74 +90,16 @@ function (::∂⃖{1})(::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.
9790
end
9891

9992
# This uses "multimap"-like constructs:
100-
10193
using StructArrays
10294
splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...)))
10395
splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
10496

105-
#=
106-
# This is how you could handle CuArrays, route them to unzip(map(...)) fallback path.
107-
# Maybe 2nd derivatives too, to avoid writing a gradient for splitcast, rule for unzip is easy.
108-
109-
function Diffractor.splitmap(f, args...)
110-
if any(a -> a isa CuArray, args)
111-
Diffractor._print("unzip splitmap")
112-
unzip(map(f, args...))
113-
else
114-
StructArrays.components(StructArray(Iterators.map(f, args...)))
115-
end
116-
end
117-
function Diffractor.splitcast(f, args...)
118-
if any(a -> a isa CuArray, args)
119-
Diffractor._print("unzip splitcast")
120-
unzip(broadcast(f, args...))
121-
else
122-
StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))))
123-
end
124-
end
125-
126-
gradient(x -> sum(log.(x) .+ x'), cu([1,2,3]))[1]
127-
gradient(x -> sum(sqrt.(atan.(x, x'))), cu([1,2,3]))[1]
128-
129-
=#
130-
131-
function unzip(xs::AbstractArray)
132-
x1 = first(xs)
133-
x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples"))
134-
N = length(x1)
135-
unzip(xs, Val(N)) # like Zygote's unzip
136-
end
137-
@generated function unzip(xs, ::Val{N}) where {N}
138-
each = [:(map($(Get(i)), xs)) for i in 1:N]
139-
Expr(:tuple, each...)
140-
end
141-
unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy
142-
@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple}
143-
each = if count(!Base.issingletontype, Ts.parameters) < 2
144-
# good case, no copy of data, some trivial arrays
145-
[Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters]
146-
else
147-
[:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))]
148-
end
149-
Expr(:tuple, each...)
150-
end
151-
152-
struct Get{i} end
153-
Get(i) = Get{Int(i)}()
154-
(::Get{i})(x) where {i} = x[i]
155-
156-
function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray)
157-
rezip(dy) = (NoTangent(), tuple.(unthunk(dy)...))
158-
return unzip(xs), rezip
159-
end
160-
16197
# For certain cheap operations we can easily allow fused broadcast:
16298

16399
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), args...) = lazy_bc_plus(args...)
164100
(::∂⃖{1})(::typeof(broadcasted), ::typeof(+), arg::Array) = lazy_bc_plus(arg) # ambiguity
165101
function lazy_bc_plus(xs...) where {F}
166102
broadcasted(+, xs...), Δraw -> let Δ = unthunk(Δraw)
167-
_print("broadcast +")
168103
(NoTangent(), NoTangent(), map(x -> unbroadcast(x, Δ), xs)...)
169104
end
170105
end
@@ -173,7 +108,6 @@ end
173108

174109
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(-), x, y)
175110
broadcasted(-, x, y), Δraw -> let Δ = unthunk(Δraw)
176-
_print("broadcast -")
177111
(NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ))
178112
# Ideally you could fuse the - into unbroadcast, mapreduce() not sum, when y is a smaller array
179113
end
@@ -184,7 +118,6 @@ const Numeric{T<:Number} = Union{T, AbstractArray{T}}
184118

185119
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric)
186120
broadcasted(*, x, y), Δraw -> let Δ = unthunk(Δraw)
187-
_print("broadcast *")
188121
dx = eltype(x)==Bool ? NoTangent() : x isa Number ? dot(y, Δ) : unbroadcast(x, Δ .* conj.(y))
189122
dy = eltype(y)==Bool ? NoTangent() : y isa Number ? dot(x, Δ) : unbroadcast(y, Δ .* conj.(x))
190123
# When x is an array but a smaller one, instead of dot you may be able to use mapreduce()
@@ -193,19 +126,16 @@ function (::∂⃖{1})(::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeri
193126
end
194127

195128
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x, ::Val{2})
196-
_print("broadcast ^2")
197129
broadcasted(*, x, x), Δ -> begin
198130
dx = unbroadcast(x, 2 .* unthunk(Δ) .* conj.(x))
199131
(NoTangent(), NoTangent(), NoTangent(), dx, NoTangent())
200132
end
201133
end
202134
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2})
203-
_print("simple ^2")
204135
x^2, Δ -> (NoTangent(), NoTangent(), NoTangent(), 2 * Δ * conj(x), NoTangent())
205136
end
206137

207138
function (::∂⃖{1})(::typeof(broadcasted), ::typeof(/), x::Numeric, y::Number)
208-
_print("simple /")
209139
z, back = ∂⃖{1}()(/, x, y)
210140
z, dz -> begin
211141
_, dx, dy = back(dz)

0 commit comments

Comments
 (0)