Skip to content

Commit ecd0e9b

Browse files
committed
guard against trait type piracy in a dependent package
Prevent dependent packages from commiting type piracy easily and unintentionally when defining an AbstractTrees trait for all subtypes of a type. Specifically: the bottom type, `Union{}`, is a subtype of each type, so add a method for the bottom type for each AbstractTrees trait. This method will take precedence over sane methods in dependent packages, thus preventing spurious type piracy.
1 parent 4fe5615 commit ecd0e9b

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

src/iteration.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ abstract type TreeIterator{T} end
6363
_iterator_eltype(::NodeTypeUnknown) = EltypeUnknown()
6464
_iterator_eltype(::HasNodeType) = HasEltype()
6565

66+
Base.IteratorEltype(::Type{<:TreeIterator{Union{}}}) = throw_not_supported_exc() # prevent type piracy in dependent package
6667
Base.IteratorEltype(::Type{<:TreeIterator{T}}) where {T} = _iterator_eltype(NodeType(T))
6768

69+
Base.eltype(::Type{<:TreeIterator{Union{}}}) = throw_not_supported_exc() # prevent type piracy in dependent package
6870
Base.eltype(::Type{<:TreeIterator{T}}) where {T} = nodetype(T)
6971
Base.eltype(ti::TreeIterator) = eltype(typeof(ti))
7072

src/traits.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
function throw_not_supported_exc()
2+
@noinline
3+
throw(ArgumentError("not supported"))
4+
end
15

26
"""
37
ParentLinks(::Type{T})
@@ -42,6 +46,7 @@ the tree structure and cannot be inferred through a single node.
4246
"""
4347
struct ImplicitParents <: ParentLinks; end
4448

49+
ParentLinks(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
4550
ParentLinks(::Type) = ImplicitParents()
4651
ParentLinks(tree) = ParentLinks(typeof(tree))
4752

@@ -84,6 +89,7 @@ from the tree structure.
8489
"""
8590
struct ImplicitSiblings <: SiblingLinks; end
8691

92+
SiblingLinks(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
8793
SiblingLinks(::Type) = ImplicitSiblings()
8894
SiblingLinks(tree) = SiblingLinks(typeof(tree))
8995

@@ -126,6 +132,7 @@ class of indexable trees consisting of arrays.
126132
"""
127133
struct NonIndexedChildren <: ChildIndexing end
128134

135+
ChildIndexing(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
129136
ChildIndexing(::Type) = NonIndexedChildren()
130137
ChildIndexing(node) = ChildIndexing(typeof(node))
131138

@@ -143,6 +150,7 @@ If the `childrentype` can be inferred from the type of the node alone, the type
143150
**OPTIONAL**: In most cases, [`childtype`](@ref) is used instead. If `childtype` is not defined it will fall back
144151
to `eltype ∘ childrentype`.
145152
"""
153+
childrentype(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
146154
childrentype(::Type{T}) where {T} = Base._return_type(children, Tuple{T})
147155
childrentype(node) = typeof(children(node))
148156

@@ -159,6 +167,7 @@ If `childtype` can be inferred from the type of the node alone, the type `::Type
159167
can be type-stable. If `childrentype` is defined and can be known from the node type alone, this function will
160168
fall back to `eltype(childrentype(T))`. If this gives a correct result it's not necessary to define `childtype`.
161169
"""
170+
childtype(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
162171
childtype(::Type{T}) where {T} = eltype(childrentype(T))
163172
childtype(node) = eltype(childrentype(node))
164173

@@ -172,6 +181,7 @@ traversal is type stable.
172181
173182
**OPTIONAL**: Type inference is used to attempt to
174183
"""
184+
childstatetype(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
175185
childstatetype(::Type{T}) where {T} = Iterators.approx_iter_type(childrentype(T))
176186
childstatetype(node) = childstatetype(typeof(node))
177187

@@ -204,6 +214,7 @@ type.
204214
"""
205215
struct NodeTypeUnknown <: NodeType end
206216

217+
NodeType(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
207218
NodeType(::Type) = NodeTypeUnknown()
208219
NodeType(node) = NodeType(typeof(node))
209220

@@ -214,5 +225,6 @@ NodeType(node) = NodeType(typeof(node))
214225
Returns a type which must be a parent type of all nodes in the tree connected to `node`. This can be used to,
215226
for example, specify the `eltype` of any `TreeIterator` on `node`.
216227
"""
228+
nodetype(::Type{Union{}}) = throw_not_supported_exc() # prevent type piracy in dependent package
217229
nodetype(::Type) = Any
218230
nodetype(node) = nodetype(typeof(node))

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using AbstractTrees, Test
22
using Aqua
33

4+
if applicable(parentmodule, which(sin, Tuple{Float64})::Method)
5+
# tests use `parentmodule(::Method)`, only supported on v1.10 and up
6+
@testset "Traits" begin include("traits.jl") end
7+
end
48
@testset "Builtins" begin include("builtins.jl") end
59
@testset "Custom tree types" begin include("trees.jl") end
610
if Base.VERSION >= v"1.6"

test/traits.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module TestTraits
2+
3+
using AbstractTrees
4+
using Test
5+
6+
function is_owned_by(m::Module, n::Module)
7+
ret = false
8+
while m != Main
9+
if m == n
10+
ret = true
11+
break
12+
end
13+
m = parentmodule(m)
14+
end
15+
ret
16+
end
17+
18+
function is_owned_by(m::Method, n::Module)
19+
is_owned_by(parentmodule(m), n)
20+
end
21+
22+
function we_own_the_method(m::Method)
23+
is_owned_by(m, AbstractTrees)
24+
end
25+
26+
const traits = (
27+
ParentLinks, SiblingLinks, ChildIndexing, childrentype, childtype, AbstractTrees.childstatetype, NodeType, nodetype,
28+
)
29+
30+
const base_traits = (
31+
eltype, Base.IteratorEltype,
32+
)
33+
34+
struct T end
35+
36+
for func traits
37+
f = nameof(func)
38+
@eval begin
39+
function AbstractTrees.$f(::Type{<:T})
40+
# This method should not ever get called, it just serves to test dispatch/type piracy.
41+
throw(ArgumentError("this is not the method you're looking for"))
42+
end
43+
end
44+
end
45+
46+
for func base_traits
47+
f = nameof(func)
48+
@eval begin
49+
function Base.$f(::Type{<:AbstractTrees.TreeIterator{<:T}})
50+
# This method should not ever get called, it just serves to test dispatch/type piracy.
51+
throw(ArgumentError("this is not the method you're looking for"))
52+
end
53+
end
54+
end
55+
56+
@testset "Traits" begin
57+
@testset "traits should not make dependents vulnerable to commiting type piracy" begin
58+
@testset "AbstractTrees traits" begin
59+
@testset "func: $func" for func traits
60+
arg = Union{}
61+
@test_throws Exception func(arg)
62+
@test all(we_own_the_method, methods(func, Tuple{Type{arg}}))
63+
end
64+
end
65+
@testset "Base traits" begin
66+
@testset "func: $func" for func base_traits
67+
arg = AbstractTrees.TreeIterator{Union{}}
68+
@test_throws Exception func(arg)
69+
@test all(we_own_the_method, methods(func, Tuple{Type{arg}}))
70+
end
71+
end
72+
end
73+
end
74+
75+
end

0 commit comments

Comments
 (0)