From aec3dd238fd39e994febb8f2d1c97abf529c536a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 7 Jul 2022 23:58:35 -0600 Subject: [PATCH 01/20] broadcasting, adapted from Diffractor PR68 --- Project.toml | 1 + src/ChainRules.jl | 5 + src/rulesets/Base/broadcast.jl | 248 +++++++++++++++++++++++++++++ src/rulesets/Base/fastmath_able.jl | 24 ++- src/tuplecast.jl | 107 +++++++++++++ test/rulesets/Base/broadcast.jl | 101 ++++++++++++ test/runtests.jl | 2 + test/tuplecast.jl | 47 ++++++ 8 files changed, 530 insertions(+), 5 deletions(-) create mode 100644 src/rulesets/Base/broadcast.jl create mode 100644 src/tuplecast.jl create mode 100644 test/rulesets/Base/broadcast.jl create mode 100644 test/tuplecast.jl diff --git a/Project.toml b/Project.toml index a04d6e52d..e79287e25 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] ChainRulesCore = "1.15.3" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index e323f7b6d..2ab8d4baa 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -11,6 +11,7 @@ using Random using RealDot: realdot using SparseArrays using Statistics +using StructArrays # Basically everything this package does is overloading these, so we make an exception # to the normal rule of only overload via `ChainRulesCore.rrule`. @@ -22,6 +23,9 @@ using ChainRulesCore: derivatives_given_output # numbers that we know commute under multiplication const CommutativeMulNumber = Union{Real,Complex} +# StructArrays +include("tuplecast.jl") + include("rulesets/Core/core.jl") include("rulesets/Base/utils.jl") @@ -34,6 +38,7 @@ include("rulesets/Base/arraymath.jl") include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") +include("rulesets/Base/broadcast.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl new file mode 100644 index 000000000..023d0a6e2 --- /dev/null +++ b/src/rulesets/Base/broadcast.jl @@ -0,0 +1,248 @@ +using Base.Broadcast: Broadcast, broadcasted, Broadcasted +const RCR = RuleConfig{>:HasReverseMode} + +rrule(::typeof(copy), bc::Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) + +# Skip AD'ing through the axis computation +function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted) + uninstantiate(Δ) = Core.tuple(NoTangent(), Δ) + return Broadcast.instantiate(bc), uninstantiate +end + +_print(args...) = nothing # println(join(args, " ")) + +##### +##### Split broadcasting +##### + +function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N} + # = split_bc_rule(cfg, f, args...) + # function split_bc_rule(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} + T = Broadcast.combine_eltypes(f, args) + TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) + if T === Bool + # 1: Trivial case: non-differentiable output, e.g. `x .> 0` + _print("split_bc_rule 1 ", f) + back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) + return f.(args...), back_1 + elseif T <: Number && isconcretetype(TΔ) + # 2: Fast path: just broadcast, and use arguments & result to find derivatives. + _print("split_bc_rule 2", f, N) + ys = f.(args...) + function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all + delta = broadcast(unthunk(dys), ys, args...) do dy, y, a + das = only(derivatives_given_output(y, f, a)) + dy * conj(only(das)) # possibly this * should be made nan-safe. + end + (NoTangent(), NoTangent(), ProjectTo(only(args))(delta)) + end + back_2_one(z::AbstractZero) = (NoTangent(), NoTangent(), z) + function back_2_many(dys) + deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... + das = only(derivatives_given_output(y, f, as...)) + map(da -> dy * conj(da), das) + end + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast? + (NoTangent(), NoTangent(), dargs...) + end + back_2_many(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + return ys, N==1 ? back_2_one : back_2_many + else + _print("split_bc_rule 3", f, N) + # 3: Slow path: collect all the pullbacks & apply them later. + # (Since broadcast makes no guarantee about order of calls, and un-fusing + # can change the number of calls, don't bother to try to reverse the iteration.) + ys3, backs = tuplecast(args...) do a... + rrule_via_ad(cfg, f, a...) + end + function back_3(dys) + deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match + map(unthunk, back(dy)) + end + dargs = map(unbroadcast, args, Base.tail(deltas)) + (NoTangent(), ProjectTo(f)(sum(first(deltas))), dargs...) + end + back_3(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + return ys3, back_3 + end +end + +# Don't run broadcasting on scalars +function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F} +# function split_bc_rule(cfg::RCR, f::F, args::Number...) where {F} + _print("split_bc_rule scalar", f) + z, back = rrule_via_ad(cfg, f, args...) + return z, dz -> (NoTangent(), back(dz)...) +end + +# using StructArrays +# +# function tuplecast(f::F, args...) where {F} +# T = Broadcast.combine_eltypes(f, args) +# if isconcretetype(T) +# T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple.")) +# end +# bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) +# StructArrays.components(StructArray(bc)) +# end + +##### +##### Fused broadcasting +##### + +# For certain cheap operations we can easily allow fused broadcast. +# These all have `RuleConfig{>:HasReverseMode}` as otherwise the split rule matches first & they are not used. +# They accept `Broadcasted` because they produce it; it has no eltype but is assumed to contain `Number`s. +const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) + _print("plus", length(xs)) + function bc_plus_back(dy_raw) + dy = unthunk(dy_raw) + (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) + end + return broadcasted(+, xs...), bc_plus_back +end + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) + _print("minus 2") + bc_minus_back(Δraw) = let Δ = unthunk(Δraw) + (NoTangent(), NoTangent(), @thunk(unbroadcast(x, Δ)), @thunk(-unbroadcast(y, Δ))) + end + return broadcasted(-, x, y), bc_minus_back +end + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) + _print("minus 1") + bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy)) + return broadcasted(-, x), bc_minus_back +end + +using LinearAlgebra: dot + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) + _print("times") + function bc_times_back(Δraw) + Δ = unthunk(Δraw) + (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) + end + return broadcasted(*, x, y), bc_times_back +end +_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) +_back_star(x::Number, y, Δ) = @thunk dot(y, Δ) +_back_star(x::Bool, y, Δ) = NoTangent() +_back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x) + +# TODO check what happens for A * B * C + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) + _print("square") + function bc_square_back(dy_raw) + dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x)) + (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) + end + return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back +end + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) + _print("divide") + z = broadcast(/, x, y) + function bc_divide_back(Δraw) + Δ = unthunk(Δraw) + dx = @thunk unbroadcast(x, Δ ./ conj.(y)) + dy = @thunk -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here + (NoTangent(), NoTangent(), dx, dy) + end + return z, bc_divide_back +end + +# For the same functions, send accidental broadcasting over numbers directly to `rrule`. +# Could perhaps move all to @scalar_rule? + +function _prepend_zero((y, back)) + extra_back(dy) = (NoTangent(), back(dy)...) + return y, extra_back +end + +rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = + rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero + +# A few more cheap functions + +rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::NumericOrBroadcast) + bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx))) + return broadcasted(conj, x), bc_conj_back +end +rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::Number) = rrule(conj, x) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + +# TODO real, imag + +##### +##### Shape fixing +##### + +# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape: + +function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) + N = ndims(dx) + if length(x) == length(dx) + ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors + else + dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims` + ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked? + end +end +unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx + +unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) +function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} + val = if length(x) == length(dx) + dx + else + sum(dx; dims=2:ndims(dx)) + end + ProjectTo(x)(NTuple{length(x)}(val)) # Tangent +end + +unbroadcast(f::Function, df) = sum(df) +unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) +unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx))) + +unbroadcast(::Bool, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() +unbroadcast(::AbstractArray{Bool}, dx::AbstractZero) = dx # ambiguity +unbroadcast(::Val, dx) = NoTangent() + +function unbroadcast(x, dx) + p = ProjectTo(x) + if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero} + return NoTangent() + end + b = Broadcast.broadcastable(x) + if b isa Ref # then x is scalar under broadcast + return p(sum(dx)) + else + error("don't know how to handle broadcast gradient for x::$(typeof(x))") + end +end + +##### +##### For testing +##### + +function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f, args...) + y, back = rrule(cfg, broadcasted, f, args...) + return _maybe_copy(y), back +end + +_maybe_copy(y) = copy(y) +_maybe_copy(y::Tuple) = y diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 5a9a9b08c..d3af247bd 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -167,6 +167,16 @@ let @scalar_rule x + y (true, true) @scalar_rule x - y (true, -1) @scalar_rule x / y (one(x) / y, -(Ω / y)) + + ## many-arg + + function frule((_, Δx, Δy...), ::typeof(+), x::Number, ys::Number...) + +(x, ys...), +(Δx, Δy...) + end + + function rrule(::typeof(+), x::Number, ys::Number...) + plus_back(dz) = (NoTangent(), dz, map(Returns(dz), ys)...) + +(x, ys...), plus_back + end ## power # literal_pow is in base.jl @@ -276,6 +286,10 @@ let return Ω4, times_pullback4 end rrule(::typeof(*), x::Number) = rrule(identity, x) + + # This is used to choose a faster path in some broadcasting operations: + ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number) = tuple((y', x')) + ChainRulesCore.derivatives_given_output(Ω, ::typeof(*), x::Number, y::Number, z::Number) = tuple((y'z', x'z', x'y')) end # fastable_ast # Rewrite everything to use fast_math functions, including the type-constraints @@ -288,12 +302,12 @@ let non_transformed_definitions = intersect(fastable_ast.args, fast_ast.args) filter!(expr->!(expr isa LineNumberNode), non_transformed_definitions) if !isempty(non_transformed_definitions) - error( - "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" * - join(non_transformed_definitions, "\n") - ) + # error( + # "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" * + # join(non_transformed_definitions, "\n") + # ) # This error() may not play well with Revise. But a wanring @error does: - # @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions + @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions end eval(fast_ast) diff --git a/src/tuplecast.jl b/src/tuplecast.jl new file mode 100644 index 000000000..2130285a2 --- /dev/null +++ b/src/tuplecast.jl @@ -0,0 +1,107 @@ + +""" + tuplecast(f, args...) + +For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`, +but performed using `StructArrays` for efficiency. +""" +function tuplecast(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""tuplecast(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + # if any(a -> a isa CuArray, args) + # return unzip(broadcast(f, args...)) + # end + bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) + StructArrays.components(StructArray(bc)) +end + +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplecast), f::F, args...) where {F} + y, back = rrule_via_ad(cfg, broadcasted, f, args...) + z = unzip(y) + function untuplecast(dz) + dy = StructArray(map(unthunk, dz)) + db, df, dargs... = back(dy) + (db, sum(df), map(unbroadcast, args, dargs)...) + end + return z, untuplecast +end + +# function rrule(cfg::RCR, ::typeof(collect∘tuplecast), f, args...) +# y, back = rrule(cfg, tuplecast, f, args...) +# return collect(y), back +# end + +""" + tuplemap(f, args...) + +For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, +but performed using `StructArrays` for efficiency. +""" +function tuplemap(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""tuplemap(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + # if any(a -> a isa CuArray, args) + # return unzip(map(f, args...)) + # end + StructArrays.components(StructArray(Iterators.map(f, args...))) +end + +# function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, args...) where {F} +# y, back = rrule(cfg, map, f, xs...) # won't work, but also, you want the lazier fwd +# z = unzip(y) +# function untuplemap(dz) +# dy = StructArray(map(unthunk, dz)) +# back(dy) +# end +# return unzip(xs), untuplemap +# end + +""" + unzip(A) + +Converts an array of tuples into a tuple of arrays. +Eager. Will work by `reinterpret` when possible. +""" +function unzip(xs::AbstractArray) + x1 = first(xs) + x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) + N = length(x1) + unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case. +end + +@generated function unzip(xs, ::Val{N}) where {N} + each = [:(map($(Get(i)), xs)) for i in 1:N] + Expr(:tuple, each...) +end + +unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy + +@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} + each = if count(!Base.issingletontype, Ts.parameters) < 2 + # good case, no copy of data, some trivial arrays + [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] + else + [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] + end + Expr(:tuple, each...) +end + +struct Get{i} end +Get(i) = Get{Int(i)}() +(::Get{i})(x) where {i} = x[i] + +function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray{T}) where {T <: Tuple} + function rezip(dy) + dxs = map(unthunk.(dy)...) do ys... + Tangent{T}(ys...) + end + (NoTangent(), dxs) + end + return unzip(xs), rezip +end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl new file mode 100644 index 000000000..a51319c8e --- /dev/null +++ b/test/rulesets/Base/broadcast.jl @@ -0,0 +1,101 @@ +using Base.Broadcast: broadcasted + +@testset "Broadcasting" begin + @testset "generic 1: trivial path" begin + # test_rrule(copy∘broadcasted, >, rand(3), rand(3)) # MethodError: no method matching eps(::UInt64) inside FiniteDifferences + y1, bk1 = rrule(CFG, copy∘broadcasted, >, rand(3), rand(3)) + @test y1 isa AbstractArray{Bool} + @test all(d -> d isa AbstractZero, bk1(99)) + + y2, bk2 = rrule(CFG, copy∘broadcasted, isinteger, Tuple(rand(3))) + @test y2 isa Tuple{Bool,Bool,Bool} + @test all(d -> d isa AbstractZero, bk2(99)) + end + + @testset "generic 2: fast path" begin + test_rrule(copy∘broadcasted, log, rand(3)) + test_rrule(copy∘broadcasted, log, Tuple(rand(3))) + + # Two args uses StructArrays + test_rrule(copy∘broadcasted, atan, rand(3), rand(3)) + test_rrule(copy∘broadcasted, atan, rand(3), rand(4)') + test_rrule(copy∘broadcasted, atan, rand(3), rand()) + test_rrule(copy∘broadcasted, atan, rand(3), Tuple(rand(1))) + test_rrule(copy∘broadcasted, atan, Tuple(rand(3)), Tuple(rand(3))) + + # Protected by Ref/Tuple: + test_rrule(copy∘broadcasted, *, rand(3), Ref(rand())) + test_rrule(copy∘broadcasted, *, rand(3), Ref(rand(2))) + end + + @testset "generic 3: slow path" begin + test_rrule(copy∘broadcasted, sin∘cos, rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, sin∘atan, rand(3), rand(3)', check_inferred=false) + test_rrule(copy∘broadcasted, sin∘atan, rand(), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, ^, rand(3), 3.0, check_inferred=false) + + # From test_helpers.jl + test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), rand(4)', check_inferred=false) + @test_skip test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), 5.0im, check_inferred=false) # ProjectTo(f) fails to correct this + test_rrule(copy∘broadcasted, make_two_vec, rand(3), check_inferred=false) + + # Non-diff components + test_rrule(copy∘broadcasted, first∘tuple, rand(3), :sym, rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, last∘tuple, rand(3), nothing, rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, |>, rand(3), sin, check_inferred=false) + _call(f, x...) = f(x...) + test_rrule(copy∘broadcasted, _call, atan, rand(3), rand(4)', check_inferred=false) + + # Protected by Ref/Tuple: + test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand() + im), check_inferred=false) + test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand(2) .+ im), check_inferred=false) + test_rrule(copy∘broadcasted, /, (rand(2),), rand(3), check_inferred=false) + end + + @testset "lazy rules" begin + test_rrule(copy∘broadcasted, +, rand(3), rand(3)) + test_rrule(copy∘broadcasted, +, rand(3), rand(4)') + test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) + test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) + test_rrule(copy∘broadcasted, +, rand(3), true) + test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) + + test_rrule(copy∘broadcasted, -, rand(3), rand(3)) + test_rrule(copy∘broadcasted, -, rand(3), rand(4)') + test_rrule(copy∘broadcasted, -, rand(3)) + # test_rrule(copy∘broadcasted, -, Tuple(rand(3))) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::Thunk{ChainRules.var"#1614#1616"{Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}, ::Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}) is ambiguous. + + test_rrule(copy∘broadcasted, *, rand(3), rand(3)) + test_rrule(copy∘broadcasted, *, rand(3), rand()) + test_rrule(copy∘broadcasted, *, rand(), rand(3)) + + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) + + test_rrule(copy∘broadcasted, /, rand(3), rand()) + + test_rrule(copy∘broadcasted, identity, rand(3)) + + test_rrule(copy∘broadcasted, conj, rand(3)) + test_rrule(copy∘broadcasted, conj, rand(3) .+ im) + end + + @testset "scalar rules" begin + test_rrule(copy∘broadcasted, sin, rand()) + test_rrule(copy∘broadcasted, atan, rand(), rand()) + # test_rrule(copy∘broadcasted, >, rand(), rand()) # DimensionMismatch from FiniteDifferences + + # Functions with lazy rules + test_rrule(copy∘broadcasted, +, rand(), rand(), rand()) + test_rrule(copy∘broadcasted, +, rand()) + test_rrule(copy∘broadcasted, -, rand(), rand()) + test_rrule(copy∘broadcasted, -, rand()) + test_rrule(copy∘broadcasted, *, rand(), rand()) + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(), Val(2)) + test_rrule(copy∘broadcasted, /, rand(), rand()) + + test_rrule(copy∘broadcasted, identity, rand()) + test_rrule(copy∘broadcasted, conj, rand()) + test_rrule(copy∘broadcasted, conj, rand() + im) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 24c1d85b9..ba8e84d94 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,7 @@ end @testset "ChainRules" begin # One overall @testset ensures it keeps going after failures include("test_helpers.jl") + include("tuplecast.jl") println() test_method_tables() # Check the global method tables are consistent @@ -57,6 +58,7 @@ end include_test("rulesets/Base/indexing.jl") include_test("rulesets/Base/mapreduce.jl") include_test("rulesets/Base/sort.jl") + include_test("rulesets/Base/broadcast.jl") println() diff --git a/test/tuplecast.jl b/test/tuplecast.jl new file mode 100644 index 000000000..6863d9723 --- /dev/null +++ b/test/tuplecast.jl @@ -0,0 +1,47 @@ + +using ChainRules: tuplecast, unzip # tuplemap, + +@testset "tuplecast" begin + @testset "basics: $(sprint(show, fun))" for fun in [tuplecast, unzip∘broadcast] # [tuplemap, tuplecast, unzip∘map, unzip∘broadcast] + @test_throws Exception fun(sqrt, 1:3) + + @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) + @test fun(tuple, [1, 10, 100]) == ([1, 10, 100],) + @test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3)) + @test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3)) + @test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3)) + + if contains(string(fun), "map") + @test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6]) + else + @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + end + end + + # tuplemap(tuple, (1,2,3), (4,5,6)) == ([1, 2, 3], [4, 5, 6]) + + @testset "unzip" begin + @test unzip([(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6]) + @test unzip([(nothing,2), (3,4), (5,6)]) == ([nothing, 3, 5], [2, 4, 6]) + @test unzip([(missing,2), (missing,4), (missing,6)])[2] isa Base.ReinterpretArray + + y, bk = rrule(unzip, [(1,2), (3,4), (5,6)]) + @test y == ([1, 3, 5], [2, 4, 6]) + @test bk(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] isa Vector{<:Tangent{<:Tuple}} + end + + @testset "rrules" begin + # These exist to allow for second derivatives + + # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], check_inferred=false) + y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0]) + @test y1 == ([1, 2, 3], [4, 5, 6]) + @test bk1(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + + y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0) + @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) + @test bk2(y2)[5] ≈ 36 + + test_rrule(unzip, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)], check_inferred=false) + end +end \ No newline at end of file From 74981cdddaa46ce3f70285a4cb82e439a957b4c7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 11 Jul 2022 23:22:47 -0400 Subject: [PATCH 02/20] many small upgrades --- src/rulesets/Base/base.jl | 2 + src/rulesets/Base/broadcast.jl | 137 ++++++++++++++++++++++---------- src/tuplecast.jl | 30 +++++++ test/rulesets/Base/base.jl | 1 + test/rulesets/Base/broadcast.jl | 122 +++++++++++++++++++--------- test/tuplecast.jl | 2 +- 6 files changed, 212 insertions(+), 82 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6bea6e06c..c10ba6e71 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -72,6 +72,8 @@ function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex} return (T(x, y), Complex_pullback) end +@scalar_rule complex(x) true + # `hypot` @scalar_rule hypot(x::Real) sign(x) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 023d0a6e2..24ca1f373 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -1,15 +1,18 @@ using Base.Broadcast: Broadcast, broadcasted, Broadcasted const RCR = RuleConfig{>:HasReverseMode} -rrule(::typeof(copy), bc::Broadcasted) = copy(bc), Δ -> (NoTangent(), Δ) +function rrule(::typeof(copy), bc::Broadcasted) + uncopy(Δ) = (NoTangent(), Δ) + return copy(bc), uncopy +end # Skip AD'ing through the axis computation function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted) - uninstantiate(Δ) = Core.tuple(NoTangent(), Δ) + uninstantiate(Δ) = (NoTangent(), Δ) return Broadcast.instantiate(bc), uninstantiate end -_print(args...) = nothing # println(join(args, " ")) +_print(args...) = nothing # println(join(args, " ")) # ##### ##### Split broadcasting @@ -69,45 +72,37 @@ end # Don't run broadcasting on scalars function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F} -# function split_bc_rule(cfg::RCR, f::F, args::Number...) where {F} - _print("split_bc_rule scalar", f) + _print("split_bc_scalar", f) z, back = rrule_via_ad(cfg, f, args...) return z, dz -> (NoTangent(), back(dz)...) end -# using StructArrays -# -# function tuplecast(f::F, args...) where {F} -# T = Broadcast.combine_eltypes(f, args) -# if isconcretetype(T) -# T <: Tuple || throw(ArgumentError("tuplecast(f, args) only works on functions returning a tuple.")) -# end -# bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) -# StructArrays.components(StructArray(bc)) -# end - ##### ##### Fused broadcasting ##### -# For certain cheap operations we can easily allow fused broadcast. -# These all have `RuleConfig{>:HasReverseMode}` as otherwise the split rule matches first & they are not used. -# They accept `Broadcasted` because they produce it; it has no eltype but is assumed to contain `Number`s. +# For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice. +# These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above. +# Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s. + const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} +##### Arithmetic: +, -, *, ^2, / + function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) _print("plus", length(xs)) function bc_plus_back(dy_raw) dy = unthunk(dy_raw) - (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) + return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3 end return broadcasted(+, xs...), bc_plus_back end function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) _print("minus 2") - bc_minus_back(Δraw) = let Δ = unthunk(Δraw) - (NoTangent(), NoTangent(), @thunk(unbroadcast(x, Δ)), @thunk(-unbroadcast(y, Δ))) + function bc_minus_back(dz_raw) + dz = unthunk(dz_raw) + return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz))) end return broadcasted(-, x, y), bc_minus_back end @@ -118,46 +113,59 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) return broadcasted(-, x), bc_minus_back end -using LinearAlgebra: dot - function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) _print("times") function bc_times_back(Δraw) Δ = unthunk(Δraw) - (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) + return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) end return broadcasted(*, x, y), bc_times_back end -_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) -_back_star(x::Number, y, Δ) = @thunk dot(y, Δ) +_back_star(x, y, Δ) = @thunk unbroadcast(x, Δ .* conj.(y)) # this case probably isn't better than generic +_back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this is why the rule exists _back_star(x::Bool, y, Δ) = NoTangent() _back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x) -# TODO check what happens for A * B * C +#= +# This works, but not sure it improves any benchmarks. +function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...) + _print("times", 2 + length(zs)) + xy, back1 = rrule(cfg, broadcasted, *, x, y) + xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...) + function bc_times3_back(dxyz) + _, _, dxy, dzs... = back2(dxyz) + _, _, dx, dy = back1(dxy) + return (NoTangent(), NoTangent(), dx, dy, dzs...) + end + xyz, bc_times3_back +end +=# function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) _print("square") function bc_square_back(dy_raw) dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x)) - (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) + return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) end return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back end function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) _print("divide") - z = broadcast(/, x, y) - function bc_divide_back(Δraw) - Δ = unthunk(Δraw) - dx = @thunk unbroadcast(x, Δ ./ conj.(y)) - dy = @thunk -dot(z, Δ) / (conj(y)) # the reason to be eager is to allow dot here + # z = broadcast(/, x, y) + z = broadcasted(/, x, y) + function bc_divide_back(dz_raw) + dz = unthunk(dz_raw) + dx = @thunk unbroadcast(x, dz ./ conj.(y)) + # dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here + dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast? (NoTangent(), NoTangent(), dx, dy) end return z, bc_divide_back end # For the same functions, send accidental broadcasting over numbers directly to `rrule`. -# Could perhaps move all to @scalar_rule? +# (Could perhaps move all to @scalar_rule?) function _prepend_zero((y, back)) extra_back(dy) = (NoTangent(), back(dy)...) @@ -172,25 +180,66 @@ rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x:: rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero -# A few more cheap functions +##### Identity, number types rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity -function rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::NumericOrBroadcast) - bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx))) - return broadcasted(conj, x), bc_conj_back +function rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number} + _print("bc type", T) + bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(T, x), bc_type_back +end +rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast) + _print("bc float") + bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(float, x), bc_float_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::Number) = rrule(conj, x) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero -# TODO real, imag +##### Complex: conj, real, imag + +for conj in [:conj, :adjoint] # identical as we know eltype <: Number + @eval begin + function rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast) + bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx))) + return broadcasted($conj, x), bc_conj_back + end + rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero + rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + # This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule. + # Could upgrade to infer eltype of the `Broadcasted`? + end +end + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast) + _print("real") + bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz)))) + return broadcasted(real, x), bc_real_back +end +rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + +function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast) + _print("imag") + bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz)))) + return broadcasted(imag, x), bc_imag_back +end +rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero +function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real}) + _print("imag(real)") + bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent()) + return broadcasted(imag, x), bc_imag_back_2 +end ##### ##### Shape fixing ##### -# Reverse mode broadcasting uses `unbroadcast` to reduce to correct shape: +# When sizes disagree, broadcasting gradient uses `unbroadcast` to reduce to correct shape. +# It's sometimes a little wasteful to allocate a too-large `dx`, but difficult to make more efficient. function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) N = ndims(dx) @@ -198,7 +247,7 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors else dims = ntuple(d -> get(size(x), d, 1) == 1 ? d : N+1, N) # hack to get type-stable `dims` - ProjectTo(x)(sum(dx; dims)) # ideally this sum might be thunked? + ProjectTo(x)(sum(dx; dims)) end end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx diff --git a/src/tuplecast.jl b/src/tuplecast.jl index 2130285a2..e7ade2b7e 100644 --- a/src/tuplecast.jl +++ b/src/tuplecast.jl @@ -4,6 +4,21 @@ For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`, but performed using `StructArrays` for efficiency. + +# Examples +``` +julia> using ChainRules: tuplecast, unzip + +julia> tuplecast(x -> (x,2x), 1:3) +([1, 2, 3], [2, 4, 6]) + +julia> mats = @btime tuplecast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB + min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB) + +julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples + min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB) +true +``` """ function tuplecast(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) @@ -67,6 +82,21 @@ end Converts an array of tuples into a tuple of arrays. Eager. Will work by `reinterpret` when possible. + +```jldoctest +julia> ChainRules.unzip([(1,2), (3,4), (5,6)]) # makes two new Arrays: +([1, 3, 5], [2, 4, 6]) + +julia> typeof(ans) +Tuple{Vector{Int64}, Vector{Int64}} + +julia> ChainRules.unzip([(1,nothing) (3,nothing) (5,nothing)]) # this can reinterpret: +([1 3 5], [nothing nothing nothing]) + +julia> ans[1] +1×3 reinterpret(Int64, ::Matrix{Tuple{Int64, Nothing}}): + 1 3 5 +``` """ function unzip(xs::AbstractArray) x1 = first(xs) diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a65881747..36452da1e 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -77,6 +77,7 @@ for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im) test_scalar(real, x) test_scalar(imag, x) + test_scalar(complex, x) test_scalar(hypot, x) test_scalar(adjoint, x) end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index a51319c8e..485e5bf4d 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -48,54 +48,102 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, _call, atan, rand(3), rand(4)', check_inferred=false) # Protected by Ref/Tuple: + test_rrule(copy∘broadcasted, *, rand(3), Ref(rand(2)), check_inferred=false) test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand() + im), check_inferred=false) test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand(2) .+ im), check_inferred=false) test_rrule(copy∘broadcasted, /, (rand(2),), rand(3), check_inferred=false) end @testset "lazy rules" begin - test_rrule(copy∘broadcasted, +, rand(3), rand(3)) - test_rrule(copy∘broadcasted, +, rand(3), rand(4)') - test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) - test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) - test_rrule(copy∘broadcasted, +, rand(3), true) - test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) + @testset "arithmetic" begin + test_rrule(copy∘broadcasted, +, rand(3), rand(3)) + test_rrule(copy∘broadcasted, +, rand(3), rand(4)') + test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) + test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) + test_rrule(copy∘broadcasted, +, rand(3), true) + test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) - test_rrule(copy∘broadcasted, -, rand(3), rand(3)) - test_rrule(copy∘broadcasted, -, rand(3), rand(4)') - test_rrule(copy∘broadcasted, -, rand(3)) - # test_rrule(copy∘broadcasted, -, Tuple(rand(3))) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::Thunk{ChainRules.var"#1614#1616"{Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}, ::Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}) is ambiguous. + test_rrule(copy∘broadcasted, -, rand(3), rand(3)) + test_rrule(copy∘broadcasted, -, rand(3), rand(4)') + test_rrule(copy∘broadcasted, -, rand(3)) + # test_rrule(copy∘broadcasted, -, Tuple(rand(3))) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::Thunk{ChainRules.var"#1614#1616"{Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}, ::Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}) is ambiguous. - test_rrule(copy∘broadcasted, *, rand(3), rand(3)) - test_rrule(copy∘broadcasted, *, rand(3), rand()) - test_rrule(copy∘broadcasted, *, rand(), rand(3)) - - test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) - - test_rrule(copy∘broadcasted, /, rand(3), rand()) + test_rrule(copy∘broadcasted, *, rand(3), rand(3)) + test_rrule(copy∘broadcasted, *, rand(3), rand()) + test_rrule(copy∘broadcasted, *, rand(), rand(3)) + + test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand(3) .+ 2im) + test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand() + 3im) + test_rrule(copy∘broadcasted, *, rand() + im, rand(3) .+ 4im) + + # test_rrule(copy∘broadcasted, *, im, rand(3)) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) + # test_rrule(copy∘broadcasted, *, rand(3), im) + y4, bk4 = rrule(CFG, copy∘broadcasted, *, im, [1,2,3.0]) + @test y4 == [im, 2im, 3im] + @test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7] + + test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) + test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') - test_rrule(copy∘broadcasted, identity, rand(3)) + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) - test_rrule(copy∘broadcasted, conj, rand(3)) - test_rrule(copy∘broadcasted, conj, rand(3) .+ im) - end + test_rrule(copy∘broadcasted, /, rand(3), rand()) + test_rrule(copy∘broadcasted, /, rand(3) .+ im, rand() + 3im) + end + @testset "identity etc" begin + test_rrule(copy∘broadcasted, identity, rand(3)) + + test_rrule(copy∘broadcasted, Float32, rand(3), rtol=1e-4) + test_rrule(copy∘broadcasted, ComplexF32, rand(3), rtol=1e-4) + + test_rrule(copy∘broadcasted, float, rand(3)) + end + @testset "complex" begin + test_rrule(copy∘broadcasted, conj, rand(3)) + test_rrule(copy∘broadcasted, conj, rand(3) .+ im) + test_rrule(copy∘broadcasted, adjoint, rand(3)) + test_rrule(copy∘broadcasted, adjoint, rand(3) .+ im) - @testset "scalar rules" begin - test_rrule(copy∘broadcasted, sin, rand()) - test_rrule(copy∘broadcasted, atan, rand(), rand()) - # test_rrule(copy∘broadcasted, >, rand(), rand()) # DimensionMismatch from FiniteDifferences + test_rrule(copy∘broadcasted, real, rand(3)) + test_rrule(copy∘broadcasted, real, rand(3) .+ im) - # Functions with lazy rules - test_rrule(copy∘broadcasted, +, rand(), rand(), rand()) - test_rrule(copy∘broadcasted, +, rand()) - test_rrule(copy∘broadcasted, -, rand(), rand()) - test_rrule(copy∘broadcasted, -, rand()) - test_rrule(copy∘broadcasted, *, rand(), rand()) - test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(), Val(2)) - test_rrule(copy∘broadcasted, /, rand(), rand()) - - test_rrule(copy∘broadcasted, identity, rand()) - test_rrule(copy∘broadcasted, conj, rand()) - test_rrule(copy∘broadcasted, conj, rand() + im) + test_rrule(copy∘broadcasted, imag, rand(3)) + test_rrule(copy∘broadcasted, imag, rand(3) .+ im .* rand.()) + + test_rrule(copy∘broadcasted, complex, rand(3)) + end + end + + @testset "scalar rules" begin + @testset "generic" begin + test_rrule(copy∘broadcasted, sin, rand()) + test_rrule(copy∘broadcasted, atan, rand(), rand()) + # test_rrule(copy∘broadcasted, >, rand(), rand()) # DimensionMismatch from FiniteDifferences + end + # Functions with lazy broadcasting rules: + @testset "arithmetic" begin + test_rrule(copy∘broadcasted, +, rand(), rand(), rand()) + test_rrule(copy∘broadcasted, +, rand()) + test_rrule(copy∘broadcasted, -, rand(), rand()) + test_rrule(copy∘broadcasted, -, rand()) + test_rrule(copy∘broadcasted, *, rand(), rand()) + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(), Val(2)) + test_rrule(copy∘broadcasted, /, rand(), rand()) + end + @testset "identity etc" begin + test_rrule(copy∘broadcasted, identity, rand()) + test_rrule(copy∘broadcasted, Float32, rand(), rtol=1e-4) + test_rrule(copy∘broadcasted, float, rand()) + end + @testset "complex" begin + test_rrule(copy∘broadcasted, conj, rand()) + test_rrule(copy∘broadcasted, conj, rand() + im) + test_rrule(copy∘broadcasted, real, rand()) + test_rrule(copy∘broadcasted, real, rand() + im) + test_rrule(copy∘broadcasted, imag, rand()) + test_rrule(copy∘broadcasted, imag, rand() + im) + test_rrule(copy∘broadcasted, complex, rand()) + end end end \ No newline at end of file diff --git a/test/tuplecast.jl b/test/tuplecast.jl index 6863d9723..c1796db97 100644 --- a/test/tuplecast.jl +++ b/test/tuplecast.jl @@ -1,7 +1,7 @@ using ChainRules: tuplecast, unzip # tuplemap, -@testset "tuplecast" begin +@testset "tuplecast.jl" begin @testset "basics: $(sprint(show, fun))" for fun in [tuplecast, unzip∘broadcast] # [tuplemap, tuplecast, unzip∘map, unzip∘broadcast] @test_throws Exception fun(sqrt, 1:3) From 74a6a8b1769efd45601415ced65e48aefae822b3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 08:56:21 -0400 Subject: [PATCH 03/20] fixup tuplecast --- src/tuplecast.jl | 85 ++++++++++++++++++++++++++++++++++------------- test/tuplecast.jl | 69 +++++++++++++++++++++++++++++--------- 2 files changed, 114 insertions(+), 40 deletions(-) diff --git a/src/tuplecast.jl b/src/tuplecast.jl index e7ade2b7e..7da9134cb 100644 --- a/src/tuplecast.jl +++ b/src/tuplecast.jl @@ -26,28 +26,36 @@ function tuplecast(f::F, args...) where {F} T <: Tuple || throw(ArgumentError("""tuplecast(f, args) only works on functions returning a tuple, but f = $(sprint(show, f)) returns type T = $T""")) end + # TODO allow GPU arrays, possibly just as a fallback unzip, but see also: + # https://github.com/JuliaArrays/StructArrays.jl/issues/150 # if any(a -> a isa CuArray, args) # return unzip(broadcast(f, args...)) # end bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) - StructArrays.components(StructArray(bc)) + if Broadcast.BroadcastStyle(typeof(bc)) isa Broadcast.AbstractArrayStyle + return StructArrays.components(StructArray(bc)) + else + return unzip(broadcast(f, args...)) # e.g. tuples + end end function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplecast), f::F, args...) where {F} - y, back = rrule_via_ad(cfg, broadcasted, f, args...) + y, back = rrule_via_ad(cfg, broadcast, f, args...) z = unzip(y) function untuplecast(dz) - dy = StructArray(map(unthunk, dz)) + # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) + dy = broadcast(tuple, map(unthunk, dz)...) db, df, dargs... = back(dy) - (db, sum(df), map(unbroadcast, args, dargs)...) + return (db, sum(df), map(unbroadcast, args, dargs)...) end + untuplecast(dz::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(dz), args)) return z, untuplecast end -# function rrule(cfg::RCR, ::typeof(collect∘tuplecast), f, args...) -# y, back = rrule(cfg, tuplecast, f, args...) -# return collect(y), back -# end +function rrule(cfg::RCR, ::typeof(collect∘tuplecast), f, args...) # for testing, but doesn't work? + y, back = rrule(cfg, tuplecast, f, args...) + return collect(y), back +end """ tuplemap(f, args...) @@ -64,18 +72,19 @@ function tuplemap(f::F, args...) where {F} # if any(a -> a isa CuArray, args) # return unzip(map(f, args...)) # end - StructArrays.components(StructArray(Iterators.map(f, args...))) + return StructArrays.components(StructArray(Iterators.map(f, args...))) end -# function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, args...) where {F} -# y, back = rrule(cfg, map, f, xs...) # won't work, but also, you want the lazier fwd -# z = unzip(y) -# function untuplemap(dz) -# dy = StructArray(map(unthunk, dz)) -# back(dy) -# end -# return unzip(xs), untuplemap -# end +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, xs...) where {F} + y, back = rrule_via_ad(cfg, map, f, xs...) + z = unzip(y) + function untuplemap(dz) + # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) + dy = broadcast(tuple, map(unthunk, dz)...) + return back(dy) + end + return z, untuplemap +end """ unzip(A) @@ -84,8 +93,8 @@ Converts an array of tuples into a tuple of arrays. Eager. Will work by `reinterpret` when possible. ```jldoctest -julia> ChainRules.unzip([(1,2), (3,4), (5,6)]) # makes two new Arrays: -([1, 3, 5], [2, 4, 6]) +julia> ChainRules.unzip([(1,2), (30,40), (500,600)]) # makes two new Arrays: +([1, 30, 500], [2, 40, 600]) julia> typeof(ans) Tuple{Vector{Int64}, Vector{Int64}} @@ -102,7 +111,7 @@ function unzip(xs::AbstractArray) x1 = first(xs) x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) N = length(x1) - unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case. + return unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case. end @generated function unzip(xs, ::Val{N}) where {N} @@ -122,16 +131,44 @@ unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best cas Expr(:tuple, each...) end +""" + unzip(t) + +Also works on a tuple of tuples: + +```jldoctest +julia> unzip(((1,2), (30,40), (500,600))) +((1, 30, 500), (2, 40, 600)) +``` +""" +function unzip(xs::Tuple) + x1 = first(xs) + x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays or tuples of tuples")) + return ntuple(i -> map(Get(i), xs),length(x1)) +end + struct Get{i} end Get(i) = Get{Int(i)}() (::Get{i})(x) where {i} = x[i] function ChainRulesCore.rrule(::typeof(unzip), xs::AbstractArray{T}) where {T <: Tuple} function rezip(dy) - dxs = map(unthunk.(dy)...) do ys... - Tangent{T}(ys...) + dxs = broadcast(xs, unthunk.(dy)...) do x, ys... + ProjectTo(x)(Tangent{T}(ys...)) end - (NoTangent(), dxs) + return (NoTangent(), dxs) end + rezip(dz::AbstractZero) = (NoTangent(), dz) return unzip(xs), rezip end + +function ChainRulesCore.rrule(::typeof(unzip), xs::Tuple) + function rezip_2(dy) + dxs = broadcast(xs, unthunk.(dy)...) do x, ys... + Tangent{typeof(x)}(ys...) + end + return (NoTangent(), ProjectTo(xs)(dxs)) + end + rezip_2(dz::AbstractZero) = (NoTangent(), dz) + return unzip(xs), rezip_2 +end diff --git a/test/tuplecast.jl b/test/tuplecast.jl index c1796db97..458a51fa6 100644 --- a/test/tuplecast.jl +++ b/test/tuplecast.jl @@ -1,8 +1,8 @@ -using ChainRules: tuplecast, unzip # tuplemap, +using ChainRules: tuplecast, unzip, tuplemap @testset "tuplecast.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [tuplecast, unzip∘broadcast] # [tuplemap, tuplecast, unzip∘map, unzip∘broadcast] + @testset "basics: $(sprint(show, fun))" for fun in [tuplemap, tuplecast, unzip∘map, unzip∘broadcast] @test_throws Exception fun(sqrt, 1:3) @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) @@ -16,32 +16,69 @@ using ChainRules: tuplecast, unzip # tuplemap, else @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) end + + if fun == tuplemap + @test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) + elseif fun == unzip∘map + @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) + else + @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) + @test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7)) + @test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8)) + end + @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector end + + @testset "rrules" begin + # These exist to allow for second derivatives - # tuplemap(tuple, (1,2,3), (4,5,6)) == ([1, 2, 3], [4, 5, 6]) + # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any} + + y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0]) + @test y1 == ([1, 2, 3], [4, 5, 6]) + @test bk1(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + + # bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences + + y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0) + @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) + @test bk2(y2)[5] ≈ 36 + y4, bk4 = rrule(CFG, tuplemap, tuple, [1,2,3.0], [4,5,6.0]) + @test y4 == ([1, 2, 3], [4, 5, 6]) + @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + end + @testset "unzip" begin @test unzip([(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6]) + @test unzip(Any[(1,2), (3,4), (5,6)]) == ([1, 3, 5], [2, 4, 6]) + @test unzip([(nothing,2), (3,4), (5,6)]) == ([nothing, 3, 5], [2, 4, 6]) @test unzip([(missing,2), (missing,4), (missing,6)])[2] isa Base.ReinterpretArray + @test unzip([(1,), (3,), (5,)]) == ([1, 3, 5],) + @test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray + + @test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6)) + + # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2 + y, bk = rrule(unzip, [(1,2), (3,4), (5,6)]) @test y == ([1, 3, 5], [2, 4, 6]) @test bk(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] isa Vector{<:Tangent{<:Tuple}} - end - - @testset "rrules" begin - # These exist to allow for second derivatives - # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], check_inferred=false) - y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0]) - @test y1 == ([1, 2, 3], [4, 5, 6]) - @test bk1(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + y3, bk3 = rrule(unzip, [(1,ZeroTangent()), (3,ZeroTangent()), (5,ZeroTangent())]) + @test y3 == ([1, 3, 5], [ZeroTangent(), ZeroTangent(), ZeroTangent()]) + dx3 = bk3(Tangent{Tuple}([1,1,1], [10,100,1000]))[2] + @test dx3 isa Vector{<:Tangent{<:Tuple}} + @test Tuple(dx3[1]) == (1.0, NoTangent()) - y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0) - @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) - @test bk2(y2)[5] ≈ 36 - - test_rrule(unzip, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)], check_inferred=false) + y5, bk5 = rrule(unzip, ((1,2), (3,4), (5,6))) + @test y5 == ((1, 3, 5), (2, 4, 6)) + @test bk5(y5)[2] isa Tangent{<:Tuple} + @test Tuple(bk5(y5)[2][2]) == (3, 4) + dx5 = bk5(((1,10,100), ZeroTangent())) + @test dx5[2] isa Tangent{<:Tuple} + @test Tuple(dx5[2][2]) == (10, ZeroTangent()) end end \ No newline at end of file From 3319600efa657288fd56751a4a8f2a2b0642e79d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 08:57:27 -0400 Subject: [PATCH 04/20] re-organise split bc, add forward mode --- src/rulesets/Base/broadcast.jl | 162 +++++++++++++++++++++++--------- test/rulesets/Base/broadcast.jl | 18 ++-- test/test_helpers.jl | 14 +++ 3 files changed, 143 insertions(+), 51 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 24ca1f373..31b822d5f 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -18,56 +18,121 @@ _print(args...) = nothing # println(join(args, " ")) # ##### Split broadcasting ##### +# For `z = g.(f.(xs))`, this finds `y = f.(x)` eagerly because the rules for either `f` or `g` may need it, +# and we don't know whether re-computing `y` is cheap. +# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.) + function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N} - # = split_bc_rule(cfg, f, args...) - # function split_bc_rule(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) - TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) - if T === Bool + if T === Bool # TODO use nondifftype here # 1: Trivial case: non-differentiable output, e.g. `x .> 0` - _print("split_bc_rule 1 ", f) - back_1(_) = ntuple(Returns(ZeroTangent()), length(args)+2) - return f.(args...), back_1 - elseif T <: Number && isconcretetype(TΔ) - # 2: Fast path: just broadcast, and use arguments & result to find derivatives. - _print("split_bc_rule 2", f, N) - ys = f.(args...) - function back_2_one(dys) # For f.(x) we do not need StructArrays / unzip at all - delta = broadcast(unthunk(dys), ys, args...) do dy, y, a - das = only(derivatives_given_output(y, f, a)) - dy * conj(only(das)) # possibly this * should be made nan-safe. - end - (NoTangent(), NoTangent(), ProjectTo(only(args))(delta)) + _print("split_bc_trivial", f) + bc_trivial_back(_) = (NoTangent(), NoTangent(), ntuple(Returns(ZeroTangent()), length(args))...) + return f.(args...), bc_trivial_back + elseif T <: Number && may_bc_derivatives(T, f, args...) + # 2: Fast path: use arguments & result to find derivatives. + return split_bc_derivatives(f, args...) + elseif T <: Number && may_bc_forwards(cfg, f, args...) + # 3: Future path: use `frule_via_ad`? + return split_bc_forwards(cfg, f, args...) + else + # 4: Slow path: collect all the pullbacks & apply them later. + return split_bc_pullbacks(cfg, f, args...) + end +end + +# Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast. + +function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N} + TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) + return isconcretetype(TΔ) +end + +_eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcasted(+, [1,2,3], 4.5)) == Any`: +_eltype(bc::Broadcast.Broadcasted) = Broadcast.combine_eltypes(bc.f, bc.args) + +function split_bc_derivatives(f::F, arg) where {F} + _print("split_bc_derivative", f) + ys = f.(arg) + function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all + delta = broadcast(unthunk(dys), ys, arg) do dy, y, a + das = only(derivatives_given_output(y, f, a)) + dy * conj(only(das)) # possibly this * should be made nan-safe. end - back_2_one(z::AbstractZero) = (NoTangent(), NoTangent(), z) - function back_2_many(dys) - deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... - das = only(derivatives_given_output(y, f, as...)) - map(da -> dy * conj(da), das) - end - dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast? - (NoTangent(), NoTangent(), dargs...) + return (NoTangent(), NoTangent(), ProjectTo(arg)(delta)) + end + bc_one_back(z::AbstractZero) = (NoTangent(), NoTangent(), z) + return ys, bc_one_back +end +function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} + _print("split_bc_derivatives", f, N) + ys = f.(args...) + function bc_many_back(dys) + deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... + das = only(derivatives_given_output(y, f, as...)) + map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. end - back_2_many(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) - return ys, N==1 ? back_2_one : back_2_many - else - _print("split_bc_rule 3", f, N) - # 3: Slow path: collect all the pullbacks & apply them later. - # (Since broadcast makes no guarantee about order of calls, and un-fusing - # can change the number of calls, don't bother to try to reverse the iteration.) - ys3, backs = tuplecast(args...) do a... - rrule_via_ad(cfg, f, a...) + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast? + return (NoTangent(), NoTangent(), dargs...) + end + bc_many_back(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + return ys, bc_many_back +end + +# Path 3: Use forward mode, or an `frule` if one exists. +# To allow `args...` we need either chunked forward mode, with `adot::Tuple` perhaps: +# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/92 +# https://github.com/JuliaDiff/Diffractor.jl/pull/54 +# Or else we need to call the `f` multiple times, and maybe that's OK: +# We do know that `f` doesn't have parameters, so maybe it's pure enough, +# and split broadcasting may anyway change N^2 executions into N, e.g. `g.(v ./ f.(v'))`. +# We don't know `f` is cheap, but `split_bc_pullbacks` tends to be very slow. + +function may_bc_forwards(cfg::C, f::F, args::Vararg{Any,N}) where {C,F,N} + Base.issingletontype(F) || return false + N==1 || return false # Could weaken this to 1 differentiable + cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad + TA = map(_eltype, args) + TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA...}, F, TA...}) + return isconcretetype(TF) && TF <: Tuple +end + +split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = split_bc_inner(frule_via_ad, cfg, f, arg) +split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg) +function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} + _print("split_bc_forwards", frule_fun, f) + ys, ydots = tuplecast(arg) do a + frule_fun(cfg, (NoTangent(), one(a)), f, a) + end + function back_forwards(dys) + delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a + ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe. end - function back_3(dys) - deltas = tuplecast(backs, unthunk(dys)) do back, dy # could be map, sizes match - map(unthunk, back(dy)) - end - dargs = map(unbroadcast, args, Base.tail(deltas)) - (NoTangent(), ProjectTo(f)(sum(first(deltas))), dargs...) + return (NoTangent(), NoTangent(), ProjectTo(arg)(delta)) + end + back_forwards(z::AbstractZero) = (NoTangent(), NoTangent(), z) + return ys, back_forwards +end + +# Path 4: The most generic, save all the pullbacks. Can be 1000x slower. +# Since broadcast makes no guarantee about order of calls, and un-fusing +# can change the number of calls, don't bother to try to reverse the iteration. + +function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} + _print("split_bc_generic", f, N) + ys3, backs = tuplecast(args...) do a... + rrule_via_ad(cfg, f, a...) + end + function back_generic(dys) + deltas = tuplecast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + map(unthunk, back(dy)) end - back_3(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) - return ys3, back_3 + dargs = map(unbroadcast, args, Base.tail(deltas)) + df = ProjectTo(f)(sum(first(deltas))) + return (NoTangent(), df, dargs...) end + back_generic(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + return ys3, back_generic end # Don't run broadcasting on scalars @@ -158,8 +223,8 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, dz = unthunk(dz_raw) dx = @thunk unbroadcast(x, dz ./ conj.(y)) # dy = @thunk -LinearAlgebra.dot(z, dz) / conj(y) # the reason to be eager is to allow dot here - dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast? - (NoTangent(), NoTangent(), dx, dy) + dy = @thunk -sum(Broadcast.instantiate(broadcasted(*, broadcasted(conj, z), dz))) / conj(y) # complete sum is fast + return (NoTangent(), NoTangent(), dx, dy) end return z, bc_divide_back end @@ -234,6 +299,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<: return broadcasted(imag, x), bc_imag_back_2 end +function rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast) + _print("bc complex") + bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) + return broadcasted(complex, x), bc_complex_back +end +rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero + ##### ##### Shape fixing ##### @@ -259,7 +331,7 @@ function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} else sum(dx; dims=2:ndims(dx)) end - ProjectTo(x)(NTuple{length(x)}(val)) # Tangent + return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent end unbroadcast(f::Function, df) = sum(df) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 485e5bf4d..52c630cf8 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -1,7 +1,7 @@ using Base.Broadcast: broadcasted @testset "Broadcasting" begin - @testset "generic 1: trivial path" begin + @testset "split 1: trivial path" begin # test_rrule(copy∘broadcasted, >, rand(3), rand(3)) # MethodError: no method matching eps(::UInt64) inside FiniteDifferences y1, bk1 = rrule(CFG, copy∘broadcasted, >, rand(3), rand(3)) @test y1 isa AbstractArray{Bool} @@ -12,7 +12,7 @@ using Base.Broadcast: broadcasted @test all(d -> d isa AbstractZero, bk2(99)) end - @testset "generic 2: fast path" begin + @testset "split 2: derivatives" begin test_rrule(copy∘broadcasted, log, rand(3)) test_rrule(copy∘broadcasted, log, Tuple(rand(3))) @@ -23,16 +23,22 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, atan, rand(3), Tuple(rand(1))) test_rrule(copy∘broadcasted, atan, Tuple(rand(3)), Tuple(rand(3))) - # Protected by Ref/Tuple: test_rrule(copy∘broadcasted, *, rand(3), Ref(rand())) - test_rrule(copy∘broadcasted, *, rand(3), Ref(rand(2))) + end + + @testset "split 3: forwards" begin + test_rrule(copy∘broadcasted, flog, rand(3)) + test_rrule(copy∘broadcasted, flog, rand(3) .+ im) + # Also, `sin∘cos` may use this path as CFG uses frule_via_ad end - @testset "generic 3: slow path" begin + @testset "split 4: generic" begin test_rrule(copy∘broadcasted, sin∘cos, rand(3), check_inferred=false) test_rrule(copy∘broadcasted, sin∘atan, rand(3), rand(3)', check_inferred=false) test_rrule(copy∘broadcasted, sin∘atan, rand(), rand(3), check_inferred=false) - test_rrule(copy∘broadcasted, ^, rand(3), 3.0, check_inferred=false) + test_rrule(copy∘broadcasted, ^, rand(3), 3.0, check_inferred=false) # NoTangent vs. Union{NoTangent, ZeroTangent} + # Many have quite small inference failures, like: + # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Float64} does not match inferred return type Tuple{NoTangent, Union{NoTangent, ZeroTangent}, Vector{Float64}, Float64} # From test_helpers.jl test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), check_inferred=false) diff --git a/test/test_helpers.jl b/test/test_helpers.jl index b347e789b..e58b3f240 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -176,6 +176,14 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x) return make_two_vec(x), make_two_vec_pullback end +"A version of `*` with only an `frule` defined" +fstar(A, B) = A * B +ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(fstar), A, B) = A * B, muladd(ΔA, B, A * ΔB) + +"A version of `log` with only an `frule` defined" +flog(x:::Number) = log(x) +ChainRulesCore.frule((_, xdot), ::typeof(flog), x::Number) = log(x), inv(x) * xdot + @testset "test_helpers.jl" begin @testset "Multiplier" begin @@ -204,5 +212,11 @@ end @testset "make_two_vec" begin test_rrule(make_two_vec, 1.5) end + + @testset "fstar, flog" begin + test_frule(fstar, 1.2, 3.4 + 5im) + test_frule(flog, 6.7) + test_frule(flog, 8.9 + im) + end end From 070d4b77271aed2d41eff3bb9475c09fb58cda3f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 09:53:41 -0400 Subject: [PATCH 05/20] fix tests --- src/rulesets/Base/broadcast.jl | 2 +- src/tuplecast.jl | 7 +++++-- test/rulesets/Base/broadcast.jl | 8 +++++--- test/rulesets/Base/mapreduce.jl | 2 -- test/runtests.jl | 5 +++-- test/test_helpers.jl | 5 +++-- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 31b822d5f..7dc4a3a77 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -240,7 +240,7 @@ end rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero diff --git a/src/tuplecast.jl b/src/tuplecast.jl index 7da9134cb..9108b669b 100644 --- a/src/tuplecast.jl +++ b/src/tuplecast.jl @@ -3,7 +3,7 @@ tuplecast(f, args...) For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`, -but performed using `StructArrays` for efficiency. +but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting. # Examples ``` @@ -52,7 +52,8 @@ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplec return z, untuplecast end -function rrule(cfg::RCR, ::typeof(collect∘tuplecast), f, args...) # for testing, but doesn't work? +# This is for testing, but the tests using it don't work. +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘tuplecast), f, args...) y, back = rrule(cfg, tuplecast, f, args...) return collect(y), back end @@ -62,6 +63,8 @@ end For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, but performed using `StructArrays` for efficiency. + +Not in use at present, but see `tuplecast`. """ function tuplemap(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 52c630cf8..d9e13fe7e 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -60,7 +60,7 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, /, (rand(2),), rand(3), check_inferred=false) end - @testset "lazy rules" begin + @testset "fused rules" begin @testset "arithmetic" begin test_rrule(copy∘broadcasted, +, rand(3), rand(3)) test_rrule(copy∘broadcasted, +, rand(3), rand(4)') @@ -88,8 +88,9 @@ using Base.Broadcast: broadcasted @test y4 == [im, 2im, 3im] @test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7] - test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) - test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') + test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3), check_inferred=false) # Union{NoTangent, ZeroTangent} + test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)', check_inferred=false) # Union{NoTangent, ZeroTangent} + # (These two may infer with vararg rrule) test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) @@ -134,6 +135,7 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, -, rand(), rand()) test_rrule(copy∘broadcasted, -, rand()) test_rrule(copy∘broadcasted, *, rand(), rand()) + test_rrule(copy∘broadcasted, *, rand(), rand(), rand(), rand()) test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(), Val(2)) test_rrule(copy∘broadcasted, /, rand(), rand()) end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 4fed0988f..c4b985df3 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -2,8 +2,6 @@ Base.sum(xs::AbstractArray, weights::AbstractArray) = dot(xs, weights) struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end -const CFG = ChainRulesTestUtils.ADviaRuleConfig() - @testset "Reductions" begin @testset "sum(::Tuple)" begin test_frule(sum, Tuple(rand(5))) diff --git a/test/runtests.jl b/test/runtests.jl index ba8e84d94..4f4845e7b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,8 +43,7 @@ else end @testset "ChainRules" begin # One overall @testset ensures it keeps going after failures - include("test_helpers.jl") - include("tuplecast.jl") + include("test_helpers.jl") # This can't be skipped println() test_method_tables() # Check the global method tables are consistent @@ -60,6 +59,8 @@ end include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") + include_test("tuplecast.jl") # used primarily for broadcast + println() include_test("rulesets/Statistics/statistics.jl") diff --git a/test/test_helpers.jl b/test/test_helpers.jl index e58b3f240..e06759054 100644 --- a/test/test_helpers.jl +++ b/test/test_helpers.jl @@ -98,6 +98,7 @@ function _gpu_test(::typeof(frule), f::Function, g::Function, xs...; kw...) # s _gpu_test(frule, xdots, f, g, xs...; kw...) end +const CFG = ChainRulesTestUtils.TestConfig() """ Multiplier(x) @@ -181,8 +182,8 @@ fstar(A, B) = A * B ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(fstar), A, B) = A * B, muladd(ΔA, B, A * ΔB) "A version of `log` with only an `frule` defined" -flog(x:::Number) = log(x) -ChainRulesCore.frule((_, xdot), ::typeof(flog), x::Number) = log(x), inv(x) * xdot +flog(x::Number) = log(x) +ChainRulesCore.frule((_, Δx), ::typeof(flog), x::Number) = log(x), inv(x) * Δx @testset "test_helpers.jl" begin From 62d51454da8bab67e2af98bf638797be81f50700 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 11:02:06 -0400 Subject: [PATCH 06/20] add Yota to downstream tests --- .github/workflows/IntegrationTest.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index aabc25d3a..4e6b6461e 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -16,6 +16,7 @@ jobs: os: [ubuntu-latest] package: # - {user: dpsanders, repo: ReversePropagation.jl} + - {user: dfdx, repo: Yota.jl} - {user: FluxML, repo: Zygote.jl} # Diffractor needs to run on Julia nightly # include: From 818147a85746703b2b848759254b9b092afd7ffd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 17:48:01 -0400 Subject: [PATCH 07/20] fix an ambiguity --- src/rulesets/Base/broadcast.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 7dc4a3a77..ee330f4e6 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -241,6 +241,7 @@ rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, arg rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero +rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero From 3d4d9b2e2666230ffdb11e3ef0fd49f7695a17ca Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 12 Jul 2022 21:49:18 -0400 Subject: [PATCH 08/20] fix tests on 1.6 --- test/rulesets/Base/broadcast.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index d9e13fe7e..2fc589ea7 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -13,15 +13,15 @@ using Base.Broadcast: broadcasted end @testset "split 2: derivatives" begin - test_rrule(copy∘broadcasted, log, rand(3)) - test_rrule(copy∘broadcasted, log, Tuple(rand(3))) + test_rrule(copy∘broadcasted, log, rand(3) .+ 1) + test_rrule(copy∘broadcasted, log, Tuple(rand(3) .+ 1)) # Two args uses StructArrays test_rrule(copy∘broadcasted, atan, rand(3), rand(3)) test_rrule(copy∘broadcasted, atan, rand(3), rand(4)') test_rrule(copy∘broadcasted, atan, rand(3), rand()) test_rrule(copy∘broadcasted, atan, rand(3), Tuple(rand(1))) - test_rrule(copy∘broadcasted, atan, Tuple(rand(3)), Tuple(rand(3))) + test_rrule(copy∘broadcasted, atan, Tuple(rand(3)), Tuple(rand(3)), check_inferred = VERSION > v"1.7") test_rrule(copy∘broadcasted, *, rand(3), Ref(rand())) end From 5c35f9c3df46d443d33a212177b0c0e0e33ee422 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 14 Jul 2022 09:13:23 -0400 Subject: [PATCH 09/20] testing --- src/rulesets/Base/broadcast.jl | 2 +- test/rulesets/Base/broadcast.jl | 4 ++++ test/rulesets/Base/mapreduce.jl | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index ee330f4e6..b5d04d850 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -12,7 +12,7 @@ function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted) return Broadcast.instantiate(bc), uninstantiate end -_print(args...) = nothing # println(join(args, " ")) # +_print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # nothing # ##### ##### Split broadcasting diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 2fc589ea7..b2308819d 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -52,6 +52,10 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, |>, rand(3), sin, check_inferred=false) _call(f, x...) = f(x...) test_rrule(copy∘broadcasted, _call, atan, rand(3), rand(4)', check_inferred=false) + + test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], [3,1], check_inferred=false) + # test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], (3,1), check_inferred=false) + # test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], Ref(CartesianIndex(2)), check_inferred=false) # Protected by Ref/Tuple: test_rrule(copy∘broadcasted, *, rand(3), Ref(rand(2)), check_inferred=false) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index c4b985df3..89f41c933 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -83,6 +83,7 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end # inference fails for array of arrays test_rrule(sum, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false) + test_rrule(sum, norm, collect.(eachcol(rand(3,4))); check_inferred=false) # dims kwarg test_rrule(sum, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1)) From 54bfee2d4eb6acd20c0e3a14c72967e484f1b093 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 5 Aug 2022 18:46:38 -0600 Subject: [PATCH 10/20] improve unbroadcast --- src/rulesets/Base/broadcast.jl | 42 ++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index b5d04d850..1c0df2daa 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -325,32 +325,60 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx) end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx -unbroadcast(x::T, dx) where {T<:Tuple{Any}} = ProjectTo(x)(Tangent{T}(sum(dx))) function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} val = if length(x) == length(dx) dx else sum(dx; dims=2:ndims(dx)) end + eltype(val) <: AbstractZero && return NoTangent() return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent end +unbroadcast(x::Tuple, dx::AbstractZero) = dx + +# Scalar types -unbroadcast(f::Function, df) = sum(df) unbroadcast(x::Number, dx) = ProjectTo(x)(sum(dx)) -unbroadcast(x::Base.RefValue, dx) = ProjectTo(x)(Ref(sum(dx))) + +function unbroadcast(x::T, dx) where {T<:Tuple{Any}} + p1 = ProjectTo(only(x)) + p1 isa ProjectTo{<:AbstractZero} && return NoTangent() + dx1 = p1(sum(dx)) + dx1 isa AbstractZero && return dx1 + return Tangent{T}(dx1) +end +unbroadcast(x::Tuple{Any}, dx::AbstractZero) = dx + +function unbroadcast(x::Base.RefValue, dx) + p1 = ProjectTo(x.x) + p1 isa ProjectTo{<:AbstractZero} && return NoTangent() + dx1 = p1(sum(dx)) + dx1 isa AbstractZero && return dx1 + return Tangent{typeof(x)}(; x = dx1) +end +unbroadcast(x::Base.RefValue, dx::AbstractZero) = dx + +# Zero types unbroadcast(::Bool, dx) = NoTangent() unbroadcast(::AbstractArray{Bool}, dx) = NoTangent() unbroadcast(::AbstractArray{Bool}, dx::AbstractZero) = dx # ambiguity unbroadcast(::Val, dx) = NoTangent() +function unbroadcast(f::Function, df) + Base.issingletontype(typeof(f)) && return NoTangent() + return sum(df) +end + +# Fallback + function unbroadcast(x, dx) + @info "last unbroadcast method!" x dx + dx isa AbstractZero && return dx p = ProjectTo(x) - if dx isa AbstractZero || p isa ProjectTo{<:AbstractZero} + if p isa ProjectTo{<:AbstractZero} return NoTangent() - end - b = Broadcast.broadcastable(x) - if b isa Ref # then x is scalar under broadcast + elseif Broadcast.broadcastable(x) isa Ref # then x is scalar under broadcast return p(sum(dx)) else error("don't know how to handle broadcast gradient for x::$(typeof(x))") From 888ddc1f85de023caa4e0cadef4766861f1d1666 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 5 Aug 2022 21:36:12 -0600 Subject: [PATCH 11/20] change generic rule to use BroadcastStyle --- src/rulesets/Base/broadcast.jl | 115 +++++++++++++++++--------------- test/rulesets/Base/broadcast.jl | 97 +++++++++++++++------------ 2 files changed, 117 insertions(+), 95 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 1c0df2daa..daf73e975 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -1,5 +1,6 @@ -using Base.Broadcast: Broadcast, broadcasted, Broadcasted +using Base.Broadcast: Broadcast, broadcasted, Broadcasted, BroadcastStyle const RCR = RuleConfig{>:HasReverseMode} +const TRI_NO = (NoTangent(), NoTangent(), NoTangent()) function rrule(::typeof(copy), bc::Broadcasted) uncopy(Δ) = (NoTangent(), Δ) @@ -22,12 +23,16 @@ _print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # n # and we don't know whether re-computing `y` is cheap. # (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.) -function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N} +# This rule has `::BroadcastStyle` in part becuase Zygote's generic rule does, to avoid ambiguities. +# It applies one step later in AD, and all args have `broadcastable(x)` thus many have `Ref(x)`, complicating some tests. +# But it also means that the lazy rules below do not need `::RuleConfig{>:HasReverseMode}` just for dispatch. + +function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N} T = Broadcast.combine_eltypes(f, args) if T === Bool # TODO use nondifftype here # 1: Trivial case: non-differentiable output, e.g. `x .> 0` _print("split_bc_trivial", f) - bc_trivial_back(_) = (NoTangent(), NoTangent(), ntuple(Returns(ZeroTangent()), length(args))...) + bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...) return f.(args...), bc_trivial_back elseif T <: Number && may_bc_derivatives(T, f, args...) # 2: Fast path: use arguments & result to find derivatives. @@ -59,9 +64,9 @@ function split_bc_derivatives(f::F, arg) where {F} das = only(derivatives_given_output(y, f, a)) dy * conj(only(das)) # possibly this * should be made nan-safe. end - return (NoTangent(), NoTangent(), ProjectTo(arg)(delta)) + return (TRI_NO..., ProjectTo(arg)(delta)) end - bc_one_back(z::AbstractZero) = (NoTangent(), NoTangent(), z) + bc_one_back(z::AbstractZero) = (TRI_NO..., z) return ys, bc_one_back end function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} @@ -73,9 +78,9 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. end dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast? - return (NoTangent(), NoTangent(), dargs...) + return (TRI_NO..., dargs...) end - bc_many_back(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys, bc_many_back end @@ -108,9 +113,9 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} delta = broadcast(ydots, unthunk(dys), arg) do ydot, dy, a ProjectTo(a)(conj(ydot) * dy) # possibly this * should be made nan-safe. end - return (NoTangent(), NoTangent(), ProjectTo(arg)(delta)) + return (TRI_NO..., ProjectTo(arg)(delta)) end - back_forwards(z::AbstractZero) = (NoTangent(), NoTangent(), z) + back_forwards(z::AbstractZero) = (TRI_NO..., z) return ys, back_forwards end @@ -129,17 +134,17 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} end dargs = map(unbroadcast, args, Base.tail(deltas)) df = ProjectTo(f)(sum(first(deltas))) - return (NoTangent(), df, dargs...) + return (NoTangent(), NoTangent(), df, dargs...) end - back_generic(z::AbstractZero) = (NoTangent(), NoTangent(), map(Returns(z), args)...) + back_generic(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys3, back_generic end # Don't run broadcasting on scalars -function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Number...) where {F} +function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Number...) where {F} _print("split_bc_scalar", f) z, back = rrule_via_ad(cfg, f, args...) - return z, dz -> (NoTangent(), back(dz)...) + return z, dz -> (NoTangent(), NoTangent(), back(dz)...) end ##### @@ -147,14 +152,13 @@ end ##### # For certain cheap operations we can easily allow fused broadcast; the forward pass may be run twice. -# These all have `RuleConfig{>:HasReverseMode}` only for dispatch, to beat the split rule above. # Accept `x::Broadcasted` because they produce it; can't dispatch on eltype but `x` is assumed to contain `Number`s. const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} ##### Arithmetic: +, -, *, ^2, / -function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) +function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) _print("plus", length(xs)) function bc_plus_back(dy_raw) dy = unthunk(dy_raw) @@ -163,7 +167,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast return broadcasted(+, xs...), bc_plus_back end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) _print("minus 2") function bc_minus_back(dz_raw) dz = unthunk(dz_raw) @@ -172,13 +176,13 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, return broadcasted(-, x, y), bc_minus_back end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) _print("minus 1") bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy)) return broadcasted(-, x), bc_minus_back end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) _print("times") function bc_times_back(Δraw) Δ = unthunk(Δraw) @@ -191,12 +195,11 @@ _back_star(x::Number, y, Δ) = @thunk LinearAlgebra.dot(y, Δ) # ... but this i _back_star(x::Bool, y, Δ) = NoTangent() _back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x) -#= -# This works, but not sure it improves any benchmarks. -function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...) +# This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities. +function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...) _print("times", 2 + length(zs)) - xy, back1 = rrule(cfg, broadcasted, *, x, y) - xyz, back2 = rrule(cfg, broadcasted, *, xy, zs...) + xy, back1 = rrule(broadcasted, *, x, y) + xyz, back2 = rrule(broadcasted, *, xy, zs...) function bc_times3_back(dxyz) _, _, dxy, dzs... = back2(dxyz) _, _, dx, dy = back1(dxy) @@ -204,9 +207,8 @@ function rrule(cfg::RCR, ::typeof(broadcasted), ::typeof(*), x::NumericOrBroadca end xyz, bc_times3_back end -=# -function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) +function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) _print("square") function bc_square_back(dy_raw) dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x)) @@ -215,7 +217,7 @@ function rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeo return broadcasted(Base.literal_pow, ^, x, Val(2)), bc_square_back end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) +function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) _print("divide") # z = broadcast(/, x, y) z = broadcasted(/, x, y) @@ -237,75 +239,76 @@ function _prepend_zero((y, back)) return y, extra_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity -rrule(::RCR, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = +rrule(::typeof(broadcasted), ::typeof(+), args::Number...) = rrule(+, args...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(-), x::Number, y::Number) = rrule(-, x, y) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(-), x::Number) = rrule(-, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(*), args::Number...) = rrule(*, args...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number) = rrule(*, x, y) |> _prepend_zero # ambiguity +rrule(::typeof(broadcasted), ::typeof(*), x::Number, y::Number, zs::Number...) = rrule(*, x, y, zs...) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{2}) = rrule(Base.literal_pow, ^, x, Val(2)) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(/), x::Number, y::Number) = rrule(/, x, y) |> _prepend_zero ##### Identity, number types -rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity +rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule(identity, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity -function rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number} +function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number} _print("bc type", T) bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(T, x), bc_type_back end -rrule(::RCR, ::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero -function rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast) _print("bc float") bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(float, x), bc_float_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(float), x::Number) = rrule(float, x) |> _prepend_zero ##### Complex: conj, real, imag for conj in [:conj, :adjoint] # identical as we know eltype <: Number @eval begin - function rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast) + function rrule(::typeof(broadcasted), ::typeof($conj), x::NumericOrBroadcast) bc_conj_back(dx) = (NoTangent(), NoTangent(), conj(unthunk(dx))) return broadcasted($conj, x), bc_conj_back end - rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero - rrule(::RCR, ::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero + rrule(::typeof(broadcasted), ::typeof($conj), x::Number) = rrule($conj, x) |> _prepend_zero + rrule(::typeof(broadcasted), ::typeof($conj), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero # This `AbstractArray{<:Real}` rule won't catch `conj.(x.+1)` with lazy `.+` rule. # Could upgrade to infer eltype of the `Broadcasted`? end end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast) _print("real") bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz)))) return broadcasted(real, x), bc_real_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero -rrule(::RCR, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero -function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast) _print("imag") bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz)))) return broadcasted(imag, x), bc_imag_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero -function rrule(::RCR, ::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real}) +rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero +function rrule(::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real}) _print("imag(real)") bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent()) return broadcasted(imag, x), bc_imag_back_2 end -function rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast) +function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast) _print("bc complex") bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(complex, x), bc_complex_back end -rrule(::RCR, ::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero +rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) |> _prepend_zero ##### ##### Shape fixing @@ -389,8 +392,16 @@ end ##### For testing ##### -function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f, args...) - y, back = rrule(cfg, broadcasted, f, args...) +function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f_args...) + tmp = rrule(cfg, broadcasted, f_args...) + isnothing(tmp) && throw("rrule gave nothing") + y, back = tmp + return _maybe_copy(y), back +end +function rrule(::typeof(copy∘broadcasted), f_args...) + tmp = rrule(broadcasted, f_args...) + isnothing(tmp) && throw("rrule gave nothing") + y, back = tmp return _maybe_copy(y), back end diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index b2308819d..4d865fa28 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -1,67 +1,78 @@ using Base.Broadcast: broadcasted +BS0 = Broadcast.BroadcastStyle(Float64) +BS1 = Broadcast.BroadcastStyle(Vector) +BS2 = Broadcast.BroadcastStyle(Matrix) + +BT1 = Broadcast.BroadcastStyle(Tuple) + @testset "Broadcasting" begin @testset "split 1: trivial path" begin # test_rrule(copy∘broadcasted, >, rand(3), rand(3)) # MethodError: no method matching eps(::UInt64) inside FiniteDifferences - y1, bk1 = rrule(CFG, copy∘broadcasted, >, rand(3), rand(3)) + y1, bk1 = rrule(CFG, copy∘broadcasted, BS1, >, rand(3), rand(3)) @test y1 isa AbstractArray{Bool} @test all(d -> d isa AbstractZero, bk1(99)) - y2, bk2 = rrule(CFG, copy∘broadcasted, isinteger, Tuple(rand(3))) + y2, bk2 = rrule(CFG, copy∘broadcasted, BT1, isinteger, Tuple(rand(3))) @test y2 isa Tuple{Bool,Bool,Bool} @test all(d -> d isa AbstractZero, bk2(99)) end @testset "split 2: derivatives" begin - test_rrule(copy∘broadcasted, log, rand(3) .+ 1) - test_rrule(copy∘broadcasted, log, Tuple(rand(3) .+ 1)) + test_rrule(copy∘broadcasted, BS1, log, rand(3) .+ 1) + test_rrule(copy∘broadcasted, BT1, log, Tuple(rand(3) .+ 1)) # Two args uses StructArrays - test_rrule(copy∘broadcasted, atan, rand(3), rand(3)) - test_rrule(copy∘broadcasted, atan, rand(3), rand(4)') - test_rrule(copy∘broadcasted, atan, rand(3), rand()) - test_rrule(copy∘broadcasted, atan, rand(3), Tuple(rand(1))) - test_rrule(copy∘broadcasted, atan, Tuple(rand(3)), Tuple(rand(3)), check_inferred = VERSION > v"1.7") + test_rrule(copy∘broadcasted, BS1, atan, rand(3), rand(3)) + test_rrule(copy∘broadcasted, BS2, atan, rand(3), rand(4)') + test_rrule(copy∘broadcasted, BS1, atan, rand(3), rand()) + test_rrule(copy∘broadcasted, BT1, atan, rand(3), Tuple(rand(1))) + test_rrule(copy∘broadcasted, BT1, atan, Tuple(rand(3)), Tuple(rand(3)), check_inferred = VERSION > v"1.7") - test_rrule(copy∘broadcasted, *, rand(3), Ref(rand())) + # test_rrule(copy∘broadcasted, *, BS1, rand(3), Ref(rand())) # don't know what I was testing end @testset "split 3: forwards" begin - test_rrule(copy∘broadcasted, flog, rand(3)) - test_rrule(copy∘broadcasted, flog, rand(3) .+ im) + # In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else. + test_rrule(copy∘broadcasted, BS1, flog, rand(3)) + test_rrule(copy∘broadcasted, BS1, flog, rand(3) .+ im) # Also, `sin∘cos` may use this path as CFG uses frule_via_ad + # TODO use different CFGs, https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/255 end @testset "split 4: generic" begin - test_rrule(copy∘broadcasted, sin∘cos, rand(3), check_inferred=false) - test_rrule(copy∘broadcasted, sin∘atan, rand(3), rand(3)', check_inferred=false) - test_rrule(copy∘broadcasted, sin∘atan, rand(), rand(3), check_inferred=false) - test_rrule(copy∘broadcasted, ^, rand(3), 3.0, check_inferred=false) # NoTangent vs. Union{NoTangent, ZeroTangent} + test_rrule(copy∘broadcasted, BS1, sin∘cos, rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS2, sin∘atan, rand(3), rand(3)', check_inferred=false) + test_rrule(copy∘broadcasted, BS1, sin∘atan, rand(), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, ^, rand(3), 3.0, check_inferred=false) # NoTangent vs. Union{NoTangent, ZeroTangent} # Many have quite small inference failures, like: - # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Float64} does not match inferred return type Tuple{NoTangent, Union{NoTangent, ZeroTangent}, Vector{Float64}, Float64} + # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Float64} does not match inferred + # return type Tuple{NoTangent, Union{NoTangent, ZeroTangent}, Vector{Float64}, Float64} # From test_helpers.jl - test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), check_inferred=false) - test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), rand(4)', check_inferred=false) - @test_skip test_rrule(copy∘broadcasted, Multiplier(rand()), rand(3), 5.0im, check_inferred=false) # ProjectTo(f) fails to correct this - test_rrule(copy∘broadcasted, make_two_vec, rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, Multiplier(rand()), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS2, Multiplier(rand()), rand(3), rand(4)', check_inferred=false) # Union{ZeroTangent, Tangent{Multiplier{... + @test_skip test_rrule(copy∘broadcasted, BS1, Multiplier(rand()), rand(3), 5.0im, check_inferred=false) # ProjectTo(f) fails to remove the imaginary part of Multiplier's gradient + test_rrule(copy∘broadcasted, BS1, make_two_vec, rand(3), check_inferred=false) - # Non-diff components - test_rrule(copy∘broadcasted, first∘tuple, rand(3), :sym, rand(4)', check_inferred=false) - test_rrule(copy∘broadcasted, last∘tuple, rand(3), nothing, rand(4)', check_inferred=false) - test_rrule(copy∘broadcasted, |>, rand(3), sin, check_inferred=false) + # Non-diff components -- note that with BroadcastStyle, Ref is from e.g. Broadcast.broadcastable(nothing) + test_rrule(copy∘broadcasted, BS2, first∘tuple, rand(3), Ref(:sym), rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, BS2, last∘tuple, rand(3), Ref(nothing), rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, BS1, |>, rand(3), Ref(sin), check_inferred=false) _call(f, x...) = f(x...) - test_rrule(copy∘broadcasted, _call, atan, rand(3), rand(4)', check_inferred=false) + test_rrule(copy∘broadcasted, BS2, _call, Ref(atan), rand(3), rand(4)', check_inferred=false) - test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], [3,1], check_inferred=false) - # test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], (3,1), check_inferred=false) - # test_rrule(copy∘broadcasted, getindex, [rand(3) for _ in 1:2], Ref(CartesianIndex(2)), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], [3,1], check_inferred=false) + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], (3,1), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, getindex, [rand(3) for _ in 1:2], Ref(CartesianIndex(2)), check_inferred=false) + test_rrule(copy∘broadcasted, BT1, getindex, Tuple([rand(3) for _ in 1:2]), (3,1), check_inferred=false) + test_rrule(copy∘broadcasted, BT1, getindex, Tuple([Tuple(rand(3)) for _ in 1:2]), (3,1), check_inferred=false) # Protected by Ref/Tuple: - test_rrule(copy∘broadcasted, *, rand(3), Ref(rand(2)), check_inferred=false) - test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand() + im), check_inferred=false) - test_rrule(copy∘broadcasted, conj∘*, rand(3), Ref(rand(2) .+ im), check_inferred=false) - test_rrule(copy∘broadcasted, /, (rand(2),), rand(3), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, *, rand(3), Ref(rand(2)), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, conj∘*, rand(3), Ref(rand() + im), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, conj∘*, rand(3), Ref(rand(2) .+ im), check_inferred=false) + test_rrule(copy∘broadcasted, BS1, /, (rand(2),), rand(3), check_inferred=false) end @testset "fused rules" begin @@ -76,7 +87,7 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, -, rand(3), rand(3)) test_rrule(copy∘broadcasted, -, rand(3), rand(4)') test_rrule(copy∘broadcasted, -, rand(3)) - # test_rrule(copy∘broadcasted, -, Tuple(rand(3))) # MethodError: (::ChainRulesTestUtils.var"#test_approx##kw")(::NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}, ::typeof(test_approx), ::Thunk{ChainRules.var"#1614#1616"{Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}}}, ::Tangent{Tuple{Float64, Float64, Float64}, Tuple{Float64, Float64, Float64}}) is ambiguous. + test_rrule(copy∘broadcasted, -, Tuple(rand(3))) test_rrule(copy∘broadcasted, *, rand(3), rand(3)) test_rrule(copy∘broadcasted, *, rand(3), rand()) @@ -86,16 +97,16 @@ using Base.Broadcast: broadcasted test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand() + 3im) test_rrule(copy∘broadcasted, *, rand() + im, rand(3) .+ 4im) - # test_rrule(copy∘broadcasted, *, im, rand(3)) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) - # test_rrule(copy∘broadcasted, *, rand(3), im) + @test_skip test_rrule(copy∘broadcasted, *, im, rand(3)) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) + @test_skip test_rrule(copy∘broadcasted, *, rand(3), im) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}}) y4, bk4 = rrule(CFG, copy∘broadcasted, *, im, [1,2,3.0]) @test y4 == [im, 2im, 3im] @test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7] - test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3), check_inferred=false) # Union{NoTangent, ZeroTangent} - test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)', check_inferred=false) # Union{NoTangent, ZeroTangent} - # (These two may infer with vararg rrule) - + # These two test vararg rrule * rule: + test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) + test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') + test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) @@ -128,9 +139,9 @@ using Base.Broadcast: broadcasted @testset "scalar rules" begin @testset "generic" begin - test_rrule(copy∘broadcasted, sin, rand()) - test_rrule(copy∘broadcasted, atan, rand(), rand()) - # test_rrule(copy∘broadcasted, >, rand(), rand()) # DimensionMismatch from FiniteDifferences + test_rrule(copy∘broadcasted, BS0, sin, rand()) + test_rrule(copy∘broadcasted, BS0, atan, rand(), rand()) + # test_rrule(copy∘broadcasted, BS0, >, rand(), rand()) # DimensionMismatch from FiniteDifferences end # Functions with lazy broadcasting rules: @testset "arithmetic" begin From d4e4e51a81540282db205da14cdfc3125fd87f58 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 5 Aug 2022 22:02:16 -0600 Subject: [PATCH 12/20] debug --- src/rulesets/Base/broadcast.jl | 39 +++++++++++++++++----------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index daf73e975..2c41b7849 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -13,7 +13,6 @@ function rrule(::typeof(Broadcast.instantiate), bc::Broadcasted) return Broadcast.instantiate(bc), uninstantiate end -_print(args...) = printstyled("CR: ", join(args, " "), "\n", color=:magenta) # nothing # ##### ##### Split broadcasting @@ -31,7 +30,7 @@ function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Va T = Broadcast.combine_eltypes(f, args) if T === Bool # TODO use nondifftype here # 1: Trivial case: non-differentiable output, e.g. `x .> 0` - _print("split_bc_trivial", f) + @debug("split broadcasting trivial", f, T) bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...) return f.(args...), bc_trivial_back elseif T <: Number && may_bc_derivatives(T, f, args...) @@ -57,7 +56,7 @@ _eltype(x) = eltype(x) # ... but try harder to avoid `eltype(Broadcast.broadcas _eltype(bc::Broadcast.Broadcasted) = Broadcast.combine_eltypes(bc.f, bc.args) function split_bc_derivatives(f::F, arg) where {F} - _print("split_bc_derivative", f) + @debug("split broadcasting derivative", f) ys = f.(arg) function bc_one_back(dys) # For f.(x) we do not need StructArrays / unzip at all delta = broadcast(unthunk(dys), ys, arg) do dy, y, a @@ -70,7 +69,7 @@ function split_bc_derivatives(f::F, arg) where {F} return ys, bc_one_back end function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} - _print("split_bc_derivatives", f, N) + @debug("split broadcasting derivatives", f, N) ys = f.(args...) function bc_many_back(dys) deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... @@ -105,7 +104,7 @@ end split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = split_bc_inner(frule_via_ad, cfg, f, arg) split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg) function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} - _print("split_bc_forwards", frule_fun, f) + @debug("split broadcasting forwards", frule_fun, f) ys, ydots = tuplecast(arg) do a frule_fun(cfg, (NoTangent(), one(a)), f, a) end @@ -124,7 +123,7 @@ end # can change the number of calls, don't bother to try to reverse the iteration. function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} - _print("split_bc_generic", f, N) + @debug("split broadcasting generic", f, N) ys3, backs = tuplecast(args...) do a... rrule_via_ad(cfg, f, a...) end @@ -142,7 +141,7 @@ end # Don't run broadcasting on scalars function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Number...) where {F} - _print("split_bc_scalar", f) + @debug("split broadcasting scalar", f) z, back = rrule_via_ad(cfg, f, args...) return z, dz -> (NoTangent(), NoTangent(), back(dz)...) end @@ -159,7 +158,7 @@ const NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,N ##### Arithmetic: +, -, *, ^2, / function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) - _print("plus", length(xs)) + @debug("broadcasting: plus", length(xs)) function bc_plus_back(dy_raw) dy = unthunk(dy_raw) return (NoTangent(), NoTangent(), map(x -> unbroadcast(x, dy), xs)...) # no copies, this may return dx2 === dx3 @@ -168,7 +167,7 @@ function rrule(::typeof(broadcasted), ::typeof(+), xs::NumericOrBroadcast...) end function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::NumericOrBroadcast) - _print("minus 2") + @debug("broadcasting: minus 2") function bc_minus_back(dz_raw) dz = unthunk(dz_raw) return (NoTangent(), NoTangent(), @thunk(unbroadcast(x, dz)), @thunk(-unbroadcast(y, dz))) @@ -177,13 +176,13 @@ function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast, y::Num end function rrule(::typeof(broadcasted), ::typeof(-), x::NumericOrBroadcast) - _print("minus 1") + @debug("broadcasting: minus 1") bc_minus_back(dy) = (NoTangent(), NoTangent(), @thunk -unthunk(dy)) return broadcasted(-, x), bc_minus_back end function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast) - _print("times") + @debug("broadcasting: times") function bc_times_back(Δraw) Δ = unthunk(Δraw) return (NoTangent(), NoTangent(), _back_star(x, y, Δ), _back_star(y, x, Δ)) @@ -197,7 +196,7 @@ _back_star(x::Complex{Bool}, y, Δ) = NoTangent() # e.g. for fun.(im.*x) # This works, but not sure it improves any benchmarks. Needs corresponding scalar rule to avoid ambiguities. function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::NumericOrBroadcast, zs::NumericOrBroadcast...) - _print("times", 2 + length(zs)) + @debug("broadcasting: times", 2 + length(zs)) xy, back1 = rrule(broadcasted, *, x, y) xyz, back2 = rrule(broadcasted, *, xy, zs...) function bc_times3_back(dxyz) @@ -209,7 +208,7 @@ function rrule(::typeof(broadcasted), ::typeof(*), x::NumericOrBroadcast, y::Num end function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::NumericOrBroadcast, ::Val{2}) - _print("square") + @debug("broadcasting: square") function bc_square_back(dy_raw) dx = @thunk ProjectTo(x)(2 .* unthunk(dy_raw) .* conj.(x)) return (NoTangent(), NoTangent(), NoTangent(), dx, NoTangent()) @@ -218,7 +217,7 @@ function rrule(::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x end function rrule(::typeof(broadcasted), ::typeof(/), x::NumericOrBroadcast, y::Number) - _print("divide") + @debug("broadcasting: divide") # z = broadcast(/, x, y) z = broadcasted(/, x, y) function bc_divide_back(dz_raw) @@ -255,14 +254,14 @@ rrule(::typeof(broadcasted), ::typeof(identity), x::NumericOrBroadcast) = rrule( rrule(::typeof(broadcasted), ::typeof(identity), x::Number) = rrule(identity, x) |> _prepend_zero # ambiguity function rrule(::typeof(broadcasted), ::Type{T}, x::NumericOrBroadcast) where {T<:Number} - _print("bc type", T) + @debug("broadcasting: type", T) bc_type_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(T, x), bc_type_back end rrule(::typeof(broadcasted), ::Type{T}, x::Number) where {T<:Number} = rrule(T, x) |> _prepend_zero function rrule(::typeof(broadcasted), ::typeof(float), x::NumericOrBroadcast) - _print("bc float") + @debug("broadcasting: float") bc_float_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(float, x), bc_float_back end @@ -284,7 +283,7 @@ for conj in [:conj, :adjoint] # identical as we know eltype <: Number end function rrule(::typeof(broadcasted), ::typeof(real), x::NumericOrBroadcast) - _print("real") + @debug("broadcasting: real") bc_real_back(dz) = (NoTangent(), NoTangent(), @thunk(real(unthunk(dz)))) return broadcasted(real, x), bc_real_back end @@ -292,19 +291,19 @@ rrule(::typeof(broadcasted), ::typeof(real), x::Number) = rrule(real, x) |> _pre rrule(::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) = rrule(identity, x) |> _prepend_zero function rrule(::typeof(broadcasted), ::typeof(imag), x::NumericOrBroadcast) - _print("imag") + @debug("broadcasting: imag") bc_imag_back(dz) = (NoTangent(), NoTangent(), @thunk(im .* real.(unthunk(dz)))) return broadcasted(imag, x), bc_imag_back end rrule(::typeof(broadcasted), ::typeof(imag), x::Number) = rrule(imag, x) |> _prepend_zero function rrule(::typeof(broadcasted), ::typeof(imag), x::AbstractArray{<:Real}) - _print("imag(real)") + @debug("broadcasting: imag(real)") bc_imag_back_2(dz) = (NoTangent(), NoTangent(), ZeroTangent()) return broadcasted(imag, x), bc_imag_back_2 end function rrule(::typeof(broadcasted), ::typeof(complex), x::NumericOrBroadcast) - _print("bc complex") + @debug("broadcasting: complex") bc_complex_back(dz) = (NoTangent(), NoTangent(), @thunk(unbroadcast(x, unthunk(dz)))) return broadcasted(complex, x), bc_complex_back end From fca9a75b2be29efadded8e3c78080b410012eef5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 5 Aug 2022 22:03:00 -0600 Subject: [PATCH 13/20] rename with unzip --- src/ChainRules.jl | 2 +- src/rulesets/Base/broadcast.jl | 10 ++++---- src/{tuplecast.jl => unzipped.jl} | 38 +++++++++++++++++------------- test/runtests.jl | 2 +- test/{tuplecast.jl => unzipped.jl} | 16 ++++++------- 5 files changed, 36 insertions(+), 32 deletions(-) rename src/{tuplecast.jl => unzipped.jl} (84%) rename test/{tuplecast.jl => unzipped.jl} (82%) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 2ab8d4baa..2fb8d32f9 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -24,7 +24,7 @@ using ChainRulesCore: derivatives_given_output const CommutativeMulNumber = Union{Real,Complex} # StructArrays -include("tuplecast.jl") +include("unzipped.jl") include("rulesets/Core/core.jl") diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 2c41b7849..ee12767c6 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -72,11 +72,11 @@ function split_bc_derivatives(f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting derivatives", f, N) ys = f.(args...) function bc_many_back(dys) - deltas = tuplecast(unthunk(dys), ys, args...) do dy, y, as... + deltas = unzip_broadcast(unthunk(dys), ys, args...) do dy, y, as... das = only(derivatives_given_output(y, f, as...)) map(da -> dy * conj(da), das) # possibly this * should be made nan-safe. end - dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of tuplecast? + dargs = map(unbroadcast, args, deltas) # ideally sum in unbroadcast could be part of unzip_broadcast? return (TRI_NO..., dargs...) end bc_many_back(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) @@ -105,7 +105,7 @@ split_bc_forwards(cfg::RuleConfig{>:HasForwardsMode}, f::F, arg) where {F} = spl split_bc_forwards(cfg::RuleConfig, f::F, arg) where {F} = split_bc_inner(frule, cfg, f, arg) function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} @debug("split broadcasting forwards", frule_fun, f) - ys, ydots = tuplecast(arg) do a + ys, ydots = unzip_broadcast(arg) do a frule_fun(cfg, (NoTangent(), one(a)), f, a) end function back_forwards(dys) @@ -124,11 +124,11 @@ end function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) - ys3, backs = tuplecast(args...) do a... + ys3, backs = unzip_broadcast(args...) do a... rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = tuplecast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) diff --git a/src/tuplecast.jl b/src/unzipped.jl similarity index 84% rename from src/tuplecast.jl rename to src/unzipped.jl index 9108b669b..191c58696 100644 --- a/src/tuplecast.jl +++ b/src/unzipped.jl @@ -1,18 +1,18 @@ """ - tuplecast(f, args...) + unzip_broadcast(f, args...) For a function `f` which returns a tuple, this is `== unzip(broadcast(f, args...))`, but performed using `StructArrays` for efficiency. Used in the gradient of broadcasting. # Examples ``` -julia> using ChainRules: tuplecast, unzip +julia> using ChainRules: unzip_broadcast, unzip -julia> tuplecast(x -> (x,2x), 1:3) +julia> unzip_broadcast(x -> (x,2x), 1:3) ([1, 2, 3], [2, 4, 6]) -julia> mats = @btime tuplecast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB +julia> mats = @btime unzip_broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB) julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples @@ -20,10 +20,10 @@ julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1 true ``` """ -function tuplecast(f::F, args...) where {F} +function unzip_broadcast(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) if isconcretetype(T) - T <: Tuple || throw(ArgumentError("""tuplecast(f, args) only works on functions returning a tuple, + T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple, but f = $(sprint(show, f)) returns type T = $T""")) end # TODO allow GPU arrays, possibly just as a fallback unzip, but see also: @@ -39,7 +39,7 @@ function tuplecast(f::F, args...) where {F} end end -function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplecast), f::F, args...) where {F} +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_broadcast), f::F, args...) where {F} y, back = rrule_via_ad(cfg, broadcast, f, args...) z = unzip(y) function untuplecast(dz) @@ -53,23 +53,25 @@ function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplec end # This is for testing, but the tests using it don't work. -function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘tuplecast), f, args...) - y, back = rrule(cfg, tuplecast, f, args...) +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broadcast), f, args...) + y, back = rrule(cfg, unzip_broadcast, f, args...) return collect(y), back end +#= + """ - tuplemap(f, args...) + unzip_map(f, args...) For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, but performed using `StructArrays` for efficiency. -Not in use at present, but see `tuplecast`. +Not in use at present, but see `unzip_broadcast`. """ -function tuplemap(f::F, args...) where {F} +function unzip_map(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) if isconcretetype(T) - T <: Tuple || throw(ArgumentError("""tuplemap(f, args) only works on functions returning a tuple, + T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, but f = $(sprint(show, f)) returns type T = $T""")) end # if any(a -> a isa CuArray, args) @@ -78,17 +80,19 @@ function tuplemap(f::F, args...) where {F} return StructArrays.components(StructArray(Iterators.map(f, args...))) end -function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(tuplemap), f::F, xs...) where {F} +function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F} y, back = rrule_via_ad(cfg, map, f, xs...) z = unzip(y) - function untuplemap(dz) + function ununzip_map(dz) # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) dy = broadcast(tuple, map(unthunk, dz)...) return back(dy) end - return z, untuplemap + return z, ununzip_map end +=# + """ unzip(A) @@ -114,7 +118,7 @@ function unzip(xs::AbstractArray) x1 = first(xs) x1 isa Tuple || throw(ArgumentError("unzip only accepts arrays of tuples")) N = length(x1) - return unzip(xs, Val(N)) # like Zygote's unzip, here this is the fallback case. + return unzip(xs, Val(N)) # like Zygote's unzip. Here this is the fallback case. end @generated function unzip(xs, ::Val{N}) where {N} diff --git a/test/runtests.jl b/test/runtests.jl index 4f4845e7b..9ac5c5981 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -59,7 +59,7 @@ end include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") - include_test("tuplecast.jl") # used primarily for broadcast + include_test("unzipped.jl") # used primarily for broadcast println() diff --git a/test/tuplecast.jl b/test/unzipped.jl similarity index 82% rename from test/tuplecast.jl rename to test/unzipped.jl index 458a51fa6..0d616b3f2 100644 --- a/test/tuplecast.jl +++ b/test/unzipped.jl @@ -1,8 +1,8 @@ -using ChainRules: tuplecast, unzip, tuplemap +using ChainRules: unzip_broadcast, unzip #, unzip_map -@testset "tuplecast.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [tuplemap, tuplecast, unzip∘map, unzip∘broadcast] +@testset "unzip_broadcast.jl" begin + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, @test_throws Exception fun(sqrt, 1:3) @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) @@ -17,7 +17,7 @@ using ChainRules: tuplecast, unzip, tuplemap @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) end - if fun == tuplemap + if fun == unzip_map @test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) elseif fun == unzip∘map @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) @@ -32,19 +32,19 @@ using ChainRules: tuplecast, unzip, tuplemap @testset "rrules" begin # These exist to allow for second derivatives - # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any} + # test_rrule(collect∘unzip_broadcast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any} - y1, bk1 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4,5,6.0]) + y1, bk1 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4,5,6.0]) @test y1 == ([1, 2, 3], [4, 5, 6]) @test bk1(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] # bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences - y2, bk2 = rrule(CFG, tuplecast, tuple, [1,2,3.0], [4 5.0], 6.0) + y2, bk2 = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0], 6.0) @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) @test bk2(y2)[5] ≈ 36 - y4, bk4 = rrule(CFG, tuplemap, tuple, [1,2,3.0], [4,5,6.0]) + y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) @test y4 == ([1, 2, 3], [4, 5, 6]) @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] end From 91b54057d01b5867b04401f2aa6bbece392bc9f5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 6 Aug 2022 05:32:31 -0600 Subject: [PATCH 14/20] fix for 1.6 --- test/rulesets/Base/broadcast.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 4d865fa28..640a4badb 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -1,7 +1,10 @@ using Base.Broadcast: broadcasted +if VERSION < v"1.7" + Base.ndims(::Type{<:AbstractArray{<:Any,N}}) where {N} = N +end BS0 = Broadcast.BroadcastStyle(Float64) -BS1 = Broadcast.BroadcastStyle(Vector) +BS1 = Broadcast.BroadcastStyle(Vector) # without ndims method, error on 1.6 BS2 = Broadcast.BroadcastStyle(Matrix) BT1 = Broadcast.BroadcastStyle(Tuple) From 9d5dad7a1ab58948967d286ff737ef092b863dbf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 6 Aug 2022 05:35:47 -0600 Subject: [PATCH 15/20] test bugs --- src/rulesets/Base/broadcast.jl | 4 ++-- test/unzipped.jl | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index ee12767c6..025e9e936 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -393,13 +393,13 @@ end function rrule(cfg::RCR, ::typeof(copy∘broadcasted), f_args...) tmp = rrule(cfg, broadcasted, f_args...) - isnothing(tmp) && throw("rrule gave nothing") + isnothing(tmp) && return nothing y, back = tmp return _maybe_copy(y), back end function rrule(::typeof(copy∘broadcasted), f_args...) tmp = rrule(broadcasted, f_args...) - isnothing(tmp) && throw("rrule gave nothing") + isnothing(tmp) && return nothing y, back = tmp return _maybe_copy(y), back end diff --git a/test/unzipped.jl b/test/unzipped.jl index 0d616b3f2..ae1ea7a14 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -1,7 +1,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map -@testset "unzip_broadcast.jl" begin +@testset "unzipped.jl" begin @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, @test_throws Exception fun(sqrt, 1:3) @@ -16,10 +16,8 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map else @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) end - - if fun == unzip_map - @test_broken fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) - elseif fun == unzip∘map + + if fun == unzip∘map @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) else @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) @@ -44,9 +42,9 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test y2 == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5], [6 6; 6 6; 6 6]) @test bk2(y2)[5] ≈ 36 - y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) - @test y4 == ([1, 2, 3], [4, 5, 6]) - @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] + # y4, bk4 = rrule(CFG, unzip_map, tuple, [1,2,3.0], [4,5,6.0]) + # @test y4 == ([1, 2, 3], [4, 5, 6]) + # @test bk4(([1,10,100.0], [7,8,9.0]))[3] ≈ [1,10,100] end @testset "unzip" begin From 45102c509b3d6ca828d110b74891dc0f1eab420d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 8 Aug 2022 13:46:56 -0700 Subject: [PATCH 16/20] version --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e79287e25..b33544714 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.42.0" +version = "1.43.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -26,6 +26,7 @@ JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" RealDot = "0.1" StaticArrays = "1.2" +StructArrays = "0.6.11" julia = "1.6" [extras] From ccbe5613f5bb96ee8cd92325e8685cdd29076e26 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 8 Aug 2022 14:59:12 -0700 Subject: [PATCH 17/20] tidy unzipped --- src/ChainRules.jl | 1 + src/unzipped.jl | 55 ++++++++++++++++------------------------------- test/unzipped.jl | 4 ++-- 3 files changed, 21 insertions(+), 39 deletions(-) diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 2fb8d32f9..b314d7be7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad using ChainRulesCore using Compat using Distributed +using GPUArraysCore: AbstractGPUArrayStyle using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/unzipped.jl b/src/unzipped.jl index 191c58696..fe5875e6f 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -1,3 +1,6 @@ +##### +##### broadcast +##### """ unzip_broadcast(f, args...) @@ -26,17 +29,18 @@ function unzip_broadcast(f::F, args...) where {F} T <: Tuple || throw(ArgumentError("""unzip_broadcast(f, args) only works on functions returning a tuple, but f = $(sprint(show, f)) returns type T = $T""")) end - # TODO allow GPU arrays, possibly just as a fallback unzip, but see also: - # https://github.com/JuliaArrays/StructArrays.jl/issues/150 - # if any(a -> a isa CuArray, args) - # return unzip(broadcast(f, args...)) - # end bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...)) - if Broadcast.BroadcastStyle(typeof(bc)) isa Broadcast.AbstractArrayStyle + bcs = Broadcast.BroadcastStyle(typeof(bc)) + if bcs isa AbstractGPUArrayStyle + # This is a crude way to allow GPU arrays, not currently tested, TODO. + # See also https://github.com/JuliaArrays/StructArrays.jl/issues/150 + return unzip(broadcast(f, args...)) + elseif bcs isa Broadcast.AbstractArrayStyle return StructArrays.components(StructArray(bc)) else return unzip(broadcast(f, args...)) # e.g. tuples end + # TODO maybe this if-else can be replaced by methods of `unzip(:::Broadcast.Broadcasted)`? end function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_broadcast), f::F, args...) where {F} @@ -58,40 +62,17 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect∘unzip_broad return collect(y), back end -#= +##### +##### map +##### -""" - unzip_map(f, args...) - -For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, -but performed using `StructArrays` for efficiency. - -Not in use at present, but see `unzip_broadcast`. -""" -function unzip_map(f::F, args...) where {F} - T = Broadcast.combine_eltypes(f, args) - if isconcretetype(T) - T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, - but f = $(sprint(show, f)) returns type T = $T""")) - end - # if any(a -> a isa CuArray, args) - # return unzip(map(f, args...)) - # end - return StructArrays.components(StructArray(Iterators.map(f, args...))) -end +# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`, +# will be useful for the gradient of `map` etc. -function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F} - y, back = rrule_via_ad(cfg, map, f, xs...) - z = unzip(y) - function ununzip_map(dz) - # dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) - dy = broadcast(tuple, map(unthunk, dz)...) - return back(dy) - end - return z, ununzip_map -end -=# +##### +##### unzip +##### """ unzip(A) diff --git a/test/unzipped.jl b/test/unzipped.jl index ae1ea7a14..9138cf845 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -17,7 +17,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) end - if fun == unzip∘map + if contains(string(fun), "map") @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) else @test fun(tuple, (1,2,3), (4,5,6)) == ((1, 2, 3), (4, 5, 6)) @@ -26,7 +26,7 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map end @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector end - + @testset "rrules" begin # These exist to allow for second derivatives From 609196a7f4ccf34d0b1a43921c9a04a6df4a5bd1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 8 Aug 2022 15:28:50 -0700 Subject: [PATCH 18/20] add some GPU tests --- test/rulesets/Base/broadcast.jl | 37 +++++++++++++++++---------------- test/unzipped.jl | 15 +++++++++++++ 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 640a4badb..68d47a7d4 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -80,21 +80,21 @@ BT1 = Broadcast.BroadcastStyle(Tuple) @testset "fused rules" begin @testset "arithmetic" begin - test_rrule(copy∘broadcasted, +, rand(3), rand(3)) - test_rrule(copy∘broadcasted, +, rand(3), rand(4)') - test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) - test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) - test_rrule(copy∘broadcasted, +, rand(3), true) - test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(4)') + @gpu test_rrule(copy∘broadcasted, +, rand(3), rand(1), rand()) + @gpu test_rrule(copy∘broadcasted, +, rand(3), 1.0*im) + @gpu test_rrule(copy∘broadcasted, +, rand(3), true) + @gpu_broken test_rrule(copy∘broadcasted, +, rand(3), Tuple(rand(3))) - test_rrule(copy∘broadcasted, -, rand(3), rand(3)) - test_rrule(copy∘broadcasted, -, rand(3), rand(4)') - test_rrule(copy∘broadcasted, -, rand(3)) + @gpu test_rrule(copy∘broadcasted, -, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, -, rand(3), rand(4)') + @gpu test_rrule(copy∘broadcasted, -, rand(3)) test_rrule(copy∘broadcasted, -, Tuple(rand(3))) - test_rrule(copy∘broadcasted, *, rand(3), rand(3)) - test_rrule(copy∘broadcasted, *, rand(3), rand()) - test_rrule(copy∘broadcasted, *, rand(), rand(3)) + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand(3)) + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand()) + @gpu test_rrule(copy∘broadcasted, *, rand(), rand(3)) test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand(3) .+ 2im) test_rrule(copy∘broadcasted, *, rand(3) .+ im, rand() + 3im) @@ -107,14 +107,15 @@ BT1 = Broadcast.BroadcastStyle(Tuple) @test unthunk(bk4([4, 5im, 6+7im])[4]) == [0,5,7] # These two test vararg rrule * rule: - test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) - test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') + @gpu test_rrule(copy∘broadcasted, *, rand(3), rand(3), rand(3), rand(3), rand(3)) + @gpu_broken test_rrule(copy∘broadcasted, *, rand(), rand(), rand(3), rand(3) .+ im, rand(4)') + # GPU error from dot(x::JLArray{Float32, 1}, y::JLArray{ComplexF32, 2}) - test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) - test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) + @gpu test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3), Val(2)) + @gpu test_rrule(copy∘broadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2)) - test_rrule(copy∘broadcasted, /, rand(3), rand()) - test_rrule(copy∘broadcasted, /, rand(3) .+ im, rand() + 3im) + @gpu test_rrule(copy∘broadcasted, /, rand(3), rand()) + @gpu test_rrule(copy∘broadcasted, /, rand(3) .+ im, rand() + 3im) end @testset "identity etc" begin test_rrule(copy∘broadcasted, identity, rand(3)) diff --git a/test/unzipped.jl b/test/unzipped.jl index 9138cf845..97aaa23f5 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -79,4 +79,19 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test dx5[2] isa Tangent{<:Tuple} @test Tuple(dx5[2][2]) == (10, ZeroTangent()) end + + @testset "JLArray tests" begin # fake GPU testing + (y1, y2), bk = rrule(CFG, unzip_broadcast, tuple, [1,2,3.0], [4 5.0]) + (y1jl, y2jl), bk_jl = rrule(CFG, unzip_broadcast, tuple, jl([1,2,3.0]), jl([4 5.0])) + @test y1 == Array(y1jl) + # TODO invent some tests of this rrule's pullback function + + @test unzip(jl([(1,2), (3,4), (5,6)])) == (jl([1, 3, 5]), jl([2, 4, 6])) + + @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] == jl([2, 4, 6]) + @test unzip(jl([(missing,2), (missing,4), (missing,6)]))[2] isa Base.ReinterpretArray + + @test unzip(jl([(1,), (3,), (5,)]))[1] == jl([1, 3, 5]) + @test unzip(jl([(1,), (3,), (5,)]))[1] isa Base.ReinterpretArray + end end \ No newline at end of file From 85aa2face4b7a897e45f9f815d0cb1fd5d2503eb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 8 Aug 2022 15:38:36 -0700 Subject: [PATCH 19/20] remove fallback unbroadcast method --- src/rulesets/Base/broadcast.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 025e9e936..be11eb76a 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -372,21 +372,6 @@ function unbroadcast(f::Function, df) return sum(df) end -# Fallback - -function unbroadcast(x, dx) - @info "last unbroadcast method!" x dx - dx isa AbstractZero && return dx - p = ProjectTo(x) - if p isa ProjectTo{<:AbstractZero} - return NoTangent() - elseif Broadcast.broadcastable(x) isa Ref # then x is scalar under broadcast - return p(sum(dx)) - else - error("don't know how to handle broadcast gradient for x::$(typeof(x))") - end -end - ##### ##### For testing ##### From b7e63f5bc974c97d9bce5cdd748a98277fe92f78 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 8 Aug 2022 15:39:37 -0700 Subject: [PATCH 20/20] re-instate the error which breaks Revise --- src/rulesets/Base/fastmath_able.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index d3af247bd..d8afb630a 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -302,11 +302,11 @@ let non_transformed_definitions = intersect(fastable_ast.args, fast_ast.args) filter!(expr->!(expr isa LineNumberNode), non_transformed_definitions) if !isempty(non_transformed_definitions) - # error( - # "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" * - # join(non_transformed_definitions, "\n") - # ) - # This error() may not play well with Revise. But a wanring @error does: + error( + "Non-FastMath compatible rules defined in fastmath_able.jl. \n Definitions:\n" * + join(non_transformed_definitions, "\n") + ) + # This error() may not play well with Revise. But a wanring @error does, we should change it: @error "Non-FastMath compatible rules defined in fastmath_able.jl." non_transformed_definitions end