Skip to content

Commit 7ec7013

Browse files
Match on enum arm collection using tree structure
This is a preparation step to supporting enum matching on inner enum variants. To support multi level matching we need to address different enum variants will have different parameters, therefor a tree structure, EnumPaths, is introduced. Coverage of all variants is done by checking if all tree leaves are full and not empty, where empty cases are missing subvariants. commit-id:978a80be
1 parent 5657547 commit 7ec7013

File tree

2 files changed

+521
-45
lines changed

2 files changed

+521
-45
lines changed

crates/cairo-lang-lowering/src/lower/lower_match.rs

Lines changed: 244 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use cairo_lang_debug::DebugWithDb;
22
use cairo_lang_defs::ids::NamedLanguageElementId;
33
use cairo_lang_filesystem::flag::Flag;
44
use cairo_lang_filesystem::ids::FlagId;
5-
use cairo_lang_semantic::{self as semantic, GenericArgumentId, corelib};
5+
use cairo_lang_semantic::{self as semantic, ConcreteVariant, GenericArgumentId, corelib};
66
use cairo_lang_syntax::node::TypedStablePtr;
77
use cairo_lang_syntax::node::ids::SyntaxStablePtrId;
8+
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
89
use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
910
use cairo_lang_utils::{LookupIntern, try_extract_matches};
1011
use itertools::{Itertools, zip_eq};
@@ -26,7 +27,7 @@ use super::{
2627
alloc_empty_block, generators, lower_expr_block, lower_expr_literal, lower_tail_expr,
2728
lowered_expr_to_block_scope_end, recursively_call_loop_func,
2829
};
29-
use crate::diagnostic::LoweringDiagnosticKind::*;
30+
use crate::diagnostic::LoweringDiagnosticKind::{self, *};
3031
use crate::diagnostic::{LoweringDiagnosticsBuilder, MatchDiagnostic, MatchError, MatchKind};
3132
use crate::ids::{LocationId, SemanticFunctionIdEx};
3233
use crate::lower::context::VarRequest;
@@ -116,7 +117,7 @@ fn extract_concrete_enum_tuple(
116117
}
117118

118119
/// The arm and pattern indices of a pattern in a match arm with an or list.
119-
#[derive(Debug, Clone)]
120+
#[derive(Debug, Clone, Copy)]
120121
struct PatternPath {
121122
arm_index: usize,
122123
pattern_index: Option<usize>,
@@ -188,61 +189,251 @@ fn get_underscore_pattern_path_and_mark_unreachable(
188189
Some(otherwise_variant)
189190
}
190191

191-
/// Returns a map from variants to their corresponding pattern path in a match statement.
192-
fn get_variant_to_arm_map<'a>(
193-
ctx: &mut LoweringContext<'_, '_>,
194-
arms: impl Iterator<Item = &'a MatchArmWrapper>,
195-
concrete_enum_id: semantic::ConcreteEnumId,
196-
match_type: MatchKind,
197-
) -> LoweringResult<UnorderedHashMap<semantic::ConcreteVariant, PatternPath>> {
198-
let mut map = UnorderedHashMap::default();
199-
for (arm_index, arm) in arms.enumerate() {
200-
for (pattern_index, pattern) in arm.patterns.iter().enumerate() {
201-
let pattern = &ctx.function_body.arenas.patterns[*pattern];
192+
/// A sparse tree that records which enum‑variant paths are *already*
193+
/// covered by user code during `match` lowering.
194+
///
195+
/// Each node captures the coverage state for a single variant path:
196+
/// [`Mapping`] The current variant is itself an enum; a `Vec` entry is kept for every child variant
197+
/// match tree. [`Full`] A concrete pattern fully covers this path. Additional patterns reaching
198+
/// here are unreachable (even if current variant is itself an enum, subtrees are irrelevant).
199+
/// [`Empty`] No pattern has touched this path yet. Useful to emit a `MissingMatchArm` diagnostic
200+
/// later on.
201+
#[derive(Debug, Clone)]
202+
enum VariantMatchTree {
203+
/// Mapping of enum variant id(x) to sub [VariantMatchTree].
204+
Mapping(Vec<VariantMatchTree>),
205+
/// Path is covered by a pattern.
206+
Full(PatternId, PatternPath),
207+
/// No pattern has been added to this path.
208+
Empty,
209+
}
202210

203-
if let semantic::Pattern::Otherwise(_) = pattern {
204-
continue;
211+
impl VariantMatchTree {
212+
/// Pushes a pattern to the enum paths. Fails if the pattern is unreachable.
213+
fn push_pattern_path(
214+
&mut self,
215+
ptrn_id: PatternId,
216+
pattern_path: PatternPath,
217+
) -> Result<(), LoweringDiagnosticKind> {
218+
match self {
219+
VariantMatchTree::Empty => {
220+
*self = VariantMatchTree::Full(ptrn_id, pattern_path);
221+
Ok(())
222+
}
223+
VariantMatchTree::Full(_, _) => Err(MatchError(MatchError {
224+
kind: MatchKind::Match,
225+
error: MatchDiagnostic::UnreachableMatchArm,
226+
})),
227+
VariantMatchTree::Mapping(mapping) => {
228+
// Need at least one empty path, but should write to all (pattern covers multiple
229+
// paths).
230+
let mut any_ok = false;
231+
for path in mapping.iter_mut() {
232+
if path.push_pattern_path(ptrn_id, pattern_path).is_ok() {
233+
any_ok = true;
234+
}
235+
}
236+
if any_ok {
237+
Ok(())
238+
} else {
239+
Err(MatchError(MatchError {
240+
kind: MatchKind::Match,
241+
error: MatchDiagnostic::UnreachableMatchArm,
242+
}))
243+
}
205244
}
245+
}
246+
}
206247

207-
let pat_stable_ptr = pattern.stable_ptr();
248+
/// Utility to collect every [`PatternId`] found in `Full` leaves into `leaves`.
249+
fn collect_leaves(&self, leaves: &mut OrderedHashSet<PatternId>) {
250+
match self {
251+
VariantMatchTree::Empty => {}
252+
VariantMatchTree::Full(ptrn_id, _) => {
253+
leaves.insert(*ptrn_id);
254+
}
255+
VariantMatchTree::Mapping(mapping) => {
256+
for path in mapping.iter() {
257+
path.collect_leaves(leaves);
258+
}
259+
}
260+
}
261+
}
208262

209-
let Some(enum_pattern) = try_extract_matches!(pattern, semantic::Pattern::EnumVariant)
210-
else {
211-
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
212-
pat_stable_ptr,
213-
MatchError(MatchError {
214-
kind: match_type,
215-
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
216-
}),
217-
)));
218-
};
263+
/// Fails on missing enum in db.
264+
/// Returns None if path is full otherwise reference the [EnumPaths] of appropriate variant.
265+
fn get_mapping_or_insert<'a>(
266+
&'a mut self,
267+
ctx: &LoweringContext<'_, '_>,
268+
concrete_variant: ConcreteVariant,
269+
) -> cairo_lang_diagnostics::Maybe<Option<&'a mut Self>> {
270+
match self {
271+
VariantMatchTree::Empty => {
272+
let variant_count =
273+
ctx.db.concrete_enum_variants(concrete_variant.concrete_enum_id)?.len();
274+
*self = VariantMatchTree::Mapping(vec![VariantMatchTree::Empty; variant_count]);
275+
if let VariantMatchTree::Mapping(items) = self {
276+
Ok(Some(&mut items[concrete_variant.idx]))
277+
} else {
278+
unreachable!("We just created the mapping.")
279+
}
280+
}
281+
VariantMatchTree::Full(_, _) => Ok(None),
282+
VariantMatchTree::Mapping(items) => Ok(Some(&mut items[concrete_variant.idx])),
283+
}
284+
}
285+
}
219286

220-
if enum_pattern.variant.concrete_enum_id != concrete_enum_id {
287+
/// Returns a map from variants to their corresponding pattern path in a match statement.
288+
fn get_variant_to_arm_map(
289+
ctx: &mut LoweringContext<'_, '_>,
290+
arms: &[MatchArmWrapper],
291+
concrete_enum_id: semantic::ConcreteEnumId,
292+
match_type: MatchKind,
293+
) -> LoweringResult<UnorderedHashMap<semantic::ConcreteVariant, PatternPath>> {
294+
// I could have Option<Option<Option<.....>>> but only match the first one.
295+
// This can be generalized to a tree of enums A::A(C::C) does not need to check recursion of
296+
// A::B(_). Check pattern is legal and collect paths.
297+
let mut variant_map = VariantMatchTree::Empty;
298+
for (arm_index, arm) in arms.iter().enumerate() {
299+
for (pattern_index, mut pattern) in arm.patterns.iter().copied().enumerate() {
300+
let pattern_path = PatternPath { arm_index, pattern_index: Some(pattern_index) };
301+
let pattern_ptr = ctx.function_body.arenas.patterns[pattern].stable_ptr();
302+
303+
let mut variant_map = &mut variant_map;
304+
let mut concrete_enum_id = concrete_enum_id;
305+
if !(matches_enum(ctx, pattern) | matches_other(ctx, pattern)) {
221306
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
222-
pat_stable_ptr,
307+
pattern_ptr,
223308
MatchError(MatchError {
224309
kind: match_type,
225310
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
226311
}),
227312
)));
228313
}
314+
loop {
315+
match &ctx.function_body.arenas.patterns[pattern] {
316+
semantic::Pattern::Otherwise(_) => {
317+
// Fill leaves and check for usefulness.
318+
let _ = variant_map.push_pattern_path(pattern, pattern_path);
319+
// TODO(eytan-starkware) Check result and report warning if unreachable.
320+
break;
321+
}
322+
semantic::Pattern::EnumVariant(enum_pattern) => {
323+
if concrete_enum_id != enum_pattern.variant.concrete_enum_id {
324+
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
325+
pattern_ptr,
326+
MatchError(MatchError {
327+
kind: match_type,
328+
error: MatchDiagnostic::UnsupportedMatchArmNotAVariant,
329+
}),
330+
)));
331+
}
332+
// Expand paths in map to include all variants of this enum_pattern.
333+
if let Some(vmap) = variant_map
334+
.get_mapping_or_insert(ctx, enum_pattern.variant.clone())
335+
.map_err(LoweringFlowError::Failed)?
336+
{
337+
variant_map = vmap;
338+
} else {
339+
ctx.diagnostics.report(
340+
pattern_ptr,
341+
MatchError(MatchError {
342+
kind: match_type,
343+
error: MatchDiagnostic::UnreachableMatchArm,
344+
}),
345+
);
346+
break;
347+
}
229348

230-
match map.entry(enum_pattern.variant.clone()) {
231-
Entry::Occupied(_) => {
232-
ctx.diagnostics.report(
233-
pat_stable_ptr,
234-
MatchError(MatchError {
235-
kind: match_type,
236-
error: MatchDiagnostic::UnreachableMatchArm,
237-
}),
238-
);
349+
if let Some(inner_pattern) = enum_pattern.inner_pattern {
350+
if !matches_enum(ctx, inner_pattern) {
351+
let _ = try_push(ctx, pattern, pattern_path, variant_map);
352+
break;
353+
}
354+
355+
let ptr = ctx.function_body.arenas.patterns[inner_pattern].stable_ptr();
356+
let variant = &ctx
357+
.db
358+
.concrete_enum_variants(concrete_enum_id)
359+
.map_err(LoweringFlowError::Failed)?[enum_pattern.variant.idx];
360+
let next_enum =
361+
extract_concrete_enum(ctx, ptr.into(), variant.ty, match_type);
362+
concrete_enum_id = next_enum?.concrete_enum_id;
363+
364+
pattern = inner_pattern;
365+
} else {
366+
let _ = try_push(ctx, pattern, pattern_path, variant_map);
367+
break;
368+
}
369+
}
370+
_ => {
371+
break;
372+
}
239373
}
240-
Entry::Vacant(entry) => {
241-
entry.insert(PatternPath { arm_index, pattern_index: Some(pattern_index) });
374+
}
375+
}
376+
}
377+
let mut map = UnorderedHashMap::default();
378+
// Assert only one level of mapping for now and turn it into normal map format.
379+
let concrete_variants =
380+
ctx.db.concrete_enum_variants(concrete_enum_id).map_err(LoweringFlowError::Failed)?;
381+
match variant_map {
382+
VariantMatchTree::Empty => {}
383+
VariantMatchTree::Full(_, pattern_path) => {
384+
for variant in concrete_variants.iter() {
385+
map.insert(variant.clone(), pattern_path);
386+
}
387+
}
388+
VariantMatchTree::Mapping(items) => {
389+
for (variant_idx, path) in items.into_iter().enumerate() {
390+
match path {
391+
VariantMatchTree::Mapping(_) => {
392+
// Bad pattern we don't support inner enums.
393+
let mut leaves: OrderedHashSet<_> = Default::default();
394+
path.collect_leaves(&mut leaves);
395+
for leaf in leaves.iter() {
396+
ctx.diagnostics.report(
397+
ctx.function_body.arenas.patterns[*leaf].stable_ptr(),
398+
UnsupportedPattern,
399+
);
400+
}
401+
}
402+
VariantMatchTree::Full(_, pattern_path) => {
403+
map.insert(concrete_variants[variant_idx].clone(), pattern_path);
404+
}
405+
VariantMatchTree::Empty => {}
242406
}
243-
};
407+
}
244408
}
245409
}
410+
411+
/// Fails on unreachable pattern.
412+
fn try_push(
413+
ctx: &mut LoweringContext<'_, '_>,
414+
pattern: id_arena::Id<Pattern>,
415+
pattern_path: PatternPath,
416+
variant_map: &mut VariantMatchTree,
417+
) -> Result<(), LoweringFlowError> {
418+
variant_map.push_pattern_path(pattern, pattern_path).map_err(|e| {
419+
LoweringFlowError::Failed(
420+
ctx.diagnostics.report(ctx.function_body.arenas.patterns[pattern].stable_ptr(), e),
421+
)
422+
})?;
423+
Ok(())
424+
}
425+
426+
fn matches_enum(ctx: &LoweringContext<'_, '_>, pattern: PatternId) -> bool {
427+
matches!(ctx.function_body.arenas.patterns[pattern], semantic::Pattern::EnumVariant(_))
428+
}
429+
430+
fn matches_other(ctx: &LoweringContext<'_, '_>, pattern: PatternId) -> bool {
431+
matches!(
432+
ctx.function_body.arenas.patterns[pattern],
433+
semantic::Pattern::Otherwise(_) | semantic::Pattern::Variable(_)
434+
)
435+
}
436+
246437
Ok(map)
247438
}
248439

@@ -271,7 +462,7 @@ fn insert_tuple_path_patterns(
271462
match map.entry(path) {
272463
Entry::Occupied(_) => {}
273464
Entry::Vacant(entry) => {
274-
entry.insert(pattern_path.clone());
465+
entry.insert(*pattern_path);
275466
}
276467
};
277468
return Ok(());
@@ -493,6 +684,7 @@ fn lower_full_match_tree(
493684
leaves_builders: &mut Vec<MatchLeafBuilder>,
494685
match_type: MatchKind,
495686
) -> LoweringResult<MatchInfo> {
687+
// Always 0 for initial call as this is default
496688
let index = match_tuple_ctx.current_path.variants.len();
497689
let mut arm_var_ids = vec![];
498690
let block_ids = extracted_enums_details[index]
@@ -692,6 +884,7 @@ pub(crate) fn lower_expr_match(
692884
}
693885

694886
/// Lower the collected match arms according to the matched expression.
887+
/// To be used in multi pattern matching scenarios (if let/while let/match).
695888
pub(crate) fn lower_match_arms(
696889
ctx: &mut LoweringContext<'_, '_>,
697890
builder: &mut BlockBuilder,
@@ -726,6 +919,9 @@ pub(crate) fn lower_match_arms(
726919
lower_concrete_enum_match(ctx, builder, matched_expr, lowered_expr, &arms, location, match_type)
727920
}
728921

922+
/// Lowers a match expression on a concrete enum.
923+
/// This function is used for match expressions on concrete enums, such as `match x { A => 1, B => 2
924+
/// }` and in if/while let.
729925
pub(crate) fn lower_concrete_enum_match(
730926
ctx: &mut LoweringContext<'_, '_>,
731927
builder: &mut BlockBuilder,
@@ -738,12 +934,16 @@ pub(crate) fn lower_concrete_enum_match(
738934
let matched_expr = &ctx.function_body.arenas.exprs[matched_expr];
739935
let ExtractedEnumDetails { concrete_enum_id, concrete_variants, n_snapshots } =
740936
extract_concrete_enum(ctx, matched_expr.into(), matched_expr.ty(), match_type)?;
937+
938+
// TODO(eytan-starkware) I need to have all the concrete variants down to the lowest level?
741939
let match_input = lowered_matched_expr.as_var_usage(ctx, builder)?;
742940

743941
// Merge arm blocks.
942+
// Collect for all variants the arm index and pattern index. This can be recursive as for a
943+
// variant we can have inner variants.
744944
let otherwise_variant = get_underscore_pattern_path_and_mark_unreachable(ctx, arms, match_type);
945+
let variant_map = get_variant_to_arm_map(ctx, arms, concrete_enum_id, match_type)?;
745946

746-
let variant_map = get_variant_to_arm_map(ctx, arms.iter(), concrete_enum_id, match_type)?;
747947
let mut arm_var_ids = vec![];
748948
let mut block_ids = vec![];
749949
let variants_block_builders = concrete_variants
@@ -877,7 +1077,7 @@ pub(crate) fn lower_optimized_extern_match(
8771077
get_underscore_pattern_path_and_mark_unreachable(ctx, match_arms, match_type);
8781078

8791079
let variant_map =
880-
get_variant_to_arm_map(ctx, match_arms.iter(), extern_enum.concrete_enum_id, match_type)?;
1080+
get_variant_to_arm_map(ctx, match_arms, extern_enum.concrete_enum_id, match_type)?;
8811081
let mut arm_var_ids = vec![];
8821082
let mut block_ids = vec![];
8831083

0 commit comments

Comments
 (0)