Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -31,13 +29,18 @@ DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
ManifoldsBoundaryValueDiffEqExt = "BoundaryValueDiffEq"
ManifoldsNLsolveExt = "NLsolve"
ManifoldsOrdinaryDiffEqDiffEqCallbacksExt = ["DiffEqCallbacks", "OrdinaryDiffEq"]
ManifoldsOrdinaryDiffEqExt = "OrdinaryDiffEq"
ManifoldsOrdinaryDiffEqDiffEqCallbacksRecursiveArrayToolsExt = ["DiffEqCallbacks", "OrdinaryDiffEq", "RecursiveArrayTools"]
ManifoldsOrdinaryDiffEqStaticArraysExt = ["OrdinaryDiffEq", "StaticArrays"]
ManifoldsRecursiveArrayToolsExt = ["RecursiveArrayTools", "StaticArrays"]
ManifoldsRecursiveArrayToolsStaticArraysExt = ["RecursiveArrayTools", "StaticArrays"]
ManifoldsStaticArraysExt = ["StaticArrays"]
ManifoldsRecipesBaseExt = ["Colors", "RecipesBase"]
ManifoldsTestExt = "Test"

Expand Down Expand Up @@ -88,4 +91,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92"

[targets]
test = ["Test", "BoundaryValueDiffEq", "Colors", "DiffEqCallbacks", "DoubleFloats", "FiniteDifferences", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PythonPlot", "Quaternions", "QuartzImageIO", "RecipesBase"]
test = ["Test", "BoundaryValueDiffEq", "Colors", "DiffEqCallbacks", "DoubleFloats", "FiniteDifferences", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PythonPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "StaticArrays"]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if isdefined(Base, :get_extension)
transition_map_diff!
import Manifolds: solve_chart_exp_ode, solve_chart_parallel_transport_ode
using ManifoldsBase

using RecursiveArrayTools: ArrayPartition
using DiffEqCallbacks
using OrdinaryDiffEq: OrdinaryDiffEq, SciMLBase, Rodas5, AutoVern9, ODEProblem, solve
else
Expand All @@ -27,6 +27,7 @@ else
using ..ManifoldsBase

using ..DiffEqCallbacks
using ..RecursiveArrayTools: ArrayPartition
using ..OrdinaryDiffEq: OrdinaryDiffEq, SciMLBase, Rodas5, AutoVern9, ODEProblem, solve
end

Expand Down
90 changes: 90 additions & 0 deletions ext/ManifoldsRecursiveArrayToolsExt/FiberBundleRATExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

@inline function allocate_result(M::FiberBundle, f::TF) where {TF}
p = allocate_result(M.manifold, f)
X = allocate_result(Fiber(M.manifold, p, M.type), f)
return ArrayPartition(p, X)
end

function get_vector(M::FiberBundle, p, X, B::AbstractBasis)
n = manifold_dimension(M.manifold)
xp1, xp2 = submanifold_components(M, p)
F = Fiber(M.manifold, xp1, M.type)
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B),
get_vector(F, xp2, X[(n + 1):end], B),
)
end
function get_vector(
M::FiberBundle,
p,
X,
B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:FiberBundleBasisData},
) where {𝔽}
n = manifold_dimension(M.manifold)
xp1, xp2 = submanifold_components(M, p)
F = Fiber(M.manifold, xp1, M.type)
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B.data.base_basis),
get_vector(F, xp2, X[(n + 1):end], B.data.fiber_basis),
)
end

function get_vectors(
M::FiberBundle,
p::ArrayPartition,
B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:FiberBundleBasisData},
) where {𝔽}
xp1, xp2 = submanifold_components(M, p)
zero_m = zero_vector(M.manifold, xp1)
F = Fiber(M.manifold, xp1, M.type)
zero_f = zero_vector(F, xp1)
vs = typeof(ArrayPartition(zero_m, zero_f))[]
for bv in get_vectors(M.manifold, xp1, B.data.base_basis)
push!(vs, ArrayPartition(bv, zero_f))
end
for bv in get_vectors(F, xp2, B.data.fiber_basis)
push!(vs, ArrayPartition(zero_m, bv))
end
return vs
end

"""
getindex(p::ArrayPartition, M::FiberBundle, s::Symbol)
p[M::FiberBundle, s]

Access the element(s) at index `s` of a point `p` on a [`FiberBundle`](@ref) `M` by
using the symbols `:point` and `:vector` or `:fiber` for the base and vector or fiber
component, respectively.
"""
@inline function getindex(p::ArrayPartition, M::FiberBundle, s::Symbol)
(s === :point) && return p.x[1]
(s === :vector || s === :fiber) && return p.x[2]
return throw(DomainError(s, "unknown component $s on $M."))
end

"""
setindex!(p::ArrayPartition, val, M::FiberBundle, s::Symbol)
p[M::VectorBundle, s] = val

Set the element(s) at index `s` of a point `p` on a [`FiberBundle`](@ref) `M` to `val` by
using the symbols `:point` and `:fiber` or `:vector` for the base and fiber or vector
component, respectively.

!!! note

The *content* of element of `p` is replaced, not the element itself.
"""
@inline function setindex!(x::ArrayPartition, val, M::FiberBundle, s::Symbol)
if s === :point
return copyto!(x.x[1], val)
elseif s === :vector || s === :fiber
return copyto!(x.x[2], val)
else
throw(DomainError(s, "unknown component $s on $M."))
end
end
@inline function view(x::ArrayPartition, M::FiberBundle, s::Symbol)
(s === :point) && return x.x[1]
(s === :vector || s === :fiber) && return x.x[2]
throw(DomainError(s, "unknown component $s on $M."))
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
module ManifoldsRecursiveArrayToolsExt

if isdefined(Base, :get_extension)
using Base: @propagate_inbounds
using Manifolds
using Manifolds: submanifold_components
using RecursiveArrayTools: ArrayPartition
import Base: getindex, setindex!, view
import Manifolds:
ProductFVectorDistribution,
adjoint_Jacobi_field,
allocate,
allocate_result,
apply,
apply!,
apply_diff,
apply_diff_group!,
compose,
_compose,
exp,
exp_lie,
get_coordinates,
get_vector,
get_vectors,
hat,
identity_element,
inverse_apply,
inverse_apply_diff,
inverse_translate,
inverse_translate_diff,
isapprox,
jacobi_field,
lie_bracket,
log,
optimal_alignment,
project,
rand,
_rand!,
translate,
translate_diff,
_vector_transport_direction,
_vector_transport_to,
vee,
else
# imports need to be relative for Requires.jl-based workflows:
# https://github.com/JuliaArrays/ArrayInterface.jl/pull/387
using ..Manifolds
using ..RecursiveArrayTools
import Base: getindex, setindex!, view
import Manifolds:
ProductFVectorDistribution,
adjoint_Jacobi_field,
allocate,
allocate_result,
apply,
apply!,
apply_diff,
apply_diff_group!,
_compose,
exp,
exp_lie,
get_vector,
get_vectors,
identity_element,
inverse_apply,
inverse_apply_diff,
inverse_translate,
inverse_translate_diff,
isapprox,
jacobi_field,
log,
optimal_alignment,
rand,
_rand!,
translate,
translate_diff,
_vector_transport_direction,
_vector_transport_to
end

include("FiberBundleRATExt.jl")
include("ProductGroupRATExt.jl")
include("ProductManifoldRATExt.jl")
include("rotation_translation_actionRATExt.jl")
include("semidirect_product_groupRATExt.jl")
include("VectorBundleRATExt.jl")
end
131 changes: 131 additions & 0 deletions ext/ManifoldsRecursiveArrayToolsExt/ProductGroupRATExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
function _compose(M::ProductManifold, p::ArrayPartition, q::ArrayPartition)
return ArrayPartition(
map(
compose,
M.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
)...,
)
end

function exp(M::ProductGroup, p::Identity{ProductOperation}, X::ArrayPartition)
return ArrayPartition(
map(
exp,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, X),
)...,
)
end

function exp_lie(G::ProductGroup, X)
M = G.manifold
return ArrayPartition(map(exp_lie, M.manifolds, submanifold_components(G, X))...)
end

Base.@propagate_inbounds function Base.getindex(
p::ArrayPartition,
M::ProductGroup,
i::Union{Integer,Colon,AbstractVector,Val},
)
return getindex(p, base_manifold(M), i)
end

function identity_element(G::ProductGroup)
M = G.manifold
return ArrayPartition(map(identity_element, M.manifolds))
end

function inverse_translate(G::ProductGroup, p, q, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
inverse_translate,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
repeated(conv),
)...,
)
end

function inverse_translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
inverse_translate_diff,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)...,
)
end

# these isapprox methods are here just to reduce time-to-first-isapprox
function isapprox(G::ProductGroup, p::ArrayPartition, q::ArrayPartition; kwargs...)
return isapprox(G.manifold, p, q; kwargs...)
end
function isapprox(
G::ProductGroup,
p::ArrayPartition,
X::ArrayPartition,
Y::ArrayPartition;
kwargs...,
)
return isapprox(G.manifold, p, X, Y; kwargs...)
end

function Base.log(M::ProductGroup, p::Identity{ProductOperation}, q::ArrayPartition)
return ArrayPartition(
map(
log,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
)...,
)
end

Base.@propagate_inbounds function Base.setindex!(
q::ArrayPartition,
p,
M::ProductGroup,
i::Union{Integer,Colon,AbstractVector,Val},
)
return setindex!(q, p, base_manifold(M), i)
end

function translate(
M::ProductGroup,
p::ArrayPartition,
q::ArrayPartition,
conv::ActionDirectionAndSide,
)
return ArrayPartition(
map(
translate,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
repeated(conv),
)...,
)
end

function translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
translate_diff,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)...,
)
end
Loading