Skip to content

Commit c199109

Browse files
committed
add Unthunk
1 parent 40e8193 commit c199109

File tree

5 files changed

+34
-16
lines changed

5 files changed

+34
-16
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
44
export frule, rrule
55
export wirtinger_conjugate, wirtinger_primal, refine_differential
66
export @scalar_rule, @thunk
7+
export unthunk,
78
export extern, cast, store!
89
export Wirtinger, Zero, DoesNotExist, Thunk, InplaceableThunk
910
export NO_FIELDS

src/differential_arithmetic.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ for T in (:AbstractThunk, :Any)
6767
end
6868

6969

70-
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
71-
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
70+
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
71+
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
7272
for T in (:Any,)
73-
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
74-
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
73+
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
74+
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)
7575

76-
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
77-
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
76+
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
77+
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
7878
end

src/differentials.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ end
149149
return element, (externed, new_state)
150150
end
151151

152+
"""
153+
unthunk(x)
154+
155+
On `AbstractThunk`s this removes 1 layer of thunking.
156+
On any other type, it is the identity operation.
157+
158+
In contrast to `extern` this is nonrecursive.
159+
"""
160+
@inline unthunk(x) = x
161+
162+
@inline extern(x::AbstractThunk) = extern(unthunk(x))
163+
152164
#####
153165
##### `Thunk`
154166
#####
@@ -213,9 +225,9 @@ end
213225
# have to define this here after `@thunk` and `Thunk` is defined
214226
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
215227

216-
217228
(x::Thunk)() = x.f()
218-
@inline extern(x::Thunk) = extern(x())
229+
@inline unthunk(x::Thunk) = x()
230+
219231

220232
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
221233

@@ -237,8 +249,8 @@ struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
237249
add!::F
238250
end
239251

240-
(x::InplaceableThunk)() = x.val()
241-
@inline extern(x::InplaceableThunk) = extern(x.val)
252+
@inline unthunk(x::InplaceableThunk) = unthunk(x.val)
253+
(x::InplaceableThunk)() = unthunk(x)
242254

243255
function Base.show(io::IO, x::InplaceableThunk)
244256
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")

test/differentials.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
@test extern(@thunk(@thunk(3))) == 3
4646
end
4747

48+
@testset "unthunk" begin
49+
@test unthunk(@thunk(3)) == 3
50+
@test unthunk(@thunk(@thunk(3))) isa Thunk
51+
end
52+
4853
@testset "calling thunks should call inner function" begin
4954
@test (@thunk(3))() == 3
5055
@test (@thunk(@thunk(3)))() isa Thunk

test/rules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
134134
fx, f_pushforward = res
135135
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
136136

137-
df_dx::Thunk = df(one(x), Zero())
138-
df_dp::Thunk = df(Zero(), one(p))
137+
df_dx = df(one(x), Zero())
138+
df_dp = df(Zero(), one(p))
139139
@test fx == f(x, p) # Check we still get the normal value, right
140-
@test df_dx() isa expected_type_df_dx
141-
@test df_dp() isa expected_type_df_dp
140+
@test df_dx isa expected_type_df_dx
141+
@test df_dp isa expected_type_df_dp
142142

143143

144144
res = rrule(f, x, p)
@@ -147,8 +147,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
147147
dself, df_dx, df_dp = f_pullback(One())
148148
@test fx == f(x, p) # Check we still get the normal value, right
149149
@test dself == NO_FIELDS
150-
@test df_dx() isa expected_type_df_dx
151-
@test df_dp() isa expected_type_df_dp
150+
@test df_dx isa expected_type_df_dx
151+
@test df_dp isa expected_type_df_dp
152152
end
153153
end
154154
end

0 commit comments

Comments
 (0)