-
Notifications
You must be signed in to change notification settings - Fork 32
Remove Composite Bundle #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c0641c
be091aa
c205e34
cf137a0
13762e7
f4f996d
50383b6
bfd99c2
121ec20
b257eeb
b9f74ce
34cedf8
9794666
ba5841e
40da34b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) | |
partial(x::UniformTangent, i) = getfield(x, :val) | ||
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) | ||
partial(x::AbstractZero, i) = x | ||
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...) | ||
function partial(x::CompositeBundle{N, B}, i) where {N, B} | ||
# This is tangent for a struct, but fields partials are each stored in a plain tuple | ||
# so we add the names back using the primal `B` | ||
# TODO: If required this can be done as a `@generated` function so it is type-stable | ||
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup))) | ||
return Tangent{B, typeof(backing)}(backing) | ||
end | ||
|
||
|
||
primal(x::AbstractTangentBundle) = x.primal | ||
|
@@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B} | |
ntuple(_sdown, N-1)) | ||
end | ||
|
||
function shuffle_down(b::CompositeBundle{N, B}) where {N, B} | ||
z = CompositeBundle{N-1, CompositeBundle{1, B}}( | ||
(CompositeBundle{N-1, Tuple}( | ||
map(shuffle_down, b.tup) | ||
),) | ||
) | ||
z | ||
end | ||
|
||
function shuffle_up(r::CompositeBundle{1}) | ||
z₀ = primal(r.tup[1]) | ||
z₁ = partial(r.tup[1], 1) | ||
z₂ = primal(r.tup[2]) | ||
z₁₂ = partial(r.tup[2], 1) | ||
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} | ||
z₀ = primal(r)[1] | ||
z₁ = partial(r, 1)[1] | ||
z₂ = primal(r)[2] | ||
z₁₂ = partial(r, 1)[2] | ||
if z₁ == z₂ | ||
return TaylorBundle{2}(z₀, (z₁, z₁₂)) | ||
else | ||
|
@@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} | |
end | ||
end | ||
|
||
# Check whether the tangent bundle element is taylor-like | ||
isswifty(::TaylorBundle) = true | ||
isswifty(::UniformBundle) = true | ||
isswifty(b::CompositeBundle) = all(isswifty, b.tup) | ||
isswifty(::Any) = false | ||
|
||
function shuffle_up(r::CompositeBundle{N}) where {N} | ||
a, b = r.tup | ||
if isswifty(a) && isswifty(b) && taylor_compatible(a, b) | ||
return TaylorBundle{N+1}(primal(a), | ||
ntuple(i->i == N+1 ? | ||
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)], | ||
N+1)) | ||
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} | ||
partial(r, 1)[1] == primal(r)[2] || return false | ||
return all(1:N-1) do i | ||
partial(r, i+1)[1] == partial(r, i)[2] | ||
end | ||
end | ||
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2} | ||
the_primal = primal(r)[1] | ||
if taylor_compatible(r) | ||
the_partials = ntuple(N+1) do i | ||
if i <= N | ||
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2]) | ||
else # ii = N+1 | ||
partial(r, i-1)[2] | ||
end | ||
end | ||
return TaylorBundle{N+1}(the_primal, the_partials) | ||
else | ||
return TangentBundle{N+1}(r.tup[1].primal, | ||
(r.tup[1].tangent.partials..., primal(b), | ||
ntuple(i->partial(b,i), 1<<(N+1)-1)...)) | ||
#XXX: am dubious of the correctness of this | ||
a_partials = ntuple(i->partial(r, i)[1], N) | ||
b_partials = ntuple(i->partial(r, i)[2], N) | ||
the_partials = (a_partials..., primal_b, b_partials...) | ||
return TangentBundle{N+1}(the_primal, the_partials) | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reference as a part way step to determining this , I first refactored the CompositeBundle version into function shuffle_up(r::CompositeBundle{N}) where {N}
a, b = r.tup
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
the_partials = ntuple(N+1) do i
if ii <= N
a[TaylorTangentIndex(i)] # == b[TaylorTangentIndex(i-1)] (except first which is b.primal)
else # ii = N+1
b[TaylorTangentIndex(i-1)]
end
end
return TaylorBundle{N+1}(primal(a), the_partials)
else
the_primal = r.tup[1].primal
a_partials = r.tup[1].tangent.partials
b_partials = ntuple(i->partial(b,i), 1<<(N+1)-1)
the_partials = (a_partials..., primal_b, b_partials...)
return TangentBundle{N+1}(the_primal, the_partials)
end
end |
||
|
||
|
||
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} | ||
(a, b) = primal(r) | ||
if r.tangent.val === b | ||
|
@@ -185,13 +176,6 @@ end | |
map(y->lifted_getfield(y, s), x.tangent.coeffs)) | ||
end | ||
|
||
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} | ||
x.tup[primal(s)] | ||
end | ||
|
||
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} | ||
x.tup[Base.fieldindex(B, primal(s))] | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be needed as we already have cases for |
||
|
||
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} | ||
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val) | ||
|
@@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} | |
end | ||
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) | ||
|
||
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N} | ||
∂vararg{N}()(map(FwdMap(f), tup.tup)...) | ||
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N} | ||
∂vararg{N}()(map(FwdMap(f), destructure(tup))...) | ||
end | ||
|
||
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} | ||
|
@@ -254,35 +238,37 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate | |
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N} | ||
r = iterate(t.tup) | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N} | ||
r = iterate(destructure(t)) | ||
r === nothing && return ZeroBundle{N}(nothing) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} | ||
r = iterate(t.tup, primal(a), map(primal, args)...) | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} | ||
r = iterate(destructure(t), primal(a), map(primal, args)...) | ||
r === nothing && return ZeroBundle{N}(nothing) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N} | ||
r = Base.indexed_iterate(t.tup, primal(i)) | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N} | ||
r = Base.indexed_iterate(destructure(t), primal(i)) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} | ||
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...) | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} | ||
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N} | ||
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) | ||
end | ||
|
||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N} | ||
t.tup[primal(i)] | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N} | ||
field_ind = primal(i) | ||
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N) | ||
TaylorBundle{N}(primal(t)[field_ind], the_partials) | ||
end | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,9 +208,7 @@ end | |
|
||
function check_taylor_invariants(coeffs, primal, N) | ||
@assert length(coeffs) == N | ||
if isa(primal, TangentBundle) | ||
@assert isa(coeffs[1], TangentBundle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check is just wrong AFAICT. But this was never hit before because we were not making |
||
end | ||
|
||
end | ||
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) | ||
|
||
|
@@ -230,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex) | |
tb.tangent.coeffs[count_ones(tti.i)] | ||
end | ||
|
||
"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple" | ||
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple} | ||
return ntuple(fieldcount(B)) do field_ii | ||
the_primal = primal(r)[field_ii] | ||
the_partials = ntuple(N) do order_ii | ||
partial(r, order_ii)[field_ii] | ||
end | ||
return TaylorBundle{N}(the_primal, the_partials) | ||
end | ||
end | ||
Comment on lines
+231
to
+240
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this, or can we just use |
||
|
||
|
||
function truncate(tt::TaylorTangent, order::Val{N}) where {N} | ||
TaylorTangent(tt.coeffs[1:N]) | ||
end | ||
|
@@ -290,33 +300,6 @@ end | |
|
||
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val | ||
|
||
""" | ||
CompositeBundle{N, B, B <: Tuple} | ||
|
||
Represents the tagent bundle where the base space is some tuple or struct type. | ||
Mathematically, this tangent bundle is the product bundle of the individual | ||
element bundles. | ||
""" | ||
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B} | ||
tup::T | ||
end | ||
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup) | ||
|
||
function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B} | ||
B <: SArray && error() | ||
return partial(tb, tti.i) | ||
end | ||
|
||
primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup) | ||
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle | ||
T(map(primal, b.tup)...) | ||
end | ||
@generated primal(b::CompositeBundle{N, B} where N) where {B} = | ||
quote | ||
x = map(primal, b.tup) | ||
$(Expr(:splatnew, B, :x)) | ||
end | ||
|
||
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) | ||
expand_singleton_to_array(asize, a::AbstractArray) = a | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.