Skip to content

Represent match arm wrapper as an enum with explicit cases #7726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/cairo-lang-lowering/src/lower/lower_if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ pub fn lower_expr_if_let(
}

let arms = vec![
MatchArmWrapper { patterns: patterns.into(), expr: Some(expr.if_block) },
MatchArmWrapper { patterns: vec![], expr: expr.else_block },
MatchArmWrapper::Arm(patterns, expr.if_block),
expr.else_block.map(MatchArmWrapper::ElseClause).unwrap_or(MatchArmWrapper::DefaultClause),
];

lower_match::lower_match_arms(
Expand Down
188 changes: 118 additions & 70 deletions crates/cairo-lang-lowering/src/lower/lower_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,51 @@ struct ExtractedEnumDetails {
n_snapshots: usize,
}

/// MatchArm wrapper that allows for optional expression clause.
/// Used in the case of if-let with missing else clause.
pub struct MatchArmWrapper {
pub patterns: Vec<PatternId>,
pub expr: Option<semantic::ExprId>,
/// A wrapper enum to provide a unified interface for handling different types of match arms
/// during the lowering phase of the compiler, allowing for consistent pattern matching
/// and expression evaluation across different match-like constructs.
pub enum MatchArmWrapper<'a> {
/// A match arm. Patterns (non-empty) guard the expression to evaluate.
Arm(&'a [PatternId], semantic::ExprId),
/// Else clause (no patterns) and it's expression to evaluate (if-let).
ElseClause(semantic::ExprId),
/// Default clause when else clause is not provided (if-let/while-let).
DefaultClause,
}

impl From<&semantic::MatchArm> for MatchArmWrapper {
fn from(arm: &semantic::MatchArm) -> Self {
Self { patterns: arm.patterns.clone(), expr: Some(arm.expression) }
impl<'a> From<&'a semantic::MatchArm> for MatchArmWrapper<'a> {
fn from(arm: &'a semantic::MatchArm) -> Self {
MatchArmWrapper::Arm(&arm.patterns, arm.expression)
}
}

impl MatchArmWrapper<'_> {
/// Returns the expression of the guarded by the match arm.
pub fn expr(&self) -> Option<semantic::ExprId> {
match self {
MatchArmWrapper::Arm(_, expr) => Some(*expr),
MatchArmWrapper::ElseClause(expr) => Some(*expr),
MatchArmWrapper::DefaultClause => None,
}
}

/// Returns the patterns of the match arm.
pub fn patterns(&self) -> Option<&[PatternId]> {
match self {
MatchArmWrapper::Arm(patterns, _) => Some(patterns),
MatchArmWrapper::ElseClause(_) => None,
MatchArmWrapper::DefaultClause => None,
}
}

/// Try and extract the pattern from the match arm by index.
/// Returns None if the arm is a missing clause, else clause.
pub fn pattern<'a>(
&self,
ctx: &'a LoweringContext<'_, '_>,
index: usize,
) -> Option<&'a Pattern> {
self.patterns().map(|patterns| &ctx.function_body.arenas.patterns[patterns[index]])
}
}

Expand Down Expand Up @@ -126,65 +161,69 @@ struct PatternPath {
/// Returns an option containing the PatternPath of the underscore pattern, if it exists.
fn get_underscore_pattern_path_and_mark_unreachable(
ctx: &mut LoweringContext<'_, '_>,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
match_type: MatchKind,
) -> Option<PatternPath> {
let otherwise_variant = arms
.iter()
.enumerate()
.filter_map(|(arm_index, arm)| {
let pattern_index = if arm.patterns.is_empty() {
// Special path for if-let else clause where no patterns exist.
None
} else {
let position = arm.patterns.iter().position(|pattern| {
matches!(
ctx.function_body.arenas.patterns[*pattern],
semantic::Pattern::Otherwise(_)
)
})?;
Some(position)
let pattern_index = match arm {
MatchArmWrapper::Arm(patterns, _) => {
let position = patterns.iter().position(|pattern| {
matches!(
ctx.function_body.arenas.patterns[*pattern],
semantic::Pattern::Otherwise(_)
)
})?;
Some(position)
}
MatchArmWrapper::DefaultClause | MatchArmWrapper::ElseClause(_) => None,
};
Some(PatternPath { arm_index, pattern_index })
})
.next()?;

for arm in arms.iter().skip(otherwise_variant.arm_index + 1) {
if arm.patterns.is_empty() && arm.expr.is_some() {
let expr = ctx.function_body.arenas.exprs[arm.expr.unwrap()].clone();
ctx.diagnostics.report(
&expr,
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnreachableMatchArm,
}),
);
match arm {
MatchArmWrapper::Arm(patterns, _expr) => {
for pattern in patterns.iter() {
let pattern_ptr = ctx.function_body.arenas.patterns[*pattern].stable_ptr();
ctx.diagnostics.report(
pattern_ptr,
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnreachableMatchArm,
}),
);
}
}
MatchArmWrapper::DefaultClause => continue,
MatchArmWrapper::ElseClause(e) => {
let expr_ptr = ctx.function_body.arenas.exprs[*e].stable_ptr();
ctx.diagnostics.report(
expr_ptr,
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnreachableMatchArm,
}),
);
}
}
for pattern in arm.patterns.iter() {
let pattern = ctx.function_body.arenas.patterns[*pattern].clone();
}

if let Some(patterns) = arms[otherwise_variant.arm_index].patterns() {
for pattern in patterns.iter().skip(otherwise_variant.pattern_index.unwrap_or(0) + 1) {
let pattern = &ctx.function_body.arenas.patterns[*pattern];
ctx.diagnostics.report(
&pattern,
pattern.stable_ptr(),
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnreachableMatchArm,
}),
);
}
}
for pattern in arms[otherwise_variant.arm_index]
.patterns
.iter()
.skip(otherwise_variant.pattern_index.unwrap_or(0) + 1)
{
let pattern = ctx.function_body.arenas.patterns[*pattern].clone();
ctx.diagnostics.report(
&pattern,
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnreachableMatchArm,
}),
);
}

Some(otherwise_variant)
}
Expand Down Expand Up @@ -287,7 +326,7 @@ impl VariantMatchTree {
/// Returns a map from variants to their corresponding pattern path in a match statement.
fn get_variant_to_arm_map(
ctx: &mut LoweringContext<'_, '_>,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
concrete_enum_id: semantic::ConcreteEnumId,
match_type: MatchKind,
) -> LoweringResult<UnorderedHashMap<semantic::ConcreteVariant, PatternPath>> {
Expand All @@ -311,7 +350,10 @@ fn get_variant_to_arm_map(
// We use [VariantMatchTree] to check patterns are legal (reachable and all branches end with a pattern), and then collect an arm map.
let mut variant_match_tree = VariantMatchTree::Empty;
for (arm_index, arm) in arms.iter().enumerate() {
for (pattern_index, pattern) in arm.patterns.iter().copied().enumerate() {
let Some(patterns) = arm.patterns() else {
continue;
};
for (pattern_index, pattern) in patterns.iter().copied().enumerate() {
let pattern_path = PatternPath { arm_index, pattern_index: Some(pattern_index) };
let pattern_ptr = ctx.function_body.arenas.patterns[pattern].stable_ptr();

Expand Down Expand Up @@ -576,13 +618,16 @@ fn insert_tuple_path_patterns(
/// Returns a map from a matching paths to their corresponding pattern path in a match statement.
fn get_variants_to_arm_map_tuple<'a>(
ctx: &mut LoweringContext<'_, '_>,
arms: impl Iterator<Item = &'a MatchArmWrapper>,
arms: impl Iterator<Item = &'a MatchArmWrapper<'a>>,
extracted_enums_details: &[ExtractedEnumDetails],
match_type: MatchKind,
) -> LoweringResult<UnorderedHashMap<MatchingPath, PatternPath>> {
let mut map = UnorderedHashMap::default();
for (arm_index, arm) in arms.enumerate() {
for (pattern_index, pattern) in arm.patterns.iter().enumerate() {
let Some(patterns) = arm.patterns() else {
continue;
};
for (pattern_index, pattern) in patterns.iter().enumerate() {
let pattern = ctx.function_body.arenas.patterns[*pattern].clone();
if let semantic::Pattern::Otherwise(_) = pattern {
break;
Expand Down Expand Up @@ -644,7 +689,7 @@ struct LoweringMatchTupleContext {
fn lower_tuple_match_arm(
ctx: &mut LoweringContext<'_, '_>,
mut builder: BlockBuilder,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
match_tuple_ctx: &mut LoweringMatchTupleContext,
leaves_builders: &mut Vec<MatchLeafBuilder>,
match_type: MatchKind,
Expand All @@ -668,11 +713,12 @@ fn lower_tuple_match_arm(
}),
))
})?;
let pattern = pattern_path.pattern_index.map(|pattern_index| {
ctx.function_body.arenas.patterns[arms[pattern_path.arm_index].patterns[pattern_index]]
let pattern = pattern_path.pattern_index.map(|i| {
arms[pattern_path.arm_index]
.pattern(ctx, i)
.expect("Pattern previously found in arm, but is now missing at index.")
.clone()
});

let lowering_inner_pattern_result = match pattern {
Some(semantic::Pattern::Tuple(patterns)) => patterns
.field_patterns
Expand Down Expand Up @@ -709,8 +755,9 @@ fn lower_tuple_match_arm(
.map(|_| ()),
Some(semantic::Pattern::Otherwise(_)) | None => Ok(()),
_ => {
let stable_ptr = pattern.unwrap().stable_ptr();
return Err(LoweringFlowError::Failed(ctx.diagnostics.report(
&pattern.unwrap(),
stable_ptr,
MatchError(MatchError {
kind: match_type,
error: MatchDiagnostic::UnsupportedMatchArmNotATuple,
Expand All @@ -730,7 +777,7 @@ fn lower_tuple_match_arm(
fn lower_full_match_tree(
ctx: &mut LoweringContext<'_, '_>,
builder: &mut BlockBuilder,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
match_tuple_ctx: &mut LoweringMatchTupleContext,
extracted_enums_details: &[ExtractedEnumDetails],
leaves_builders: &mut Vec<MatchLeafBuilder>,
Expand Down Expand Up @@ -820,7 +867,7 @@ pub(crate) fn lower_expr_match_tuple(
expr: LoweredExpr,
matched_expr: semantic::ExprId,
tuple_info: &TupleInfo,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
match_type: MatchKind,
) -> LoweringResult<LoweredExpr> {
let location = expr.location();
Expand Down Expand Up @@ -942,7 +989,7 @@ pub(crate) fn lower_match_arms(
builder: &mut BlockBuilder,
matched_expr: semantic::ExprId,
lowered_expr: LoweredExpr,
arms: Vec<MatchArmWrapper>,
arms: Vec<MatchArmWrapper<'_>>,
location: LocationId,
match_type: MatchKind,
) -> Result<LoweredExpr, LoweringFlowError> {
Expand Down Expand Up @@ -979,7 +1026,7 @@ pub(crate) fn lower_concrete_enum_match(
builder: &mut BlockBuilder,
matched_expr: semantic::ExprId,
lowered_matched_expr: LoweredExpr,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
location: LocationId,
match_type: MatchKind,
) -> LoweringResult<LoweredExpr> {
Expand Down Expand Up @@ -1021,7 +1068,8 @@ pub(crate) fn lower_concrete_enum_match(
let mut subscope = create_subscope(ctx, builder);

let pattern = pattern_index.map(|pattern_index| {
&ctx.function_body.arenas.patterns[arm.patterns[pattern_index]]
arm.pattern(ctx, pattern_index)
.expect("Pattern was previously found and should be present in the arm.")
});
let block_id = subscope.block_id;
block_ids.push(block_id);
Expand Down Expand Up @@ -1050,9 +1098,10 @@ pub(crate) fn lower_concrete_enum_match(
Pattern::EnumVariant(PatternEnumVariant { inner_pattern: None, .. })
| Pattern::Otherwise(_),
) => {
let location = ctx.get_location(pattern.unwrap().into());
let var_id = ctx.new_var(VarRequest {
ty: wrap_in_snapshots(ctx.db, concrete_variant.ty, n_snapshots),
location: ctx.get_location(pattern.unwrap().into()),
location,
});
arm_var_ids.push(vec![var_id]);
Ok(())
Expand Down Expand Up @@ -1116,7 +1165,7 @@ pub(crate) fn lower_optimized_extern_match(
ctx: &mut LoweringContext<'_, '_>,
builder: &mut BlockBuilder,
extern_enum: LoweredExprExternEnum,
match_arms: &[MatchArmWrapper],
match_arms: &[MatchArmWrapper<'_>],
match_type: MatchKind,
) -> LoweringResult<LoweredExpr> {
log::trace!("Started lowering of an optimized extern match.");
Expand Down Expand Up @@ -1172,9 +1221,8 @@ pub(crate) fn lower_optimized_extern_match(
})?;

let arm = &match_arms[*arm_index];
let pattern = pattern_index.map(|pattern_index| {
&ctx.function_body.arenas.patterns[arm.patterns[pattern_index]]
});
let pattern =
pattern_index.map(|pattern_index| arm.pattern(ctx, pattern_index).unwrap());

let lowering_inner_pattern_result = match pattern {
Some(Pattern::EnumVariant(PatternEnumVariant {
Expand Down Expand Up @@ -1242,7 +1290,7 @@ fn group_match_arms(
ctx: &mut LoweringContext<'_, '_>,
empty_match_info: MatchInfo,
location: LocationId,
arms: &[MatchArmWrapper],
arms: &[MatchArmWrapper<'_>],
variants_block_builders: Vec<MatchLeafBuilder>,
kind: MatchKind,
) -> LoweringResult<Vec<SealedBlockBuilder>> {
Expand All @@ -1265,7 +1313,7 @@ fn group_match_arms(
return match lowering_inner_pattern_result {
Ok(_) => {
// Lower the arm expression.
match (arm.expr, kind) {
match (arm.expr(), kind) {
(Some(expr), MatchKind::IfLet | MatchKind::Match) => {
lower_tail_expr(ctx, subscope, expr)
}
Expand Down Expand Up @@ -1311,8 +1359,8 @@ fn group_match_arms(
.map(|(lowering_inner_pattern_result, subscope)| {
// Use the first pattern for the location of the for variable assignment block.
let location = arm
.patterns
.first()
.patterns()
.and_then(|ps| ps.first())
.map(|pattern| {
ctx.get_location(
ctx.function_body.arenas.patterns[*pattern].stable_ptr().untyped(),
Expand All @@ -1337,7 +1385,7 @@ fn group_match_arms(
sealed_blocks,
location,
)?;
match (arm.expr, kind) {
match (arm.expr(), kind) {
(Some(expr), MatchKind::IfLet | MatchKind::Match) => {
lower_tail_expr(ctx, outer_subscope, expr)
}
Expand Down
5 changes: 1 addition & 4 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,7 @@ pub fn lower_expr_while_let(
)));
}

let arms = vec![
MatchArmWrapper { patterns: patterns.into(), expr: Some(loop_expr.body) },
MatchArmWrapper { patterns: vec![], expr: None },
];
let arms = vec![MatchArmWrapper::Arm(patterns, loop_expr.body), MatchArmWrapper::DefaultClause];

lower_match::lower_match_arms(
ctx,
Expand Down