Skip to content

Commit 9b56ede

Browse files
committed
It's Perfect. It's flawless. Really something.
1 parent 7df36ef commit 9b56ede

File tree

7 files changed

+10
-137
lines changed

7 files changed

+10
-137
lines changed

src/jet.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ function (∂⃖ₙ::∂⃖{N})(::typeof(map), f, a::Array) where {N}
187187
∂f = ∂☆{N}()(ZeroBundle{N}(f),
188188
TaylorBundle{N}(x,
189189
(one(x), (zero(x) for i = 1:(N-1))...,)))
190-
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
191-
Jet{typeof(x), typeof(x), N}(x, ∂f.primal,
192-
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
190+
@assert isa(∂f, TaylorBundle)
191+
Jet{typeof(x), typeof(x), N}(x, ∂f.primal, ∂f.tangent.coeffs)
193192
end
194193
∂⃖ₙ(mapev, js, a)
195194
end
@@ -248,13 +247,3 @@ end
248247
($((:(jet_taylor_ev(Val{$i}(), coeffs, j)) for i = 1:O)...),))
249248
end
250249
end
251-
252-
function (j::Jet{S, T, 1} where {S,T})(x::ExplicitTangentBundle{1})
253-
domain_check(j, x.primal)
254-
coeffs = x.tangent.partials
255-
ExplicitTangentBundle{1}(j[0], (jet_taylor_ev(Val{1}(), coeffs, j),))
256-
end
257-
258-
function (j::Jet{S, T, N} where T)(x::ExplicitTangentBundle{N, M}) where {S, N, M}
259-
error("TODO")
260-
end

src/stage1/forward.jl

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
partial(x::TangentBundle, i) = partial(getfield(x, :tangent), i)
2-
partial(x::ExplicitTangent, i) = getfield(getfield(x, :partials), i)
32
partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
43
partial(x::UniformTangent, i) = getfield(x, :val)
54
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
@@ -25,15 +24,6 @@ my_frule(::ZeroBundle{1, typeof(my_frule)}, args::ATB{1}...) = nothing
2524
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
2625
UniformBundle{N-1, <:Any, U}(UniformBundle{1, B, U}(b.primal, b.tangent.val), b.tangent.val)
2726

28-
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
29-
# N.B: This depends on the special properties of the canonical tangent index order
30-
ExplicitTangentBundle{N-1}(
31-
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
32-
ntuple(1<<(N-1)-1) do i
33-
ExplicitTangentBundle{1}(partial(b, 2*i), (partial(b, 2*i+1),))
34-
end)
35-
end
36-
3727
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
3828
TaylorBundle{N-1}(
3929
TaylorBundle{1}(b.primal, (b.tangent.coeffs[1],)),
@@ -58,31 +48,12 @@ function shuffle_up(r::CompositeBundle{1})
5848
return TaylorBundle{2}(z₀, (z₁, z₁₂))
5949
end
6050

61-
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
62-
primal(b) === a[TaylorTangentIndex(1)] || return false
63-
return all(1:(N-1)) do i
64-
b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)]
65-
end
66-
end
67-
68-
# Check whether the tangent bundle element is taylor-like
69-
isswifty(::TaylorBundle) = true
70-
isswifty(::UniformBundle) = true
71-
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
72-
isswifty(::Any) = false
73-
7451
function shuffle_up(r::CompositeBundle{N}) where {N}
7552
a, b = r.tup
76-
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
77-
return TaylorBundle{N+1}(primal(a),
78-
ntuple(i->i == N+1 ?
79-
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
80-
N+1))
81-
else
82-
return TangentBundle{N+1}(r.tup[1].primal,
83-
(r.tup[1].tangent.partials..., primal(b),
84-
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
85-
end
53+
return TaylorBundle{N+1}(primal(a),
54+
ntuple(i->i == N+1 ?
55+
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
56+
N+1))
8657
end
8758

8859
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
@@ -134,18 +105,6 @@ end
134105
(::∂☆{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆internal{N}()(args...)
135106

136107
# Special case rules for performance
137-
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::AbstractTangentBundle{N}) where {N}
138-
s = primal(s)
139-
ExplicitTangentBundle{N}(getfield(primal(x), s),
140-
map(x->lifted_getfield(x, s), x.tangent.partials))
141-
end
142-
143-
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::ExplicitTangentBundle{N}, s::ATB{N}, inbounds::ATB{N}) where {N}
144-
s = primal(s)
145-
ExplicitTangentBundle{N}(getfield(primal(x), s, primal(inbounds)),
146-
map(x->lifted_getfield(x, s), x.tangent.partials))
147-
end
148-
149108
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::TaylorBundle{N}, s::AbstractTangentBundle{N}) where {N}
150109
s = primal(s)
151110
TaylorBundle{N}(getfield(primal(x), s),

src/stage1/mixed.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ function (∂⃖ₙ::∂⃖{N})(∂☆ₘ::∂☆{M}, ::ZeroBundle{M, typeof(map
9595
∂f = ∂☆{N+M}()(ZeroBundle{N+M}(primal(f)),
9696
TaylorBundle{N+M}(x,
9797
(one(x), (zero(x) for i = 1:(N+M-1))...,)))
98-
@assert isa(∂f, TaylorBundle) || isa(∂f, ExplicitTangentBundle{1})
99-
Jet{typeof(x), N+M}(x, ∂f.primal,
100-
isa(∂f, ExplicitTangentBundle) ? ∂f.tangent.partials : ∂f.tangent.coeffs)
98+
@assert isa(∂f, TaylorBundle)
99+
Jet{typeof(x), N+M}(x, ∂f.primal, ∂f.tangent.coeffs)
101100
end
102101
∂⃖ₙ(mapev_unbundled, ∂☆ₘ, js, a)
103102
end

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct ∂☆new{N}; end
1212
(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
1313
CompositeBundle{N, B}(a)
1414

15-
@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
15+
(::∂☆new{N})(B::Type) where {N} = return ZeroBundle{N}(B)
1616

1717
# Sometimes we don't know whether or not we need to the ZeroBundle when doing
1818
# the transform, so this can happen - allow it for now.

src/tangent.jl

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,6 @@ end
8080

8181
abstract type AbstractTangentSpace; end
8282

83-
"""
84-
struct ExplicitTangent{P}
85-
86-
A fully explicit coordinate representation of the tangent space,
87-
represented by a vector of `2^(N-1)` partials.
88-
"""
89-
struct ExplicitTangent{P <: Tuple} <: AbstractTangentSpace
90-
partials::P
91-
end
92-
9383
struct TaylorTangent{C <: Tuple} <: AbstractTangentSpace
9484
coeffs::C
9585
end
@@ -151,46 +141,9 @@ struct TangentBundle{N, B, P <: AbstractTangentSpace} <: AbstractTangentBundle{N
151141
TangentBundle{N}(B, P) where {N} = new{N, typeof(B), typeof(P)}(B,P)
152142
end
153143

154-
const ExplicitTangentBundle{N, B, P} = TangentBundle{N, B, ExplicitTangent{P}}
155-
156144
check_tangent_invariant(lp, N) = @assert lp == 2^N - 1
157145
@ChainRulesCore.non_differentiable check_tangent_invariant(lp, N)
158146

159-
function ExplicitTangentBundle{N}(primal::B, partials::P) where {N, B, P}
160-
check_tangent_invariant(length(partials), N)
161-
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
162-
end
163-
164-
function ExplicitTangentBundle{N,B}(primal::B, partials::P) where {N, B, P}
165-
check_tangent_invariant(length(partials), N)
166-
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
167-
end
168-
169-
function ExplicitTangentBundle{N,B,P}(primal::B, partials::P) where {N, B, P}
170-
check_tangent_invariant(length(partials), N)
171-
TangentBundle{N}(primal, ExplicitTangent{P}(partials))
172-
end
173-
174-
function Base.show(io::IO, x::ExplicitTangentBundle)
175-
print(io, x.primal)
176-
print(io, " + ")
177-
x = x.tangent
178-
print(io, x.partials[1], " ∂₁")
179-
length(x.partials) >= 2 && print(io, " + ", x.partials[2], " ∂₂")
180-
length(x.partials) >= 3 && print(io, " + ", x.partials[3], " ∂₁ ∂₂")
181-
length(x.partials) >= 4 && print(io, " + ", x.partials[4], " ∂₃")
182-
length(x.partials) >= 5 && print(io, " + ", x.partials[5], " ∂₁ ∂₃")
183-
length(x.partials) >= 6 && print(io, " + ", x.partials[6], " ∂₂ ∂₃")
184-
length(x.partials) >= 7 && print(io, " + ", x.partials[7], " ∂₁ ∂₂ ∂₃")
185-
end
186-
187-
function Base.getindex(a::ExplicitTangentBundle{N}, b::TaylorTangentIndex) where {N}
188-
if b.i === N
189-
return a.tangent.partials[end]
190-
end
191-
error("$(typeof(a)) is not taylor-like. Taylor indexing is ambiguous")
192-
end
193-
194147
const TaylorBundle{N, B, P} = TangentBundle{N, B, TaylorTangent{P}}
195148

196149
function TaylorBundle{N, B}(primal::B, coeffs) where {N, B}
@@ -268,24 +221,6 @@ end
268221
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
269222
expand_singleton_to_array(asize, a::AbstractArray) = a
270223

271-
function unbundle(atb::ExplicitTangentBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}}
272-
asize = size(atb.primal)
273-
StructArray{ExplicitTangentBundle{Order, T}}((atb.primal, map(a->expand_singleton_to_array(asize, a), atb.tangent.partials)...))
274-
end
275-
276-
function StructArrays.staticschema(::Type{<:ExplicitTangentBundle{N, B, T}}) where {N, B, T}
277-
Tuple{B, T.parameters...}
278-
end
279-
280-
function StructArrays.component(m::ExplicitTangentBundle{N, B, T}, i::Int) where {N, B, T}
281-
i == 1 && return m.primal
282-
return m.tangent.partials[i - 1]
283-
end
284-
285-
function StructArrays.createinstance(T::Type{<:ExplicitTangentBundle}, args...)
286-
T(first(args), Base.tail(args))
287-
end
288-
289224
function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:AbstractArray{T, Dim}}
290225
StructArray{TaylorBundle{Order, T}}((atb.primal, atb.tangent.coeffs...))
291226
end
@@ -323,14 +258,6 @@ function StructArrays.createinstance(T::Type{<:ZeroBundle}, args...)
323258
T(args[1], args[2])
324259
end
325260

326-
function rebundle(A::AbstractArray{<:ExplicitTangentBundle{N}}) where {N}
327-
ExplicitTangentBundle{N}(
328-
map(x->x.primal, A),
329-
ntuple(2^N-1) do i
330-
map(x->x.tangent.partials[i], A)
331-
end)
332-
end
333-
334261
function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N}
335262
TaylorBundle{N}(
336263
map(x->x.primal, A),

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
5050

5151
# Minimal 2-nd order forward smoke test
5252
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
53-
Diffractor.TaylorBundle{2}(1.0, (1.0 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
53+
Diffractor.TaylorBundle{2}(1.0, (1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
5454

5555
function simple_control_flow(b, x)
5656
if b

test/stage2_fwd.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ module stage2_fwd
1414

1515
self_minus(a) = myminus(a, a)
1616
let self_minus′′ = Diffractor.dontuse_nth_order_forward_stage2(Tuple{typeof(self_minus), Float64}, 2)
17-
# TODO: The IR for this currently contains Union{Diffractor.TangentBundle{2, Float64, Diffractor.ExplicitTangent{Tuple{Float64, Float64, Float64}}}, Diffractor.TangentBundle{2, Float64, Diffractor.TaylorTangent{Tuple{Float64, Float64}}}}
1817
# We should have Diffractor be able to prove uniformity
1918
@test_broken isa(self_minus′′, Core.OpaqueClosure{Tuple{Float64}, Float64})
2019
@test self_minus′′(1.0) == 0.

0 commit comments

Comments
 (0)