diff --git a/Strata/Languages/Laurel/ConstrainedTypeElim.lean b/Strata/Languages/Laurel/ConstrainedTypeElim.lean new file mode 100644 index 000000000..0d6264991 --- /dev/null +++ b/Strata/Languages/Laurel/ConstrainedTypeElim.lean @@ -0,0 +1,284 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +import Strata.Languages.Laurel.Laurel +import Strata.Languages.Laurel.Resolution + +/-! +# Constrained Type Elimination + +A Laurel-to-Laurel pass that eliminates constrained types by: +1. Adding `requires` for constrained-typed inputs (Core handles caller asserts and body assumes) +2. Adding `ensures` for constrained-typed outputs (Core handles body checks and caller assumes) + - Skipped for `isFunctional` procedures since the Laurel translator does not yet support + function postconditions. Constrained return types on functions are not checked. +3. Inserting `assert` for local variable init and reassignment of constrained-typed variables +4. Using the witness as default initializer for uninitialized constrained-typed variables +5. Adding a synthetic witness-validation procedure per constrained type +6. Injecting constraints into quantifier bodies (`forall` → `implies`, `exists` → `and`) +7. Resolving all constrained type references to their base types +-/ + +namespace Strata.Laurel + +open Strata + +abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType +/-- Map from variable name to its constrained HighType (e.g. UserDefined "nat") -/ +abbrev PredVarMap := Std.HashMap String HighType + +def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap := + types.foldl (init := {}) fun m td => + match td with | .Constrained ct => m.insert ct.name.text ct | _ => m + +partial def resolveBaseType (ptMap : ConstrainedTypeMap) (ty : HighType) : HighType := + match ty with + | .UserDefined name => match ptMap.get? name.text with + | some ct => resolveBaseType ptMap ct.base.val | none => ty + | .Applied ctor args => + .Applied ctor (args.map fun a => ⟨resolveBaseType ptMap a.val, a.md⟩) + | _ => ty + +def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd := + ⟨resolveBaseType ptMap ty.val, ty.md⟩ + +def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool := + match ty with | .UserDefined name => ptMap.contains name.text | _ => false + +/-- All predicates for a type transitively (e.g. evenpos → [(x, x > 0), (x, x % 2 == 0)]) -/ +partial def getAllConstraints (ptMap : ConstrainedTypeMap) (ty : HighType) + : List (Identifier × StmtExprMd) := + match ty with + | .UserDefined name => match ptMap.get? name.text with + | some ct => (ct.valueName, ct.constraint) :: getAllConstraints ptMap ct.base.val + | none => [] + | _ => [] + +/-- Substitute `Identifier old` with `Identifier new` in a constraint expression -/ +partial def substId (old new : Identifier) : StmtExprMd → StmtExprMd + | ⟨.Identifier n, md⟩ => ⟨if n == old then .Identifier new else .Identifier n, md⟩ + | ⟨.PrimitiveOp op args, md⟩ => + ⟨.PrimitiveOp op (args.map fun a => substId old new a), md⟩ + | ⟨.StaticCall c args, md⟩ => + ⟨.StaticCall c (args.map fun a => substId old new a), md⟩ + | ⟨.IfThenElse c t (some el), md⟩ => + ⟨.IfThenElse (substId old new c) (substId old new t) (some (substId old new el)), md⟩ + | ⟨.IfThenElse c t none, md⟩ => + ⟨.IfThenElse (substId old new c) (substId old new t) none, md⟩ + | ⟨.Block ss sep, md⟩ => + ⟨.Block (ss.map fun s => substId old new s) sep, md⟩ + | ⟨.Forall param body, md⟩ => + if param.name == old then ⟨.Forall param body, md⟩ + else if param.name == new then + let fresh : Identifier := mkId (param.name.text ++ "$") + ⟨.Forall { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ + else ⟨.Forall param (substId old new body), md⟩ + | ⟨.Exists param body, md⟩ => + if param.name == old then ⟨.Exists param body, md⟩ + else if param.name == new then + let fresh : Identifier := mkId (param.name.text ++ "$") + ⟨.Exists { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ + else ⟨.Exists param (substId old new body), md⟩ + | e => e + +def mkAsserts (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier) + (md : Imperative.MetaData Core.Expression) : List StmtExprMd := + (getAllConstraints ptMap ty).map fun (valueName, pred) => + ⟨.Assert (substId valueName varName pred), md⟩ + +private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression) + : StmtExprMd := + match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩ + +/-- Inject constraints into a quantifier body for a constrained type -/ +private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighType) + (varName : Identifier) (body : StmtExprMd) (isForall : Bool) : StmtExprMd := + let constraints := getAllConstraints ptMap ty + match constraints with + | [] => body + | _ => + let preds := constraints.map fun (vn, pred) => substId vn varName pred + match preds with + | [] => body -- unreachable + | first :: rest => + let conj := rest.foldl (init := first) fun acc p => + ⟨.PrimitiveOp .And [acc, p], body.md⟩ + if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ + else ⟨.PrimitiveOp .And [conj, body], body.md⟩ + +/-- Resolve constrained types in all type positions of an expression -/ +def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd + | ⟨.LocalVariable n ty (some init), md⟩ => + ⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩ + | ⟨.LocalVariable n ty none, md⟩ => + ⟨.LocalVariable n (resolveType ptMap ty) none, md⟩ + | ⟨.Forall param body, md⟩ => + let body' := resolveExpr ptMap body + let param' := { param with type := resolveType ptMap param.type } + ⟨.Forall param' (injectQuantifierConstraint ptMap param.type.val param.name body' true), md⟩ + | ⟨.Exists param body, md⟩ => + let body' := resolveExpr ptMap body + let param' := { param with type := resolveType ptMap param.type } + ⟨.Exists param' (injectQuantifierConstraint ptMap param.type.val param.name body' false), md⟩ + | ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ + | ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ + | ⟨.PrimitiveOp op args, md⟩ => + ⟨.PrimitiveOp op (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ + | ⟨.StaticCall c args, md⟩ => + ⟨.StaticCall c (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ + | ⟨.Block ss sep, md⟩ => + ⟨.Block (ss.attach.map fun ⟨s, _⟩ => resolveExpr ptMap s) sep, md⟩ + | ⟨.IfThenElse c t (some el), md⟩ => + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) (some (resolveExpr ptMap el)), md⟩ + | ⟨.IfThenElse c t none, md⟩ => + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) none, md⟩ + | ⟨.While c inv dec body, md⟩ => + ⟨.While (resolveExpr ptMap c) (inv.attach.map fun ⟨i, _⟩ => resolveExpr ptMap i) + dec (resolveExpr ptMap body), md⟩ + | ⟨.Assign ts v, md⟩ => + ⟨.Assign (ts.attach.map fun ⟨t, _⟩ => resolveExpr ptMap t) (resolveExpr ptMap v), md⟩ + | ⟨.Return (some v), md⟩ => ⟨.Return (some (resolveExpr ptMap v)), md⟩ + | ⟨.Return none, md⟩ => ⟨.Return none, md⟩ + | ⟨.Assert c, md⟩ => ⟨.Assert (resolveExpr ptMap c), md⟩ + | ⟨.Assume c, md⟩ => ⟨.Assume (resolveExpr ptMap c), md⟩ + | e => e +termination_by e => sizeOf e +decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem) + +/-- Insert asserts for constrained-typed variable init and reassignment -/ +abbrev ElimM := StateM PredVarMap + +private def inScope (action : ElimM α) : ElimM α := do + let saved ← get + let result ← action + set saved + return result + +def elimStmt (ptMap : ConstrainedTypeMap) + (stmt : StmtExprMd) : ElimM (List StmtExprMd) := do + let md := stmt.md + match _h : stmt.val with + | .LocalVariable name ty init => + let isPred := isConstrainedType ptMap ty.val + if isPred then modify fun pv => pv.insert name.text ty.val + let asserts := if isPred then mkAsserts ptMap ty.val name md else [] + -- Use witness as default initializer for uninitialized constrained variables + let init' := match init with + | none => match ty.val with + | .UserDefined n => (ptMap.get? n.text).map (·.witness) + | _ => none + | some _ => init + pure ([⟨.LocalVariable name ty init', md⟩] ++ asserts) + + -- Single-target only; multi-target assignments are not supported by the Laurel grammar + | .Assign [target] _ => match target.val with + | .Identifier name => do + match (← get).get? name.text with + | some ty => pure ([stmt] ++ mkAsserts ptMap ty name md) + | none => pure [stmt] + | _ => pure [stmt] + + | .Block stmts sep => + let stmtss ← inScope (stmts.mapM (elimStmt ptMap)) + pure [⟨.Block stmtss.flatten sep, md⟩] + + | .IfThenElse cond thenBr (some elseBr) => + let thenSs ← inScope (elimStmt ptMap thenBr) + let elseSs ← inScope (elimStmt ptMap elseBr) + pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩] + | .IfThenElse cond thenBr none => + let thenSs ← inScope (elimStmt ptMap thenBr) + pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩] + + | .While cond inv dec body => + let bodySs ← inScope (elimStmt ptMap body) + pure [⟨.While cond inv dec (wrap bodySs md), md⟩] + + | _ => pure [stmt] +termination_by sizeOf stmt +decreasing_by + all_goals simp_wf + all_goals (try have := WithMetadata.sizeOf_val_lt stmt) + all_goals (try term_by_mem) + all_goals omega + +def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := + -- Add requires for constrained-typed inputs + let inputRequires := proc.inputs.flatMap fun p => + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => + ⟨(substId vn p.name pred).val, p.type.md⟩ + -- Add ensures for constrained-typed outputs (skip for isFunctional — not yet supported) + let outputEnsures := if proc.isFunctional then [] else proc.outputs.flatMap fun p => + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => + ⟨(substId vn p.name pred).val, p.type.md⟩ + -- Transform body: insert asserts for local variable init/reassignment + let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p => + if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s + let body' := match proc.body with + | .Transparent bodyExpr => + let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars + let body := wrap stmts bodyExpr.md + if outputEnsures.isEmpty then .Transparent body + else + -- Wrap expression body in a Return so it translates correctly as a procedure + let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body + .Opaque outputEnsures (some retBody) [] + | .Opaque postconds impl modif => + let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.md + .Opaque (postconds ++ outputEnsures) impl' modif + | .Abstract postconds => .Abstract (postconds ++ outputEnsures) + | .External => .External + -- Resolve all constrained types to base types + let resolve := resolveExpr ptMap + let resolveBody : Body → Body := fun body => match body with + | .Transparent b => .Transparent (resolve b) + | .Opaque ps impl modif => .Opaque (ps.map resolve) (impl.map resolve) (modif.map resolve) + | .Abstract ps => .Abstract (ps.map resolve) + | .External => .External + { proc with + body := resolveBody body' + inputs := proc.inputs.map fun p => { p with type := resolveType ptMap p.type } + outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type } + preconditions := (proc.preconditions ++ inputRequires).map resolve } + +/-- Create a synthetic procedure that asserts the witness satisfies all constraints -/ +private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := + let md := ct.witness.md + let witnessId : Identifier := mkId "$witness" + let witnessInit : StmtExprMd := + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩ + let asserts := (getAllConstraints ptMap (.UserDefined ct.name)).map fun (vn, pred) => + ⟨.Assert (substId vn witnessId pred), md⟩ + { name := mkId s!"$witness_{ct.name.text}" + inputs := [] + outputs := [] + body := .Transparent ⟨.Block ([witnessInit] ++ asserts) none, md⟩ + preconditions := [] + isFunctional := false + determinism := .deterministic none + decreases := none + md := md } + +/-- Eliminate constrained types from a Laurel program. + The `witness` field is used as the default initializer for uninitialized + constrained-typed variables, and is validated via synthetic procedures. -/ +def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program × Array DiagnosticModel := + let ptMap := buildConstrainedTypeMap program.types + if ptMap.isEmpty then (program, #[]) else + -- Report unsupported: isFunctional procedures with constrained return types + let funcDiags := program.staticProcedures.foldl (init := #[]) fun acc proc => + if proc.isFunctional && proc.outputs.any (fun p => isConstrainedType ptMap p.type.val) then + acc.push (proc.md.toDiagnostic "constrained return types on functions are not yet supported") + else acc + let witnessProcedures := program.types.filterMap fun + | .Constrained ct => some (mkWitnessProc ptMap ct) + | _ => none + ({ program with + staticProcedures := program.staticProcedures.map (elimProc ptMap) ++ witnessProcedures + types := program.types.filter fun | .Constrained _ => false | _ => true }, + funcDiags) + +end Strata.Laurel diff --git a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean index 2ada05bab..75eaf50ca 100644 --- a/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean +++ b/Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean @@ -495,6 +495,20 @@ def parseDatatype (arg : Arg) : TransM TypeDefinition := do | _, _ => TransM.error s!"parseDatatype expects datatype, got {repr op.name}" +def parseConstrainedType (arg : Arg) : TransM ConstrainedType := do + let .op op := arg + | TransM.error s!"parseConstrainedType expects operation" + match op.name, op.args with + | q`Laurel.constrainedType, #[nameArg, valueNameArg, baseArg, constraintArg, witnessArg] => + let name ← translateIdent nameArg + let valueName ← translateIdent valueNameArg + let base ← translateHighType baseArg + let constraint ← translateStmtExpr constraintArg + let witness ← translateStmtExpr witnessArg + return { name, base, valueName, constraint, witness } + | _, _ => + TransM.error s!"parseConstrainedType expects constrainedType, got {repr op.name}" + def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinition) := do let .op op := arg | TransM.error s!"parseTopLevel expects operation" @@ -509,8 +523,11 @@ def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinitio | q`Laurel.topLevelDatatype, #[datatypeArg] => let typeDef ← parseDatatype datatypeArg return (none, some typeDef) + | q`Laurel.topLevelConstrainedType, #[ctArg] => + let ct ← parseConstrainedType ctArg + return (none, some (.Constrained ct)) | _, _ => - TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelDatatype, got {repr op.name}" + TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, topLevelDatatype, or topLevelConstrainedType, got {repr op.name}" /-- Translate concrete Laurel syntax into abstract Laurel syntax diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean index 282254daa..632e0a69b 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.lean @@ -7,7 +7,6 @@ -- Laurel dialect definition, loaded from LaurelGrammar.st -- NOTE: Changes to LaurelGrammar.st are not automatically tracked by the build system. -- Update this file (e.g. this comment) to trigger a recompile after modifying LaurelGrammar.st. --- Last grammar change: require semicolon after procedure/function definitions. import Strata.DDM.Integration.Lean namespace Strata.Laurel diff --git a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st index 720d64b86..406e704bb 100644 --- a/Strata/Languages/Laurel/Grammar/LaurelGrammar.st +++ b/Strata/Languages/Laurel/Grammar/LaurelGrammar.st @@ -166,4 +166,10 @@ op topLevelComposite(composite: Composite): TopLevel => composite; op topLevelProcedure(procedure: Procedure): TopLevel => procedure; op topLevelDatatype(datatype: Datatype): TopLevel => datatype; +category ConstrainedType; +op constrainedType (name: Ident, valueName: Ident, base: LaurelType, + constraint: StmtExpr, witness: StmtExpr): ConstrainedType + => "constrained " name " = " valueName ": " base " where " constraint:0 " witness " witness:0; +op topLevelConstrainedType(ct: ConstrainedType): TopLevel => ct; + op program (items: Seq TopLevel): Command => items; \ No newline at end of file diff --git a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean index 5a67f500c..ca50f9c98 100644 --- a/Strata/Languages/Laurel/LaurelToCoreTranslator.lean +++ b/Strata/Languages/Laurel/LaurelToCoreTranslator.lean @@ -21,6 +21,7 @@ import Strata.DL.Imperative.Stmt import Strata.DL.Imperative.MetaData import Strata.DL.Lambda.LExpr import Strata.Languages.Laurel.LaurelFormat +import Strata.Languages.Laurel.ConstrainedTypeElim import Strata.Util.Tactics open Core (VCResult VCResults VerifyOptions) @@ -620,6 +621,11 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program let (program, model) := (result.program, result.model) _resolutionDiags := _resolutionDiags ++ result.errors + let (program, constrainedTypeDiags) := constrainedTypeElim model program + let result := resolve program (some model) + let (program, model) := (result.program, result.model) + _resolutionDiags := _resolutionDiags ++ result.errors + -- Procedures marked isFunctional are translated to Core functions; all others become Core procedures. -- External procedures are completely ignored (not translated to Core). let nonExternal := program.staticProcedures.filter (fun p => !p.body.isExternal) @@ -668,7 +674,7 @@ def translate (program : Program): Except (Array DiagnosticModel) (Core.Program -- dbg_trace "=== Generated Strata Core Program ===" -- dbg_trace (toString (Std.Format.pretty (Strata.Core.formatProgram program) 100)) -- dbg_trace "=================================" - pure (program, diamondErrors ++ modifiesDiags) + pure (program, diamondErrors ++ modifiesDiags ++ constrainedTypeDiags.toList) /-- Verify a Laurel program using an SMT solver diff --git a/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean new file mode 100644 index 000000000..54841c473 --- /dev/null +++ b/StrataTest/Languages/Laurel/ConstrainedTypeElimTest.lean @@ -0,0 +1,91 @@ +/- + Copyright Strata Contributors + + SPDX-License-Identifier: Apache-2.0 OR MIT +-/ + +/- +Tests that the constrained type elimination pass correctly transforms +Laurel programs by comparing the output against expected results. +-/ + +import Strata.DDM.Elab +import Strata.DDM.BuiltinDialects.Init +import Strata.Languages.Laurel.Grammar.LaurelGrammar +import Strata.Languages.Laurel.Grammar.ConcreteToAbstractTreeTranslator +import Strata.Languages.Laurel.ConstrainedTypeElim +import Strata.Languages.Laurel.Resolution + +open Strata +open Strata.Elab (parseStrataProgramFromDialect) + +namespace Strata.Laurel + +def testProgram : String := r" +constrained nat = x: int where x >= 0 witness 0 +procedure test(n: nat) returns (r: nat) { + var y: nat := n; + return y; +}; +" + +def parseLaurelAndElim (input : String) : IO Program := do + let inputCtx := Strata.Parser.stringInputContext "test" input + let dialects := Strata.Elab.LoadedDialects.ofDialects! #[initDialect, Laurel] + let strataProgram ← parseStrataProgramFromDialect dialects Laurel.name inputCtx + let uri := Strata.Uri.file "test" + match Laurel.TransM.run uri (Laurel.parseProgram strataProgram) with + | .error e => throw (IO.userError s!"Translation errors: {e}") + | .ok program => + let result := resolve program + let (program, model) := (result.program, result.model) + pure (constrainedTypeElim model program).1 + +/-- +info: procedure test(n: int) returns ⏎ +(r: int) +requires n >= 0 +deterministic + ensures r >= 0 := { var y: int := n; assert y >= 0; return y } +procedure $witness_nat() returns ⏎ +() +deterministic +{ var $witness: int := 0; assert $witness >= 0 } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndElim testProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + +-- Scope management: constrained variable in if-branch must not leak into sibling block +def scopeProgram : String := r" +constrained pos = v: int where v > 0 witness 1 +procedure test(b: bool) { + if (b) { + var x: pos := 1; + } + { + var x: int := -5; + x := -10; + } +}; +" + +/-- +info: procedure test(b: bool) returns ⏎ +() +deterministic +{ if b then { var x: int := 1; assert x > 0 }; { var x: int := -5; x := -10 } } +procedure $witness_pos() returns ⏎ +() +deterministic +{ var $witness: int := 1; assert $witness > 0 } +-/ +#guard_msgs in +#eval! do + let program ← parseLaurelAndElim scopeProgram + for proc in program.staticProcedures do + IO.println (toString (Std.Format.pretty (Std.ToFormat.format proc))) + +end Laurel diff --git a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean index 33fcb29b0..a8c36fe87 100644 --- a/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean +++ b/StrataTest/Languages/Laurel/Examples/Fundamentals/T10_ConstrainedTypes.lean @@ -8,23 +8,161 @@ import StrataTest.Util.TestDiagnostics import StrataTest.Languages.Laurel.TestExamples open StrataTest.Util -open Strata +namespace Strata namespace Laurel def program := r" constrained nat = x: int where x >= 0 witness 0 +constrained posnat = x: nat where x > 0 witness 1 -composite Option {} -composite Some extends Option { - value: int -} -composite None extends Option -constrained SealedOption = x: Option where x is Some || x is None witness None +// Input constraint becomes requires — body can rely on it +procedure inputAssumed(n: nat) { + assert n >= 0; +}; + +// Output constraint — valid return passes +procedure outputValid(): nat { + return 3; +}; + +// Output constraint — invalid return fails +procedure outputInvalid(): nat { +// ^^^ error: assertion does not hold + return -1; +}; + +// Return value of constrained type — caller gets ensures via call elimination +procedure opaqueNat(): nat; +procedure callerAssumes() returns (r: int) { + var x: int := opaqueNat(); + assert x >= 0; + return x; +}; + +// Assignment to constrained-typed variable — valid +procedure assignValid() { + var y: nat := 5; +}; + +// Assignment to constrained-typed variable — invalid +procedure assignInvalid() { + var y: nat := -1; +//^^^^^^^^^^^^^^^^^ error: assertion does not hold +}; + +// Reassignment to constrained-typed variable — invalid +procedure reassignInvalid() { + var y: nat := 5; + y := -1; +//^^^^^^^^ error: assertion does not hold +}; + +// Argument to constrained-typed parameter — valid +procedure takesNat(n: nat) returns (r: int) { return n; }; +// ^^^ error: assertion does not hold +procedure argValid() returns (r: int) { + var x: int := takesNat(3); + return x; +}; + +// Argument to constrained-typed parameter — invalid (requires violation) +procedure argInvalid() returns (r: int) { + var x: int := takesNat(-1); + return x; +}; + +// Nested constrained type — independent constraints require transitive collection +constrained even = x: int where x % 2 == 0 witness 0 +constrained evenpos = x: even where x > 0 witness 2 +procedure nestedInput(x: evenpos) { + assert x > 0; + assert x % 2 == 0; +}; + +// Multiple constrained-typed parameters +procedure multiParam(a: nat, b: nat) { + assert a >= 0; + assert b >= 0; +}; -procedure foo() returns (r: nat) { +// Two calls to same procedure — no temp var collision +procedure twoCalls() returns (r: int) { + var a: int := takesNat(1); + var b: int := takesNat(2); + return a + b; +}; + +// Constrained type in expression position must be resolved +procedure constrainedInExpr() { + var b: bool := forall(n: nat) => n + 1 > n; + assert b; +}; + +// Invalid witness — witness -1 does not satisfy x > 0 +constrained bad = x: int where x > 0 witness -1 +// ^^ error: assertion does not hold + +// Uninitialized constrained variable — witness used as default +procedure uninitNat() { + var y: nat; + assert y >= 0; +}; + +// Uninitialized nested constrained variable — outermost witness used +procedure uninitPosnat() { + var y: posnat; + assert y > 0; + assert y >= 0; +}; + +// Function with valid constrained return — constraint not checked (not yet supported) +function goodFunc(): nat { 3 }; +// ^^^^^^^^ error: constrained return types on functions are not yet supported + +// Function with invalid constrained return — constraint not checked (not yet supported) +function badFunc(): nat { -1 }; +// ^^^^^^^ error: constrained return types on functions are not yet supported + +// Caller of constrained function — body is inlined, caller sees actual value +procedure callerGood() { + var x: int := goodFunc(); + assert x >= 0; +}; + +// Quantifier constraint injection — forall +// n + 1 > 0 is only provable with n >= 0 injected; false for all int +procedure forallNat() { + var b: bool := forall(n: nat) => n + 1 > 0; + assert b; +}; + +// Quantifier constraint injection — exists +// n == -1 is satisfiable for int, but not when n >= 0 is required +// n == 42 works because 42 >= 0 +procedure existsNat() { + var b: bool := exists(n: nat) => n == 42; + assert b; +}; + +// Quantifier constraint injection — nested constrained type +// n - 1 >= 0 is only provable with n > 0 injected +procedure forallPosnat() { + var b: bool := forall(n: posnat) => n - 1 >= 0; + assert b; +}; + +// Capture avoidance — bound var y in constraint must not collide with parameter y +// Without capture avoidance, requires becomes exists(y) => y > y (false), making body vacuously true +constrained haslarger = x: int where (exists(y: int) => y > x) witness 0 +procedure captureTest(y: haslarger) { + assert false; +//^^^^^^^^^^^^^ error: assertion does not hold }; " --- Not working yet --- #eval! testInput "ConstrainedTypes" program processLaurelFile +#guard_msgs(drop info, error) in +#eval testInputWithOffset "ConstrainedTypes" program 14 processLaurelFile + +end Laurel +end Strata