diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index db3e6e94f..908d32370 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -5,7 +5,8 @@ export frule, rrule export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk export extern, store! -export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk +export unthunk +export Wirtinger, Zero, DoesNotExist, Thunk, InplaceableThunk export NO_FIELDS include("differentials.jl") diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index eb06b9bac..5bd1a6e70 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger) return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate) end -for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any) +for T in (:Zero, :DoesNotExist, :AbstractThunk, :Any) @eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero()) @eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b @@ -47,7 +47,7 @@ end Base.:+(::Zero, b::Zero) = Zero() Base.:*(::Zero, ::Zero) = Zero() -for T in (:DoesNotExist, :One, :AbstractThunk, :Any) +for T in (:DoesNotExist, :AbstractThunk, :Any) @eval Base.:+(::Zero, b::$T) = b @eval Base.:+(a::$T, ::Zero) = a @@ -58,7 +58,7 @@ end Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist() Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist() -for T in (:One, :AbstractThunk, :Any) +for T in (:AbstractThunk, :Any) @eval Base.:+(::DoesNotExist, b::$T) = b @eval Base.:+(a::$T, ::DoesNotExist) = a @@ -67,23 +67,12 @@ for T in (:One, :AbstractThunk, :Any) end -Base.:+(a::One, b::One) = extern(a) + extern(b) -Base.:*(::One, ::One) = One() -for T in (:AbstractThunk, :Any) - @eval Base.:+(a::One, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::One) = a + extern(b) - - @eval Base.:*(::One, b::$T) = b - @eval Base.:*(a::$T, ::One) = a -end - - -Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b) -Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b) +Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) +Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Any,) - @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) + @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b) - @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) + @eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end diff --git a/src/differentials.jl b/src/differentials.jl index 640ee02bb..eada532b4 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -63,8 +63,10 @@ The two fields of the returned instance can be accessed generically via the struct Wirtinger{P,C} <: AbstractDifferential primal::P conjugate::C - function Wirtinger(primal::Union{Number,AbstractDifferential}, - conjugate::Union{Number,AbstractDifferential}) + function Wirtinger( + primal::Union{Number,AbstractDifferential}, + conjugate::Union{Number,AbstractDifferential}, + ) return new{typeof(primal),typeof(conjugate)}(primal, conjugate) end end @@ -75,10 +77,13 @@ wirtinger_primal(x) = x wirtinger_conjugate(x::Wirtinger) = x.conjugate wirtinger_conjugate(::Any) = Zero() -extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) +function extern(x::Wirtinger) + return throw(ArgumentError("`Wirtinger` cannot be converted to an external type.")) +end -Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal), - broadcastable(w.conjugate)) +function Base.Broadcast.broadcastable(w::Wirtinger) + return Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate)) +end Base.iterate(x::Wirtinger) = (x, nothing) Base.iterate(::Wirtinger, ::Any) = nothing @@ -104,7 +109,6 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero()) Base.iterate(x::Zero) = (x, nothing) Base.iterate(::Zero, ::Any) = nothing - ##### ##### `DoesNotExist` ##### @@ -127,25 +131,6 @@ Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist()) Base.iterate(x::DoesNotExist) = (x, nothing) Base.iterate(::DoesNotExist, ::Any) = nothing -##### -##### `One` -##### - -""" - One() -The Differential which is the multiplicative identity. -Basically, this represents `1`. -""" -struct One <: AbstractDifferential end - -extern(x::One) = true # true is a strong 1. - -Base.Broadcast.broadcastable(::One) = Ref(One()) - -Base.iterate(x::One) = (x, nothing) -Base.iterate(::One, ::Any) = nothing - - ##### ##### `AbstractThunk ##### @@ -164,6 +149,18 @@ end return element, (externed, new_state) end +""" + unthunk(x) + +On `AbstractThunk`s this removes 1 layer of thunking. +On any other type, it is the identity operation. + +In contrast to `extern` this is nonrecursive. +""" +@inline unthunk(x) = x + +@inline extern(x::AbstractThunk) = extern(unthunk(x)) + ##### ##### `Thunk` ##### @@ -228,9 +225,9 @@ end # have to define this here after `@thunk` and `Thunk` is defined Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) - (x::Thunk)() = x.f() -@inline extern(x::Thunk) = extern(x()) +@inline unthunk(x::Thunk) = x() + Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") @@ -252,8 +249,8 @@ struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk add!::F end -(x::InplaceableThunk)() = x.val() -@inline extern(x::InplaceableThunk) = extern(x.val) +@inline unthunk(x::InplaceableThunk) = unthunk(x.val) +(x::InplaceableThunk)() = unthunk(x) function Base.show(io::IO, x::InplaceableThunk) println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") diff --git a/test/differentials.jl b/test/differentials.jl index 6ddad9b4f..1d40734a9 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -1,12 +1,11 @@ -@testset "Differentials" begin +@testset "differentials" begin @testset "Wirtinger" begin w = Wirtinger(1+1im, 2+2im) @test wirtinger_primal(w) == 1+1im @test wirtinger_conjugate(w) == 2+2im @test w + w == Wirtinger(2+2im, 4+4im) - @test w + One() == w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im) - @test w * One() == One() * w == w + @test w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im) @test w * 2 == 2 * w == Wirtinger(2 + 2im, 4 + 4im) # TODO: other + methods stack overflow @@ -33,22 +32,6 @@ @test broadcastable(z) isa Ref{Zero} @test conj(z) == z end - @testset "One" begin - o = One() - @test extern(o) === true - @test o + o == 2 - @test o + 1 == 2 - @test 1 + o == 2 - @test o * o == o - @test o * 1 == 1 - @test 1 * o == 1 - for x in o - @test x === o - end - @test broadcastable(o) isa Ref{One} - @test conj(o) == o - end - @testset "Thunk" begin @test @thunk(3) isa Thunk @@ -62,11 +45,15 @@ @test extern(@thunk(@thunk(3))) == 3 end + @testset "unthunk" begin + @test unthunk(@thunk(3)) == 3 + @test unthunk(@thunk(@thunk(3))) isa Thunk + end + @testset "calling thunks should call inner function" begin @test (@thunk(3))() == 3 @test (@thunk(@thunk(3)))() isa Thunk end - @testset "erroring thunks should include the source in the backtrack" begin expected_line = (@__LINE__) + 2 # for testing it is at right palce try @@ -104,7 +91,7 @@ @test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4 # For most differentials, in most domains, this does nothing - for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0) + for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], Zero(), 0.0) for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2])) @test refine_differential(𝒟, der) === der end diff --git a/test/rules.jl b/test/rules.jl index e23680326..7231b5f82 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -6,146 +6,149 @@ cool(x, y) = x + y + 1 # a rule we define so we can test rules dummy_identity(x) = x -@scalar_rule(dummy_identity(x), One()) +@scalar_rule(dummy_identity(x), one(x)) ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) -@testset "frule and rrule" begin - @test frule(cool, 1) === nothing - @test frule(cool, 1; iscool=true) === nothing - @test rrule(cool, 1) === nothing - @test rrule(cool, 1; iscool=true) === nothing - - # add some methods: - ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) - @test hasmethod(rrule, Tuple{typeof(cool),Number}) - ChainRulesCore.@scalar_rule(Main.cool(x::String), "wow such dfdx") - @test hasmethod(rrule, Tuple{typeof(cool),String}) - # Ensure those are the *only* methods that have been defined - cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) - only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, - Tuple{typeof(rrule),typeof(cool),String}]) - @test cool_methods == only_methods - - frx, cool_pushforward = frule(cool, 1) - @test frx == 2 - @test cool_pushforward(NamedTuple(), 1) == 1 - rrx, cool_pullback = rrule(cool, 1) - self, rr1 = cool_pullback(1) - @test self == NO_FIELDS - @test rrx == 2 - @test rr1 == 1 -end +@testset "rules" begin + + @testset "frule and rrule" begin + @test frule(cool, 1) === nothing + @test frule(cool, 1; iscool=true) === nothing + @test rrule(cool, 1) === nothing + @test rrule(cool, 1; iscool=true) === nothing + + # add some methods: + ChainRulesCore.@scalar_rule(Main.cool(x), one(x)) + @test hasmethod(rrule, Tuple{typeof(cool),Number}) + ChainRulesCore.@scalar_rule(Main.cool(x::String), "wow such dfdx") + @test hasmethod(rrule, Tuple{typeof(cool),String}) + # Ensure those are the *only* methods that have been defined + cool_methods = Set(m.sig for m in methods(rrule) if _second(m.sig) == typeof(cool)) + only_methods = Set([Tuple{typeof(rrule),typeof(cool),Number}, + Tuple{typeof(rrule),typeof(cool),String}]) + @test cool_methods == only_methods + + frx, cool_pushforward = frule(cool, 1) + @test frx == 2 + @test cool_pushforward(NamedTuple(), 1) == 1 + rrx, cool_pullback = rrule(cool, 1) + self, rr1 = cool_pullback(1) + @test self == NO_FIELDS + @test rrx == 2 + @test rr1 == 1 + end -@testset "Basic Wirtinger scalar_rule" begin - myabs2(x) = abs2(x) - @scalar_rule(myabs2(x), Wirtinger(x', x)) + @testset "Basic Wirtinger scalar_rule" begin + myabs2(x) = abs2(x) + @scalar_rule(myabs2(x), Wirtinger(x', x)) - @testset "real input" begin - # even though our rule was define in terms of Wirtinger, - # pushforward result will be real as real (even if seed is Compex) + @testset "real input" begin + # even though our rule was define in terms of Wirtinger, + # pushforward result will be real as real (even if seed is Compex) - x = rand(Float64) - f, myabs2_pushforward = frule(myabs2, x) - @test f === x^2 + x = rand(Float64) + f, myabs2_pushforward = frule(myabs2, x) + @test f === x^2 - Δ = One() - df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === x + x + Δ = 1.0 + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === x + x - Δ = rand(Complex{Int64}) - df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === Δ * (x + x) - end + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === Δ * (x + x) + end - @testset "complex input" begin - z = rand(Complex{Float64}) - f, myabs2_pushforward = frule(myabs2, z) - @test f === abs2(z) + @testset "complex input" begin + z = rand(Complex{Float64}) + f, myabs2_pushforward = frule(myabs2, z) + @test f === abs2(z) - df = @inferred myabs2_pushforward(NamedTuple(), One()) - @test df === Wirtinger(z', z) + df = @inferred myabs2_pushforward(NamedTuple(), one(z)) + @test df === Wirtinger(z', z) - Δ = rand(Complex{Int64}) - df = @inferred myabs2_pushforward(NamedTuple(), Δ) - @test df === Wirtinger(Δ * z', Δ * z) + Δ = rand(Complex{Int64}) + df = @inferred myabs2_pushforward(NamedTuple(), Δ) + @test df === Wirtinger(Δ * z', Δ * z) + end end -end -@testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin - # This is based on SimeonSchaub excellent example: - # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 + @testset "Advanced Wirtinger @scalar_rule: abs_to_pow" begin + # This is based on SimeonSchaub excellent example: + # https://gist.github.com/simeonschaub/a6dfcd71336d863b3777093b3b8d9c97 - # This is much more complex than the previous case - # as it has many different types - # depending on input, and the output types do not always agree + # This is much more complex than the previous case + # as it has many different types + # depending on input, and the output types do not always agree - abs_to_pow(x, p) = abs(x)^p - @scalar_rule( - abs_to_pow(x::Real, p), - ( - p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), - Ω * log(abs(x)) + abs_to_pow(x, p) = abs(x)^p + @scalar_rule( + abs_to_pow(x::Real, p), + ( + p == 0 ? Zero() : p * abs_to_pow(x, p-1) * sign(x), + Ω * log(abs(x)) + ) ) - ) - - @scalar_rule( - abs_to_pow(x::Complex, p), - @setup(u = abs(x)), - ( - p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), - Ω * log(abs(x)) + + @scalar_rule( + abs_to_pow(x::Complex, p), + @setup(u = abs(x)), + ( + p == 0 ? Zero() : p * u^(p-1) * Wirtinger(x' / 2u, x / 2u), + Ω * log(abs(x)) + ) + ) + + + f = abs_to_pow + @testset "f($x, $p)" for (x, p) in Iterators.product( + (2, 3.4, -2.1, -10+0im, 2.3-2im), + (0, 1, 2, 4.3, -2.1, 1+.2im) ) - ) - - - f = abs_to_pow - @testset "f($x, $p)" for (x, p) in Iterators.product( - (2, 3.4, -2.1, -10+0im, 2.3-2im), - (0, 1, 2, 4.3, -2.1, 1+.2im) - ) - expected_type_df_dx = - if iszero(p) - Zero - elseif typeof(x) <: Complex - Wirtinger - elseif typeof(p) <: Complex - Complex - else - Real - end - - expected_type_df_dp = - if typeof(p) <: Real - Real - else - Complex - end - - - res = frule(f, x, p) - @test res !== nothing # Check the rule was defined - fx, f_pushforward = res - df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) - - df_dx::Thunk = df(One(), Zero()) - df_dp::Thunk = df(Zero(), One()) - @test fx == f(x, p) # Check we still get the normal value, right - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp - - - res = rrule(f, x, p) - @test res !== nothing # Check the rule was defined - fx, f_pullback = res - dself, df_dx, df_dp = f_pullback(One()) - @test fx == f(x, p) # Check we still get the normal value, right - @test dself == NO_FIELDS - @test df_dx() isa expected_type_df_dx - @test df_dp() isa expected_type_df_dp + expected_type_df_dx = + if iszero(p) + Zero + elseif typeof(x) <: Complex + Wirtinger + elseif typeof(p) <: Complex + Complex + else + Real + end + + expected_type_df_dp = + if typeof(p) <: Real + Real + else + Complex + end + + + res = frule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pushforward = res + df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp) + + df_dx::Thunk = df(one(x), Zero()) + df_dp::Thunk = df(Zero(), one(p)) + @test fx == f(x, p) # Check we still get the normal value, right + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp + + + res = rrule(f, x, p) + @test res !== nothing # Check the rule was defined + fx, f_pullback = res + dself, df_dx, df_dp = f_pullback(one(fx)) + @test fx == f(x, p) # Check we still get the normal value, right + @test dself == NO_FIELDS + @test df_dx() isa expected_type_df_dx + @test df_dp() isa expected_type_df_dp + end end end diff --git a/test/runtests.jl b/test/runtests.jl index bc3c9cde1..6f4d6ea2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using ChainRulesCore using LinearAlgebra: Diagonal using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, - Zero, One, DoesNotExist, Thunk + Zero, DoesNotExist, Thunk using Base.Broadcast: broadcastable @testset "ChainRulesCore" begin