Skip to content

Commit

Permalink
Fix expanding OR patterns
Browse files Browse the repository at this point in the history
This was found in Inko (see
inko-lang/inko#679). The flatten_or() method
was applied after pushing bindings out of the rows. The result was that
OR patterns containing bindings would be left as-is, triggering
unreachable panics elsewhere in the implementation.

This commit fixes it by using a different implementation: we process and
expand all rows containing OR patterns _before_ pushing bindings out of
patterns. Because this is potentially a bit more expensive to perform,
the implementation skips this work if there aren't any OR patterns
present.
  • Loading branch information
yorickpeterse committed Feb 2, 2024
1 parent 08a9fac commit 50fcd0e
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 66 deletions.
39 changes: 10 additions & 29 deletions jacobs2021/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,35 +164,16 @@ code and corresponding tests.

## OR patterns

OR patterns are not covered in the article, but supporting them is easy.
Supporting these requires an extra `flatten_or` function that takes as input a
pattern and a `Row`, returning an array of `(Pattern, Row)` tuples. If the input
pattern is an OR pattern, it returns its sub patterns zipped with a copy of the
input row. If the input pattern is any other pattern, the function just returns
an array of the input pattern and row:

```rust
fn flatten_or(pattern: Pattern, row: Row) -> Vec<(Pattern, Row)> {
if let Pattern::Or(args) = pattern {
args.into_iter().map(|p| (p, row.clone())).collect()
} else {
vec![(pattern, row)]
}
}
```

When removing the branch column from a row you then use this function, instead
of acting upon a column's pattern directly:

```rust
if let Some(col) = row.remove_column(&branch_var) {
for (pat, row) in flatten_or(col.pattern, row) {
...
}
} else {
...
}
```
OR patterns are not covered in the article. To support these patterns we have to
take rows containing OR patterns in any columns, then expand those OR patterns
into separate rows. The code here handles this in the `expand_or_patterns()`
function. This function is called _before_ pushing variable/wildcard patterns
out of the rows, ensuring that OR patterns containing these patterns work as
expected.

**NOTE:** a previous implementation used a `flatten_or` method called, with a
different implementation. This implementation proved incorrect as it failed to
handle bindings in OR patterns (e.g. `10 or number`).

## Range patterns

Expand Down
122 changes: 85 additions & 37 deletions jacobs2021/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,69 @@ impl Constructor {
}
}

/// Expands rows containing OR patterns into individual rows, such that each
/// branch in the OR produces its own row.
///
/// For each column that tests against an OR pattern, each sub pattern is
/// translated into a new row. This work repeats itself until no more OR
/// patterns remain in the rows.
///
/// The implementation here is probably not as fast as it can be. Instead, it's
/// optimized for ease of maintenance and readability.
fn expand_or_patterns(rows: &mut Vec<Row>) {
// If none of the rows contain any OR patterns, we can avoid the below work
// loop, saving some allocations and time.
if !rows
.iter()
.any(|r| r.columns.iter().any(|c| matches!(c.pattern, Pattern::Or(_))))
{
return;
}

// The implementation uses two Vecs: the original one, and a temporary one
// we push newly created rows into. After processing all rows we swap the
// two, repeating this process until we no longer find any OR patterns.
let mut new_rows = Vec::with_capacity(rows.len());
let mut found = true;

while found {
found = false;

for row in rows.drain(0..) {
// Find the first column containing an OR pattern. We process this
// one column at a time, as that's (much) easier to implement
// compared to handling all columns at once (as multiple columns may
// contain OR patterns).
let res = row.columns.iter().enumerate().find_map(|(idx, col)| {
if let Pattern::Or(pats) = &col.pattern {
Some((idx, col.variable, pats))
} else {
None
}
});

if let Some((idx, var, pats)) = res {
found = true;

// This creates a new row for each branch in the OR pattern.
// Other columns are left as-is. If such columns contain OR
// patterns themselves, we'll expand them in a future iteration
// of the surrounding `while` loop.
for pat in pats {
let mut new_row = row.clone();

new_row.columns[idx] = Column::new(var, pat.clone());
new_rows.push(new_row);
}
} else {
new_rows.push(row);
}
}

std::mem::swap(rows, &mut new_rows);
}
}

/// A user defined pattern such as `Some((x, 10))`.
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum Pattern {
Expand All @@ -53,16 +116,6 @@ pub enum Pattern {
Range(i64, i64),
}

impl Pattern {
fn flatten_or(self, row: Row) -> Vec<(Pattern, Row)> {
if let Pattern::Or(args) = self {
args.into_iter().map(|p| (p, row.clone())).collect()
} else {
vec![(self, row)]
}
}
}

/// A representation of a type.
///
/// In a real compiler this would probably be a more complicated structure, but
Expand Down Expand Up @@ -384,6 +437,8 @@ impl Compiler {
return Decision::Failure;
}

expand_or_patterns(&mut rows);

for row in &mut rows {
self.move_variable_patterns(row);
}
Expand Down Expand Up @@ -479,25 +534,21 @@ impl Compiler {

for mut row in rows {
if let Some(col) = row.remove_column(&branch_var) {
for (pat, row) in col.pattern.flatten_or(row) {
let (key, cons) = match pat {
Pattern::Int(val) => {
((val, val), Constructor::Int(val))
}
Pattern::Range(start, stop) => {
((start, stop), Constructor::Range(start, stop))
}
_ => unreachable!(),
};

if let Some(index) = tested.get(&key) {
raw_cases[*index].2.push(row);
continue;
let (key, cons) = match col.pattern {
Pattern::Int(val) => ((val, val), Constructor::Int(val)),
Pattern::Range(start, stop) => {
((start, stop), Constructor::Range(start, stop))
}
_ => unreachable!(),
};

tested.insert(key, raw_cases.len());
raw_cases.push((cons, Vec::new(), vec![row]));
if let Some(index) = tested.get(&key) {
raw_cases[*index].2.push(row);
continue;
}

tested.insert(key, raw_cases.len());
raw_cases.push((cons, Vec::new(), vec![row]));
} else {
fallback_rows.push(row);
}
Expand Down Expand Up @@ -549,19 +600,16 @@ impl Compiler {
) -> Vec<Case> {
for mut row in rows {
if let Some(col) = row.remove_column(&branch_var) {
for (pat, row) in col.pattern.flatten_or(row) {
if let Pattern::Constructor(cons, args) = pat {
let idx = cons.index();
let mut cols = row.columns;

for (var, pat) in
cases[idx].1.iter().zip(args.into_iter())
{
cols.push(Column::new(*var, pat));
}
if let Pattern::Constructor(cons, args) = col.pattern {
let idx = cons.index();
let mut cols = row.columns;

cases[idx].2.push(Row::new(cols, row.guard, row.body));
for (var, pat) in cases[idx].1.iter().zip(args.into_iter())
{
cols.push(Column::new(*var, pat));
}

cases[idx].2.push(Row::new(cols, row.guard, row.body));
}
} else {
for (_, _, rows) in &mut cases {
Expand Down

0 comments on commit 50fcd0e

Please sign in to comment.