Skip to content

guard against trait type piracy in a dependent package #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ abstract type TreeIterator{T} end
_iterator_eltype(::NodeTypeUnknown) = EltypeUnknown()
_iterator_eltype(::HasNodeType) = HasEltype()

Base.IteratorEltype(::Type{<:TreeIterator{Union{}}}) = throw(not_supported_exc)
Base.IteratorEltype(::Type{<:TreeIterator{T}}) where {T} = _iterator_eltype(NodeType(T))

Base.eltype(::Type{<:TreeIterator{Union{}}}) = throw(not_supported_exc)
Base.eltype(::Type{<:TreeIterator{T}}) where {T} = nodetype(T)
Base.eltype(ti::TreeIterator) = eltype(typeof(ti))

Expand Down
9 changes: 9 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
const not_supported_exc = ArgumentError("not supported")

"""
ParentLinks(::Type{T})
Expand Down Expand Up @@ -42,6 +43,7 @@ the tree structure and cannot be inferred through a single node.
"""
struct ImplicitParents <: ParentLinks; end

ParentLinks(::Type{Union{}}) = throw(not_supported_exc)
ParentLinks(::Type) = ImplicitParents()
ParentLinks(tree) = ParentLinks(typeof(tree))

Expand Down Expand Up @@ -84,6 +86,7 @@ from the tree structure.
"""
struct ImplicitSiblings <: SiblingLinks; end

SiblingLinks(::Type{Union{}}) = throw(not_supported_exc)
SiblingLinks(::Type) = ImplicitSiblings()
SiblingLinks(tree) = SiblingLinks(typeof(tree))

Expand Down Expand Up @@ -126,6 +129,7 @@ class of indexable trees consisting of arrays.
"""
struct NonIndexedChildren <: ChildIndexing end

ChildIndexing(::Type{Union{}}) = throw(not_supported_exc)
ChildIndexing(::Type) = NonIndexedChildren()
ChildIndexing(node) = ChildIndexing(typeof(node))

Expand All @@ -143,6 +147,7 @@ If the `childrentype` can be inferred from the type of the node alone, the type
**OPTIONAL**: In most cases, [`childtype`](@ref) is used instead. If `childtype` is not defined it will fall back
to `eltype ∘ childrentype`.
"""
childrentype(::Type{Union{}}) = throw(not_supported_exc)
childrentype(::Type{T}) where {T} = Base._return_type(children, Tuple{T})
childrentype(node) = typeof(children(node))

Expand All @@ -159,6 +164,7 @@ If `childtype` can be inferred from the type of the node alone, the type `::Type
can be type-stable. If `childrentype` is defined and can be known from the node type alone, this function will
fall back to `eltype(childrentype(T))`. If this gives a correct result it's not necessary to define `childtype`.
"""
childtype(::Type{Union{}}) = throw(not_supported_exc)
childtype(::Type{T}) where {T} = eltype(childrentype(T))
childtype(node) = eltype(childrentype(node))

Expand All @@ -172,6 +178,7 @@ traversal is type stable.

**OPTIONAL**: Type inference is used to attempt to
"""
childstatetype(::Type{Union{}}) = throw(not_supported_exc)
childstatetype(::Type{T}) where {T} = Iterators.approx_iter_type(childrentype(T))
childstatetype(node) = childstatetype(typeof(node))

Expand Down Expand Up @@ -204,6 +211,7 @@ type.
"""
struct NodeTypeUnknown <: NodeType end

NodeType(::Type{Union{}}) = throw(not_supported_exc)
NodeType(::Type) = NodeTypeUnknown()
NodeType(node) = NodeType(typeof(node))

Expand All @@ -214,5 +222,6 @@ NodeType(node) = NodeType(typeof(node))
Returns a type which must be a parent type of all nodes in the tree connected to `node`. This can be used to,
for example, specify the `eltype` of any `TreeIterator` on `node`.
"""
nodetype(::Type{Union{}}) = throw(not_supported_exc)
nodetype(::Type) = Any
nodetype(node) = nodetype(typeof(node))
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using AbstractTrees, Test
using Aqua

if Base.VERSION >= v"1.9"
# tests use `parentmodule(::Method)`, only supported on v1.9 and up
@testset "Traits" begin include("traits.jl") end
end
@testset "Builtins" begin include("builtins.jl") end
@testset "Custom tree types" begin include("trees.jl") end
if Base.VERSION >= v"1.6"
Expand Down
75 changes: 75 additions & 0 deletions test/traits.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
module TestTraits

using AbstractTrees
using Test

function is_owned_by(m::Module, n::Module)
ret = false
while m != Main
if m == n
ret = true
break
end
m = parentmodule(m)
end
ret
end

function is_owned_by(m::Method, n::Module)
is_owned_by(parentmodule(m), n)
end

function we_own_the_method(m::Method)
is_owned_by(m, AbstractTrees)
end

const traits = (
ParentLinks, SiblingLinks, ChildIndexing, childrentype, childtype, AbstractTrees.childstatetype, NodeType, nodetype,
)

const base_traits = (
eltype, Base.IteratorEltype,
)

struct T end

for func ∈ traits
f = nameof(func)
@eval begin
function AbstractTrees.$f(::Type{<:T})
# This method should not ever get called, it just serves to test dispatch/type piracy.
throw(ArgumentError("this is not the method you're looking for"))
end
end
end

for func ∈ base_traits
f = nameof(func)
@eval begin
function Base.$f(::Type{<:AbstractTrees.TreeIterator{<:T}})
# This method should not ever get called, it just serves to test dispatch/type piracy.
throw(ArgumentError("this is not the method you're looking for"))
end
end
end

@testset "Traits" begin
@testset "traits should not make dependents vulnerable to commiting type piracy" begin
@testset "AbstractTrees traits" begin
@testset "func: $func" for func ∈ traits
arg = Union{}
@test_throws Exception func(arg)
@test all(we_own_the_method, methods(func, Tuple{Type{arg}}))
end
end
@testset "Base traits" begin
@testset "func: $func" for func ∈ base_traits
arg = AbstractTrees.TreeIterator{Union{}}
@test_throws Exception func(arg)
@test all(we_own_the_method, methods(func, Tuple{Type{arg}}))
end
end
end
end

end
Loading