|
| 1 | +/- |
| 2 | + Copyright Strata Contributors |
| 3 | +
|
| 4 | + SPDX-License-Identifier: Apache-2.0 OR MIT |
| 5 | +-/ |
| 6 | + |
| 7 | +import Strata.Languages.Laurel.Laurel |
| 8 | +import Strata.Languages.Laurel.Resolution |
| 9 | + |
| 10 | +/-! |
| 11 | +# Constrained Type Elimination |
| 12 | +
|
| 13 | +A Laurel-to-Laurel pass that eliminates constrained types by: |
| 14 | +1. Adding `requires` for constrained-typed inputs (Core handles caller asserts and body assumes) |
| 15 | +2. Adding `ensures` for constrained-typed outputs (Core handles body checks and caller assumes) |
| 16 | + - For `isFunctional` procedures, clears the flag since the Laurel translator does not yet |
| 17 | + support function postconditions. TODO: restore `isFunctional` once supported. |
| 18 | +3. Inserting `assert` for local variable init and reassignment of constrained-typed variables |
| 19 | +4. Using the witness as default initializer for uninitialized constrained-typed variables |
| 20 | +5. Adding a synthetic witness-validation procedure per constrained type |
| 21 | +6. Injecting constraints into quantifier bodies (`forall` → `implies`, `exists` → `and`) |
| 22 | +7. Resolving all constrained type references to their base types |
| 23 | +-/ |
| 24 | + |
| 25 | +namespace Strata.Laurel |
| 26 | + |
| 27 | +open Strata |
| 28 | + |
| 29 | +abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType |
| 30 | +/-- Map from variable name to its original constrained type -/ |
| 31 | +abbrev PredVarMap := Std.HashMap String HighType |
| 32 | + |
| 33 | +def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap := |
| 34 | + types.foldl (init := {}) fun m td => |
| 35 | + match td with | .Constrained ct => m.insert ct.name.text ct | _ => m |
| 36 | + |
| 37 | +partial def resolveBaseType (ptMap : ConstrainedTypeMap) (ty : HighType) : HighType := |
| 38 | + match ty with |
| 39 | + | .UserDefined name => match ptMap.get? name.text with |
| 40 | + | some ct => resolveBaseType ptMap ct.base.val | none => ty |
| 41 | + | .Applied ctor args => |
| 42 | + .Applied ctor (args.map fun a => ⟨resolveBaseType ptMap a.val, a.md⟩) |
| 43 | + | _ => ty |
| 44 | + |
| 45 | +def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd := |
| 46 | + ⟨resolveBaseType ptMap ty.val, ty.md⟩ |
| 47 | + |
| 48 | +def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool := |
| 49 | + match ty with | .UserDefined name => ptMap.contains name.text | _ => false |
| 50 | + |
| 51 | +/-- All predicates for a type transitively (e.g. posnat → [x > 0, x >= 0]) -/ |
| 52 | +partial def getAllConstraints (ptMap : ConstrainedTypeMap) (ty : HighType) |
| 53 | + : List (Identifier × StmtExprMd) := |
| 54 | + match ty with |
| 55 | + | .UserDefined name => match ptMap.get? name.text with |
| 56 | + | some ct => (ct.valueName, ct.constraint) :: getAllConstraints ptMap ct.base.val |
| 57 | + | none => [] |
| 58 | + | _ => [] |
| 59 | + |
| 60 | +/-- Substitute `Identifier old` with `Identifier new` in a constraint expression -/ |
| 61 | +partial def substId (old new : Identifier) : StmtExprMd → StmtExprMd |
| 62 | + | ⟨.Identifier n, md⟩ => ⟨if n == old then .Identifier new else .Identifier n, md⟩ |
| 63 | + | ⟨.PrimitiveOp op args, md⟩ => |
| 64 | + ⟨.PrimitiveOp op (args.map fun a => substId old new a), md⟩ |
| 65 | + | ⟨.StaticCall c args, md⟩ => |
| 66 | + ⟨.StaticCall c (args.map fun a => substId old new a), md⟩ |
| 67 | + | ⟨.IfThenElse c t (some el), md⟩ => |
| 68 | + ⟨.IfThenElse (substId old new c) (substId old new t) (some (substId old new el)), md⟩ |
| 69 | + | ⟨.IfThenElse c t none, md⟩ => |
| 70 | + ⟨.IfThenElse (substId old new c) (substId old new t) none, md⟩ |
| 71 | + | ⟨.Block ss sep, md⟩ => |
| 72 | + ⟨.Block (ss.map fun s => substId old new s) sep, md⟩ |
| 73 | + | ⟨.Forall param body, md⟩ => |
| 74 | + if param.name == old then ⟨.Forall param body, md⟩ |
| 75 | + else if param.name == new then |
| 76 | + let fresh : Identifier := mkId (param.name.text ++ "$") |
| 77 | + ⟨.Forall { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ |
| 78 | + else ⟨.Forall param (substId old new body), md⟩ |
| 79 | + | ⟨.Exists param body, md⟩ => |
| 80 | + if param.name == old then ⟨.Exists param body, md⟩ |
| 81 | + else if param.name == new then |
| 82 | + let fresh : Identifier := mkId (param.name.text ++ "$") |
| 83 | + ⟨.Exists { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩ |
| 84 | + else ⟨.Exists param (substId old new body), md⟩ |
| 85 | + | e => e |
| 86 | + |
| 87 | +def mkAsserts (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier) |
| 88 | + (md : Imperative.MetaData Core.Expression) : List StmtExprMd := |
| 89 | + (getAllConstraints ptMap ty).map fun (valueName, pred) => |
| 90 | + ⟨.Assert (substId valueName varName pred), md⟩ |
| 91 | + |
| 92 | +private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression) |
| 93 | + : StmtExprMd := |
| 94 | + match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩ |
| 95 | + |
| 96 | +/-- Inject constraints into a quantifier body for a constrained type -/ |
| 97 | +private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighType) |
| 98 | + (varName : Identifier) (body : StmtExprMd) (isForall : Bool) : StmtExprMd := |
| 99 | + let constraints := getAllConstraints ptMap ty |
| 100 | + match constraints with |
| 101 | + | [] => body |
| 102 | + | _ => |
| 103 | + let preds := constraints.map fun (vn, pred) => substId vn varName pred |
| 104 | + let conj := preds.tail.foldl (init := preds.head!) fun acc p => |
| 105 | + ⟨.PrimitiveOp .And [acc, p], body.md⟩ |
| 106 | + if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩ |
| 107 | + else ⟨.PrimitiveOp .And [conj, body], body.md⟩ |
| 108 | + |
| 109 | +/-- Resolve constrained types in all type positions of an expression -/ |
| 110 | +def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd |
| 111 | + | ⟨.LocalVariable n ty (some init), md⟩ => |
| 112 | + ⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩ |
| 113 | + | ⟨.LocalVariable n ty none, md⟩ => |
| 114 | + ⟨.LocalVariable n (resolveType ptMap ty) none, md⟩ |
| 115 | + | ⟨.Forall param body, md⟩ => |
| 116 | + let body' := resolveExpr ptMap body |
| 117 | + let param' := { param with type := resolveType ptMap param.type } |
| 118 | + ⟨.Forall param' (injectQuantifierConstraint ptMap param.type.val param.name body' true), md⟩ |
| 119 | + | ⟨.Exists param body, md⟩ => |
| 120 | + let body' := resolveExpr ptMap body |
| 121 | + let param' := { param with type := resolveType ptMap param.type } |
| 122 | + ⟨.Exists param' (injectQuantifierConstraint ptMap param.type.val param.name body' false), md⟩ |
| 123 | + | ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ |
| 124 | + | ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩ |
| 125 | + | ⟨.PrimitiveOp op args, md⟩ => |
| 126 | + ⟨.PrimitiveOp op (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ |
| 127 | + | ⟨.StaticCall c args, md⟩ => |
| 128 | + ⟨.StaticCall c (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩ |
| 129 | + | ⟨.Block ss sep, md⟩ => |
| 130 | + ⟨.Block (ss.attach.map fun ⟨s, _⟩ => resolveExpr ptMap s) sep, md⟩ |
| 131 | + | ⟨.IfThenElse c t (some el), md⟩ => |
| 132 | + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) (some (resolveExpr ptMap el)), md⟩ |
| 133 | + | ⟨.IfThenElse c t none, md⟩ => |
| 134 | + ⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) none, md⟩ |
| 135 | + | ⟨.While c inv dec body, md⟩ => |
| 136 | + ⟨.While (resolveExpr ptMap c) (inv.attach.map fun ⟨i, _⟩ => resolveExpr ptMap i) |
| 137 | + dec (resolveExpr ptMap body), md⟩ |
| 138 | + | ⟨.Assign ts v, md⟩ => |
| 139 | + ⟨.Assign (ts.attach.map fun ⟨t, _⟩ => resolveExpr ptMap t) (resolveExpr ptMap v), md⟩ |
| 140 | + | ⟨.Return (some v), md⟩ => ⟨.Return (some (resolveExpr ptMap v)), md⟩ |
| 141 | + | ⟨.Return none, md⟩ => ⟨.Return none, md⟩ |
| 142 | + | ⟨.Assert c, md⟩ => ⟨.Assert (resolveExpr ptMap c), md⟩ |
| 143 | + | ⟨.Assume c, md⟩ => ⟨.Assume (resolveExpr ptMap c), md⟩ |
| 144 | + | e => e |
| 145 | +termination_by e => sizeOf e |
| 146 | +decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem) |
| 147 | + |
| 148 | +/-- Insert asserts for constrained-typed variable init and reassignment -/ |
| 149 | +abbrev ElimM := StateM PredVarMap |
| 150 | + |
| 151 | +def elimStmt (ptMap : ConstrainedTypeMap) |
| 152 | + (stmt : StmtExprMd) : ElimM (List StmtExprMd) := do |
| 153 | + let md := stmt.md |
| 154 | + match _h : stmt.val with |
| 155 | + | .LocalVariable name ty init => |
| 156 | + let isPred := isConstrainedType ptMap ty.val |
| 157 | + if isPred then modify fun pv => pv.insert name.text ty.val |
| 158 | + let asserts := if isPred then mkAsserts ptMap ty.val name md else [] |
| 159 | + -- Use witness as default initializer for uninitialized constrained variables |
| 160 | + let init' := match init with |
| 161 | + | none => match ty.val with |
| 162 | + | .UserDefined n => (ptMap.get? n.text).map (·.witness) |
| 163 | + | _ => none |
| 164 | + | some _ => init |
| 165 | + pure ([⟨.LocalVariable name ty init', md⟩] ++ asserts) |
| 166 | + |
| 167 | + -- Single-target only; multi-target assignments are not supported by the Laurel grammar |
| 168 | + | .Assign [target] _ => match target.val with |
| 169 | + | .Identifier name => do |
| 170 | + match (← get).get? name.text with |
| 171 | + | some ty => pure ([stmt] ++ mkAsserts ptMap ty name md) |
| 172 | + | none => pure [stmt] |
| 173 | + | _ => pure [stmt] |
| 174 | + |
| 175 | + | .Block stmts sep => |
| 176 | + let stmtss ← stmts.mapM (elimStmt ptMap) |
| 177 | + pure [⟨.Block stmtss.flatten sep, md⟩] |
| 178 | + |
| 179 | + | .IfThenElse cond thenBr (some elseBr) => |
| 180 | + let thenSs ← elimStmt ptMap thenBr |
| 181 | + let elseSs ← elimStmt ptMap elseBr |
| 182 | + pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩] |
| 183 | + | .IfThenElse cond thenBr none => |
| 184 | + let thenSs ← elimStmt ptMap thenBr |
| 185 | + pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩] |
| 186 | + |
| 187 | + | .While cond inv dec body => |
| 188 | + let bodySs ← elimStmt ptMap body |
| 189 | + pure [⟨.While cond inv dec (wrap bodySs md), md⟩] |
| 190 | + |
| 191 | + | _ => pure [stmt] |
| 192 | +termination_by sizeOf stmt |
| 193 | +decreasing_by |
| 194 | + all_goals simp_wf |
| 195 | + all_goals (try have := WithMetadata.sizeOf_val_lt stmt) |
| 196 | + all_goals (try term_by_mem) |
| 197 | + all_goals omega |
| 198 | + |
| 199 | +def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure := |
| 200 | + -- Add requires for constrained-typed inputs |
| 201 | + let inputRequires := proc.inputs.flatMap fun p => |
| 202 | + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => |
| 203 | + ⟨(substId vn p.name pred).val, p.type.md⟩ |
| 204 | + -- Add ensures for constrained-typed outputs |
| 205 | + let outputEnsures := proc.outputs.flatMap fun p => |
| 206 | + (getAllConstraints ptMap p.type.val).map fun (vn, pred) => |
| 207 | + ⟨(substId vn p.name pred).val, p.type.md⟩ |
| 208 | + -- Transform body: insert asserts for local variable init/reassignment |
| 209 | + let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p => |
| 210 | + if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s |
| 211 | + let body' := match proc.body with |
| 212 | + | .Transparent bodyExpr => |
| 213 | + let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars |
| 214 | + let body := wrap stmts bodyExpr.md |
| 215 | + if outputEnsures.isEmpty then .Transparent body |
| 216 | + else |
| 217 | + -- Wrap expression body in a Return so it translates correctly as a procedure |
| 218 | + let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body |
| 219 | + .Opaque outputEnsures (some retBody) [] |
| 220 | + | .Opaque postconds impl modif => |
| 221 | + let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.md |
| 222 | + .Opaque (postconds ++ outputEnsures) impl' modif |
| 223 | + | .Abstract postconds => .Abstract (postconds ++ outputEnsures) |
| 224 | + | .External => .External |
| 225 | + -- Resolve all constrained types to base types |
| 226 | + let resolve := resolveExpr ptMap |
| 227 | + let resolveBody : Body → Body := fun body => match body with |
| 228 | + | .Transparent b => .Transparent (resolve b) |
| 229 | + | .Opaque ps impl modif => .Opaque (ps.map resolve) (impl.map resolve) (modif.map resolve) |
| 230 | + | .Abstract ps => .Abstract (ps.map resolve) |
| 231 | + | .External => .External |
| 232 | + { proc with |
| 233 | + body := resolveBody body' |
| 234 | + -- TODO: restore isFunctional once function postconditions are supported |
| 235 | + isFunctional := if outputEnsures.isEmpty then proc.isFunctional else false |
| 236 | + inputs := proc.inputs.map fun p => { p with type := resolveType ptMap p.type } |
| 237 | + outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type } |
| 238 | + preconditions := (proc.preconditions ++ inputRequires).map resolve } |
| 239 | + |
| 240 | +/-- Create a synthetic procedure that asserts the witness satisfies all constraints -/ |
| 241 | +private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure := |
| 242 | + let md := ct.constraint.md |
| 243 | + let witnessId : Identifier := mkId "$witness" |
| 244 | + let witnessInit : StmtExprMd := |
| 245 | + ⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩ |
| 246 | + let asserts := (getAllConstraints ptMap (.UserDefined ct.name)).map fun (vn, pred) => |
| 247 | + ⟨.Assert (substId vn witnessId pred), md⟩ |
| 248 | + { name := mkId s!"$witness_{ct.name.text}" |
| 249 | + inputs := [] |
| 250 | + outputs := [] |
| 251 | + body := .Transparent ⟨.Block ([witnessInit] ++ asserts) none, md⟩ |
| 252 | + preconditions := [] |
| 253 | + isFunctional := false |
| 254 | + determinism := .deterministic none |
| 255 | + decreases := none |
| 256 | + md := md } |
| 257 | + |
| 258 | +/-- Eliminate constrained types from a Laurel program. |
| 259 | + The `witness` field is used as the default initializer for uninitialized |
| 260 | + constrained-typed variables, and is validated via synthetic procedures. -/ |
| 261 | +def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program := |
| 262 | + let ptMap := buildConstrainedTypeMap program.types |
| 263 | + if ptMap.isEmpty then program else |
| 264 | + let witnessProcedures := program.types.filterMap fun |
| 265 | + | .Constrained ct => some (mkWitnessProc ptMap ct) |
| 266 | + | _ => none |
| 267 | + { program with |
| 268 | + staticProcedures := program.staticProcedures.map (elimProc ptMap) ++ witnessProcedures |
| 269 | + types := program.types.filter fun | .Constrained _ => false | _ => true } |
| 270 | + |
| 271 | +end Strata.Laurel |
0 commit comments