Skip to content

Commit 21e7b23

Browse files
committed
add Unthunk
fix extra comma
1 parent d42ef9f commit 21e7b23

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export frule, rrule
55
export wirtinger_conjugate, wirtinger_primal, refine_differential
66
export @scalar_rule, @thunk
77
export extern, store!
8+
export unthunk
89
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
910
export NO_FIELDS
1011

src/differential_arithmetic.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ for T in (:AbstractThunk, :Any)
7878
end
7979

8080

81-
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
82-
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
81+
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
82+
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
8383
for T in (:Any,)
84-
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
85-
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
84+
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
85+
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)
8686

87-
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
88-
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
87+
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
88+
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
8989
end

src/differentials.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,18 @@ end
164164
return element, (externed, new_state)
165165
end
166166

167+
"""
168+
unthunk(x)
169+
170+
On `AbstractThunk`s this removes 1 layer of thunking.
171+
On any other type, it is the identity operation.
172+
173+
In contrast to `extern` this is nonrecursive.
174+
"""
175+
@inline unthunk(x) = x
176+
177+
@inline extern(x::AbstractThunk) = extern(unthunk(x))
178+
167179
#####
168180
##### `Thunk`
169181
#####
@@ -228,9 +240,9 @@ end
228240
# have to define this here after `@thunk` and `Thunk` is defined
229241
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
230242

231-
232243
(x::Thunk)() = x.f()
233-
@inline extern(x::Thunk) = extern(x())
244+
@inline unthunk(x::Thunk) = x()
245+
234246

235247
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
236248

@@ -252,8 +264,8 @@ struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
252264
add!::F
253265
end
254266

255-
(x::InplaceableThunk)() = x.val()
256-
@inline extern(x::InplaceableThunk) = extern(x.val)
267+
@inline unthunk(x::InplaceableThunk) = unthunk(x.val)
268+
(x::InplaceableThunk)() = unthunk(x)
257269

258270
function Base.show(io::IO, x::InplaceableThunk)
259271
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")

test/differentials.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
@test extern(@thunk(@thunk(3))) == 3
6363
end
6464

65+
@testset "unthunk" begin
66+
@test unthunk(@thunk(3)) == 3
67+
@test unthunk(@thunk(@thunk(3))) isa Thunk
68+
end
69+
6570
@testset "calling thunks should call inner function" begin
6671
@test (@thunk(3))() == 3
6772
@test (@thunk(@thunk(3)))() isa Thunk

0 commit comments

Comments
 (0)