Skip to content

Remove One #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, store!
export Wirtinger, Zero, One, DoesNotExist, Thunk, InplaceableThunk
export unthunk
export Wirtinger, Zero, DoesNotExist, Thunk, InplaceableThunk
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
export Wirtinger, Zero, DoesNotExist, Thunk, InplaceableThunk
export DoesNotExist, InplaceableThunk, Thunk, unthunk, Wirtinger, Zero

export NO_FIELDS

include("differentials.jl")
Expand Down
29 changes: 9 additions & 20 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Zero, :DoesNotExist, :One, :AbstractThunk, :Any)
for T in (:Zero, :DoesNotExist, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

Expand All @@ -47,7 +47,7 @@ end

Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
for T in (:DoesNotExist, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

Expand All @@ -58,7 +58,7 @@ end

Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
for T in (:One, :AbstractThunk, :Any)
for T in (:AbstractThunk, :Any)
@eval Base.:+(::DoesNotExist, b::$T) = b
@eval Base.:+(a::$T, ::DoesNotExist) = a

Expand All @@ -67,23 +67,12 @@ for T in (:One, :AbstractThunk, :Any)
end


Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
for T in (:AbstractThunk, :Any)
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)

@eval Base.:*(::One, b::$T) = b
@eval Base.:*(a::$T, ::One) = a
end


Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Any,)
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)

@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
end
55 changes: 26 additions & 29 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ The two fields of the returned instance can be accessed generically via the
struct Wirtinger{P,C} <: AbstractDifferential
primal::P
conjugate::C
function Wirtinger(primal::Union{Number,AbstractDifferential},
conjugate::Union{Number,AbstractDifferential})
function Wirtinger(
primal::Union{Number,AbstractDifferential},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
primal::Union{Number,AbstractDifferential},
primal::Union{Number, AbstractDifferential},

conjugate::Union{Number,AbstractDifferential},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conjugate::Union{Number,AbstractDifferential},
conjugate::Union{Number, AbstractDifferential},

)
return new{typeof(primal),typeof(conjugate)}(primal, conjugate)
end
end
Expand All @@ -75,10 +77,13 @@ wirtinger_primal(x) = x
wirtinger_conjugate(x::Wirtinger) = x.conjugate
wirtinger_conjugate(::Any) = Zero()

extern(x::Wirtinger) = throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))
function extern(x::Wirtinger)
return throw(ArgumentError("`Wirtinger` cannot be converted to an external type."))
end

Base.Broadcast.broadcastable(w::Wirtinger) = Wirtinger(broadcastable(w.primal),
broadcastable(w.conjugate))
function Base.Broadcast.broadcastable(w::Wirtinger)
return Wirtinger(broadcastable(w.primal), broadcastable(w.conjugate))
end

Base.iterate(x::Wirtinger) = (x, nothing)
Base.iterate(::Wirtinger, ::Any) = nothing
Expand All @@ -104,7 +109,6 @@ Base.Broadcast.broadcastable(::Zero) = Ref(Zero())
Base.iterate(x::Zero) = (x, nothing)
Base.iterate(::Zero, ::Any) = nothing


#####
##### `DoesNotExist`
#####
Expand All @@ -127,25 +131,6 @@ Base.Broadcast.broadcastable(::DoesNotExist) = Ref(DoesNotExist())
Base.iterate(x::DoesNotExist) = (x, nothing)
Base.iterate(::DoesNotExist, ::Any) = nothing

#####
##### `One`
#####

"""
One()
The Differential which is the multiplicative identity.
Basically, this represents `1`.
"""
struct One <: AbstractDifferential end

extern(x::One) = true # true is a strong 1.

Base.Broadcast.broadcastable(::One) = Ref(One())

Base.iterate(x::One) = (x, nothing)
Base.iterate(::One, ::Any) = nothing


#####
##### `AbstractThunk
#####
Expand All @@ -164,6 +149,18 @@ end
return element, (externed, new_state)
end

"""
unthunk(x)

On `AbstractThunk`s this removes 1 layer of thunking.
On any other type, it is the identity operation.

In contrast to `extern` this is nonrecursive.
"""
@inline unthunk(x) = x

@inline extern(x::AbstractThunk) = extern(unthunk(x))

#####
##### `Thunk`
#####
Expand Down Expand Up @@ -228,9 +225,9 @@ end
# have to define this here after `@thunk` and `Thunk` is defined
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))


(x::Thunk)() = x.f()
@inline extern(x::Thunk) = extern(x())
@inline unthunk(x::Thunk) = x()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change


Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")

Expand All @@ -252,8 +249,8 @@ struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
add!::F
end

(x::InplaceableThunk)() = x.val()
@inline extern(x::InplaceableThunk) = extern(x.val)
@inline unthunk(x::InplaceableThunk) = unthunk(x.val)
(x::InplaceableThunk)() = unthunk(x)

function Base.show(io::IO, x::InplaceableThunk)
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
Expand Down
29 changes: 8 additions & 21 deletions test/differentials.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
@testset "Differentials" begin
@testset "differentials" begin
@testset "Wirtinger" begin
w = Wirtinger(1+1im, 2+2im)
@test wirtinger_primal(w) == 1+1im
@test wirtinger_conjugate(w) == 2+2im
@test w + w == Wirtinger(2+2im, 4+4im)

@test w + One() == w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im)
@test w * One() == One() * w == w
@test w + 1 == w + Thunk(()->1) == Wirtinger(2+1im, 2+2im)
@test w * 2 == 2 * w == Wirtinger(2 + 2im, 4 + 4im)

# TODO: other + methods stack overflow
Expand All @@ -33,22 +32,6 @@
@test broadcastable(z) isa Ref{Zero}
@test conj(z) == z
end
@testset "One" begin
o = One()
@test extern(o) === true
@test o + o == 2
@test o + 1 == 2
@test 1 + o == 2
@test o * o == o
@test o * 1 == 1
@test 1 * o == 1
for x in o
@test x === o
end
@test broadcastable(o) isa Ref{One}
@test conj(o) == o
end

@testset "Thunk" begin
@test @thunk(3) isa Thunk

Expand All @@ -62,11 +45,15 @@
@test extern(@thunk(@thunk(3))) == 3
end

@testset "unthunk" begin
@test unthunk(@thunk(3)) == 3
@test unthunk(@thunk(@thunk(3))) isa Thunk
end

@testset "calling thunks should call inner function" begin
@test (@thunk(3))() == 3
@test (@thunk(@thunk(3)))() isa Thunk
end

@testset "erroring thunks should include the source in the backtrack" begin
expected_line = (@__LINE__) + 2 # for testing it is at right palce
try
Expand Down Expand Up @@ -104,7 +91,7 @@
@test refine_differential(typeof([1.2]), Wirtinger(2,2)) == 4

# For most differentials, in most domains, this does nothing
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], One(), Zero(), 0.0)
for der in (DoesNotExist(), @thunk(23), @thunk(Wirtinger(2,2)), [1 2], Zero(), 0.0)
for 𝒟 in typeof.((1.0 + 1im, [1.0 + 1im], 1.2, [1.2]))
@test refine_differential(𝒟, der) === der
end
Expand Down
Loading