Skip to content

Commit 6180ca0

Browse files
authored
De-duplicate edges in typeinfer instead of gf.c (#58117)
Without this PR, my system spends ~1.07 seconds just running `store_backedges` when doing `using CairoMakie` With this change, that drops to `0.641 seconds` That's still not fast enough for me, but we do call this function 236,092 times so maybe it's understandable.
2 parents 5af63ff + 7525568 commit 6180ca0

File tree

4 files changed

+74
-68
lines changed

4 files changed

+74
-68
lines changed

Compiler/src/typeinfer.jl

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -687,12 +687,17 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter, cycleid::
687687
nothing
688688
end
689689

690-
# record the backedges
691-
function store_backedges(caller::CodeInstance, edges::SimpleVector)
692-
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
693-
i = 1
694-
while true
695-
i > length(edges) && return nothing
690+
# Iterate a series of back-edges that need registering, based on the provided forward edge list.
691+
# Back-edges are returned as (invokesig, item), where the item is a Binding, MethodInstance, or
692+
# MethodTable.
693+
struct ForwardToBackedgeIterator
694+
forward_edges::SimpleVector
695+
end
696+
697+
function Base.iterate(it::ForwardToBackedgeIterator, i::Int = 1)
698+
edges = it.forward_edges
699+
i > length(edges) && return nothing
700+
while i length(edges)
696701
item = edges[i]
697702
if item isa Int
698703
i += 2
@@ -702,34 +707,55 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
702707
i += 1
703708
continue
704709
elseif isa(item, Core.Binding)
705-
i += 1
706-
maybe_add_binding_backedge!(item, caller)
707-
continue
710+
return ((nothing, item), i + 1)
708711
end
709712
if isa(item, CodeInstance)
710-
item = item.def
711-
end
712-
if isa(item, MethodInstance) # regular dispatch
713-
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), item, nothing, caller)
714-
i += 1
713+
item = get_ci_mi(item)
714+
return ((nothing, item), i + 1)
715+
elseif isa(item, MethodInstance) # regular dispatch
716+
return ((nothing, item), i + 1)
715717
else
718+
invokesig = item
716719
callee = edges[i+1]
717-
if isa(callee, MethodTable) # abstract dispatch (legacy style edges)
718-
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
719-
i += 2
720-
continue
721-
elseif isa(callee, Method)
722-
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
723-
i += 2
724-
continue
725-
# `invoke` edge
726-
elseif isa(callee, CodeInstance)
727-
callee = get_ci_mi(callee)
720+
isa(callee, Method) && (i += 2; continue) # ignore `Method`-edges (from e.g. failed `abstract_call_method`)
721+
if isa(callee, MethodTable)
722+
# abstract dispatch (legacy style edges)
723+
return ((invokesig, callee), i + 2)
728724
else
729-
callee = callee::MethodInstance
725+
# `invoke` edge
726+
callee = isa(callee, CodeInstance) ? get_ci_mi(callee) : callee::MethodInstance
727+
return ((invokesig, callee), i + 2)
728+
end
729+
end
730+
end
731+
return nothing
732+
end
733+
734+
# record the backedges
735+
function store_backedges(caller::CodeInstance, edges::SimpleVector)
736+
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
737+
738+
backedges = ForwardToBackedgeIterator(edges)
739+
for (i, (invokesig, item)) in enumerate(backedges)
740+
# check for any duplicate edges we've already registered
741+
duplicate_found = false
742+
for (i′, (invokesig′, item′)) in enumerate(backedges)
743+
i == i′ && break
744+
if item′ === item && invokesig′ == invokesig
745+
duplicate_found = true
746+
break
747+
end
748+
end
749+
750+
if !duplicate_found
751+
if item isa Core.Binding
752+
maybe_add_binding_backedge!(item, caller)
753+
elseif item isa MethodTable
754+
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), item, invokesig, caller)
755+
else
756+
item::MethodInstance
757+
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), item, invokesig, caller)
730758
end
731-
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
732-
i += 2
733759
end
734760
end
735761
nothing

Compiler/test/invalidation.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,25 +177,24 @@ begin
177177
end
178178

179179
# Verify that adding the backedge again does not actually add a new backedge
180-
let mi1 = Base.method_instance(deduped_caller1, (Int,)),
181-
mi2 = Base.method_instance(deduped_caller2, (Int,)),
182-
ci1 = mi1.cache
183-
ci2 = mi2.cache
180+
let mi = Base.method_instance(deduped_caller1, (Int,)),
181+
ci = mi.cache
184182

185183
callee_mi = Base.method_instance(deduped_callee, (Int,))
186184

187185
# Inference should have added the callers to the callee's backedges
188-
@test ci1 in callee_mi.backedges
189-
@test ci2 in callee_mi.backedges
186+
@test ci in callee_mi.backedges
190187

188+
# In practice, inference will never end up calling `store_backedges`
189+
# twice on the same CodeInstance like this - we only need to check
190+
# that de-duplication works for a single invocation
191191
N = length(callee_mi.backedges)
192-
Core.Compiler.store_backedges(ci1, Core.svec(callee_mi))
193-
Core.Compiler.store_backedges(ci2, Core.svec(callee_mi))
192+
Core.Compiler.store_backedges(ci, Core.svec(callee_mi, callee_mi))
194193
N′ = length(callee_mi.backedges)
195194

196-
# The number of backedges should not be affected by an additional store,
197-
# since de-duplication should have noticed the edge is already tracked
198-
@test N == N′
195+
# A single `store_backedges` invocation should de-duplicate any of the
196+
# edges it is adding.
197+
@test N- N == 1
199198
end
200199
end
201200

src/gf.c

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,6 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
19971997
assert(invokesig == NULL || jl_is_type(invokesig));
19981998
JL_LOCK(&callee->def.method->writelock);
19991999
if (jl_atomic_load_relaxed(&allow_new_worlds)) {
2000-
int found = 0;
20012000
jl_array_t *backedges = jl_mi_get_backedges(callee);
20022001
// TODO: use jl_cache_type_(invokesig) like cache_method does to save memory
20032002
if (!backedges) {
@@ -2006,25 +2005,7 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
20062005
callee->backedges = backedges;
20072006
jl_gc_wb(callee, backedges);
20082007
}
2009-
else {
2010-
size_t i = 0, l = jl_array_nrows(backedges);
2011-
for (i = 0; i < l; i++) {
2012-
// optimized version of while (i < l) i = get_next_edge(callee->backedges, i, &invokeTypes, &mi);
2013-
jl_value_t *ciedge = jl_array_ptr_ref(backedges, i);
2014-
if (ciedge != (jl_value_t*)caller)
2015-
continue;
2016-
jl_value_t *invokeTypes = i > 0 ? jl_array_ptr_ref(backedges, i - 1) : NULL;
2017-
if (invokeTypes && jl_is_code_instance(invokeTypes))
2018-
invokeTypes = NULL;
2019-
if ((invokesig == NULL && invokeTypes == NULL) ||
2020-
(invokesig && invokeTypes && jl_types_equal(invokesig, invokeTypes))) {
2021-
found = 1;
2022-
break;
2023-
}
2024-
}
2025-
}
2026-
if (!found)
2027-
push_edge(backedges, invokesig, caller);
2008+
push_edge(backedges, invokesig, caller);
20282009
}
20292010
JL_UNLOCK(&callee->def.method->writelock);
20302011
}
@@ -2047,14 +2028,6 @@ JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *t
20472028
else {
20482029
// check if the edge is already present and avoid adding a duplicate
20492030
size_t i, l = jl_array_nrows(mt->backedges);
2050-
for (i = 1; i < l; i += 2) {
2051-
if (jl_array_ptr_ref(mt->backedges, i) == (jl_value_t*)caller) {
2052-
if (jl_types_equal(jl_array_ptr_ref(mt->backedges, i - 1), typ)) {
2053-
JL_UNLOCK(&mt->writelock);
2054-
return;
2055-
}
2056-
}
2057-
}
20582031
// reuse an already cached instance of this type, if possible
20592032
// TODO: use jl_cache_type_(tt) like cache_method does, instead of this linear scan?
20602033
for (i = 1; i < l; i += 2) {

src/staticdata.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,12 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
785785
}
786786
goto done_fields; // for now
787787
}
788+
if (s->incremental && jl_is_mtable(v)) {
789+
jl_methtable_t *mt = (jl_methtable_t *)v;
790+
// Any back-edges will be re-validated and added by staticdata.jl, so
791+
// drop them from the image here
792+
record_field_change((jl_value_t**)&mt->backedges, NULL);
793+
}
788794
if (jl_is_method_instance(v)) {
789795
jl_method_instance_t *mi = (jl_method_instance_t*)v;
790796
if (s->incremental) {
@@ -800,12 +806,14 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
800806
// we only need 3 specific fields of this (the rest are restored afterward, if valid)
801807
// in particular, cache is repopulated by jl_mi_cache_insert for all foreign function,
802808
// so must not be present here
803-
record_field_change((jl_value_t**)&mi->backedges, NULL);
804809
record_field_change((jl_value_t**)&mi->cache, NULL);
805810
}
806811
else {
807812
assert(!needs_recaching(v, s->query_cache));
808813
}
814+
// Any back-edges will be re-validated and added by staticdata.jl, so
815+
// drop them from the image here
816+
record_field_change((jl_value_t**)&mi->backedges, NULL);
809817
// n.b. opaque closures cannot be inspected and relied upon like a
810818
// normal method since they can get improperly introduced by generated
811819
// functions, so if they appeared at all, we will probably serialize

0 commit comments

Comments
 (0)