diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 0d3b3948..cf9f8bcd 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -128,7 +128,6 @@ impl Codegen { let inner = self.lift_type(t.as_ref()); CompType::Vec(Box::new(inner)).into() } - ValueType::Any => panic!("ValueType::Any"), ValueType::Record(fields) => { let mut rt_fields = Vec::new(); for (lab, ty) in fields.iter() { diff --git a/src/codegen/rust_ast.rs b/src/codegen/rust_ast.rs index 6d979bf4..84b185ef 100644 --- a/src/codegen/rust_ast.rs +++ b/src/codegen/rust_ast.rs @@ -667,7 +667,7 @@ impl TryFrom for RustType { let inner = Self::try_from(t.as_ref().clone())?; Ok(CompType::>::Vec(Box::new(inner)).into()) } - ValueType::Any | ValueType::Record(..) | ValueType::Union(..) => Err(value), + ValueType::Record(..) | ValueType::Union(..) => Err(value), } } } diff --git a/src/decoder.rs b/src/decoder.rs index 43bb777e..ea550ce8 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -516,7 +516,7 @@ impl<'a> Compiler<'a> { let mut compiler = Compiler::new(module); // type let scope = TypeScope::new(); - let t = module.infer_format_type(&scope, format)?; + let t = module.infer_format_type(&scope, format)?.to_value_type(); // decoder compiler.queue_compile(t, format, Rc::new(Next::Empty)); while let Some((f, next, n)) = compiler.compile_queue.pop() { diff --git a/src/lib.rs b/src/lib.rs index d0ab4707..6f46657f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![allow(clippy::new_without_default)] #![deny(rust_2018_idioms)] +use std::cell::{OnceCell, RefCell}; use std::collections::HashSet; use std::ops::Add; use std::rc::Rc; @@ -57,32 +58,35 @@ impl Pattern { Pattern::Binding(name.into()) } - fn build_scope(&self, scope: &mut TypeScope<'_>, t: &ValueType) { - match (self, t) { + fn build_scope(&self, scope: &mut TypeScope<'_>, t: &VarType) { + match (self, t.expand_var()) { (Pattern::Binding(name), t) => { scope.push(name.clone(), t.clone()); } (Pattern::Wildcard, _) => {} - (Pattern::Bool(_b0), ValueType::Bool) => {} - (Pattern::U8(_i0), ValueType::U8) => {} - (Pattern::U16(_i0), ValueType::U16) => {} - (Pattern::U32(_i0), ValueType::U32) => {} - (Pattern::Tuple(ps), ValueType::Tuple(ts)) if ps.len() == ts.len() => { + (Pattern::Bool(_b0), VarType::Bool) => {} + (Pattern::U8(_i0), VarType::U8) => {} + (Pattern::U16(_i0), VarType::U16) => {} + (Pattern::U32(_i0), VarType::U32) => {} + (Pattern::Tuple(ps), VarType::Tuple(ts)) if ps.len() == ts.len() => { for (p, t) in Iterator::zip(ps.iter(), ts.iter()) { p.build_scope(scope, t); } } - (Pattern::Seq(ps), ValueType::Seq(t)) => { + (Pattern::Seq(ps), VarType::Seq(t)) => { for p in ps { p.build_scope(scope, t); } } - (Pattern::Variant(label, p), ValueType::Union(branches)) => { - if let Some((_l, t)) = branches.iter().find(|(l, _t)| label == l) { - p.build_scope(scope, t); - } else { - panic!("no {label} in {branches:?}"); - } + (Pattern::Variant(label, p), VarType::Union(r)) => { + let u = r.borrow(); + u.with_branches(|branches: &[(Label, VarType)]| { + if let Some((_l, t)) = branches.iter().find(|(l, _t)| label == l) { + p.build_scope(scope, t); + } else { + panic!("no {label} in {branches:?}"); + } + }) } _ => panic!("pattern build_scope failed"), } @@ -91,9 +95,9 @@ impl Pattern { fn infer_expr_branch_type( &self, scope: &TypeScope<'_>, - head_type: &ValueType, + head_type: &VarType, expr: &Expr, - ) -> Result { + ) -> Result { let mut pattern_scope = TypeScope::child(scope); self.build_scope(&mut pattern_scope, head_type); expr.infer_type(&pattern_scope) @@ -102,10 +106,10 @@ impl Pattern { fn infer_format_branch_type( &self, scope: &TypeScope<'_>, - head_type: &ValueType, + head_type: &VarType, module: &FormatModule, format: &Format, - ) -> Result { + ) -> Result { let mut pattern_scope = TypeScope::child(scope); self.build_scope(&mut pattern_scope, head_type); module.infer_format_type(&pattern_scope, format) @@ -113,13 +117,12 @@ impl Pattern { } pub enum ValueKind { - Value(ValueType), - Format(ValueType), + Value(VarType), + Format(VarType), } #[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)] pub enum ValueType { - Any, Empty, Bool, U8, @@ -132,10 +135,176 @@ pub enum ValueType { Seq(Box), } +#[derive(Clone, Debug)] +pub enum VarUnion { + Var(Rc>), + Union(Rc>>), +} + +#[derive(Clone, Debug)] +pub enum VarType { + Var(Rc>), + Empty, + Bool, + U8, + U16, + U32, + Char, + Tuple(Vec), + Record(Vec<(Label, VarType)>), + Union(Rc>), + Seq(Box), +} + impl ValueType { - fn record_proj(&self, label: &str) -> ValueType { + fn to_var_type(&self) -> VarType { + match self { + ValueType::Empty => VarType::Empty, + ValueType::Bool => VarType::Bool, + ValueType::U8 => VarType::U8, + ValueType::U16 => VarType::U16, + ValueType::U32 => VarType::U32, + ValueType::Char => VarType::Char, + ValueType::Tuple(ts) => VarType::Tuple(ts.iter().map(|t| t.to_var_type()).collect()), + ValueType::Record(fields) => VarType::Record( + fields + .iter() + .map(|(l, t)| (l.clone(), t.to_var_type())) + .collect(), + ), + ValueType::Union(branches) => VarType::union( + branches + .iter() + .map(|(l, t)| (l.clone(), t.to_var_type())) + .collect(), + ), + ValueType::Seq(t) => VarType::Seq(Box::new(t.to_var_type())), + } + } +} + +impl VarUnion { + fn with_branches(&self, mut f: impl FnMut(&[(Label, VarType)]) -> T) -> T { + match self { + VarUnion::Var(v) => v.borrow().with_branches(f), + VarUnion::Union(r) => f(&r.borrow()), + } + } + + fn unify(r1: &Rc>, r2: &Rc>) -> Result { + VarUnion::unify1(r1, r2)?; + if r1.as_ptr() != r2.as_ptr() { + *r1.borrow_mut() = VarUnion::Var(r2.clone()); + } + Ok(VarType::Union(r1.clone())) + } + + fn unify1(r1: &Rc>, r2: &Rc>) -> Result<(), String> { + match &*r1.borrow() { + VarUnion::Var(v) => VarUnion::unify1(v, r2), + VarUnion::Union(u1) => VarUnion::unify2(u1, r2), + } + } + + fn unify2( + u1: &Rc>>, + r2: &Rc>, + ) -> Result<(), String> { + match &*r2.borrow() { + VarUnion::Var(v) => VarUnion::unify2(u1, v), + VarUnion::Union(u2) => { + let bs = VarUnion::unify3(u1, u2)?; + *u2.borrow_mut() = bs; + Ok(()) + } + } + } + + fn unify3( + r1: &Rc>>, + r2: &Rc>>, + ) -> Result, String> { + let bs1 = r1.borrow(); + let bs2 = r2.borrow(); + let mut bs: Vec<(Label, VarType)> = Vec::new(); + for (label, t2) in bs2.iter() { + let t = if let Some((_l, t1)) = bs.iter().find(|(l, _)| label == l) { + t1.unify(t2)? + } else { + t2.clone() + }; + bs.push((label.clone(), t)); + } + for (label, t1) in bs1.iter() { + if !bs.iter().any(|(l, _)| label == l) { + bs.push((label.clone(), t1.clone())); + } + } + Ok(bs) + } +} + +impl VarType { + // FIXME all uses of this function should probably be unifying type vars + fn expand_var(&self) -> &VarType { + if let VarType::Var(v) = self { + if let Some(v) = v.get() { + v.expand_var() + } else { + panic!("expand_var: unbound variable") + } + } else { + self + } + } + + fn to_value_type(&self) -> ValueType { match self { - ValueType::Record(fields) => match fields.iter().find(|(l, _)| label == l) { + VarType::Var(v) => { + if let Some(v) = v.get() { + v.to_value_type() + } else { + panic!("to_value_type: unbound variable") + } + } + VarType::Empty => ValueType::Empty, + VarType::Bool => ValueType::Bool, + VarType::U8 => ValueType::U8, + VarType::U16 => ValueType::U16, + VarType::U32 => ValueType::U32, + VarType::Char => ValueType::Char, + VarType::Tuple(ts) => ValueType::Tuple(ts.iter().map(|t| t.to_value_type()).collect()), + VarType::Record(fields) => ValueType::Record( + fields + .iter() + .map(|(l, t)| (l.clone(), t.to_value_type())) + .collect(), + ), + VarType::Union(u) => { + ValueType::Union(u.borrow().with_branches(&|branches: &[(Label, VarType)]| { + branches + .iter() + .map(|(l, t)| (l.clone(), t.to_value_type())) + .collect() + })) + } + VarType::Seq(t) => ValueType::Seq(Box::new(t.to_value_type())), + } + } + + fn var() -> VarType { + VarType::Var(Rc::new(OnceCell::new())) + } + + fn union(branches: Vec<(Label, VarType)>) -> VarType { + VarType::Union(Rc::new(RefCell::new(VarUnion::Union(Rc::new( + RefCell::new(branches), + ))))) + } + + fn record_proj(&self, label: &str) -> VarType { + match self.expand_var() { + VarType::Record(fields) => match fields.iter().find(|(l, _)| label == l) { Some((_, t)) => t.clone(), None => panic!("{label} not found in record type"), }, @@ -143,28 +312,34 @@ impl ValueType { } } - fn unwrap_tuple_type(self) -> Vec { - match self { - ValueType::Tuple(ts) => ts, + fn unwrap_tuple_type(self) -> Vec { + match self.expand_var().clone() { + VarType::Tuple(ts) => ts, _ => panic!("type is not a tuple"), } } fn is_numeric_type(&self) -> bool { - matches!(self, ValueType::U8 | ValueType::U16 | ValueType::U32) + matches!(self.expand_var(), VarType::U8 | VarType::U16 | VarType::U32) } - fn unify(&self, other: &ValueType) -> Result { + fn unify(&self, other: &VarType) -> Result { match (self, other) { - (ValueType::Any, rhs) => Ok(rhs.clone()), - (lhs, ValueType::Any) => Ok(lhs.clone()), - (ValueType::Empty, ValueType::Empty) => Ok(ValueType::Empty), - (ValueType::Bool, ValueType::Bool) => Ok(ValueType::Bool), - (ValueType::U8, ValueType::U8) => Ok(ValueType::U8), - (ValueType::U16, ValueType::U16) => Ok(ValueType::U16), - (ValueType::U32, ValueType::U32) => Ok(ValueType::U32), - (ValueType::Char, ValueType::Char) => Ok(ValueType::Char), - (ValueType::Tuple(ts1), ValueType::Tuple(ts2)) => { + (VarType::Var(v), other) | (other, VarType::Var(v)) => { + if let Some(t) = v.get() { + t.unify(other)?; + } else { + v.set(other.clone()).unwrap() + } + Ok(other.clone()) + } + (VarType::Empty, VarType::Empty) => Ok(VarType::Empty), + (VarType::Bool, VarType::Bool) => Ok(VarType::Bool), + (VarType::U8, VarType::U8) => Ok(VarType::U8), + (VarType::U16, VarType::U16) => Ok(VarType::U16), + (VarType::U32, VarType::U32) => Ok(VarType::U32), + (VarType::Char, VarType::Char) => Ok(VarType::Char), + (VarType::Tuple(ts1), VarType::Tuple(ts2)) => { if ts1.len() != ts2.len() { return Err(format!("tuples must have same length {ts1:?} vs. {ts2:?}")); } @@ -172,9 +347,9 @@ impl ValueType { for (t1, t2) in Iterator::zip(ts1.iter(), ts2.iter()) { ts.push(t1.unify(t2)?); } - Ok(ValueType::Tuple(ts)) + Ok(VarType::Tuple(ts)) } - (ValueType::Record(fs1), ValueType::Record(fs2)) => { + (VarType::Record(fs1), VarType::Record(fs2)) => { if fs1.len() != fs2.len() { return Err(format!( "records must have same number of fields {fs1:?} vs. {fs2:?}" @@ -188,26 +363,10 @@ impl ValueType { } fs.push((l1.clone(), t1.unify(t2)?)); } - Ok(ValueType::Record(fs)) - } - (ValueType::Union(bs1), ValueType::Union(bs2)) => { - let mut bs: Vec<(Label, ValueType)> = Vec::new(); - for (label, t2) in bs2 { - let t = if let Some((_l, t1)) = bs.iter().find(|(l, _)| label == l) { - t1.unify(t2)? - } else { - t2.clone() - }; - bs.push((label.clone(), t)); - } - for (label, t1) in bs1 { - if !bs.iter().any(|(l, _)| label == l) { - bs.push((label.clone(), t1.clone())); - } - } - Ok(ValueType::Union(bs)) + Ok(VarType::Record(fs)) } - (ValueType::Seq(t1), ValueType::Seq(t2)) => Ok(ValueType::Seq(Box::new(t1.unify(t2)?))), + (VarType::Union(r1), VarType::Union(r2)) => VarUnion::unify(r1, r2), + (VarType::Seq(t1), VarType::Seq(t2)) => Ok(VarType::Seq(Box::new(t1.unify(t2)?))), (t1, t2) => Err(format!("failed to unify types {t1:?} and {t2:?}")), } } @@ -283,25 +442,25 @@ impl Expr { } impl Expr { - fn infer_type(&self, scope: &TypeScope<'_>) -> Result { + fn infer_type(&self, scope: &TypeScope<'_>) -> Result { match self { Expr::Var(name) => match scope.get_type_by_name(name) { ValueKind::Value(t) => Ok(t.clone()), ValueKind::Format(_t) => Err("expected value type".to_string()), }, - Expr::Bool(_b) => Ok(ValueType::Bool), - Expr::U8(_i) => Ok(ValueType::U8), - Expr::U16(_i) => Ok(ValueType::U16), - Expr::U32(_i) => Ok(ValueType::U32), + Expr::Bool(_b) => Ok(VarType::Bool), + Expr::U8(_i) => Ok(VarType::U8), + Expr::U16(_i) => Ok(VarType::U16), + Expr::U32(_i) => Ok(VarType::U32), Expr::Tuple(exprs) => { let mut ts = Vec::new(); for expr in exprs { ts.push(expr.infer_type(scope)?); } - Ok(ValueType::Tuple(ts)) + Ok(VarType::Tuple(ts)) } Expr::TupleProj(head, index) => match head.infer_type(scope)? { - ValueType::Tuple(vs) => Ok(vs[*index].clone()), + VarType::Tuple(vs) => Ok(vs[*index].clone()), _ => Err("expected tuple type".to_string()), }, Expr::Record(fields) => { @@ -309,26 +468,26 @@ impl Expr { for (label, expr) in fields { fs.push((label.clone(), expr.infer_type(scope)?)); } - Ok(ValueType::Record(fs)) + Ok(VarType::Record(fs)) } Expr::RecordProj(head, label) => Ok(head.infer_type(scope)?.record_proj(label)), - Expr::Variant(label, expr) => Ok(ValueType::Union(vec![( + Expr::Variant(label, expr) => Ok(VarType::union(vec![( label.clone(), expr.infer_type(scope)?, )])), Expr::Seq(exprs) => { - let mut t = ValueType::Any; + let mut t = VarType::var(); for e in exprs { t = t.unify(&e.infer_type(scope)?)?; } - Ok(ValueType::Seq(Box::new(t))) + Ok(VarType::Seq(Box::new(t))) } Expr::Match(head, branches) => { if branches.is_empty() { return Err("infer_type: empty Match".to_string()); } let head_type = head.infer_type(scope)?; - let mut t = ValueType::Any; + let mut t = VarType::var(); for (pattern, branch) in branches { t = t.unify(&pattern.infer_expr_branch_type(scope, &head_type, branch)?)?; } @@ -336,70 +495,76 @@ impl Expr { } Expr::Lambda(_, _) => Err("cannot infer_type lambda".to_string()), - Expr::IntRel(_rel, x, y) => match (x.infer_type(scope)?, y.infer_type(scope)?) { - (ValueType::U8, ValueType::U8) => Ok(ValueType::Bool), - (ValueType::U16, ValueType::U16) => Ok(ValueType::Bool), - (ValueType::U32, ValueType::U32) => Ok(ValueType::Bool), + Expr::IntRel(_rel, x, y) => match ( + x.infer_type(scope)?.expand_var(), + y.infer_type(scope)?.expand_var(), + ) { + (VarType::U8, VarType::U8) => Ok(VarType::Bool), + (VarType::U16, VarType::U16) => Ok(VarType::Bool), + (VarType::U32, VarType::U32) => Ok(VarType::Bool), (x, y) => Err(format!("mismatched operands {x:?}, {y:?}")), }, - Expr::Arith(_arith, x, y) => match (x.infer_type(scope)?, y.infer_type(scope)?) { - (ValueType::U8, ValueType::U8) => Ok(ValueType::U8), - (ValueType::U16, ValueType::U16) => Ok(ValueType::U16), - (ValueType::U32, ValueType::U32) => Ok(ValueType::U32), + Expr::Arith(_arith, x, y) => match ( + x.infer_type(scope)?.expand_var(), + y.infer_type(scope)?.expand_var(), + ) { + (VarType::U8, VarType::U8) => Ok(VarType::U8), + (VarType::U16, VarType::U16) => Ok(VarType::U16), + (VarType::U32, VarType::U32) => Ok(VarType::U32), (x, y) => Err(format!("mismatched operands {x:?}, {y:?}")), }, - Expr::AsU8(x) => match x.infer_type(scope)? { - ValueType::U8 => Ok(ValueType::U8), - ValueType::U16 => Ok(ValueType::U8), - ValueType::U32 => Ok(ValueType::U8), + Expr::AsU8(x) => match x.infer_type(scope)?.expand_var() { + VarType::U8 => Ok(VarType::U8), + VarType::U16 => Ok(VarType::U8), + VarType::U32 => Ok(VarType::U8), x => Err(format!("cannot convert {x:?} to U8")), }, - Expr::AsU16(x) => match x.infer_type(scope)? { - ValueType::U8 => Ok(ValueType::U16), - ValueType::U16 => Ok(ValueType::U16), - ValueType::U32 => Ok(ValueType::U16), + Expr::AsU16(x) => match x.infer_type(scope)?.expand_var() { + VarType::U8 => Ok(VarType::U16), + VarType::U16 => Ok(VarType::U16), + VarType::U32 => Ok(VarType::U16), x => Err(format!("cannot convert {x:?} to U16")), }, - Expr::AsU32(x) => match x.infer_type(scope)? { - ValueType::U8 => Ok(ValueType::U32), - ValueType::U16 => Ok(ValueType::U32), - ValueType::U32 => Ok(ValueType::U32), + Expr::AsU32(x) => match x.infer_type(scope)?.expand_var() { + VarType::U8 => Ok(VarType::U32), + VarType::U16 => Ok(VarType::U32), + VarType::U32 => Ok(VarType::U32), x => Err(format!("cannot convert {x:?} to U32")), }, - Expr::AsChar(x) => match x.infer_type(scope)? { - ValueType::U8 => Ok(ValueType::Char), - ValueType::U16 => Ok(ValueType::Char), - ValueType::U32 => Ok(ValueType::Char), + Expr::AsChar(x) => match x.infer_type(scope)?.expand_var() { + VarType::U8 => Ok(VarType::Char), + VarType::U16 => Ok(VarType::Char), + VarType::U32 => Ok(VarType::Char), x => Err(format!("cannot convert {x:?} to Char")), }, Expr::U16Be(bytes) => match bytes.infer_type(scope)?.unwrap_tuple_type().as_slice() { - [ValueType::U8, ValueType::U8] => Ok(ValueType::U16), + [VarType::U8, VarType::U8] => Ok(VarType::U16), other => Err(format!("U16Be: expected (U8, U8), found {other:#?}")), }, Expr::U16Le(bytes) => match bytes.infer_type(scope)?.unwrap_tuple_type().as_slice() { - [ValueType::U8, ValueType::U8] => Ok(ValueType::U16), + [VarType::U8, VarType::U8] => Ok(VarType::U16), other => Err(format!("U16Le: expected (U8, U8), found {other:#?}")), }, Expr::U32Be(bytes) => match bytes.infer_type(scope)?.unwrap_tuple_type().as_slice() { - [ValueType::U8, ValueType::U8, ValueType::U8, ValueType::U8] => Ok(ValueType::U32), + [VarType::U8, VarType::U8, VarType::U8, VarType::U8] => Ok(VarType::U32), other => Err(format!( "U32Be: expected (U8, U8, U8, U8), found {other:#?}" )), }, Expr::U32Le(bytes) => match bytes.infer_type(scope)?.unwrap_tuple_type().as_slice() { - [ValueType::U8, ValueType::U8, ValueType::U8, ValueType::U8] => Ok(ValueType::U32), + [VarType::U8, VarType::U8, VarType::U8, VarType::U8] => Ok(VarType::U32), other => Err(format!( "U32Le: expected (U8, U8, U8, U8), found {other:#?}" )), }, - Expr::SeqLength(seq) => match seq.infer_type(scope)? { - ValueType::Seq(_t) => Ok(ValueType::U32), + Expr::SeqLength(seq) => match seq.infer_type(scope)?.expand_var() { + VarType::Seq(_t) => Ok(VarType::U32), other => Err(format!("SeqLength: expected Seq, found {other:?}")), }, - Expr::SubSeq(seq, start, length) => match seq.infer_type(scope)? { - ValueType::Seq(t) => { + Expr::SubSeq(seq, start, length) => match seq.infer_type(scope)?.expand_var() { + VarType::Seq(t) => { let start_type = start.infer_type(scope)?; let length_type = length.infer_type(scope)?; if !start_type.is_numeric_type() { @@ -412,17 +577,17 @@ impl Expr { "SubSeq length must be numeric, found {length_type:?}" )); } - Ok(ValueType::Seq(t)) + Ok(VarType::Seq(t.clone())) } other => Err(format!("SubSeq: expected Seq, found {other:?}")), }, Expr::FlatMap(expr, seq) => match expr.as_ref() { - Expr::Lambda(name, expr) => match seq.infer_type(scope)? { - ValueType::Seq(t) => { + Expr::Lambda(name, expr) => match seq.infer_type(scope)?.expand_var() { + VarType::Seq(t) => { let mut child_scope = TypeScope::child(scope); - child_scope.push(name.clone(), *t); - match expr.infer_type(&child_scope)? { - ValueType::Seq(t2) => Ok(ValueType::Seq(t2)), + child_scope.push(name.clone(), (**t).clone()); + match expr.infer_type(&child_scope)?.expand_var() { + VarType::Seq(t2) => Ok(VarType::Seq(t2.clone())), other => Err(format!("FlatMap: expected Seq, found {other:?}")), } } @@ -431,20 +596,23 @@ impl Expr { other => Err(format!("FlatMap: expected Lambda, found {other:?}")), }, Expr::FlatMapAccum(expr, accum, accum_type, seq) => match expr.as_ref() { - Expr::Lambda(name, expr) => match seq.infer_type(scope)? { - ValueType::Seq(t) => { - let accum_type = accum.infer_type(scope)?.unify(accum_type)?; + Expr::Lambda(name, expr) => match seq.infer_type(scope)?.expand_var() { + VarType::Seq(t) => { + let accum_type = + accum.infer_type(scope)?.unify(&accum_type.to_var_type())?; let mut child_scope = TypeScope::child(scope); - child_scope - .push(name.clone(), ValueType::Tuple(vec![accum_type.clone(), *t])); + child_scope.push( + name.clone(), + VarType::Tuple(vec![accum_type.clone(), *t.clone()]), + ); match expr .infer_type(&child_scope)? .unwrap_tuple_type() .as_mut_slice() { - [accum_result, ValueType::Seq(t2)] => { + [accum_result, VarType::Seq(t2)] => { accum_result.unify(&accum_type)?; - Ok(ValueType::Seq(t2.clone())) + Ok(VarType::Seq(t2.clone())) } _ => panic!("FlatMapAccum: expected two values"), } @@ -458,11 +626,11 @@ impl Expr { return Err(format!("Dup: count is not numeric: {count:?}")); } let t = expr.infer_type(scope)?; - Ok(ValueType::Seq(Box::new(t))) + Ok(VarType::Seq(Box::new(t))) } - Expr::Inflate(seq) => match seq.infer_type(scope)? { + Expr::Inflate(seq) => match seq.infer_type(scope)?.expand_var() { // FIXME should check values are appropriate variants - ValueType::Seq(_values) => Ok(ValueType::Seq(Box::new(ValueType::U8))), + VarType::Seq(_values) => Ok(VarType::Seq(Box::new(VarType::U8))), other => Err(format!("Inflate: expected Seq, found {other:?}")), }, } @@ -800,10 +968,10 @@ impl FormatModule { ) -> FormatRef { let mut scope = TypeScope::new(); for (arg_name, arg_type) in &args { - scope.push(arg_name.clone(), arg_type.clone()); + scope.push(arg_name.clone(), arg_type.to_var_type()); } let format_type = match self.infer_format_type(&scope, &format) { - Ok(t) => t, + Ok(t) => t.to_value_type(), Err(msg) => panic!("{msg}"), }; let level = self.names.len(); @@ -830,21 +998,21 @@ impl FormatModule { &self.format_types[level] } - fn infer_format_type(&self, scope: &TypeScope<'_>, f: &Format) -> Result { + fn infer_format_type(&self, scope: &TypeScope<'_>, f: &Format) -> Result { match f { Format::ItemVar(level, arg_exprs) => { let arg_names = self.get_args(*level); for ((_name, arg_type), expr) in Iterator::zip(arg_names.iter(), arg_exprs.iter()) { let t = expr.infer_type(scope)?; - let _t = arg_type.unify(&t)?; + let _t = arg_type.to_var_type().unify(&t)?; } - Ok(self.get_format_type(*level).clone()) + Ok(self.get_format_type(*level).to_var_type()) } - Format::Fail => Ok(ValueType::Empty), - Format::EndOfInput => Ok(ValueType::Tuple(vec![])), - Format::Align(_n) => Ok(ValueType::Tuple(vec![])), - Format::Byte(_bs) => Ok(ValueType::U8), - Format::Variant(label, f) => Ok(ValueType::Union(vec![( + Format::Fail => Ok(VarType::Empty), + Format::EndOfInput => Ok(VarType::Tuple(vec![])), + Format::Align(_n) => Ok(VarType::Tuple(vec![])), + Format::Byte(_bs) => Ok(VarType::U8), + Format::Variant(label, f) => Ok(VarType::union(vec![( label.clone(), self.infer_format_type(scope, f)?, )])), @@ -853,10 +1021,10 @@ impl FormatModule { for (label, f) in branches { ts.push((label.clone(), self.infer_format_type(scope, f)?)); } - Ok(ValueType::Union(ts)) + Ok(VarType::union(ts)) } Format::Union(branches) => { - let mut t = ValueType::Any; + let mut t = VarType::var(); for f in branches { t = t.unify(&self.infer_format_type(scope, f)?)?; } @@ -867,7 +1035,7 @@ impl FormatModule { for f in fields { ts.push(self.infer_format_type(scope, f)?); } - Ok(ValueType::Tuple(ts)) + Ok(VarType::Tuple(ts)) } Format::Record(fields) => { let mut ts = Vec::with_capacity(fields.len()); @@ -877,20 +1045,20 @@ impl FormatModule { ts.push((label.clone(), t.clone())); record_scope.push(label.clone(), t); } - Ok(ValueType::Record(ts)) + Ok(VarType::Record(ts)) } Format::Repeat(a) | Format::Repeat1(a) => { let t = self.infer_format_type(scope, a)?; - Ok(ValueType::Seq(Box::new(t))) + Ok(VarType::Seq(Box::new(t))) } Format::RepeatCount(_expr, a) | Format::RepeatUntilLast(_expr, a) | Format::RepeatUntilSeq(_expr, a) => { let t = self.infer_format_type(scope, a)?; - Ok(ValueType::Seq(Box::new(t))) + Ok(VarType::Seq(Box::new(t))) } Format::Peek(a) => self.infer_format_type(scope, a), - Format::PeekNot(_a) => Ok(ValueType::Tuple(vec![])), + Format::PeekNot(_a) => Ok(VarType::Tuple(vec![])), Format::Slice(_expr, a) => self.infer_format_type(scope, a), Format::Bits(a) => self.infer_format_type(scope, a), Format::WithRelativeOffset(_expr, a) => self.infer_format_type(scope, a), @@ -917,7 +1085,7 @@ impl FormatModule { return Err("infer_format_type: empty Match".to_string()); } let head_type = head.infer_type(scope)?; - let mut t = ValueType::Any; + let mut t = VarType::var(); for (pattern, branch) in branches { t = t.unify( &pattern.infer_format_branch_type(scope, &head_type, self, branch)?, @@ -928,9 +1096,9 @@ impl FormatModule { Format::Dynamic(name, dynformat, format) => { match dynformat { DynFormat::Huffman(lengths_expr, _opt_values_expr) => { - match lengths_expr.infer_type(scope)? { - ValueType::Seq(t) => match &*t { - ValueType::U8 | ValueType::U16 => {} + match lengths_expr.infer_type(scope)?.expand_var() { + VarType::Seq(t) => match t.expand_var() { + VarType::U8 | VarType::U16 => {} other => { return Err(format!( "Huffman: expected U8 or U16, found {other:?}" @@ -943,7 +1111,7 @@ impl FormatModule { } } let mut child_scope = TypeScope::child(scope); - child_scope.push_format(name.clone(), ValueType::U16); + child_scope.push_format(name.clone(), VarType::U16); self.infer_format_type(&child_scope, format) } Format::Apply(name) => match scope.get_type_by_name(name) { @@ -1504,12 +1672,12 @@ impl<'a> TypeScope<'a> { } } - fn push(&mut self, name: Label, t: ValueType) { + fn push(&mut self, name: Label, t: VarType) { self.names.push(name); self.types.push(ValueKind::Value(t)); } - fn push_format(&mut self, name: Label, t: ValueType) { + fn push_format(&mut self, name: Label, t: VarType) { self.names.push(name); self.types.push(ValueKind::Format(t)); }