diff --git a/Project.toml b/Project.toml index ae6c77687..9321d78d0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.4.0" +version = "0.5.0-DEV" [compat] julia = "^1.0" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index db3e6e94f..3a8942e65 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad export frule, rrule export wirtinger_conjugate, wirtinger_primal, refine_differential export @scalar_rule, @thunk -export extern, store! +export extern, store!, unthunk export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk export NO_FIELDS diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index eb06b9bac..54d142812 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -78,12 +78,12 @@ for T in (:AbstractThunk, :Any) end -Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b) -Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b) +Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) +Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Any,) - @eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b) + @eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b + @eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b) - @eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b - @eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b) + @eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b + @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end diff --git a/src/differentials.jl b/src/differentials.jl index 640ee02bb..bfe6ba91d 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -164,6 +164,18 @@ end return element, (externed, new_state) end +""" + unthunk(x) + +On `AbstractThunk`s this removes 1 layer of thunking. +On any other type, it is the identity operation. + +In contrast to `extern` this is nonrecursive. +""" +@inline unthunk(x) = x + +@inline extern(x::AbstractThunk) = extern(unthunk(x)) + ##### ##### `Thunk` ##### @@ -228,9 +240,9 @@ end # have to define this here after `@thunk` and `Thunk` is defined Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) - (x::Thunk)() = x.f() -@inline extern(x::Thunk) = extern(x()) +@inline unthunk(x::Thunk) = x() + Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") @@ -252,8 +264,8 @@ struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk add!::F end -(x::InplaceableThunk)() = x.val() -@inline extern(x::InplaceableThunk) = extern(x.val) +@inline unthunk(x::InplaceableThunk) = unthunk(x.val) +(x::InplaceableThunk)() = unthunk(x) function Base.show(io::IO, x::InplaceableThunk) println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") diff --git a/test/differentials.jl b/test/differentials.jl index 6ddad9b4f..bbe84bd0a 100644 --- a/test/differentials.jl +++ b/test/differentials.jl @@ -62,6 +62,11 @@ @test extern(@thunk(@thunk(3))) == 3 end + @testset "unthunk" begin + @test unthunk(@thunk(3)) == 3 + @test unthunk(@thunk(@thunk(3))) isa Thunk + end + @testset "calling thunks should call inner function" begin @test (@thunk(3))() == 3 @test (@thunk(@thunk(3)))() isa Thunk