diff --git a/Strata/DDM/AST.lean b/Strata/DDM/AST.lean index 55c48ecdb..2cfc6aa39 100644 --- a/Strata/DDM/AST.lean +++ b/Strata/DDM/AST.lean @@ -13,6 +13,7 @@ public import Strata.DDM.Util.SourceRange import Std.Data.HashMap import all Strata.DDM.Util.Array +import Strata.Util.DecideProp import all Strata.DDM.Util.ByteArray set_option autoImplicit false @@ -614,6 +615,23 @@ def resultLevel (varCount : Nat) (metadata : Metadata) : Option (Fin varCount) : else panic! s!"Scope index {idx} out of bounds (varCount = {varCount})" +/-- Returns the argument index from @[preRegisterTypes] metadata, if present. -/ +def preRegisterTypesIndex (metadata : Metadata) : Option Nat := + match metadata[q`StrataDDL.preRegisterTypes]? with + | none => none + | some #[.catbvar idx] => some idx + | some _ => panic! s!"Unexpected argument count to preRegisterTypes" + +/-- Returns the level for @[preRegisterTypes] metadata, if present. -/ +def preRegisterTypesLevel (varCount : Nat) (metadata : Metadata) : Option (Fin varCount) := + match metadata.preRegisterTypesIndex with + | none => none + | some idx => + if _ : idx < varCount then + some ⟨varCount - (idx + 1), by omega⟩ + else + panic! s!"preRegisterTypes index {idx} out of bounds (varCount = {varCount})" + end Metadata abbrev Var := String @@ -907,122 +925,6 @@ def buildFunctionType (template : FunctionTemplate) -- Build arrow type: param1 -> param2 -> ... -> returnType .ok <| paramTypes.foldr (init := returnType) fun argType tp => .arrow default argType tp -/-- -Result of expanding a single template. -Contains the generated function signatures and any errors encountered. --/ -structure TemplateExpansionResult where - /-- Generated function signatures as (name, type) pairs -/ - functions : Array (String × TypeExpr) - /-- Errors encountered during expansion -/ - errors : Array String - deriving Repr - -/-- -Expand a single function template based on its scope. - -Function templates specify patterns for generating auxiliary functions -from datatype declarations. This function expands one template according to -its iteration scope: - -- **perConstructor**: Generates one function per constructor (e.g., testers -like `..isNone`) -- **perField**: Generates one function per unique field across all constructors -(e.g., accessors) - -**Parameters:** -- `datatypeName`: Name of the datatype (used in name pattern expansion) -- `datatypeType`: TypeExpr for the datatype (used in function signatures) -- `constructorInfo`: Array of constructor information -- `template`: The function template to expand -- `dialectName`: Dialect name (for resolving builtin types) -- `existingNames`: Set of already-used names (for duplicate detection) - -**Example:** For a `perConstructor` template defined as: -``` -perConstructor([.datatype, .literal "..is", .constructor], [.datatype], -.builtin "bool") -``` -This specifies: -- Name pattern: `[.datatype, .literal "..is", .constructor]` → generates names -like `Option..isNone` -- Parameter types: `[.datatype]` → takes one parameter of the datatype type -- Return type: `.builtin "bool"` → returns a boolean - -Applied to `Option` with constructors `None` and `Some`, this generates: -- `Option..isNone : Option -> bool` -- `Option..isSome : Option -> bool` --/ -def expandSingleTemplate - (datatypeName : String) - (datatypeType : TypeExpr) - (constructorInfo : Array ConstructorInfo) - (template : FunctionTemplate) - (dialectName : String) - (existingNames : Std.HashSet String) : TemplateExpansionResult := - -- First validate the pattern - match validateNamePattern template.namePattern template.scope with - | some err => { functions := #[], errors := #[err] } - | none => - match template.scope with - | .perConstructor => - -- Generate one function per constructor - let (funcs, errs, _) := constructorInfo.foldl (init := (#[], #[], existingNames)) fun (funcs, errs, names) constr => - let funcName := expandNamePattern template.namePattern datatypeName (some constr.name) - if names.contains funcName then - (funcs, errs.push s!"Duplicate function name: {funcName}", names) - else - match buildFunctionType template datatypeType none dialectName with - | .ok funcType => - (funcs.push (funcName, funcType), errs, names.insert funcName) - | .error e => - (funcs, errs.push e, names) - { functions := funcs, errors := errs } - - | .perField => - -- Generate one function per unique field across all constructors - -- Error if the same field name appears with different types - let allFields := constructorInfo.foldl (init := #[]) fun acc c => acc ++ c.fields - let (funcs, errs, _) := allFields.foldl (init := (#[], #[], existingNames)) fun (funcs, errs, names) (fieldName, fieldTp) => - let funcName := expandNamePattern template.namePattern datatypeName none (some fieldName) - if names.contains funcName then - (funcs, errs.push s!"Duplicate field name '{fieldName}' across constructors in datatype '{datatypeName}'", names) - else - match buildFunctionType template datatypeType (some fieldTp) dialectName with - | .ok funcType => - (funcs.push (funcName, funcType), errs, names.insert funcName) - | .error e => - (funcs, errs.push e, names) - { functions := funcs, errors := errs } - -/-- -This function generates function signatures for an array of function templates -in order. Templates are specified in `@[declareDatatype]` annotations -to automatically generate auxiliary functions like testers and field accessors. -Within each template, functions are generated in constructor/field declaration order. - -**Parameters:** -- `datatypeName`: Name of the datatype -- `datatypeType`: TypeExpr for the datatype -- `constructorInfo`: Array of constructor information -- `templates`: Array of function templates to expand -- `dialectName`: Dialect name (for resolving builtin types) -- `existingNames`: Optional set of pre-existing names to avoid --/ -def expandFunctionTemplates - (datatypeName : String) - (datatypeType : TypeExpr) - (constructorInfo : Array ConstructorInfo) - (templates : Array FunctionTemplate) - (dialectName : String) - (existingNames : Std.HashSet String := {}) : TemplateExpansionResult := - templates.foldl (init := { functions := #[], errors := #[] }) fun acc template => - -- Track names from previous templates to detect cross-template duplicates - let currentNames := acc.functions.foldl (init := existingNames) fun s (name, _) => s.insert name - let result := expandSingleTemplate datatypeName datatypeType constructorInfo template dialectName currentNames - { functions := acc.functions ++ result.functions - errors := acc.errors ++ result.errors } - /-- Specification for datatype declarations. Includes indices for extracting datatype information and optional function templates. @@ -1053,7 +955,6 @@ A spec for introducing a new binding into a type context. inductive BindingSpec (argDecls : ArgDecls) where | value (_ : ValueBindingSpec argDecls) | type (_ : TypeBindingSpec argDecls) -| typeForward (_ : TypeBindingSpec argDecls) -- Forward declaration (no AST node) | datatype (_ : DatatypeBindingSpec argDecls) | tvar (_ : TvarBindingSpec argDecls) deriving Repr @@ -1063,7 +964,6 @@ namespace BindingSpec def nameIndex {argDecls} : BindingSpec argDecls → DebruijnIndex argDecls.size | .value v => v.nameIndex | .type v => v.nameIndex -| .typeForward v => v.nameIndex | .datatype v => v.nameIndex | .tvar v => v.nameIndex @@ -1104,12 +1004,19 @@ private def mkValueBindingSpec newBindingErr "Arguments only allowed when result is a type." return { nameIndex, argsIndex, typeIndex, allowCat } -/-- Parse function templates from metadata arguments. -/ -private def parseFunctionTemplates (args : Array MetadataArg) : Array FunctionTemplate := - args.filterMap fun arg => +/-- Parse and validate function templates from metadata arguments. -/ +private def parseFunctionTemplates (args : Array MetadataArg) + : NewBindingM (Array FunctionTemplate) := do + let mut result := #[] + for arg in args do match arg with - | .functionTemplate t => some t - | _ => none + | .functionTemplate t => + if let some err := validateNamePattern t.namePattern t.scope then + newBindingErr s!"Function template error: {err}" + else + result := result.push t + | _ => pure () + return result def parseNewBindings (md : Metadata) (argDecls : ArgDecls) : Array (BindingSpec argDecls) × Array String := let ins (attr : MetadataAttr) : NewBindingM (Option (BindingSpec argDecls)) := do @@ -1148,22 +1055,6 @@ def parseNewBindings (md : Metadata) (argDecls : ArgDecls) : Array (BindingSpec pure <| some ⟨idx, argsP⟩ | _ => newBindingErr "declareType args invalid."; return none some <$> .type <$> pure { nameIndex, argsIndex, defIndex := none } - | q`StrataDDL.declareTypeForward => do - let #[.catbvar nameIndex, .option mArgsArg ] := attr.args - | newBindingErr s!"declareTypeForward has bad arguments {repr attr.args}."; return none - let .isTrue nameP := inferInstanceAs (Decidable (nameIndex < argDecls.size)) - | return panic! "Invalid name index" - let nameIndex := ⟨nameIndex, nameP⟩ - checkNameIndexIsIdent argDecls nameIndex - let argsIndex ← - match mArgsArg with - | none => pure none - | some (.catbvar idx) => - let .isTrue argsP := inferInstanceAs (Decidable (idx < argDecls.size)) - | return panic! "Invalid arg index" - pure <| some ⟨idx, argsP⟩ - | _ => newBindingErr "declareTypeForward args invalid."; return none - some <$> .typeForward <$> pure { nameIndex, argsIndex, defIndex := none } | q`StrataDDL.aliasType => do let #[.catbvar nameIndex, .option mArgsArg, .catbvar defIndex] := attr.args | newBindingErr "aliasType missing arguments."; return none @@ -1212,8 +1103,8 @@ def parseNewBindings (md : Metadata) (argDecls : ArgDecls) : Array (BindingSpec | return panic! "Invalid typeParams index" let .isTrue constructorsP := inferInstanceAs (Decidable (constructorsIndex < argDecls.size)) | return panic! "Invalid constructors index" - -- Parse function templates from remaining arguments (args[3..]) - let functionTemplates := parseFunctionTemplates (args.extract 3 args.size) + -- Parse and validate function templates from remaining arguments (args[3..]) + let functionTemplates ← parseFunctionTemplates (args.extract 3 args.size) some <$> .datatype <$> pure { nameIndex := ⟨nameIndex, nameP⟩, typeParamsIndex := ⟨typeParamsIndex, typeParamsP⟩, @@ -1726,6 +1617,15 @@ This includes transitive imports. partial def importedDialects (dm : DialectMap) (dialect : DialectName) (p : dialect ∈ dm) : DialectMap := importedDialectsAux dm.map dm.closed dialect p +/-- +Look up an operation's metadata in the dialect. +Returns the OpDecl if found, or none if the operation is not in the dialect. +-/ +def lookupOpDecl (dialects : DialectMap) (opName : QualifiedIdent) : Option OpDecl := + match dialects[opName.dialect]? with + | none => none + | some dialect => dialect.ops[opName.name]? + end DialectMap mutual @@ -1751,7 +1651,7 @@ partial def foldOverArgBindingSpecs {α β} /-- Invoke a function `f` over each of the declaration specifications for an operator. -/ -private partial def OperationF.foldBindingSpecs {α β} +partial def OperationF.foldBindingSpecs {α β} (m : DialectMap) (f : β → α → ∀{argDecls : ArgDecls}, BindingSpec argDecls → Vector (ArgF α) argDecls.size → β) (init : β) @@ -1798,12 +1698,6 @@ inductive GlobalKind where | type (params : List String) (definition : Option TypeExpr) deriving BEq, Inhabited, Repr -/-- State of a symbol in the GlobalContext -/ -inductive DeclState where - | forward -- Symbol is forward-declared (no AST node will be generated) - | defined -- Symbol has a complete definition -deriving BEq, DecidableEq, Repr, Inhabited - /-- Resolves a binding spec into a global kind. -/ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src : SourceRange) (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) : Option GlobalKind := match b with @@ -1829,7 +1723,7 @@ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src panic! s!"Expected new binding to be Type instead of {repr c}." | a => panic! s!"Expected new binding to be bound to type instead of {repr a}." - | .type b | .typeForward b => + | .type b => let params : Array String := match b.argsIndex with | none => #[] @@ -1849,9 +1743,9 @@ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src | _ => panic! "Bad arg" some <| .type params.toList value | .datatype b => - /- For datatypes, resolveBindingIndices only returns the datatype type - itself; the constructors and template-generated functions are handled - separately in addDatatypeBindings. -/ + -- For datatypes, resolveBindingIndices only returns the datatype type + -- itself; the constructors and template-generated functions are handled + -- separately in addDatatypeBindings!. let params : Array String := let addBinding (a : Array String) (_ : SourceRange) {argDecls : _} (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) := match args[b.nameIndex.toLevel] with @@ -1863,80 +1757,6 @@ partial def resolveBindingIndices { argDecls : ArgDecls } (m : DialectMap) (src -- tvar bindings are local only, not added to GlobalContext none -/-- -Typing environment created from declarations in an environment. --/ -structure GlobalContext where - nameMap : Std.HashMap Var FreeVarIndex - vars : Array (Var × GlobalKind × DeclState) -deriving Repr - -namespace GlobalContext - -instance : EmptyCollection GlobalContext where - emptyCollection := private { nameMap := {}, vars := {}} - ---deriving instance BEq for GlobalContext - -instance : Inhabited GlobalContext where - default := {} - -instance : Membership Var GlobalContext where - mem ctx v := v ∈ ctx.nameMap - -@[instance] -def instDecidableMem (v : Var) (ctx : GlobalContext) : Decidable (v ∈ ctx) := - inferInstanceAs (Decidable (v ∈ ctx.nameMap)) - -/-- Add a forward declaration (must not exist). Used by @[declareTypeForward]. - This adds to GlobalContext for name resolution but will NOT generate an AST node. -/ -def declareForward (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext := - if v ∈ ctx then - .error s!"Symbol '{v}' is already in scope" - else - let idx := ctx.vars.size - .ok { nameMap := ctx.nameMap.insert v idx, - vars := ctx.vars.push (v, k, .forward) } - -/-- Define a symbol. Used by @[declareDatatype], @[declareFn] with body, etc. - Replaces forward declaration, or adds new as defined. -/ -def define (ctx : GlobalContext) (v : Var) (k : GlobalKind) : Except String GlobalContext := - match ctx.nameMap.get? v with - | none => - -- Not declared, add as defined directly - let idx := ctx.vars.size - .ok { nameMap := ctx.nameMap.insert v idx, - vars := ctx.vars.push (v, k, .defined) } - | some idx => - let (name, _, state) := ctx.vars[idx]! - match state with - | .forward => - -- Replace forward declaration with definition (update in place) - .ok { ctx with vars := ctx.vars.set! idx (name, k, .defined) } - | .defined => - .error s!"Symbol '{v}' is already defined" - -/-- Check if a symbol is forward-declared (not yet defined). -/ -def isForward (ctx : GlobalContext) (idx : FreeVarIndex) : Bool := - match ctx.vars[idx]? with - | some (_, _, .forward) => true - | _ => false - -/-- Add a symbol as defined. -/ -def push (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext := - match ctx.define v k with - | .ok ctx' => ctx' - | .error msg => panic! msg - -/-- Return the index of the variable with the given name. -/ -def findIndex? (ctx : GlobalContext) (v : Var) : Option FreeVarIndex := ctx.nameMap.get? v - -def nameOf? (ctx : GlobalContext) (idx : FreeVarIndex) : Option String := ctx.vars[idx]? |>.map (·.fst) - -def kindOf! (ctx : GlobalContext) (idx : FreeVarIndex) : GlobalKind := - assert! idx < ctx.vars.size - ctx.vars[idx]!.2.1 - /-! ## Annotation-based Constructor Info Extraction @@ -1949,15 +1769,6 @@ The annotation-based approach: 3. Uses the indices from the annotation to extract the relevant arguments -/ -/-- -Look up an operation's metadata in the dialect. -Returns the OpDecl if found, or none if the operation is not in the dialect. --/ -private def lookupOpDecl (dialects : DialectMap) (opName : QualifiedIdent) : Option OpDecl := - match dialects[opName.dialect]? with - | none => none - | some dialect => dialect.ops[opName.name]? - /-- Check if an operation has the @[constructor(name, fields)] annotation. Returns the (nameIndex, fieldsIndex) if present. @@ -1985,52 +1796,58 @@ private def getConstructorListPushAnnotation (opDecl : OpDecl) : Option (Nat × | some #[.catbvar listIdx, .catbvar constrIdx] => some (listIdx, constrIdx) | _ => none -/-- -Extract fields from a Bindings argument using the existing @[declare] annotations. --/ -private def extractFieldsFromBindings (dialects : DialectMap) (arg : Arg) : Array (String × TypeExpr) := - let addField (acc : Array (String × TypeExpr)) (_ : SourceRange) - {argDecls : ArgDecls} (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) : Array (String × TypeExpr) := +/-- Extract fields from a Bindings argument using the existing @[declare] annotations. +The accumulator is `Except String ...` because `foldOverArgBindingSpecs` fixes the +fold's accumulator type; wrapping in `Except` lets us propagate errors through +the fold without changing its generic signature. -/ +private def extractFieldsFromBindings (dialects : DialectMap) (arg : Arg) + : Except String (Array (String × TypeExpr)) := + -- We thread `Except` through the accumulator rather than changing + -- `foldOverArgBindingSpecs`, which is used broadly with plain accumulators. + let addField (acc : Except String (Array (String × TypeExpr))) (_ : SourceRange) + {argDecls : ArgDecls} (b : BindingSpec argDecls) + (args : Vector Arg argDecls.size) + : Except String (Array (String × TypeExpr)) := do + let acc ← acc match b with | .value vb => match args[vb.nameIndex.toLevel], args[vb.typeIndex.toLevel] with - | .ident _ name, .type tp => acc.push (name, tp) - | _, _ => acc - | _ => acc - foldOverArgBindingSpecs dialects addField #[] arg + | .ident _ name, .type tp => return acc.push (name, tp) + | _, _ => throw s!"Expected (ident, type) for field binding, \ + got ({repr args[vb.nameIndex.toLevel]}, \ + {repr args[vb.typeIndex.toLevel]})" + | _ => return acc + foldOverArgBindingSpecs dialects addField (.ok #[]) arg /-- Extract constructor information using the @[constructor] annotation. -/ -private def extractSingleConstructor (dialects : DialectMap) (arg : Arg) : Option ConstructorInfo := - match arg with - | .op op => - match lookupOpDecl dialects op.name with - | none => none - | some opDecl => - match getConstructorAnnotation opDecl with - | none => none - | some (nameIdx, fieldsIdx) => - -- Convert deBruijn indices to levels - let argCount := opDecl.argDecls.size - if nameIdx < argCount && fieldsIdx < argCount then - let nameLevel := argCount - nameIdx - 1 - let fieldsLevel := argCount - fieldsIdx - 1 - if h1 : nameLevel < op.args.size then - if h2 : fieldsLevel < op.args.size then - match op.args[nameLevel] with - | .ident _ constrName => - -- Extract fields from the Bindings argument using @[declare] annotations - let fields := match op.args[fieldsLevel] with - | .option _ (some bindingsArg) => extractFieldsFromBindings dialects bindingsArg - | .option _ none => #[] - | other => extractFieldsFromBindings dialects other - some { name := constrName, fields := fields } - | _ => none - else none - else none - else none - | _ => none +private def extractSingleConstructor (dialects : DialectMap) (arg : Arg) + : Except String ConstructorInfo := do + let .op op := arg + | throw s!"Expected op for constructor, got {repr arg}" + let some opDecl := dialects.lookupOpDecl op.name + | throw s!"Unknown operation '{op.name}'" + let some (nameIdx, fieldsIdx) := getConstructorAnnotation opDecl + | throw s!"Operation '{op.name}' missing @[constructor] annotation" + let argCount := opDecl.argDecls.size + unless nameIdx < argCount && fieldsIdx < argCount do + throw s!"Annotation indices out of bounds: \ + nameIdx={nameIdx}, fieldsIdx={fieldsIdx}, \ + argCount={argCount}" + let nameLevel := argCount - nameIdx - 1 + let fieldsLevel := argCount - fieldsIdx - 1 + let .isTrue h1 := decideProp (nameLevel < op.args.size) + | throw s!"Name index {nameLevel} out of bounds (size {op.args.size})" + let .isTrue h2 := decideProp (fieldsLevel < op.args.size) + | throw s!"Fields index {fieldsLevel} out of bounds (size {op.args.size})" + let .ident _ constrName := op.args[nameLevel] + | throw s!"Expected ident for constructor name, got {repr op.args[nameLevel]}" + let fields ← match op.args[fieldsLevel] with + | .option _ (some bindingsArg) => extractFieldsFromBindings dialects bindingsArg + | .option _ none => pure #[] + | other => extractFieldsFromBindings dialects other + return { name := constrName, fields } /-- This function traverses a constructor list AST node and extracts structured @@ -2046,49 +1863,255 @@ dialect annotations `@[constructor]`, `@[constructorListAtom]`, ] ``` -/ -def extractConstructorInfo (dialects : DialectMap) (arg : Arg) : Array ConstructorInfo := - match arg with - | .op op => - match lookupOpDecl dialects op.name with - | none => #[] - | some opDecl => - match getConstructorListAtomAnnotation opDecl with - | some constrIdx => - let argCount := opDecl.argDecls.size - if constrIdx < argCount then - let constrLevel := argCount - constrIdx - 1 - if h : constrLevel < op.args.size then - match extractSingleConstructor dialects op.args[constrLevel] with - | some constr => #[constr] - | none => #[] - else #[] - else #[] - | none => - match getConstructorListPushAnnotation opDecl with - | some (listIdx, constrIdx) => - let argCount := opDecl.argDecls.size - if listIdx < argCount && constrIdx < argCount then - let listLevel := argCount - listIdx - 1 - let constrLevel := argCount - constrIdx - 1 - if h1 : listLevel < op.args.size then - if h2 : constrLevel < op.args.size then - let prevConstrs := extractConstructorInfo dialects op.args[listLevel] - match extractSingleConstructor dialects op.args[constrLevel] with - | some constr => prevConstrs.push constr - | none => prevConstrs - else #[] - else #[] - else #[] - | none => - -- Could be a direct constructor operation - match extractSingleConstructor dialects arg with - | some constr => #[constr] - | none => #[] - | _ => #[] +def extractConstructorInfo (dialects : DialectMap) (arg : Arg) + : Except String (Array ConstructorInfo) := do + let .op op := arg + | throw s!"Expected op for constructor list, got {repr arg}" + let some opDecl := dialects.lookupOpDecl op.name + | throw s!"Unknown operation '{op.name}'" + -- Try constructorListAtom annotation + if let some constrIdx := getConstructorListAtomAnnotation opDecl then + let argCount := opDecl.argDecls.size + unless constrIdx < argCount do + throw s!"constructorListAtom index {constrIdx} out of bounds (argCount={argCount})" + let constrLevel := argCount - constrIdx - 1 + let .isTrue h := decideProp (constrLevel < op.args.size) + | throw s!"Constructor level {constrLevel} out of bounds (size {op.args.size})" + let constr ← extractSingleConstructor dialects op.args[constrLevel] + return #[constr] + -- Try constructorListPush annotation + if let some (listIdx, constrIdx) := getConstructorListPushAnnotation opDecl then + let argCount := opDecl.argDecls.size + unless listIdx < argCount && constrIdx < argCount do + throw s!"constructorListPush indices out of bounds: \ + listIdx={listIdx}, constrIdx={constrIdx}, \ + argCount={argCount}" + let listLevel := argCount - listIdx - 1 + let constrLevel := argCount - constrIdx - 1 + let .isTrue h1 := decideProp (listLevel < op.args.size) + | throw s!"List level {listLevel} out of bounds (size {op.args.size})" + let .isTrue h2 := decideProp (constrLevel < op.args.size) + | throw s!"Constructor level {constrLevel} out of bounds (size {op.args.size})" + let prevConstrs ← extractConstructorInfo dialects op.args[listLevel] + let constr ← extractSingleConstructor dialects op.args[constrLevel] + return prevConstrs.push constr + -- Fallback: try as a direct constructor + let constr ← extractSingleConstructor dialects arg + return #[constr] decreasing_by simp_wf; rw[OperationF.sizeOf_spec] have := Array.sizeOf_get op.args (opDecl.argDecls.size - listIdx - 1) (by omega); omega + +/-- +Typing environment created from declarations in an environment. +-/ +structure GlobalContext where + nameMap : Std.HashMap Var FreeVarIndex + vars : Array (Var × GlobalKind) +deriving Repr + +namespace GlobalContext + +instance : EmptyCollection GlobalContext where + emptyCollection := private { nameMap := {}, vars := {}} + +--deriving instance BEq for GlobalContext + +instance : Inhabited GlobalContext where + default := {} + +instance : Membership Var GlobalContext where + mem ctx v := v ∈ ctx.nameMap + +@[instance] +def instDecidableMem (v : Var) (ctx : GlobalContext) : Decidable (v ∈ ctx) := + inferInstanceAs (Decidable (v ∈ ctx.nameMap)) + +/-- Define a symbol. Caller must prove `v ∉ ctx`. -/ +def define (ctx : GlobalContext) (v : Var) (k : GlobalKind) (_ : v ∉ ctx) : GlobalContext := + let idx := ctx.vars.size + { nameMap := ctx.nameMap.insert v idx, + vars := ctx.vars.push (v, k) } + +/-- Define a symbol if not already present. No-op if already defined. -/ +def ensureDefined (ctx : GlobalContext) (v : Var) (k : GlobalKind) : GlobalContext := + if h : v ∈ ctx then + ctx + else + ctx.define v k h + +/-- Define a symbol, with behavior controlled by `preRegistered`: +- `preRegistered = true`: the name is expected to already exist (was + pre-registered). Returns the context unchanged, or an error if missing. +- `preRegistered = false`: the name must be fresh. Defines it, or returns + an error if already present. -/ +def defineChecked (ctx : GlobalContext) (v : Var) (k : GlobalKind) (preRegistered : Bool) : + Except String GlobalContext := + match instDecidableMem v ctx, preRegistered with + | .isTrue _, true => .ok ctx + | .isTrue _, false => .error s!"'{v}' already defined" + | .isFalse h, false => .ok (ctx.define v k h) + | .isFalse _, true => .error s!"pre-registered '{v}' not found" + +/-- Return the index of the variable with the given name. -/ +def findIndex? (ctx : GlobalContext) (v : Var) : Option FreeVarIndex := ctx.nameMap.get? v + +def nameOf? (ctx : GlobalContext) (idx : FreeVarIndex) : Option String := ctx.vars[idx]? |>.map (·.fst) + +def kindOf! (ctx : GlobalContext) (idx : FreeVarIndex) : GlobalKind := + assert! idx < ctx.vars.size + ctx.vars[idx]!.2 + +private structure TemplateExpandState where + gctx : GlobalContext + errors : Array String := #[] + deriving Inhabited + +namespace TemplateExpandState + +private def addError (s : TemplateExpandState) (msg : String) : TemplateExpandState := + { s with errors := s.errors.push msg } + +end TemplateExpandState + +private abbrev TemplateExpandM := StateM TemplateExpandState + + +namespace TemplateExpandM + +private def runChecked (act : TemplateExpandM Unit) : TemplateExpandM Bool := do + let oldCount := (←get).errors.size + act + let newCount := (←get).errors.size + return oldCount = newCount + +private def addError (msg : String) : TemplateExpandM Unit := + modify (·.addError msg) + +private def addFunction + (name : String) (tp : TypeExpr) + (errorMsg : Thunk String := .mk fun _ => s!"{name} already defined." ) + : TemplateExpandM Unit := + modify fun s => + if h : name ∈ s.gctx then + s.addError errorMsg.get + else + { s with gctx := s.gctx.define name (.expr tp) h } + + +/-- +Build the function type from a template, then atomically check freshness +and define. Reports an error (via `dupMsg`) if the name already exists, or +if the type cannot be built. Uses `define` with a proof — never panics. +-/ +private def buildAndDefine + (template : FunctionTemplate) + (datatypeType : TypeExpr) + (fieldType : Option TypeExpr) + (dialectName : String) + (funcName : String) + (dupMsg : String) : TemplateExpandM Unit := do + match buildFunctionType template datatypeType fieldType dialectName with + | .ok funcType => + addFunction funcName funcType (errorMsg := .mk fun _ => dupMsg) + | .error e => + TemplateExpandM.addError e + +end TemplateExpandM + +/-- +Expand a single function template based on its scope. + +Function templates specify patterns for generating auxiliary functions +from datatype declarations. This function expands one template according to +its iteration scope: + +- **perConstructor**: Generates one function per constructor (e.g., testers +like `..isNone`) +- **perField**: Generates one function per unique field across all constructors +(e.g., accessors) + +**Parameters:** +- `datatypeName`: Name of the datatype (used in name pattern expansion) +- `datatypeType`: TypeExpr for the datatype (used in function signatures) +- `constructorInfo`: Array of constructor information +- `template`: The function template to expand +- `dialectName`: Dialect name (for resolving builtin types) + +**Example:** For a `perConstructor` template defined as: +``` +perConstructor([.datatype, .literal "..is", .constructor], [.datatype], +.builtin "bool") +``` +This specifies: +- Name pattern: `[.datatype, .literal "..is", .constructor]` → generates names +like `Option..isNone` +- Parameter types: `[.datatype]` → takes one parameter of the datatype type +- Return type: `.builtin "bool"` → returns a boolean + +Applied to `Option` with constructors `None` and `Some`, this generates: +- `Option..isNone : Option -> bool` +- `Option..isSome : Option -> bool` +-/ +private def expandSingleTemplate1 + (dialectName datatypeName : String) + (datatypeType : TypeExpr) + (constr : ConstructorInfo) + (template : FunctionTemplate) + : TemplateExpandM Unit := do + match template.scope with + | .perConstructor => + let funcName := expandNamePattern template.namePattern datatypeName (some constr.name) + TemplateExpandM.buildAndDefine template datatypeType none dialectName + funcName s!"Duplicate function name: {funcName}" + | .perField => + for (fieldName, fieldTp) in constr.fields do + let funcName := expandNamePattern + template.namePattern datatypeName none (some fieldName) + TemplateExpandM.buildAndDefine template datatypeType (some fieldTp) dialectName + funcName s!"Duplicate field name '{fieldName}' across \ + constructors in datatype '{datatypeName}'" + +/-- +Register constructor signatures and expand function templates for a +datatype, returning the updated `GlobalContext` and any error messages. + +`dialectName` is the dialect that owns the `@[declareDatatype]`-annotated +operator. It is used to qualify builtin type references in templates +(e.g., `.builtin "bool"` resolves to `⟨dialectName, "bool"⟩`). +-/ +def expandFunctionTemplates + (dialectName : String) + (src : SourceRange) + (datatypeName : String) + (datatypeType : TypeExpr) + (constructorInfo : Array ConstructorInfo) + (templates : Array FunctionTemplate) + (gctx : GlobalContext) + : GlobalContext × Array String := + let initState : TemplateExpandState := { gctx } + let ((), finalState) := StateT.run (m := Id) (do + -- Pass 1: Register all constructor signatures first to maintain + -- FreeVarIndex ordering (constructors before template functions). + let mut failures : Std.HashSet String := {} + for constr in constructorInfo do + let constrType := mkConstructorType src datatypeType constr.fields + let success ← TemplateExpandM.runChecked <| + TemplateExpandM.addFunction constr.name constrType + if not success then + failures := failures.insert constr.name + + -- Pass 2: Expand all templates for all constructors. + for template in templates do + for constr in constructorInfo do + -- Skip constructors that failed to be added. + if constr.name ∈ failures then + continue + expandSingleTemplate1 dialectName datatypeName datatypeType constr template + ) initState + (finalState.gctx, finalState.errors) + /-- Add all bindings for a datatype declaration to the GlobalContext when `@[declareDatatype]` is encountered. Bindings are 1) the type itself (added as @@ -2112,11 +2135,12 @@ FreeVarIndex values are consistent with this order. this adds entries for: `Option` (type), `None` (constructor), `Some` (constructor), `Option..isNone` (tester), `Option..isSome` (tester). -/ -private def addDatatypeBindings +private def addDatatypeBindings! (dialects : DialectMap) (gctx : GlobalContext) (src : SourceRange) (dialectName : DialectName) + (preRegistered : Bool) {argDecls : ArgDecls} (b : DatatypeBindingSpec argDecls) (args : Vector Arg argDecls.size) @@ -2134,56 +2158,90 @@ private def addDatatypeBindings | a => panic! s!"Expected ident for type param {repr a}" foldOverArgAtLevel dialects addBinding #[] argDecls args b.typeParamsIndex.toLevel - let constructorInfo := extractConstructorInfo dialects args[b.constructorsIndex.toLevel] - - -- Step 1: Add datatype type (or update forward declaration) - let gctx := match gctx.define datatypeName (GlobalKind.type typeParams.toList none) with - | .ok gctx' => gctx' - | .error msg => panic! s!"addDatatypeBindings: {msg}" + -- Step 1: Add datatype type. + -- When preRegistered, the type was already added by preRegisterTypeName; + -- otherwise it must be fresh. + let k := GlobalKind.type typeParams.toList none + let gctx := + match gctx.defineChecked datatypeName k preRegistered with + | .ok gctx => gctx + | .error e => panic! s!"addDatatypeBindings!: {e}" let datatypeIndex := gctx.findIndex? datatypeName |>.getD (gctx.vars.size - 1) let datatypeType := mkDatatypeTypeRef src datatypeIndex typeParams - -- Step 2: Add constructor signatures - let gctx := constructorInfo.foldl (init := gctx) fun gctx constr => - let constrType := mkConstructorType src datatypeType constr.fields - gctx.push constr.name (.expr constrType) - - -- Step 3: Expand and add function templates - let existingNames : Std.HashSet String := gctx.nameMap.fold (init := {}) fun s name _ => s.insert name - let result := expandFunctionTemplates datatypeName datatypeType constructorInfo b.functionTemplates dialectName existingNames + -- Step 2: Add constructor signatures and expand function templates + let constrArg := args[b.constructorsIndex.toLevel] + let constructorInfo := + match extractConstructorInfo dialects constrArg with + | .ok info => info + | .error e => panic! s!"Constructor extraction error: {e}" + -- Errors from template expansion are reported during elaboration + -- (evalBindingSpec); here we just take the updated context. + let (gctx, _) := expandFunctionTemplates dialectName src + datatypeName datatypeType constructorInfo + b.functionTemplates gctx + gctx - if !result.errors.isEmpty then - panic! s!"Datatype template expansion errors: {result.errors}" - else - result.functions.foldl (init := gctx) fun gctx (funcName, funcType) => - gctx.push funcName (.expr funcType) +/-- +Pre-register a type name in the `GlobalContext` before the main `addCommand` +pass. Used by operations annotated with `@[preRegisterTypes]` (e.g., mutual +blocks) so that forward references between sibling datatypes resolve correctly. +Names must be fresh — panics if the name is already defined. +-/ +private def preRegisterType (dialects : DialectMap) (gctx : GlobalContext) (l : SourceRange) + {argDecls} (b : BindingSpec argDecls) (args : Vector Arg argDecls.size) : GlobalContext := + match b with + | .datatype _ | .type _ => + let name := + match args[b.nameIndex.toLevel] with + | .ident _ e => e + | a => panic! s!"Expected ident at {b.nameIndex.toLevel} {repr a}" + match resolveBindingIndices dialects l b args with + -- Names must be fresh: this is the pre-registration pass. + | some kind => + if h : name ∈ gctx then + panic! s!"'{name}' already defined" + else + gctx.define name kind h + | none => gctx + | _ => gctx -def addCommand (dialects : DialectMap) (init : GlobalContext) (op : Operation) : GlobalContext := +private def addBinding (dialects : DialectMap) (dialectName : DialectName) (preRegistered : Bool) + (gctx : GlobalContext) (l : SourceRange) {argDecls} (b : BindingSpec argDecls) + (args : Vector Arg argDecls.size) := + match b with + | .datatype datatypeSpec => + addDatatypeBindings! dialects gctx l dialectName preRegistered datatypeSpec args + | _ => + let name : Var := + match args[b.nameIndex.toLevel] with + | .ident _ e => e + | a => panic! s!"Expected ident at {b.nameIndex.toLevel} {repr a}" + match resolveBindingIndices dialects l b args with + | some kind => + match gctx.defineChecked name kind preRegistered with + | .ok gctx => gctx + | .error e => panic! s!"addCommand: {e}" + | none => gctx + +def addCommand (dialects : DialectMap) (gctx : GlobalContext) (op : Operation) : GlobalContext := let dialectName := op.name.dialect - op.foldBindingSpecs dialects (addBinding dialectName) init - where addBinding (dialectName : DialectName) (gctx : GlobalContext) l {argDecls} (b : BindingSpec argDecls) args := - match b with - | .datatype datatypeSpec => - addDatatypeBindings dialects gctx l dialectName datatypeSpec args - | .typeForward typeSpec => - let name := - match args[typeSpec.nameIndex.toLevel] with - | .ident _ e => e - | a => panic! s!"Expected ident at {typeSpec.nameIndex.toLevel} {repr a}" - match resolveBindingIndices dialects l b args with - | some kind => - match gctx.declareForward name kind with - | .ok gctx' => gctx' - | .error msg => panic! msg - | none => gctx - | _ => - let name := - match args[b.nameIndex.toLevel] with - | .ident _ e => e - | a => panic! s!"Expected ident at {b.nameIndex.toLevel} {repr a}" - match resolveBindingIndices dialects l b args with - | some kind => gctx.push name kind - | none => gctx + -- Pre-register types if op has @[preRegisterTypes] metadata + let (gctx, preRegistered) := Id.run do + let .op decl := dialects.decl! op.name + | return (panic! "Expected operator declaration", false) + let .isTrue h := decideProp (op.args.size = decl.argDecls.size) + | return (panic! "Expected arguments to match", false) + match decl.metadata.preRegisterTypesLevel decl.argDecls.size with + | some lvl => + let gctx := foldOverArgAtLevel dialects + (preRegisterType dialects) gctx + decl.argDecls ⟨op.args, h⟩ lvl + (gctx, true) + | none => + (gctx, false) + -- Normal fold + op.foldBindingSpecs dialects (addBinding dialects dialectName preRegistered) gctx end GlobalContext diff --git a/Strata/DDM/BuiltinDialects/StrataDDL.lean b/Strata/DDM/BuiltinDialects/StrataDDL.lean index 6e2313dfa..cb97babb1 100644 --- a/Strata/DDM/BuiltinDialects/StrataDDL.lean +++ b/Strata/DDM/BuiltinDialects/StrataDDL.lean @@ -164,7 +164,7 @@ def StrataDDL : Dialect := BuiltinM.create! "StrataDDL" #[initDialect] do -- Metadata for marking an operation as a constructor list push (list followed by constructor) declareMetadata { name := "constructorListPush", args := #[.mk "list" .ident, .mk "constructor" .ident] } declareMetadata { name := "declareType", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] } - declareMetadata { name := "declareTypeForward", args := #[.mk "name" .ident, .mk "args" (.opt .ident)] } + declareMetadata { name := "preRegisterTypes", args := #[.mk "scope" .ident] } declareMetadata { name := "aliasType", args := #[.mk "name" .ident, .mk "args" (.opt .ident), .mk "def" .ident] } declareMetadata { name := "declare", args := #[.mk "name" .ident, .mk "type" .ident] } declareMetadata { name := "declareFn", args := #[.mk "name" .ident, .mk "args" .ident, .mk "type" .ident] } diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index 8a6c8b98e..0bdce6b6c 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -7,11 +7,11 @@ module public import Strata.DDM.Elab.DeclM public import Strata.DDM.Elab.Tree - +import Strata.DDM.HNF import all Strata.DDM.Util.Array import all Strata.DDM.Util.Fin import all Strata.DDM.Util.Lean -import Strata.DDM.HNF +import Strata.Util.DecideProp open Lean ( Message @@ -119,8 +119,6 @@ def applyNArgs (tctx : TypingContext) (e : TypeExpr) (n : Nat) := aux #[] e end TypingContext -def commaPrec := 30 - def elabIdent (stx : Syntax) : String := assert! stx.getKind = `ident match stx with @@ -525,7 +523,7 @@ def elabOption (f : ElabArgFn) : ElabArgFn := fun tctx stx => let tree ← f tctx astx pure <| .node (.ofOptionInfo info) #[tree] -def evalBindingNameIndex (trees : Vector Tree n) (idx : DebruijnIndex n) : String := +def evalBindingNameIndex {n} (trees : Vector Tree n) (idx : DebruijnIndex n) : String := match trees[idx.toLevel].info with | .ofIdentInfo e => e.val | a => panic! s!"Expected ident at {idx.val} {repr a}" @@ -807,15 +805,17 @@ def translateBindingKind (tree : Tree) : ElabM BindingKind := do logInternalError argInfo.loc s!"translateArgDeclKind given invalid kind {opInfo.op.name}" return default +private def checkIsTypeKind (argLoc : SourceRange) (b : Binding) : ElabM Unit := + match b.kind with + | .type _ [] _ => pure () + | .tvar .. | .type .. | .expr _ | .cat _ => + logError argLoc s!"{b.ident} must have type Type instead of {repr b.kind}." + /-- Extract type parameter names from a bindings argument. -/ def elabTypeParams {n} (initSize : Nat) (args : Vector Tree n) (idx : Option (DebruijnIndex n)) : ElabM (List String) := do let params ← elabArgIndex initSize args idx fun argLoc b => do - match b.kind with - | .type _ [] _ => pure () - | .tvar _ _ => pure () - | .type .. | .expr _ | .cat _ => - logError argLoc s!"{b.ident} must have type Type instead of {repr b.kind}." + checkIsTypeKind argLoc b return b.ident pure params.toList @@ -824,11 +824,13 @@ Construct a binding from a binding spec and the arguments to an operation. -/ def evalBindingSpec {bindings} + (tctx : TypingContext) (loc : SourceRange) (initSize : Nat) + (dialectName : DialectName) (b : BindingSpec bindings) (args : Vector Tree bindings.size) - : ElabM Binding := do + : ElabM TypingContext := do match b with | .value b => let ident := evalBindingNameIndex args b.nameIndex @@ -863,8 +865,8 @@ def evalBindingSpec | arg => panic! s!"Cannot bind {ident}: Type at {b.typeIndex.val} has unexpected arg {repr arg}" -- TODO: Decide if new bindings for Type and Expr (or other categories) and should not be allowed? - pure { ident, kind } - | .type b | .typeForward b => + pure <| tctx.push { ident, kind } + | .type b => let ident := evalBindingNameIndex args b.nameIndex let params ← elabTypeParams initSize args b.argsIndex let value : Option TypeExpr := @@ -876,14 +878,48 @@ def evalBindingSpec some info.typeExpr | _ => panic! "Bad arg" - pure { ident, kind := .type loc params value } + pure <| tctx.push { ident, kind := .type loc params value } | .datatype b => - let ident := evalBindingNameIndex args b.nameIndex + let nameInfo := args[b.nameIndex.toLevel].info + let (nameLoc, ident) ← + match nameInfo with + | .ofIdentInfo i => + pure (i.loc, i.val) + | _ => + logInternalError nameInfo.loc s!"Expected ident" + return tctx let params ← elabTypeParams initSize args (some b.typeParamsIndex) - pure { ident, kind := .type loc params none } + assert! tctx.bindings.size = 0 + let gctx := tctx.globalContext + let gctx := gctx.ensureDefined ident (.type params none) + + let dialects := (← read).dialects + + let t := args[b.constructorsIndex.toLevel] + match extractConstructorInfo dialects t.arg with + | .ok info => + let mut seen : Std.HashSet String := {} + for c in info do + let name := c.name + if name ∈ seen then + logError loc s!"Duplicate constructor name '{name}'." + continue + seen := seen.insert name + -- Expand function templates to catch name collisions early + let datatypeIndex := gctx.findIndex? ident |>.getD (gctx.vars.size - 1) + let datatypeType := + mkDatatypeTypeRef loc datatypeIndex params.toArray + let (gctx, errors) := gctx.expandFunctionTemplates + dialectName loc ident datatypeType info + b.functionTemplates + errors.forM (logError loc) + pure <| .empty gctx + | .error e => + logError loc e + pure <| .empty gctx | .tvar b => let ident := evalBindingNameIndex args b.nameIndex - pure { ident, kind := .tvar loc ident } + pure <| tctx.push { ident, kind := .tvar loc ident } /-- Given a type expression and a natural number, this returns a @@ -1001,6 +1037,38 @@ def ofArgDeclKind : ArgDeclKind → ElabArgKind end ElabArgKind +/-- Map a sequence category name to its separator format and child extraction function. -/ +private def scopeSepFormat (name : QualifiedIdent) + : Option (SepFormat × (Syntax → Array Syntax)) := + match name with + | q`Init.Seq => some (.none, Syntax.getArgs) + | q`Init.CommaSepBy => some (.comma, Syntax.getSepArgs) + | q`Init.SpaceSepBy => some (.space, Syntax.getSepArgs) + | q`Init.SpacePrefixSepBy => some (.spacePrefix, Syntax.getArgs) + | _ => none + +/-- Look up the syntax level for a given arg level +via the precomputed `argElabIndex`. -/ +private def argSyntaxLevel? + (se : SyntaxElaborator) (argLevel : Nat) : Option Nat := + se.argElabIndex[argLevel]?.join.bind fun idx => + (se.argElaborators[idx]?).map (·.val.syntaxLevel) + +/-- +Compute the result `TypingContext` after elaboration. If the +`SyntaxElaborator` specifies a `resultScope`, returns the context from +that tree; otherwise returns the input context `tctx0` unchanged. +-/ +private def resultContext {argc : Nat} + (se : SyntaxElaborator) (tctx0 : TypingContext) + (trees : Vector Tree argc) : TypingContext := + match se.resultScope with + | none => tctx0 + | some idx => Id.run do + let .isTrue p := inferInstanceAs (Decidable (idx < argc)) + | return panic! "Invalid index" + trees[idx].resultContext + mutual partial def elabOperation (tctx : TypingContext) (stx : Syntax) : ElabM Tree := do @@ -1022,116 +1090,312 @@ partial def elabOperation (tctx : TypingContext) (stx : Syntax) : ElabM Tree := let (stxArgs, success) ← runChecked <| getSyntaxArgs stx i se.syntaxCount if not success then return default - let isType i := .ofArgDeclKind argDecls[i].kind + let getKind i := .ofArgDeclKind argDecls[i].kind let ((args, newCtx), success) ← runChecked <| - runSyntaxElaborator (argc := argDecls.size) isType se tctx stxArgs + match se.preRegisterTypesScope with + | some scopeArgLevel => + elaborateWithPreRegistration argDecls se tctx loc stxArgs scopeArgLevel + | none => do + let args ← runSyntaxElaborator (argc := argDecls.size) getKind se tctx stxArgs + return (args, resultContext se tctx args) + if !success then return default + let resultCtx ← decl.newBindings.foldlM (init := newCtx) <| fun ctx spec => do - ctx.push <$> evalBindingSpec loc initSize spec args + evalBindingSpec ctx loc initSize i.dialect spec args let op : Operation := { ann := loc, name := i, args := args.toArray.map (·.arg) } if loc.isNone then return panic! s!"Missing position info {repr stx}." let info : OperationInfo := { loc := loc, inputCtx := tctx, op, resultCtx } return .node (.ofOperationInfo info) args.toArray +/-- Elaborate a single argument based on its `ElabArgKind`. +Returns the updated `trees` vector with the result placed at `argIdx`. -/ +partial def elabSyntaxArg + {argc : Nat} + (getKind : Fin argc → ElabArgKind) + (isTypeP : Fin argc → Bool) + (tctx : TypingContext) + (astx : Syntax) + (argIdx : Fin argc) + (trees : Vector (Option Tree) argc) + : ElabM (Vector (Option Tree) argc) := do + match getKind argIdx with + | .preType expectedType => + let (tree, success) ← runChecked <| elabExpr tctx astx + if success then + let expr := tree.info.asExpr!.expr + let inferredType ← inferType tctx expr + let dialects := (← read).dialects + let resolveArg (i : Nat) : Option Arg := do + assert! i < argIdx.val + Tree.arg <$> trees[argIdx.val - i - 1]! + match expandMacros dialects expectedType resolveArg with + | .error () => + panic! "Could not infer type." + | .ok expectedType => do + let trees ← unifyTypes isTypeP argIdx + expectedType tctx astx inferredType trees + assert! trees[argIdx].isNone + return trees.set argIdx (some tree) + else + return trees + | .typeExpr expectedType => + let (tree, success) ← runChecked <| elabExpr tctx astx + if success then + let expr := tree.info.asExpr!.expr + let inferredType ← inferType tctx expr + let trees ← unifyTypes isTypeP argIdx + expectedType tctx astx inferredType trees + assert! trees[argIdx].isNone + return trees.set argIdx (some tree) + else + return trees + | .cat c => + let (tree, success) ← runChecked <| catElaborator c tctx astx + if success then + return trees.set argIdx (some tree) + else + return trees + +/-- Elaborate all syntax arguments for an operation according to the +`SyntaxElaborator`'s ordering. Iterates over `se.argElaborators`, computes +the typing context for each argument (handling datatype scopes for +recursive types), and delegates to `elabSyntaxArg` for the actual +elaboration. Returns the elaborated `Tree` vector and the result +`TypingContext`. -/ partial def runSyntaxElaborator {argc : Nat} (getKind : Fin argc → ElabArgKind) (se : SyntaxElaborator) (tctx0 : TypingContext) - (args : Vector Syntax se.syntaxCount) : ElabM (Vector Tree argc × TypingContext) := do + (args : Vector Syntax se.syntaxCount) + : ElabM (Vector Tree argc) := do let isTypeP := fun i => (getKind i).isType let mut trees : Vector (Option Tree) argc := .replicate argc none for ⟨ae, sp⟩ in se.argElaborators do let argLevel := ae.argLevel let .isTrue argLevelP := inferInstanceAs (Decidable (argLevel < argc)) | return panic! "Invalid argLevel" - -- Compute the typing context for this argument - let tctx ← - /- Recursive datatypes make this a bit complicated, since we need to make - sure the type is resolved as an fvar even while processing it. -/ - match ae.datatypeScope with - | some (nameLevel, typeParamsLevel) => - let nameTree := trees[nameLevel] - let typeParamsTree := trees[typeParamsLevel] - match nameTree, typeParamsTree with - | some nameT, some typeParamsT => - let datatypeName := - match nameT.info with - | .ofIdentInfo info => info.val - | _ => panic! "Expected identifier for datatype name" - let baseCtx := typeParamsT.resultContext - /- Extract type parameter names only from NEW bindings added by - typeParams, not inherited bindings (which may include datatypes from - previous commands) -/ - let inheritedCount := tctx0.bindings.size - let typeParamNames := baseCtx.bindings.toArray.extract inheritedCount baseCtx.bindings.size - |>.filterMap fun b => - match b.kind with - | .type _ [] _ => some b.ident - | _ => none - -- Add the datatype name to the GlobalContext as a type - let gctx := baseCtx.globalContext - let gctx := - if datatypeName ∈ gctx then gctx - else gctx.push datatypeName (GlobalKind.type typeParamNames.toList none) - -- Add .tvar bindings for type parameters - let loc := typeParamsT.info.loc - -- Start with empty local bindings - don't inherit from baseCtx - -- This prevents datatype names from leaking between mutual block entries - let tctx := typeParamNames.foldl (init := TypingContext.empty gctx) fun ctx name => - ctx.push { ident := name, kind := .tvar loc name } - pure tctx - | _, _ => continue - | none => - match ae.contextLevel with - | some idx => - match trees[idx] with - | some t => pure t.resultContext - | none => continue - | none => pure tctx0 + -- Skip pre-elaborated args + if trees[argLevel].isSome then continue + -- Get syntax let astx := args[ae.syntaxLevel] - match getKind ⟨argLevel, argLevelP⟩ with - | .preType expectedType => - let (tree, success) ← runChecked <| elabExpr tctx astx - if success then - let expr := tree.info.asExpr!.expr - let inferredType ← inferType tctx expr - let dialects := (← read).dialects - let resolveArg (i : Nat) : Option Arg := do - assert! i < argLevel - Tree.arg <$> trees[argLevel - i - 1]! - match expandMacros dialects expectedType resolveArg with - | .error () => - panic! "Could not infer type." - | .ok expectedType => do - trees ← unifyTypes isTypeP ⟨argLevel, argLevelP⟩ - expectedType tctx astx inferredType trees - assert! trees[argLevel].isNone - trees := trees.set argLevel (some tree) - | .typeExpr expectedType => - let (tree, success) ← runChecked <| elabExpr tctx astx - if success then - let expr := tree.info.asExpr!.expr - let inferredType ← inferType tctx expr - trees ← unifyTypes isTypeP ⟨argLevel, argLevelP⟩ - expectedType tctx astx inferredType trees - assert! trees[argLevel].isNone - trees := trees.set argLevel (some tree) - | .cat c => - let (tree, success) ← runChecked <| catElaborator c tctx astx - if success then - trees := trees.set argLevel (some tree) - let treesr := trees.map (·.getD default) - let mut tctx := - match se.resultScope with - | none => tctx0 - | some idx => Id.run do - let .isTrue p := inferInstanceAs (Decidable (idx < argc)) - | return panic! "Invalid index" - treesr[idx].resultContext - return (treesr, tctx) + let some aloc := mkSourceRange? astx + | panic! "Arg syntax missing position information" + -- Handle datatype declaration. + if let some (nameLevel, typeParamsLevel) := ae.datatypeScope then + let some nameT := trees[nameLevel] + | logError aloc "Internal: missing name assignment" + return default + let some typeParamsT := trees[typeParamsLevel] + | logError aloc "Internal: missing type parameter" + return default + let datatypeName := + match nameT.info with + | .ofIdentInfo info => info.val + | _ => panic! "Expected identifier for datatype name" + let tloc := typeParamsT.info.loc + let paramCtx := typeParamsT.resultContext + + -- Extract type parameter names only from NEW bindings added by + -- typeParams, not inherited bindings (which may include datatypes from + -- previous commands) + let (typeParamNames, success) ← runChecked <| + paramCtx.bindings.toArray.foldlM (init := #[]) fun a b => do + match b.kind with + | .type _ [] _ => + pure <| a.push b.ident + | _ => + logError tloc "Expected only type arguments." + pure a + if success = false then + return default + -- Add the datatype name to the GlobalContext as a type + -- Use tctx0.globalContext so pre-registered types are visible + let gctx := tctx0.globalContext + let gctx := gctx.ensureDefined datatypeName (GlobalKind.type typeParamNames.toList none) + let tctx := TypingContext.empty gctx + -- Add .tvar bindings for type parameters + -- Start with extended global context. + let tctx := typeParamNames.foldl (init := tctx) fun ctx name => + ctx.push { ident := name, kind := .tvar tloc name } + trees ← elabSyntaxArg getKind isTypeP tctx astx ⟨argLevel, argLevelP⟩ trees + else if let some idx := ae.contextLevel then + let some t := trees[idx] + | -- This failed so skip + continue + trees ← elabSyntaxArg getKind isTypeP t.resultContext astx ⟨argLevel, argLevelP⟩ trees + else + trees ← elabSyntaxArg getKind isTypeP tctx0 astx ⟨argLevel, argLevelP⟩ trees + return trees.map (·.getD default) + +/-- +Two-phase elaboration for operations annotated with `@[preRegisterTypes]`. + +When a parent operation (like `mutual`) has `@[preRegisterTypes(scope)]`, its children +may reference each other's types before they are declared. This function handles that by: + +- **Phase 1**: For each child, partially elaborate name + typeParams args (which don't + reference sibling types) to extract type names and parameter names. +- **Pre-register**: Pre-register all extracted types so mutual references resolve. +- **Phase 2**: Fully elaborate each child with the updated context. Name args that were + already elaborated in Phase 1 are passed via `preElabMap` to avoid redundant work. + The remaining parent args are then elaborated by `runSyntaxElaborator`. + +**Known deviation**: typeParams args are elaborated twice — once in Phase 1 (against +`tctx0`) to extract param names, and again in Phase 2 (against the per-child context). +Phase 1 typeParams trees cannot be reused because `collectNewBindingsM` requires +`tree.info.inputCtx.bindings.size ≥ initialScope`, which fails when the Phase 1 context +is smaller than the Phase 2 per-child context. +-/ +partial def elaborateWithPreRegistration + {argc : Nat} + (argDecls : Vector ArgDecl argc) + (se : SyntaxElaborator) + (tctx0 : TypingContext) + (fallbackLoc : SourceRange) + (stxArgs : Vector Syntax se.syntaxCount) + (scopeArgLevel : Nat) : ElabM (Vector Tree argc × TypingContext) := do + -- Resolve scope: find the scope arg's syntax, category, and children + let some scopeSyntaxLevel := argSyntaxLevel? se scopeArgLevel + | logInternalError fallbackLoc "elaborateWithPreRegistration: no syntax level for scope arg" + return default + let .isTrue scopeSLBound := inferInstanceAs (Decidable (scopeSyntaxLevel < se.syntaxCount)) + | logInternalError fallbackLoc "elaborateWithPreRegistration: scope syntax level out of bounds" + return default + let scopeStx := stxArgs[scopeSyntaxLevel] + let .isTrue scopeALBound := inferInstanceAs (Decidable (scopeArgLevel < argc)) + | logInternalError fallbackLoc "elaborateWithPreRegistration: scope arg level out of bounds" + return default + let scopeArgDecl := argDecls[scopeArgLevel] + let scopeCat ← do + match scopeArgDecl.kind with + | .cat c => pure c + | _ => + logInternalError fallbackLoc "elaborateWithPreRegistration: expected category for scope arg" + return default + if scopeCat.args.size ≠ 1 then + logInternalError fallbackLoc + s!"elaborateWithPreRegistration: \ + expected 1 scope cat arg, got {scopeCat.args.size}" + return default + let some (sep, getChildren) := scopeSepFormat scopeCat.name + | logInternalError fallbackLoc + s!"elaborateWithPreRegistration: \ + unsupported scope category {scopeCat.name}" + return default + let children := getChildren scopeStx + -- Phase 1: Pre-register add all types so mutual references resolve + assert! tctx0.bindings.size = 0 + let gctx0 : GlobalContext := tctx0.globalContext + let preGCtx ← children.foldlM (init := gctx0) fun preCtx child => + extractDatatypeInfo preCtx child + -- Phase 2: Elaborate args with the scope tree pre-elaborated + let getKind (i : Fin argDecls.size) := ElabArgKind.ofArgDeclKind argDecls[i].kind + let preCtx := .empty preGCtx + let args ← runSyntaxElaborator (argc := argDecls.size) getKind se preCtx stxArgs + pure (args, resultContext se preCtx args) + +/-- +Phase 1 helper for a single child operation in a mutual block. + +Partially elaborates the child's name and typeParams args to extract +the type name and parameter names, then pre-registers the type in the +`TypingContext`'s `GlobalContext` via `preRegisterType`. Returns the +updated `TypingContext` with the new type registered. +-/ +partial def extractDatatypeInfo (gctx0 : GlobalContext) (child : Syntax) : ElabM GlobalContext := do + let dialects := (← read).dialects + let syntaxElabs := (← read).syntaxElabs + let some childIdent := qualIdentKind child + | return panic! s!"Unknown command {child.getKind}" + + let some childLoc := mkSourceRange? child + | panic! "extractDatatypeInfo: child missing source location" + let some childDecl := dialects.lookupOpDecl childIdent + | logInternalError childLoc s!"extractDatatypeInfo: unknown op declaration {childIdent}" + return default + let some childSe := syntaxElabs[childIdent]? + | logInternalError childLoc s!"extractDatatypeInfo: no syntax elaborator for {childIdent}" + return default + let childStxArgs := child.getArgs + let childArgDecls := childDecl.argDecls.toArray + let mut gctxLoop := gctx0 + for spec in childDecl.newBindings do + let (nameArgLevel, typeParamsArgLevel?) ← + match spec with + | .datatype b => pure (b.nameIndex.toLevel, some b.typeParamsIndex.toLevel) + | .type b => pure (b.nameIndex.toLevel, b.argsIndex.map (·.toLevel)) + | _ => continue + -- Elaborate name arg + let some nameSL := argSyntaxLevel? childSe nameArgLevel + | logInternalError childLoc + "extractDatatypeInfo: argLevelToSyntaxLevel \ + failed for name" + continue + if nameSL ≥ childStxArgs.size then + logInternalError childLoc + s!"extractDatatypeInfo: nameSL {nameSL} \ + out of bounds ({childStxArgs.size})" + continue + let .isTrue _ := decideProp (nameArgLevel < childArgDecls.size) + | logInternalError childLoc + s!"extractDatatypeInfo: nameArgLevel \ + {nameArgLevel} out of bounds \ + ({childArgDecls.size})" + continue + let nameCat := childArgDecls[nameArgLevel].kind.categoryOf + let gctx := gctxLoop + let (nameTree, nameSuccess) ← runChecked <| + catElaborator nameCat (.empty gctx) childStxArgs[nameSL]! + if !nameSuccess then continue + let name ← + match nameTree.info with + | .ofIdentInfo info => + pure info.val + | _ => + logInternalError childLoc "extractDatatypeInfo: expected ident for type name" + continue + + let .isFalse nameIsNew := decideProp (name ∈ gctx) + | logError nameTree.info.loc s!"Type '{name}' is already declared." + continue + + -- Elaborate typeParams to extract real param names + let some tpArgLevel := typeParamsArgLevel? + | -- .type spec with no argsIndex: register with empty params + gctxLoop := gctx.define name (.type [] none) nameIsNew + continue + let some tpSL := argSyntaxLevel? childSe tpArgLevel + | logInternalError childLoc + "extractDatatypeInfo: argLevelToSyntaxLevel \ + failed for typeParams" + continue + if tpSL ≥ childStxArgs.size then + logInternalError childLoc + s!"extractDatatypeInfo: tpSL {tpSL} \ + out of bounds ({childStxArgs.size})" + continue + if tpArgLevel ≥ childArgDecls.size then + logInternalError childLoc + s!"extractDatatypeInfo: tpArgLevel \ + {tpArgLevel} out of bounds \ + ({childArgDecls.size})" + continue + let tpCat := childArgDecls[tpArgLevel]!.kind.categoryOf + let (tpTree, tpSuccess) ← runChecked <| + catElaborator tpCat (.empty gctx0) childStxArgs[tpSL]! + if !tpSuccess then + return default + let params ← collectNewBindingsM 0 tpTree fun argLoc b => do + checkIsTypeKind argLoc b + return b.ident + gctxLoop := gctx.define name (.type params.toList none) nameIsNew + + return gctxLoop + partial def elabType (tctx : TypingContext) (stx : Syntax) : ElabM Tree := do let (tree, success) ← runChecked <| elabOperation tctx stx @@ -1242,7 +1506,7 @@ partial def catElaborator (c : SyntaxCat) : TypingContext → Syntax → ElabM T elabSeqWith c .newline "newlineSepBy" (·.getArgs) | _ => assert! c.args.isEmpty - elabOperation + fun tctx stx => elabOperation tctx stx where elabSeqWith (c : SyntaxCat) (sep : SepFormat) (name : String) (getArgsFrom : Syntax → Array Syntax) : TypingContext → Syntax → ElabM Tree := @@ -1333,7 +1597,7 @@ partial def elabExpr (tctx : TypingContext) (stx : Syntax) : ElabM Tree := do ⟨e, lvlp⟩ resultScope := none } - let (args, _) ← runSyntaxElaborator getKind se tctx ⟨args, Eq.refl args.size⟩ + let args ← runSyntaxElaborator getKind se tctx ⟨args, Eq.refl args.size⟩ let e := args.toArray.foldl (init := fvar) fun e t => .app { start := fnLoc.start, stop := t.info.loc.stop } e t.arg let info : ExprInfo := { toElabInfo := einfo, expr := e } @@ -1356,7 +1620,7 @@ partial def elabExpr (tctx : TypingContext) (stx : Syntax) : ElabM Tree := do return default let getKind (i : Fin argDecls.size) := ElabArgKind.ofArgDeclKind argDecls[i].kind - let ((args, _), success) ← runChecked <| + let (args, success) ← runChecked <| runSyntaxElaborator getKind se tctx stxArgs if !success then return default diff --git a/Strata/DDM/Elab/SyntaxElab.lean b/Strata/DDM/Elab/SyntaxElab.lean index 3724896af..6f91b5b16 100644 --- a/Strata/DDM/Elab/SyntaxElab.lean +++ b/Strata/DDM/Elab/SyntaxElab.lean @@ -91,8 +91,24 @@ structure SyntaxElaborator where syntaxCount : Nat argElaborators : ArgElaboratorArray syntaxCount resultScope : Option Nat + /-- Index into argElaborators for each argument + (indexed by argLevel), None if arg has no syntax. -/ + argElabIndex : Array (Option Nat) := #[] + /-- If set, pre-register type names from children at this arg level before elaboration. -/ + preRegisterTypesScope : Option Nat := none deriving Inhabited, Repr +/-- Build an argElabIndex mapping each argLevel to its +position in the given elaborator array. -/ +private def buildArgElabIndex (argDecls : ArgDecls) + {sc} (elabs : ArgElaboratorArray sc) + : Array (Option Nat) := + let init := Array.replicate argDecls.size none + let (result, _) := elabs.foldl (init := (init, 0)) + fun (arr, idx) ⟨ae, _⟩ => + (arr.set! ae.argLevel (some idx), idx + 1) + result + /-- Build the syntax elaborator that maps parsed syntax positions to argument positions. For `.passthrough`, this is trivial: one syntax position maps directly to argument 0. For `.mk`, we walk the atoms to @@ -110,10 +126,13 @@ private def mkSyntaxElab! (argDecls : ArgDecls) (stx : SyntaxDef) (opMd : Metada contextLevel := argDecls.argScopeLevel ⟨0, h⟩ datatypeScope := argDecls.argScopeDatatypeLevel ⟨0, h⟩ } + let elabs := #[⟨ae, Nat.zero_lt_one⟩] { syntaxCount := 1 - argElaborators := #[⟨ae, Nat.zero_lt_one⟩] + argElaborators := elabs resultScope := opMd.resultLevel argDecls.size + argElabIndex := buildArgElabIndex argDecls elabs + preRegisterTypesScope := opMd.preRegisterTypesLevel argDecls.size } | .std atoms _ => let init : ArgElaborators := { @@ -131,6 +150,8 @@ private def mkSyntaxElab! (argDecls : ArgDecls) (stx : SyntaxDef) (opMd : Metada syntaxCount := as.syntaxCount argElaborators := elabs resultScope := opMd.resultLevel argDecls.size + argElabIndex := buildArgElabIndex argDecls elabs + preRegisterTypesScope := opMd.preRegisterTypesLevel argDecls.size } private def opDeclElaborator! (decl : OpDecl) : SyntaxElaborator := diff --git a/Strata/DDM/Integration/Lean/ToExpr.lean b/Strata/DDM/Integration/Lean/ToExpr.lean index be1985962..0e7f030c4 100644 --- a/Strata/DDM/Integration/Lean/ToExpr.lean +++ b/Strata/DDM/Integration/Lean/ToExpr.lean @@ -422,7 +422,6 @@ private def toExpr {argDecls} (bi : BindingSpec argDecls) (argDeclsExpr : Lean.E match bi with | .value b => astExpr! value argDeclsExpr (b.toExpr argDeclsExpr) | .type b => astExpr! type argDeclsExpr (b.toExpr argDeclsExpr) - | .typeForward b => astExpr! typeForward argDeclsExpr (b.toExpr argDeclsExpr) | .datatype b => astExpr! datatype argDeclsExpr (b.toExpr argDeclsExpr) | .tvar b => astExpr! tvar argDeclsExpr (b.toExpr argDeclsExpr) diff --git a/Strata/Languages/Core/DDMTransform/ASTtoCST.lean b/Strata/Languages/Core/DDMTransform/ASTtoCST.lean index 0ef0a0feb..9f18028ed 100644 --- a/Strata/Languages/Core/DDMTransform/ASTtoCST.lean +++ b/Strata/Languages/Core/DDMTransform/ASTtoCST.lean @@ -314,23 +314,10 @@ def datatypeToCST {M} [Inhabited M] (datatypes : List (Lambda.LDatatype Visibili let cmd ← processDatatype dt pure [cmd] | _ => do - -- Multiple datatypes - generate forward declarations and mutual block. - let mut forwardDecls : List (Command M) := [] - for dt in datatypes.reverse do - let name : Ann String M := ⟨default, dt.name⟩ - let args : Ann (Option (Bindings M)) M := - if dt.typeArgs.isEmpty then - ⟨default, none⟩ - else - let bindings := dt.typeArgs.map fun param => - let paramName : Ann String M := ⟨default, param⟩ - let paramType := TypeP.type default - Binding.mkBinding default paramName paramType - ⟨default, some (.mkBindings default ⟨default, bindings.toArray⟩)⟩ - forwardDecls := forwardDecls ++ [.command_forward_typedecl default name args] + -- Multiple datatypes - mutual block with pre-registration handles forward references. let cmds ← datatypes.mapM processDatatype let mutualCmd := Command.command_mutual default ⟨default, cmds.toArray⟩ - pure (forwardDecls ++ [mutualCmd]) + pure [mutualCmd] /-- Convert a type synonym declaration to CST -/ def typeSynToCST {M} [Inhabited M] (syn : TypeSynonym) @@ -1081,7 +1068,7 @@ private def recreateGlobalContext (ctx : ToCSTContext M) (map.insert name i, i + 1) let vars := allFreeVars.map fun name => -- .fvar below is really a dummy value. - (name, GlobalKind.expr (.fvar default 0 #[]), DeclState.defined) + (name, GlobalKind.expr (.fvar default 0 #[])) { nameMap, vars } -- Extract types not in `Core.KnownTypes`. diff --git a/Strata/Languages/Core/DDMTransform/Grammar.lean b/Strata/Languages/Core/DDMTransform/Grammar.lean index c12107580..9e26c503e 100644 --- a/Strata/Languages/Core/DDMTransform/Grammar.lean +++ b/Strata/Languages/Core/DDMTransform/Grammar.lean @@ -274,10 +274,6 @@ op command_procedure (name : Ident, op command_typedecl (name : Ident, args : Option Bindings) : Command => "type " name args ";\n"; -@[declareTypeForward(name, some args)] -op command_forward_typedecl (name : Ident, args : Option Bindings) : Command => - "forward type " name args ";\n"; - @[aliasType(name, some args, rhs)] op command_typesynonym (name : Ident, args : Option Bindings, @@ -372,8 +368,8 @@ op command_datatype (name : Ident, "datatype " name typeParams " {" constructors "\n}" ";\n"; // Mutual block for defining mutually recursive types -// Types should be forward-declared before the mutual block -@[scope(commands)] +// Type names are pre-registered via @[preRegisterTypes] before elaboration +@[scope(commands), preRegisterTypes(commands)] op command_mutual (commands : SpacePrefixSepBy Command) : Command => "mutual\n " indent(2, commands) "end;\n"; diff --git a/Strata/Languages/Core/DDMTransform/Translate.lean b/Strata/Languages/Core/DDMTransform/Translate.lean index 4af61fd93..3d76b1ff5 100644 --- a/Strata/Languages/Core/DDMTransform/Translate.lean +++ b/Strata/Languages/Core/DDMTransform/Translate.lean @@ -350,31 +350,6 @@ def translateTypeDecl (bindings : TransBindings) (op : Operation) : let decl := Core.Decl.type (.con { name := name, numargs := numargs }) md return (decl, { bindings with freeVars := bindings.freeVars.push decl }) -/-- -Translate a forward type declaration. This creates a placeholder entry that will -be replaced when the actual datatype definition is encountered in a mutual block. --/ -def translateForwardTypeDecl (bindings : TransBindings) (op : Operation) : - TransM (Core.Decl × TransBindings) := do - let _ ← @checkOp (Core.Decl × TransBindings) op q`Core.command_forward_typedecl 2 - let name ← translateIdent TyIdentifier op.args[0]! - let numargs ← - translateOption - (fun maybearg => - do match maybearg with - | none => pure 0 - | some arg => - let bargs ← checkOpArg arg q`Core.mkBindings 1 - let numargs ← - match bargs[0]! with - | .seq _ .comma args => pure args.size - | _ => TransM.error - s!"translateForwardTypeDecl expects a comma separated list: {repr bargs[0]!}") - op.args[1]! - let md ← getOpMetaData op - let decl := Core.Decl.type (.con { name := name, numargs := numargs }) md - return (decl, { bindings with freeVars := bindings.freeVars.push decl }) - --------------------------------------------------------------------- def translateLhs (arg : Arg) : TransM Core.CoreIdent := do @@ -1491,7 +1466,9 @@ Extract and translate constructor information from a constructor list argument. -/ def translateConstructorList (p : Program) (bindings : TransBindings) (arg : Arg) : TransM (Array TransConstructorInfo) := do - let constructorInfos := GlobalContext.extractConstructorInfo p.dialects arg + let constructorInfos ← match extractConstructorInfo p.dialects arg with + | .ok info => pure info + | .error e => TransM.error s!"Constructor extraction error: {e}" constructorInfos.mapM (translateConstructorInfo bindings) --------------------------------------------------------------------- @@ -1634,7 +1611,9 @@ def translateDatatype (p : Program) (bindings : TransBindings) (op : Operation) /-- Translate a mutual block containing mutually recursive datatype definitions. This collects all datatypes, creates a single TypeDecl.data with all of them, -and updates the forward-declared entries in bindings.freeVars. +and adds placeholder entries for type references during translation. +The `@[preRegisterTypes]` metadata on the mutual block operation ensures that +type names are pre-registered in the DDM GlobalContext before processing. -/ def translateMutualBlock (p : Program) (bindings : TransBindings) (op : Operation) : TransM (Core.Decl × TransBindings) := do @@ -1653,8 +1632,8 @@ def translateMutualBlock (p : Program) (bindings : TransBindings) (op : Operatio if datatypeOps.size == 0 then TransM.error "Mutual block must contain at least one datatype" else - -- First pass: collect all datatype names, type args, and their indices in freeVars - -- Forward declarations MUST already be in bindings.freeVars + -- First pass: collect all datatype names and type args, and allocate placeholder + -- entries in freeVars for each one (replacing any pre-registered entries if present) let mut datatypeInfos : Array (String × List TyIdentifier × Nat) := #[] let mut bindingsWithPlaceholders := bindings @@ -1662,20 +1641,25 @@ def translateMutualBlock (p : Program) (bindings : TransBindings) (op : Operatio let datatypeName ← translateIdent String dtOp.args[0]! let (typeArgs, _) ← translateDatatypeTypeArgs bindings dtOp.args[1]! "translateMutualBlock" - -- Find the index of this datatype in freeVars (from forward declaration) + -- Check if this datatype was already pre-registered in freeVars let existingIdx := bindings.freeVars.findIdx? fun decl => match decl with | .type t _ => t.names.contains datatypeName | _ => false + let placeholderDecl := Core.Decl.type (.data [mkPlaceholderLDatatype datatypeName typeArgs]) match existingIdx with | some i => - let placeholderDecl := Core.Decl.type (.data [mkPlaceholderLDatatype datatypeName typeArgs]) + -- Replace existing pre-registered entry with placeholder datatypeInfos := datatypeInfos.push (datatypeName, typeArgs, i) bindingsWithPlaceholders := { bindingsWithPlaceholders with freeVars := bindingsWithPlaceholders.freeVars.set! i placeholderDecl } | none => - TransM.error s!"Mutual datatype {datatypeName} requires a forward declaration" + -- Allocate a new placeholder entry + let idx := bindingsWithPlaceholders.freeVars.size + datatypeInfos := datatypeInfos.push (datatypeName, typeArgs, idx) + bindingsWithPlaceholders := { bindingsWithPlaceholders with + freeVars := bindingsWithPlaceholders.freeVars.push placeholderDecl } -- Second pass: translate all constructors with all placeholders in scope let ldatatypes ← (datatypeOps.zip datatypeInfos).toList.mapM fun (dtOp, (datatypeName, typeArgs, _idx)) => do @@ -1696,8 +1680,7 @@ def translateMutualBlock (p : Program) (bindings : TransBindings) (op : Operatio let md ← getOpMetaData op let mutualTypeDecl := Core.Decl.type (.data ldatatypes) md - -- Update bindings.freeVars: replace forward-declared entries with the mutual block - -- For each datatype, update its entry to point to the mutual TypeDecl + -- Update bindings.freeVars: replace placeholder entries with the mutual block let mut finalBindings := bindings for (_datatypeName, _typeArgs, idx) in datatypeInfos do @@ -1739,13 +1722,7 @@ partial def translateCoreDecls (p : Program) (bindings : TransBindings) : | 0 => return ([], bindings) | _ + 1 => let op := ops[count]! - let (newDecls, bindings) ← - match op.name with - | q`Core.command_forward_typedecl => - -- Forward declarations do NOT produce AST nodes - they only update bindings - let (_, bindings) ← translateForwardTypeDecl bindings op - pure ([], bindings) - | _ => + let (newDecls, bindings) ← do let (decl, bindings) ← match op.name with | q`Core.command_datatype => diff --git a/StrataTest/DDM/MutualDatatypes.lean b/StrataTest/DDM/MutualDatatypes.lean index f15ccf970..0f4dd7017 100644 --- a/StrataTest/DDM/MutualDatatypes.lean +++ b/StrataTest/DDM/MutualDatatypes.lean @@ -10,14 +10,16 @@ import Strata.DDM.Integration.Lean /-! # Tests for mutual datatype blocks in DDM -Tests that mutually recursive datatypes can be declared via forward -declarations and mutual blocks. +Tests that mutually recursive datatypes can be declared via +pre-registration and mutual blocks. -/ #dialect dialect TestMutual; -metadata declareDatatype (name : Ident, typeParams : Ident, constructors : Ident); +metadata declareDatatype (name : Ident, typeParams : Ident, + constructors : Ident, testerTemplate : FunctionTemplate, + accessorTemplate : FunctionTemplate); type int; @@ -41,59 +43,55 @@ op constructorListAtom (c : Constructor) : ConstructorList => c; @[constructorListPush(cl, c)] op constructorListPush (cl : ConstructorList, c : Constructor) : ConstructorList => - cl ", " c; + @[prec(30), leftassoc] cl ", " c; -@[declareTypeForward(name, none)] -op command_forward (name : Ident) : Command => - "forward type " name ";\n"; - -@[declareDatatype(name, typeParams, constructors)] +@[declareDatatype(name, typeParams, constructors, + perConstructor([.literal "..is", .constructor], + [.datatype], .builtin "bool"), + perField([.datatype, .literal "..", .field], [.datatype], .fieldType))] op command_datatype (name : Ident, typeParams : Option Bindings, @[scopeDatatype(name, typeParams)] constructors : ConstructorList) : Command => "datatype " name typeParams " { " constructors " };\n"; -@[scope(commands)] +@[scope(commands), preRegisterTypes(commands)] op command_mutual (commands : SpacePrefixSepBy Command) : Command => "mutual\n" indent(2, commands) "end;\n"; #end --------------------------------------------------------------------- --- Test 1: Basic mutual recursion (Tree/Forest) +-- Test 1: Types from mutual block visible after the block --------------------------------------------------------------------- -def mutualBasicPgm := +def mutualVisibleAfterPgm := #strata program TestMutual; -forward type Tree; -forward type Forest; mutual datatype Tree { Node(val: int, children: Forest) }; datatype Forest { FNil(), FCons(head: Tree, tail: Forest) }; end; +datatype Wrapper { MkWrapper(t: Tree, f: Forest) }; #end /-- info: program TestMutual; -forward type Tree; -forward type Forest; mutual datatype Tree { Node(val:int, children:Forest) }; - datatype Forest { (FNil()), (FCons(head:Tree, tail:Forest)) }; + datatype Forest { FNil(), FCons(head:Tree, tail:Forest) }; end; +datatype Wrapper { MkWrapper(t:Tree, f:Forest) }; -/ #guard_msgs in -#eval IO.println mutualBasicPgm +#eval IO.println mutualVisibleAfterPgm --------------------------------------------------------------------- --- Test 2: Single datatype in mutual block (should not actually be used) +-- Test 2: Single datatype in mutual block (allowed but not common) --------------------------------------------------------------------- def mutualSinglePgm := #strata program TestMutual; -forward type List; mutual datatype List { Nil(), Cons(head: int, tail: List) }; end; @@ -101,9 +99,8 @@ end; /-- info: program TestMutual; -forward type List; mutual - datatype List { (Nil()), (Cons(head:int, tail:List)) }; + datatype List { Nil(), Cons(head:int, tail:List) }; end; -/ #guard_msgs in @@ -116,9 +113,6 @@ end; def mutualThreeWayPgm := #strata program TestMutual; -forward type A; -forward type B; -forward type C; mutual datatype A { MkA(toB: B) }; datatype B { MkB(toC: C) }; @@ -128,14 +122,112 @@ end; /-- info: program TestMutual; -forward type A; -forward type B; -forward type C; mutual datatype A { MkA(toB:B) }; datatype B { MkB(toC:C) }; - datatype C { (MkC(toA:A)), (CBase()) }; + datatype C { MkC(toA:A), CBase() }; end; -/ #guard_msgs in #eval IO.println mutualThreeWayPgm + +--------------------------------------------------------------------- +-- Test 4: Empty mutual block +--------------------------------------------------------------------- + +def mutualEmptyPgm := +#strata +program TestMutual; +mutual +end; +#end + +/-- +info: program TestMutual; +mutual +end; +-/ +#guard_msgs in +#eval IO.println mutualEmptyPgm + +--------------------------------------------------------------------- +-- Test 5: Function templates expand for mutual types +-- The perConstructor/perField templates on declareDatatype generate +-- tester and accessor functions (e.g., Tree..isNode, Tree..val). +-- This test verifies template expansion succeeds for mutual types +-- with multiple constructors and fields. +--------------------------------------------------------------------- + +def mutualTemplatesPgm := +#strata +program TestMutual; +mutual + datatype Expr { Lit(val: int), Add(lhs: Expr, rhs: Expr), + Call(tag: int, args: ExprList) }; + datatype ExprList { ENil(), ECons(head: Expr, tail: ExprList) }; +end; +datatype Program { MkProgram(body: Expr) }; +#end + +/-- +info: program TestMutual; +mutual + datatype Expr { Lit(val:int), Add(lhs:Expr, rhs:Expr), Call(tag:int, args:ExprList) }; + datatype ExprList { ENil(), ECons(head:Expr, tail:ExprList) }; +end; +datatype Program { MkProgram(body:Expr) }; +-/ +#guard_msgs in +#eval IO.println mutualTemplatesPgm + +--------------------------------------------------------------------- +-- Negative Tests +--------------------------------------------------------------------- + +-- Test: Reference to undefined type inside mutual block +/-- error: Undeclared type or category Bogus. -/ +#guard_msgs in +def mutualUndefinedRefPgm := +#strata +program TestMutual; +mutual + datatype A { MkA(x: Bogus) }; +end; +#end + +-- Test: Duplicate type name in mutual block +/-- error: Type 'Dup' is already declared. -/ +#guard_msgs in +def mutualDuplicatePgm := +#strata +program TestMutual; +mutual + datatype Dup { MkDup1() }; + datatype Dup { MkDup2() }; +end; +#end + +-- Test: Mutual type clashes with previously defined type +/-- error: Type 'Existing' is already declared. -/ +#guard_msgs in +def mutualClashPgm := +#strata +program TestMutual; +datatype Existing { MkExisting() }; +mutual + datatype Existing { MkClash() }; +end; +#end + +-- Test: Duplicate constructor name across mutual datatypes +/-- +error: Mk already defined. +-/ +#guard_msgs in +#eval #strata +program TestMutual; +mutual + datatype A { Mk() }; + datatype B { Mk() }; +end; +#end diff --git a/StrataTest/Languages/Core/Examples/DatatypeTypingTests.lean b/StrataTest/Languages/Core/Examples/DatatypeTypingTests.lean index ae81f23e4..6af025f61 100644 --- a/StrataTest/Languages/Core/Examples/DatatypeTypingTests.lean +++ b/StrataTest/Languages/Core/Examples/DatatypeTypingTests.lean @@ -173,8 +173,6 @@ program Core; datatype List (a : Type) { Nil(), Cons(hd: a, tl: List a) }; -forward type MutNestA (a : Type); -forward type MutNestB (a : Type); mutual datatype MutNestA (a : Type) { MkA(x: List (MutNestB a)) }; datatype MutNestB (a : Type) { BBase(), MkB(x: MutNestA a) }; @@ -214,8 +212,6 @@ def mutualNonPositivePgm : Program := #strata program Core; -forward type BadA; -forward type BadB; mutual datatype BadA () { MkA(f: BadB -> int) }; datatype BadB () { BadBBase(), MkB(a: BadA) }; @@ -253,8 +249,6 @@ def mutualUninhabitedPgm : Program := #strata program Core; -forward type Bad1; -forward type Bad2; mutual datatype Bad1 () { B1(x: Bad2) }; datatype Bad2 () { B2(x: Bad1) }; @@ -275,9 +269,6 @@ def threeWayCyclePgm : Program := #strata program Core; -forward type Cycle1; -forward type Cycle2; -forward type Cycle3; mutual datatype Cycle1 () { C1(x: Cycle2) }; datatype Cycle2 () { C2(x: Cycle3) }; diff --git a/StrataTest/Languages/Core/Examples/MutualDatatypes.lean b/StrataTest/Languages/Core/Examples/MutualDatatypes.lean index dcf894a3e..252ca0246 100644 --- a/StrataTest/Languages/Core/Examples/MutualDatatypes.lean +++ b/StrataTest/Languages/Core/Examples/MutualDatatypes.lean @@ -22,8 +22,6 @@ def roseTreeTesterPgm : Program := #strata program Core; -forward type RoseTree; -forward type Forest; mutual datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; datatype RoseTree { Node(val: int, children: Forest) }; @@ -92,8 +90,6 @@ def roseTreeDestructorPgm : Program := #strata program Core; -forward type RoseTree; -forward type Forest; mutual datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; datatype RoseTree { Node(val: int, children: Forest) }; @@ -190,8 +186,6 @@ def roseTreeEqualityPgm : Program := #strata program Core; -forward type RoseTree; -forward type Forest; mutual datatype Forest { FNil(), FCons(head: RoseTree, tail: Forest) }; datatype RoseTree { Node(val: int, children: Forest) }; @@ -255,8 +249,6 @@ def polyRoseTreeHavocPgm : Program := #strata program Core; -forward type RoseTree (a : Type); -forward type Forest (a : Type); mutual datatype Forest (a : Type) { FNil(), FCons(head: RoseTree a, tail: Forest a) }; datatype RoseTree (a : Type) { Node(val: a, children: Forest a) }; @@ -333,8 +325,6 @@ def stmtListHavocPgm : Program := #strata program Core; -forward type Stmt (e : Type, c : Type); -forward type StmtList (e : Type, c : Type); mutual datatype StmtList (e : Type, c : Type) { SNil(), SCons(hd: Stmt e c, tl: StmtList e c) }; datatype Stmt (e : Type, c : Type) { diff --git a/StrataTest/Languages/Core/TestASTtoCST.lean b/StrataTest/Languages/Core/TestASTtoCST.lean index cf8d6e744..8e98f20b8 100644 --- a/StrataTest/Languages/Core/TestASTtoCST.lean +++ b/StrataTest/Languages/Core/TestASTtoCST.lean @@ -297,8 +297,6 @@ private def polyRoseTreeHavocPgm : Program := #strata program Core; -forward type RoseTree (a : Type); -forward type Forest (a : Type); mutual datatype Forest (a : Type) { FNil(), FCons(head: RoseTree a, tail: Forest a) }; datatype RoseTree (a : Type) { Node(val: a, children: Forest a) }; @@ -322,9 +320,7 @@ spec { #end /-- -info: forward type RoseTree (a : Type); -forward type Forest (a : Type); -mutual +info: mutual datatype Forest (a : Type) {( (FNil())), (FCons(head : (RoseTree a), tail : (Forest a)))