Skip to content

Commit 8b9901b

Browse files
Match on enum arm collection using tree structure
This is a preparation step to supporting enum matching on inner enum variants. To support multi level matching we need to address different enum variants will have different parameters, therefor a tree structure, EnumPaths, is introduced. Coverage of all variants is done by checking if all tree leaves are full and not empty, where empty cases are missing subvariants. commit-id:978a80be
1 parent 64b88f0 commit 8b9901b

File tree

2 files changed

+526
-45
lines changed

2 files changed

+526
-45
lines changed

crates/cairo-lang-lowering/src/lower/lower_match.rs

+249-44
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use cairo_lang_debug::DebugWithDb;
22
use cairo_lang_defs::ids::NamedLanguageElementId;
33
use cairo_lang_filesystem::flag::Flag;
44
use cairo_lang_filesystem::ids::FlagId;
5-
use cairo_lang_semantic::{self as semantic, GenericArgumentId, corelib};
5+
use cairo_lang_semantic::{self as semantic, ConcreteVariant, GenericArgumentId, corelib};
66
use cairo_lang_syntax::node::TypedStablePtr;
77
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
8+
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
89
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
910
use cairo_lang_utils::{LookupIntern, try_extract_matches};
1011
use itertools::{Itertools, zip_eq};
@@ -26,7 +27,7 @@ use super::{
2627
alloc_empty_block, generators, lower_expr_block, lower_expr_literal, lower_tail_expr,
2728
lowered_expr_to_block_scope_end, recursively_call_loop_func,
2829
};
29-
use crate::diagnostic::LoweringDiagnosticKind::*;
30+
use crate::diagnostic::LoweringDiagnosticKind::{self, *};
3031
use crate::diagnostic::{LoweringDiagnosticsBuilder, MatchDiagnostic, MatchError, MatchKind};
3132
use crate::ids::{LocationId, SemanticFunctionIdEx};
3233
use crate::lower::context::VarRequest;
@@ -116,7 +117,7 @@ fn extract_concrete_enum_tuple(
116117
}
117118

118119
/// The arm and pattern indices of a pattern in a match arm with an or list.
119-
#[derive(Debug, Clone)]
120+
#[derive(Debug, Clone, Copy)]
120121
struct PatternPath {
121122
arm_index: usize,
122123
pattern_index: Option<usize>,
@@ -188,61 +189,256 @@ fn get_underscore_pattern_path_and_mark_unreachable(
188189
Some(otherwise_variant)
189190
}
190191

191-
/// Returns a map from variants to their corresponding pattern path in a match statement.
192-
fn get_variant_to_arm_map<'a>(
193-
ctx: &mut LoweringContext<'_, '_>,
194-
arms: impl Iterator<Item = &'a MatchArmWrapper>,
195-
concrete_enum_id: semantic::ConcreteEnumId,
196-
match_type: MatchKind,
197-
) -> LoweringResult<UnorderedHashMap<semantic::ConcreteVariant, PatternPath>> {
198-
let mut map = UnorderedHashMap::default();
199-
for (arm_index, arm) in arms.enumerate() {
200-
for (pattern_index, pattern) in arm.patterns.iter().enumerate() {
201-
let pattern = &ctx.function_body.arenas.patterns[*pattern];
192+
/// A sparse tree that records which enum‑variant paths are *already*
193+
/// covered by user code during `match` lowering.
194+
///
195+
/// Each node captures the coverage state for a single variant path:
196+
/// [`VariantMatchTree::Mapping`] The current variant is itself an enum; a `Vec` entry is kept for
197+
/// every child variant match tree.
198+
/// ['VariantMatchTree::Full'] A concrete pattern fully covers this path. Additional patterns
199+
/// reaching here are unreachable (even if current variant is itself an enum, subtrees are
200+
/// irrelevant). [`VariantMatchTree::Empty`] No pattern has touched this path yet. Useful to emit a
201+
/// `MissingMatchArm` diagnostic later on.
202+
#[derive(Debug, Clone)]
203+
enum VariantMatchTree {
204+
/// Mapping of enum variant id(x) to sub [VariantMatchTree].
205+
Mapping(Vec<VariantMatchTree>),
206+
/// Path is covered by a pattern.
207+
Full(PatternId, PatternPath),
208+
/// No pattern has been added to this path.
209+
Empty,
210+
}
202211

203-
if let semantic::Pattern::Otherwise(_) = pattern {
204-
continue;
212+
impl VariantMatchTree {
213+
/// Pushes a pattern to the enum paths. Fails if the pattern is unreachable.
214+
fn push_pattern_path(
215+
&mut self,
216+
ptrn_id: PatternId,
217+
pattern_path: PatternPath,
218+
) -> Result<(), LoweringDiagnosticKind> {
219+
match self {
220+
VariantMatchTree::Empty => {
221+
*self = VariantMatchTree::Full(ptrn_id, pattern_path);
222+
Ok(())
223+
}
224+
VariantMatchTree::Full(_, _) => Err(MatchError(MatchError {
225+
kind: MatchKind::Match,
226+
error: MatchDiagnostic::UnreachableMatchArm,
227+
})),
228+
VariantMatchTree::Mapping(mapping) => {
229+
// Need at least one empty path, but should write to all (pattern covers multiple
230+
// paths).
231+
let mut any_ok = false;
232+
for path in mapping.iter_mut() {
233+
if path.push_pattern_path(ptrn_id, pattern_path).is_ok() {
234+
any_ok = true;
235+
}
236+
}
237+
if any_ok {
238+
Ok(())
239+
} else {
240+
Err(MatchError(MatchError {
241+
kind: MatchKind::Match,
242+
error: MatchDiagnostic::UnreachableMatchArm,
243+
}))
244+
}
205245
}
246+
}
247+
}
206248

207-
let pat_stable_ptr = pattern.stable_ptr();
249+
/// Utility to collect every [`PatternId`] found in `Full` leaves into `leaves`.
250+
fn collect_leaves(&self, leaves: &mut OrderedHashSet<PatternId>) {
251+
match self {
252+
VariantMatchTree::Empty => {}
253+
VariantMatchTree::Full(ptrn_id, _) => {
254+
leaves.insert(*ptrn_id);
255+
}
256+
VariantMatchTree::Mapping(mapping) => {
257+
for path in mapping.iter() {
258+
path.collect_leaves(leaves);
259+
}
260+
}
261+
}
262+
}
208263

209-
let Some(enum_pattern) = try_extract_matches!(pattern, semantic::Pattern::EnumVariant)
210-
else {
211-
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
212-
pat_stable_ptr,
213-
MatchError(MatchError {
214-
kind: match_type,
215-
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
216-
}),
217-
)));
218-
};
264+
/// Fails on missing enum in db.
265+
/// Returns None if path is full otherwise reference the [VariantMatchTree] of appropriate
266+
/// variant.
267+
fn get_mapping_or_insert<'a>(
268+
&'a mut self,
269+
ctx: &LoweringContext<'_, '_>,
270+
concrete_variant: ConcreteVariant,
271+
) -> cairo_lang_diagnostics::Maybe<Option<&'a mut Self>> {
272+
match self {
273+
VariantMatchTree::Empty => {
274+
let variant_count =
275+
ctx.db.concrete_enum_variants(concrete_variant.concrete_enum_id)?.len();
276+
*self = VariantMatchTree::Mapping(vec![VariantMatchTree::Empty; variant_count]);
277+
if let VariantMatchTree::Mapping(items) = self {
278+
Ok(Some(&mut items[concrete_variant.idx]))
279+
} else {
280+
unreachable!("We just created the mapping.")
281+
}
282+
}
283+
VariantMatchTree::Full(_, _) => Ok(None),
284+
VariantMatchTree::Mapping(items) => Ok(Some(&mut items[concrete_variant.idx])),
285+
}
286+
}
287+
}
219288

220-
if enum_pattern.variant.concrete_enum_id != concrete_enum_id {
289+
/// Returns a map from variants to their corresponding pattern path in a match statement.
290+
fn get_variant_to_arm_map(
291+
ctx: &mut LoweringContext<'_, '_>,
292+
arms: &[MatchArmWrapper],
293+
concrete_enum_id: semantic::ConcreteEnumId,
294+
match_type: MatchKind,
295+
) -> LoweringResult<UnorderedHashMap<semantic::ConcreteVariant, PatternPath>> {
296+
// I could have Option<Option<Option<.....>>> but only match the first one.
297+
// This can be generalized to a tree of enums A::A(C::C) does not need to check recursion of
298+
// A::B(_). Check pattern is legal and collect paths.
299+
let mut variant_map = VariantMatchTree::Empty;
300+
for (arm_index, arm) in arms.iter().enumerate() {
301+
for (pattern_index, mut pattern) in arm.patterns.iter().copied().enumerate() {
302+
let pattern_path = PatternPath { arm_index, pattern_index: Some(pattern_index) };
303+
let pattern_ptr = ctx.function_body.arenas.patterns[pattern].stable_ptr();
304+
305+
let mut variant_map = &mut variant_map;
306+
let mut concrete_enum_id = concrete_enum_id;
307+
if !(matches_enum(ctx, pattern) | matches_other(ctx, pattern)) {
221308
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
222-
pat_stable_ptr,
309+
pattern_ptr,
223310
MatchError(MatchError {
224311
kind: match_type,
225312
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
226313
}),
227314
)));
228315
}
316+
loop {
317+
match &ctx.function_body.arenas.patterns[pattern] {
318+
semantic::Pattern::Otherwise(_) => {
319+
// Fill leaves and check for usefulness.
320+
let _ = variant_map.push_pattern_path(pattern, pattern_path);
321+
// TODO(eytan-starkware) Check result and report warning if unreachable.
322+
break;
323+
}
324+
semantic::Pattern::EnumVariant(enum_pattern) => {
325+
if concrete_enum_id != enum_pattern.variant.concrete_enum_id {
326+
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
327+
pattern_ptr,
328+
MatchError(MatchError {
329+
kind: match_type,
330+
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
331+
}),
332+
)));
333+
}
334+
// Expand paths in map to include all variants of this enum_pattern.
335+
if let Some(vmap) = variant_map
336+
.get_mapping_or_insert(ctx, enum_pattern.variant.clone())
337+
.map_err(LoweringFlowError::Failed)?
338+
{
339+
variant_map = vmap;
340+
} else {
341+
ctx.diagnostics.report(
342+
pattern_ptr,
343+
MatchError(MatchError {
344+
kind: match_type,
345+
error: MatchDiagnostic::UnreachableMatchArm,
346+
}),
347+
);
348+
break;
349+
}
229350

230-
match map.entry(enum_pattern.variant.clone()) {
231-
Entry::Occupied(_) => {
232-
ctx.diagnostics.report(
233-
pat_stable_ptr,
234-
MatchError(MatchError {
235-
kind: match_type,
236-
error: MatchDiagnostic::UnreachableMatchArm,
237-
}),
238-
);
351+
if let Some(inner_pattern) = enum_pattern.inner_pattern {
352+
if !matches_enum(ctx, inner_pattern) {
353+
let _ = try_push(ctx, pattern, pattern_path, variant_map);
354+
break;
355+
}
356+
357+
let ptr = ctx.function_body.arenas.patterns[inner_pattern].stable_ptr();
358+
let variant = &ctx
359+
.db
360+
.concrete_enum_variants(concrete_enum_id)
361+
.map_err(LoweringFlowError::Failed)?[enum_pattern.variant.idx];
362+
let next_enum =
363+
extract_concrete_enum(ctx, ptr.into(), variant.ty, match_type);
364+
concrete_enum_id = next_enum?.concrete_enum_id;
365+
366+
pattern = inner_pattern;
367+
} else {
368+
let _ = try_push(ctx, pattern, pattern_path, variant_map);
369+
break;
370+
}
371+
}
372+
_ => {
373+
break;
374+
}
239375
}
240-
Entry::Vacant(entry) => {
241-
entry.insert(PatternPath { arm_index, pattern_index: Some(pattern_index) });
376+
}
377+
}
378+
}
379+
let mut map = UnorderedHashMap::default();
380+
// Assert only one level of mapping for now and turn it into normal map format.
381+
let concrete_variants =
382+
ctx.db.concrete_enum_variants(concrete_enum_id).map_err(LoweringFlowError::Failed)?;
383+
match variant_map {
384+
VariantMatchTree::Empty => {}
385+
VariantMatchTree::Full(_, pattern_path) => {
386+
for variant in concrete_variants.iter() {
387+
map.insert(variant.clone(), pattern_path);
388+
}
389+
}
390+
VariantMatchTree::Mapping(items) => {
391+
for (variant_idx, path) in items.into_iter().enumerate() {
392+
match path {
393+
VariantMatchTree::Mapping(_) => {
394+
// Bad pattern we don't support inner enums.
395+
let mut leaves: OrderedHashSet<_> = Default::default();
396+
path.collect_leaves(&mut leaves);
397+
for leaf in leaves.iter() {
398+
ctx.diagnostics.report(
399+
ctx.function_body.arenas.patterns[*leaf].stable_ptr(),
400+
UnsupportedPattern,
401+
);
402+
}
403+
}
404+
VariantMatchTree::Full(_, pattern_path) => {
405+
map.insert(concrete_variants[variant_idx].clone(), pattern_path);
406+
}
407+
VariantMatchTree::Empty => {}
242408
}
243-
};
409+
}
244410
}
245411
}
412+
413+
/// This function attempts to push a pattern onto the [VariantMatchTree] representing the enum match being lowered.
414+
/// If the pattern is unreachable (i.e., the enum variant/s it represents is already covered), it returns an error.
415+
fn try_push(
416+
ctx: &mut LoweringContext<'_, '_>,
417+
pattern: id_arena::Id<Pattern>,
418+
pattern_path: PatternPath,
419+
variant_map: &mut VariantMatchTree,
420+
) -> Result<(), LoweringFlowError> {
421+
variant_map.push_pattern_path(pattern, pattern_path).map_err(|e| {
422+
LoweringFlowError::Failed(
423+
ctx.diagnostics.report(ctx.function_body.arenas.patterns[pattern].stable_ptr(), e),
424+
)
425+
})?;
426+
Ok(())
427+
}
428+
429+
/// Checks if a pattern matches an enum variant.
430+
fn matches_enum(ctx: &LoweringContext<'_, '_>, pattern: PatternId) -> bool {
431+
matches!(ctx.function_body.arenas.patterns[pattern], semantic::Pattern::EnumVariant(_))
432+
}
433+
434+
/// Checks if a pattern matches `otherwise` or a variable.
435+
fn matches_other(ctx: &LoweringContext<'_, '_>, pattern: PatternId) -> bool {
436+
matches!(
437+
ctx.function_body.arenas.patterns[pattern],
438+
semantic::Pattern::Otherwise(_) | semantic::Pattern::Variable(_)
439+
)
440+
}
441+
246442
Ok(map)
247443
}
248444

@@ -271,7 +467,7 @@ fn insert_tuple_path_patterns(
271467
match map.entry(path) {
272468
Entry::Occupied(_) => {}
273469
Entry::Vacant(entry) => {
274-
entry.insert(pattern_path.clone());
470+
entry.insert(*pattern_path);
275471
}
276472
};
277473
return Ok(());
@@ -493,6 +689,7 @@ fn lower_full_match_tree(
493689
leaves_builders: &mut Vec<MatchLeafBuilder>,
494690
match_type: MatchKind,
495691
) -> LoweringResult<MatchInfo> {
692+
// Always 0 for initial call as this is default
496693
let index = match_tuple_ctx.current_path.variants.len();
497694
let mut arm_var_ids = vec![];
498695
let block_ids = extracted_enums_details[index]
@@ -692,6 +889,7 @@ pub(crate) fn lower_expr_match(
692889
}
693890

694891
/// Lower the collected match arms according to the matched expression.
892+
/// To be used in multi pattern matching scenarios (if let/while let/match).
695893
pub(crate) fn lower_match_arms(
696894
ctx: &mut LoweringContext<'_, '_>,
697895
builder: &mut BlockBuilder,
@@ -726,6 +924,9 @@ pub(crate) fn lower_match_arms(
726924
lower_concrete_enum_match(ctx, builder, matched_expr, lowered_expr, &arms, location, match_type)
727925
}
728926

927+
/// Lowers a match expression on a concrete enum.
928+
/// This function is used for match expressions on concrete enums, such as `match x { A => 1, B => 2
929+
/// }` and in if/while let.
729930
pub(crate) fn lower_concrete_enum_match(
730931
ctx: &mut LoweringContext<'_, '_>,
731932
builder: &mut BlockBuilder,
@@ -738,12 +939,16 @@ pub(crate) fn lower_concrete_enum_match(
738939
let matched_expr = &ctx.function_body.arenas.exprs[matched_expr];
739940
let ExtractedEnumDetails { concrete_enum_id, concrete_variants, n_snapshots } =
740941
extract_concrete_enum(ctx, matched_expr.into(), matched_expr.ty(), match_type)?;
942+
943+
// TODO(eytan-starkware) I need to have all the concrete variants down to the lowest level?
741944
let match_input = lowered_matched_expr.as_var_usage(ctx, builder)?;
742945

743946
// Merge arm blocks.
947+
// Collect for all variants the arm index and pattern index. This can be recursive as for a
948+
// variant we can have inner variants.
744949
let otherwise_variant = get_underscore_pattern_path_and_mark_unreachable(ctx, arms, match_type);
950+
let variant_map = get_variant_to_arm_map(ctx, arms, concrete_enum_id, match_type)?;
745951

746-
let variant_map = get_variant_to_arm_map(ctx, arms.iter(), concrete_enum_id, match_type)?;
747952
let mut arm_var_ids = vec![];
748953
let mut block_ids = vec![];
749954
let variants_block_builders = concrete_variants
@@ -877,7 +1082,7 @@ pub(crate) fn lower_optimized_extern_match(
8771082
get_underscore_pattern_path_and_mark_unreachable(ctx, match_arms, match_type);
8781083

8791084
let variant_map =
880-
get_variant_to_arm_map(ctx, match_arms.iter(), extern_enum.concrete_enum_id, match_type)?;
1085+
get_variant_to_arm_map(ctx, match_arms, extern_enum.concrete_enum_id, match_type)?;
8811086
let mut arm_var_ids = vec![];
8821087
let mut block_ids = vec![];
8831088

0 commit comments

Comments
 (0)