Skip to content

Commit e73dd0c

Browse files
gldubcjosevalim
authored andcommitted
Perf optimizations and inferred intersections (#14605)
1 parent f4037a3 commit e73dd0c

File tree

6 files changed

+181
-74
lines changed

6 files changed

+181
-74
lines changed

lib/elixir/lib/module/types/apply.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ defmodule Module.Types.Apply do
487487
{union(type, fun_from_non_overlapping_clauses(clauses)), fallback?, context}
488488

489489
{{:infer, _, clauses}, context} when length(clauses) <= @max_clauses ->
490-
{union(type, fun_from_overlapping_clauses(clauses)), fallback?, context}
490+
{union(type, fun_from_inferred_clauses(clauses)), fallback?, context}
491491

492492
{_, context} ->
493493
{type, true, context}
@@ -705,7 +705,7 @@ defmodule Module.Types.Apply do
705705
result =
706706
case info do
707707
{:infer, _, clauses} when length(clauses) <= @max_clauses ->
708-
fun_from_overlapping_clauses(clauses)
708+
fun_from_inferred_clauses(clauses)
709709

710710
_ ->
711711
dynamic(fun(arity))

lib/elixir/lib/module/types/descr.ex

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ defmodule Module.Types.Descr do
4646
@not_non_empty_list Map.delete(@term, :list)
4747
@not_list Map.replace!(@not_non_empty_list, :bitmap, @bit_top - @bit_empty_list)
4848

49-
@empty_intersection [0, @none, []]
50-
@empty_difference [0, []]
49+
@empty_intersection [0, @none, [], :fun_bottom]
50+
@empty_difference [0, [], :fun_bottom]
5151

5252
defguard is_descr(descr) when is_map(descr) or descr == :term
5353

@@ -135,16 +135,17 @@ defmodule Module.Types.Descr do
135135
@doc """
136136
Creates a function from overlapping function clauses.
137137
"""
138-
def fun_from_overlapping_clauses(args_clauses) do
138+
def fun_from_inferred_clauses(args_clauses) do
139139
domain_clauses =
140140
Enum.reduce(args_clauses, [], fn {args, return}, acc ->
141-
pivot_overlapping_clause(args_to_domain(args), return, acc)
141+
domain = args |> Enum.map(&upper_bound/1) |> args_to_domain()
142+
pivot_overlapping_clause(domain, upper_bound(return), acc)
142143
end)
143144

144145
funs =
145146
for {domain, return} <- domain_clauses,
146147
args <- domain_to_args(domain),
147-
do: fun(args, return)
148+
do: fun(args, dynamic(return))
148149

149150
Enum.reduce(funs, &intersection/2)
150151
end
@@ -198,19 +199,19 @@ defmodule Module.Types.Descr do
198199
def domain_to_args(descr) do
199200
case :maps.take(:dynamic, descr) do
200201
:error ->
201-
tuple_elim_negations_static(descr, &Function.identity/1)
202+
unwrap_domain_tuple(descr, fn {:closed, elems} -> elems end)
202203

203204
{dynamic, static} ->
204-
tuple_elim_negations_static(static, &Function.identity/1) ++
205-
tuple_elim_negations_static(dynamic, fn elems -> Enum.map(elems, &dynamic/1) end)
205+
unwrap_domain_tuple(static, fn {:closed, elems} -> elems end) ++
206+
unwrap_domain_tuple(dynamic, fn {:closed, elems} -> Enum.map(elems, &dynamic/1) end)
206207
end
207208
end
208209

209-
defp tuple_elim_negations_static(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
210-
Enum.map(dnf, fn {:closed, elements} -> transform.(elements) end)
210+
defp unwrap_domain_tuple(%{tuple: dnf} = descr, transform) when map_size(descr) == 1 do
211+
Enum.map(dnf, transform)
211212
end
212213

213-
defp tuple_elim_negations_static(descr, _transform) when descr == %{}, do: []
214+
defp unwrap_domain_tuple(descr, _transform) when descr == %{}, do: []
214215

215216
defp domain_to_flat_args(domain, arity) do
216217
case domain_to_args(domain) do
@@ -1173,6 +1174,7 @@ defmodule Module.Types.Descr do
11731174

11741175
static_arrows == [] ->
11751176
# TODO: We need to validate this within the theory
1177+
arguments = Enum.map(arguments, &upper_bound/1)
11761178
{:ok, dynamic(fun_apply_static(arguments, dynamic_arrows, false))}
11771179

11781180
true ->
@@ -1327,9 +1329,9 @@ defmodule Module.Types.Descr do
13271329
if subtype?(rets_reached, result), do: result, else: union(result, rets_reached)
13281330
end
13291331

1330-
defp aux_apply(result, input, returns_reached, [{dom, ret} | arrow_intersections]) do
1332+
defp aux_apply(result, input, returns_reached, [{args, ret} | arrow_intersections]) do
13311333
# Calculate the part of the input not covered by this arrow's domain
1332-
dom_subtract = difference(input, args_to_domain(dom))
1334+
dom_subtract = difference(input, args_to_domain(args))
13331335

13341336
# Refine the return type by intersecting with this arrow's return type
13351337
ret_refine = intersection(returns_reached, ret)
@@ -1426,7 +1428,7 @@ defmodule Module.Types.Descr do
14261428
# determines emptiness.
14271429
length(neg_arguments) == positive_arity and
14281430
subtype?(args_to_domain(neg_arguments), positive_domain) and
1429-
phi_starter(neg_arguments, negation(neg_return), positives)
1431+
phi_starter(neg_arguments, neg_return, positives)
14301432
end)
14311433
end
14321434
end
@@ -1464,27 +1466,75 @@ defmodule Module.Types.Descr do
14641466
#
14651467
# See [Castagna and Lanvin (2024)](https://arxiv.org/abs/2408.14345), Theorem 4.2.
14661468
defp phi_starter(arguments, return, positives) do
1467-
n = length(arguments)
1468-
# Arity mismatch: if there is one positive function with a different arity,
1469-
# then it cannot be a subtype of the (arguments->type) functions.
1470-
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
1471-
false
1469+
# Optimization: When all positive functions have non-empty domains,
1470+
# we can simplify the phi function check to a direct subtyping test.
1471+
# This avoids the expensive recursive phi computation by checking only that applying the
1472+
# input to the positive intersection yields a subtype of the return
1473+
if all_non_empty_domains?([{arguments, return} | positives]) do
1474+
fun_apply_static(arguments, [positives], false)
1475+
|> subtype?(return)
14721476
else
1473-
arguments = Enum.map(arguments, &{false, &1})
1474-
phi(arguments, {false, return}, positives)
1477+
n = length(arguments)
1478+
# Arity mismatch: functions with different arities cannot be subtypes
1479+
# of the target function type (arguments -> return)
1480+
if Enum.any?(positives, fn {args, _ret} -> length(args) != n end) do
1481+
false
1482+
else
1483+
# Initialize memoization cache for the recursive phi computation
1484+
arguments = Enum.map(arguments, &{false, &1})
1485+
{result, _cache} = phi(arguments, {false, negation(return)}, positives, %{})
1486+
result
1487+
end
14751488
end
14761489
end
14771490

1478-
defp phi(args, {b, t}, []) do
1479-
Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t))
1491+
defp phi(args, {b, t}, [], cache) do
1492+
{Enum.any?(args, fn {bool, typ} -> bool and empty?(typ) end) or (b and empty?(t)), cache}
14801493
end
14811494

1482-
defp phi(args, {b, ret}, [{arguments, return} | rest_positive]) do
1483-
phi(args, {true, intersection(ret, return)}, rest_positive) and
1484-
Enum.all?(Enum.with_index(arguments), fn {type, index} ->
1485-
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
1486-
|> phi({b, ret}, rest_positive)
1487-
end)
1495+
defp phi(args, {b, ret}, [{arguments, return} | rest_positive], cache) do
1496+
# Create cache key from function arguments
1497+
cache_key = {args, {b, ret}, [{arguments, return} | rest_positive]}
1498+
1499+
case Map.get(cache, cache_key) do
1500+
nil ->
1501+
# Compute result and cache it
1502+
{result1, cache} = phi(args, {true, intersection(ret, return)}, rest_positive, cache)
1503+
1504+
if not result1 do
1505+
# Store false result in cache
1506+
cache = Map.put(cache, cache_key, false)
1507+
{false, cache}
1508+
else
1509+
# This doesn't stop if one intermediate result is false?
1510+
{result2, cache} =
1511+
Enum.with_index(arguments)
1512+
|> Enum.reduce_while({true, cache}, fn {type, index}, {acc_result, acc_cache} ->
1513+
{new_result, new_cache} =
1514+
List.update_at(args, index, fn {_, arg} -> {true, difference(arg, type)} end)
1515+
|> phi({b, ret}, rest_positive, acc_cache)
1516+
1517+
if new_result do
1518+
{:cont, {acc_result and new_result, new_cache}}
1519+
else
1520+
{:halt, {false, new_cache}}
1521+
end
1522+
end)
1523+
1524+
result = result1 and result2
1525+
# Store result in cache
1526+
cache = Map.put(cache, cache_key, result)
1527+
{result, cache}
1528+
end
1529+
1530+
cached_result ->
1531+
# Return cached result
1532+
{cached_result, cache}
1533+
end
1534+
end
1535+
1536+
defp all_non_empty_domains?(positives) do
1537+
Enum.all?(positives, fn {args, _ret} -> not empty?(args_to_domain(args)) end)
14881538
end
14891539

14901540
defp fun_union(bdd1, bdd2) do
@@ -1831,6 +1881,10 @@ defmodule Module.Types.Descr do
18311881
# b) If only the last type differs, subtracts it
18321882
# 3. Base case: adds dnf2 type to negations of dnf1 type
18331883
# The result may be larger than the initial dnf1, which is maintained in the accumulator.
1884+
defp list_difference(_, dnf) when dnf == @non_empty_list_top do
1885+
0
1886+
end
1887+
18341888
defp list_difference(dnf1, dnf2) do
18351889
Enum.reduce(dnf2, dnf1, fn {t2, last2, negs2}, acc_dnf1 ->
18361890
last2 = list_tail_unfold(last2)
@@ -1858,6 +1912,8 @@ defmodule Module.Types.Descr do
18581912
end)
18591913
end
18601914

1915+
defp list_empty?(@non_empty_list_top), do: false
1916+
18611917
defp list_empty?(dnf) do
18621918
Enum.all?(dnf, fn {list_type, last_type, negs} ->
18631919
last_type = list_tail_unfold(last_type)
@@ -2118,9 +2174,6 @@ defmodule Module.Types.Descr do
21182174

21192175
defp dynamic_to_quoted(descr, opts) do
21202176
cond do
2121-
descr == %{} ->
2122-
[]
2123-
21242177
# We check for :term literally instead of using term_type?
21252178
# because we check for term_type? in to_quoted before we
21262179
# compute the difference(dynamic, static).
@@ -2130,6 +2183,9 @@ defmodule Module.Types.Descr do
21302183
single = indivisible_bitmap(descr, opts) ->
21312184
[single]
21322185

2186+
empty?(descr) ->
2187+
[]
2188+
21332189
true ->
21342190
case non_term_type_to_quoted(descr, opts) do
21352191
{:none, _meta, []} = none -> [none]
@@ -2398,6 +2454,10 @@ defmodule Module.Types.Descr do
23982454
if empty?(type), do: throw(:empty), else: type
23992455
end
24002456

2457+
defp map_difference(_, dnf) when dnf == @map_top do
2458+
0
2459+
end
2460+
24012461
defp map_difference(dnf1, dnf2) do
24022462
Enum.reduce(dnf2, dnf1, fn
24032463
# Optimization: we are removing an open map with one field.
@@ -3048,10 +3108,15 @@ defmodule Module.Types.Descr do
30483108
zip_non_empty_intersection!(rest1, rest2, [non_empty_intersection!(type1, type2) | acc])
30493109
end
30503110

3111+
defp tuple_difference(_, dnf) when dnf == @tuple_top do
3112+
0
3113+
end
3114+
30513115
defp tuple_difference(dnf1, dnf2) do
30523116
Enum.reduce(dnf2, dnf1, fn {tag2, elements2}, dnf1 ->
30533117
Enum.reduce(dnf1, [], fn {tag1, elements1}, acc ->
3054-
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2}) ++ acc
3118+
tuple_eliminate_single_negation(tag1, elements1, {tag2, elements2})
3119+
|> tuple_union(acc)
30553120
end)
30563121
end)
30573122
end
@@ -3066,8 +3131,10 @@ defmodule Module.Types.Descr do
30663131
if (tag == :closed and n < m) or (neg_tag == :closed and n > m) do
30673132
[{tag, elements}]
30683133
else
3069-
tuple_elim_content([], tag, elements, neg_elements) ++
3134+
tuple_union(
3135+
tuple_elim_content([], tag, elements, neg_elements),
30703136
tuple_elim_size(n, m, tag, elements, neg_tag)
3137+
)
30713138
end
30723139
end
30733140

lib/elixir/lib/module/types/expr.ex

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ defmodule Module.Types.Expr do
355355
add_inferred(acc, args, body)
356356
end)
357357

358-
{fun_from_overlapping_clauses(acc), context}
358+
{fun_from_inferred_clauses(acc), context}
359359
end
360360
end
361361

@@ -476,7 +476,11 @@ defmodule Module.Types.Expr do
476476
{args_types, context} =
477477
Enum.map_reduce(args, context, &of_expr(&1, @pending, &1, stack, &2))
478478

479-
Apply.fun_apply(fun_type, args_types, call, stack, context)
479+
if stack.mode == :traversal do
480+
{dynamic(), context}
481+
else
482+
Apply.fun_apply(fun_type, args_types, call, stack, context)
483+
end
480484
end
481485

482486
def of_expr({{:., _, [callee, key_or_fun]}, meta, []} = call, expected, expr, stack, context)

0 commit comments

Comments
 (0)