@@ -34,6 +34,7 @@ use crate::surface::elaboration::reporting::Message;
3434use crate :: surface:: { distillation, pretty, BinOp , FormatField , Item , Module , Pattern , Term } ;
3535
3636mod order;
37+ mod patterns;
3738mod reporting;
3839mod unification;
3940
@@ -463,6 +464,70 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
463464 ( labels. into ( ) , filtered_fields)
464465 }
465466
467+ fn check_tuple_fields < F > (
468+ & mut self ,
469+ range : ByteRange ,
470+ fields : & [ F ] ,
471+ get_range : fn ( & F ) -> ByteRange ,
472+ expected_labels : & [ StringId ] ,
473+ ) -> Result < ( ) , ( ) > {
474+ if fields. len ( ) == expected_labels. len ( ) {
475+ return Ok ( ( ) ) ;
476+ }
477+
478+ let mut found_labels = Vec :: with_capacity ( fields. len ( ) ) ;
479+ let mut fields_iter = fields. iter ( ) . enumerate ( ) . peekable ( ) ;
480+ let mut expected_labels_iter = expected_labels. iter ( ) ;
481+
482+ // use the label names from the expected labels
483+ while let Some ( ( ( _, field) , label) ) =
484+ Option :: zip ( fields_iter. peek ( ) , expected_labels_iter. next ( ) )
485+ {
486+ found_labels. push ( ( get_range ( field) , * label) ) ;
487+ fields_iter. next ( ) ;
488+ }
489+
490+ // use numeric labels for excess fields
491+ for ( index, field) in fields_iter {
492+ found_labels. push ( (
493+ get_range ( field) ,
494+ self . interner . borrow_mut ( ) . get_tuple_label ( index) ,
495+ ) ) ;
496+ }
497+
498+ self . push_message ( Message :: MismatchedFieldLabels {
499+ range,
500+ found_labels,
501+ expected_labels : expected_labels. to_vec ( ) ,
502+ } ) ;
503+ Err ( ( ) )
504+ }
505+
506+ fn check_record_fields < F > (
507+ & mut self ,
508+ range : ByteRange ,
509+ fields : & [ F ] ,
510+ get_label : impl Fn ( & F ) -> ( ByteRange , StringId ) ,
511+ labels : & ' arena [ StringId ] ,
512+ ) -> Result < ( ) , ( ) > {
513+ if fields. len ( ) == labels. len ( )
514+ && fields
515+ . iter ( )
516+ . zip ( labels. iter ( ) )
517+ . all ( |( field, type_label) | get_label ( field) . 1 == * type_label)
518+ {
519+ return Ok ( ( ) ) ;
520+ }
521+
522+ // TODO: improve handling of duplicate labels
523+ self . push_message ( Message :: MismatchedFieldLabels {
524+ range,
525+ found_labels : fields. iter ( ) . map ( get_label) . collect ( ) ,
526+ expected_labels : labels. to_vec ( ) ,
527+ } ) ;
528+ Err ( ( ) )
529+ }
530+
466531 /// Parse a source string into number, assuming an ASCII encoding.
467532 fn parse_ascii < T > (
468533 & mut self ,
@@ -696,177 +761,6 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
696761 term
697762 }
698763
699- /// Check that a pattern matches an expected type.
700- fn check_pattern (
701- & mut self ,
702- pattern : & Pattern < ByteRange > ,
703- expected_type : & ArcValue < ' arena > ,
704- ) -> CheckedPattern {
705- match pattern {
706- Pattern :: Name ( range, name) => CheckedPattern :: Binder ( * range, * name) ,
707- Pattern :: Placeholder ( range) => CheckedPattern :: Placeholder ( * range) ,
708- Pattern :: StringLiteral ( range, lit) => {
709- let constant = match expected_type. match_prim_spine ( ) {
710- Some ( ( Prim :: U8Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U8 ) ,
711- Some ( ( Prim :: U16Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U16 ) ,
712- Some ( ( Prim :: U32Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U32 ) ,
713- Some ( ( Prim :: U64Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U64 ) ,
714- // Some((Prim::Array8Type, [len, _])) => todo!(),
715- // Some((Prim::Array16Type, [len, _])) => todo!(),
716- // Some((Prim::Array32Type, [len, _])) => todo!(),
717- // Some((Prim::Array64Type, [len, _])) => todo!(),
718- Some ( ( Prim :: ReportedError , _) ) => None ,
719- _ => {
720- let expected_type = self . pretty_print_value ( expected_type) ;
721- self . push_message ( Message :: StringLiteralNotSupported {
722- range : * range,
723- expected_type,
724- } ) ;
725- None
726- }
727- } ;
728-
729- match constant {
730- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
731- None => CheckedPattern :: ReportedError ( * range) ,
732- }
733- }
734- Pattern :: NumberLiteral ( range, lit) => {
735- let constant = match expected_type. match_prim_spine ( ) {
736- Some ( ( Prim :: U8Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U8 ) ,
737- Some ( ( Prim :: U16Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U16 ) ,
738- Some ( ( Prim :: U32Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U32 ) ,
739- Some ( ( Prim :: U64Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U64 ) ,
740- Some ( ( Prim :: S8Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S8 ) ,
741- Some ( ( Prim :: S16Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S16 ) ,
742- Some ( ( Prim :: S32Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S32 ) ,
743- Some ( ( Prim :: S64Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S64 ) ,
744- Some ( ( Prim :: F32Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: F32 ) ,
745- Some ( ( Prim :: F64Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: F64 ) ,
746- Some ( ( Prim :: ReportedError , _) ) => None ,
747- _ => {
748- let expected_type = self . pretty_print_value ( expected_type) ;
749- self . push_message ( Message :: NumericLiteralNotSupported {
750- range : * range,
751- expected_type,
752- } ) ;
753- None
754- }
755- } ;
756-
757- match constant {
758- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
759- None => CheckedPattern :: ReportedError ( * range) ,
760- }
761- }
762- Pattern :: BooleanLiteral ( range, boolean) => {
763- let constant = match expected_type. match_prim_spine ( ) {
764- Some ( ( Prim :: BoolType , [ ] ) ) => match * boolean {
765- true => Some ( Const :: Bool ( true ) ) ,
766- false => Some ( Const :: Bool ( false ) ) ,
767- } ,
768- _ => {
769- self . push_message ( Message :: BooleanLiteralNotSupported { range : * range } ) ;
770- None
771- }
772- } ;
773-
774- match constant {
775- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
776- None => CheckedPattern :: ReportedError ( * range) ,
777- }
778- }
779- Pattern :: RecordLiteral ( _, _) => todo ! ( ) ,
780- Pattern :: Tuple ( _, _) => todo ! ( ) ,
781- }
782- }
783-
784- /// Synthesize the type of a pattern.
785- fn synth_pattern (
786- & mut self ,
787- pattern : & Pattern < ByteRange > ,
788- ) -> ( CheckedPattern , ArcValue < ' arena > ) {
789- match pattern {
790- Pattern :: Name ( range, name) => {
791- let source = MetaSource :: NamedPatternType ( * range, * name) ;
792- let r#type = self . push_unsolved_type ( source) ;
793- ( CheckedPattern :: Binder ( * range, * name) , r#type)
794- }
795- Pattern :: Placeholder ( range) => {
796- let source = MetaSource :: PlaceholderPatternType ( * range) ;
797- let r#type = self . push_unsolved_type ( source) ;
798- ( CheckedPattern :: Placeholder ( * range) , r#type)
799- }
800- Pattern :: StringLiteral ( range, _) => {
801- self . push_message ( Message :: AmbiguousStringLiteral { range : * range } ) ;
802- let source = MetaSource :: ReportedErrorType ( * range) ;
803- let r#type = self . push_unsolved_type ( source) ;
804- ( CheckedPattern :: ReportedError ( * range) , r#type)
805- }
806- Pattern :: NumberLiteral ( range, _) => {
807- self . push_message ( Message :: AmbiguousNumericLiteral { range : * range } ) ;
808- let source = MetaSource :: ReportedErrorType ( * range) ;
809- let r#type = self . push_unsolved_type ( source) ;
810- ( CheckedPattern :: ReportedError ( * range) , r#type)
811- }
812- Pattern :: BooleanLiteral ( range, val) => {
813- let r#const = Const :: Bool ( * val) ;
814- let r#type = self . bool_type . clone ( ) ;
815- ( CheckedPattern :: ConstLit ( * range, r#const) , r#type)
816- }
817- Pattern :: RecordLiteral ( _, _) => todo ! ( ) ,
818- Pattern :: Tuple ( _, _) => todo ! ( ) ,
819- }
820- }
821-
822- /// Check that the type of an annotated pattern matches an expected type.
823- fn check_ann_pattern (
824- & mut self ,
825- pattern : & Pattern < ByteRange > ,
826- r#type : Option < & Term < ' _ , ByteRange > > ,
827- expected_type : & ArcValue < ' arena > ,
828- ) -> CheckedPattern {
829- match r#type {
830- None => self . check_pattern ( pattern, expected_type) ,
831- Some ( r#type) => {
832- let range = r#type. range ( ) ;
833- let r#type = self . check ( r#type, & self . universe . clone ( ) ) ;
834- let r#type = self . eval_env ( ) . eval ( & r#type) ;
835-
836- match self . unification_context ( ) . unify ( & r#type, expected_type) {
837- Ok ( ( ) ) => self . check_pattern ( pattern, & r#type) ,
838- Err ( error) => {
839- let lhs = self . pretty_print_value ( & r#type) ;
840- let rhs = self . pretty_print_value ( expected_type) ;
841- self . push_message ( Message :: FailedToUnify {
842- range,
843- lhs,
844- rhs,
845- error,
846- } ) ;
847- CheckedPattern :: ReportedError ( range)
848- }
849- }
850- }
851- }
852- }
853-
854- /// Synthesize the type of an annotated pattern.
855- fn synth_ann_pattern (
856- & mut self ,
857- pattern : & Pattern < ByteRange > ,
858- r#type : Option < & Term < ' _ , ByteRange > > ,
859- ) -> ( CheckedPattern , ArcValue < ' arena > ) {
860- match r#type {
861- None => self . synth_pattern ( pattern) ,
862- Some ( r#type) => {
863- let r#type = self . check ( r#type, & self . universe . clone ( ) ) ;
864- let type_value = self . eval_env ( ) . eval ( & r#type) ;
865- ( self . check_pattern ( pattern, & type_value) , type_value)
866- }
867- }
868- }
869-
870764 /// Push a local definition onto the context.
871765 /// The supplied `pattern` is expected to be irrefutable.
872766 fn push_local_def (
@@ -886,6 +780,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
886780 None
887781 }
888782 CheckedPattern :: ReportedError ( _) => None ,
783+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
889784 } ;
890785
891786 self . local_env . push_def ( name, expr, r#type) ;
@@ -911,6 +806,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
911806 None
912807 }
913808 CheckedPattern :: ReportedError ( _) => None ,
809+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
914810 } ;
915811
916812 let expr = self . local_env . push_param ( name, r#type) ;
@@ -970,18 +866,10 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
970866 self . check_fun_lit ( * range, patterns, body_expr, & expected_type)
971867 }
972868 ( Term :: RecordLiteral ( range, expr_fields) , Value :: RecordType ( labels, types) ) => {
973- // TODO: improve handling of duplicate labels
974- if expr_fields. len ( ) != labels. len ( )
975- || Iterator :: zip ( expr_fields. iter ( ) , labels. iter ( ) )
976- . any ( |( expr_field, type_label) | expr_field. label . 1 != * type_label)
869+ if self
870+ . check_record_fields ( * range, expr_fields, |field| field. label , labels)
871+ . is_err ( )
977872 {
978- self . push_message ( Message :: MismatchedFieldLabels {
979- range : * range,
980- expr_labels : ( expr_fields. iter ( ) )
981- . map ( |expr_field| expr_field. label )
982- . collect ( ) ,
983- type_labels : labels. to_vec ( ) ,
984- } ) ;
985873 return core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError ) ;
986874 }
987875
@@ -1045,33 +933,11 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
1045933 core:: Term :: FormatRecord ( range. into ( ) , labels, formats)
1046934 }
1047935 ( Term :: Tuple ( range, elem_exprs) , Value :: RecordType ( labels, types) ) => {
1048- if elem_exprs. len ( ) != labels. len ( ) {
1049- let mut expr_labels = Vec :: with_capacity ( elem_exprs. len ( ) ) ;
1050- let mut elem_exprs = elem_exprs. iter ( ) . enumerate ( ) . peekable ( ) ;
1051- let mut label_iter = labels. iter ( ) ;
1052-
1053- // use the label names from the expected type
1054- while let Some ( ( ( _, elem_expr) , label) ) =
1055- Option :: zip ( elem_exprs. peek ( ) , label_iter. next ( ) )
1056- {
1057- expr_labels. push ( ( elem_expr. range ( ) , * label) ) ;
1058- elem_exprs. next ( ) ;
1059- }
1060-
1061- // use numeric labels for excess elems
1062- for ( index, elem_expr) in elem_exprs {
1063- expr_labels. push ( (
1064- elem_expr. range ( ) ,
1065- self . interner . borrow_mut ( ) . get_tuple_label ( index) ,
1066- ) ) ;
1067- }
1068-
1069- self . push_message ( Message :: MismatchedFieldLabels {
1070- range : * range,
1071- expr_labels,
1072- type_labels : labels. to_vec ( ) ,
1073- } ) ;
1074- return core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError ) ;
936+ if self
937+ . check_tuple_fields ( * range, elem_exprs, |expr| expr. range ( ) , labels)
938+ . is_err ( )
939+ {
940+ return core:: Term :: error ( range. into ( ) ) ;
1075941 }
1076942
1077943 let mut types = types. clone ( ) ;
@@ -2027,6 +1893,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
20271893 self . elab_match_unreachable ( match_info, equations) ;
20281894 core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError )
20291895 }
1896+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
20301897 }
20311898 }
20321899 None => self . elab_match_absurd ( is_reachable, match_info) ,
@@ -2116,6 +1983,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
21161983 default_branch = ( None , self . scope . to_scope ( default_expr) as & _ ) ;
21171984 self . local_env . pop ( ) ;
21181985 }
1986+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
21191987 } ;
21201988
21211989 // A default pattern was found, check any unreachable patterns.
@@ -2196,15 +2064,17 @@ impl_from_str_radix!(u64);
21962064
21972065/// Simple patterns that have had some initial elaboration performed on them
21982066#[ derive( Debug ) ]
2199- enum CheckedPattern {
2200- /// Pattern that binds local variable
2201- Binder ( ByteRange , StringId ) ,
2067+ enum CheckedPattern < ' arena > {
2068+ /// Error sentinel
2069+ ReportedError ( ByteRange ) ,
22022070 /// Placeholder patterns that match everything
22032071 Placeholder ( ByteRange ) ,
2072+ /// Pattern that binds local variable
2073+ Binder ( ByteRange , StringId ) ,
22042074 /// Constant literals
22052075 ConstLit ( ByteRange , Const ) ,
2206- /// Error sentinel
2207- ReportedError ( ByteRange ) ,
2076+ /// Record literals
2077+ RecordLit ( ByteRange , & ' arena [ StringId ] , & ' arena [ Self ] ) ,
22082078}
22092079
22102080/// Scrutinee of a match expression
0 commit comments