Skip to content

Commit 491f781

Browse files
committed
overload Base.real for some differentials
1 parent 71e0c7b commit 491f781

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/differentials.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
148148
Base.iterate(x::Zero) = (x, nothing)
149149
Base.iterate(::Zero, ::Any) = nothing
150150

151+
Base.real(::Zero) = Zero()
151152

152153
#####
153154
##### `DNE`
@@ -189,6 +190,7 @@ Base.Broadcast.broadcastable(::One) = Ref(One())
189190
Base.iterate(x::One) = (x, nothing)
190191
Base.iterate(::One, ::Any) = nothing
191192

193+
Base.real(::One) = One()
192194

193195
#####
194196
##### `AbstractThunk
@@ -216,6 +218,8 @@ wirtinger_primal(::Union{AbstractThunk}) =
216218
wirtinger_conjugate(::Union{AbstractThunk}) =
217219
throw(ArgumentError("`wirtinger_conjugate` is not defined for `AbstractThunk`. Call `unthunk` first."))
218220

221+
Base.real(x::AbstractThunk) = real(x())
222+
219223
#####
220224
##### `Thunk`
221225
#####

test/differentials.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@
102102
@test refine_differential(typeof(1.0 + 1im), g) === g
103103
@test refine_differential(typeof([1.0 + 1im]), g) === g
104104

105-
c isa Thunk && continue
106105
@test refine_differential(typeof(1.2), g) == real(c)
107106
@test refine_differential(typeof([1.2]), g) == real(c)
108107
end

0 commit comments

Comments
 (0)