Skip to content

Commit e2ff8b6

Browse files
committed
Laurel: add constrained type support
A Laurel-to-Laurel elimination pass (ConstrainedTypeElim.lean) that: - Adds requires for constrained-typed inputs - Adds ensures for constrained-typed outputs - Clears isFunctional when adding ensures (function postconditions not yet supported) - Inserts assert for local variable init and reassignment - Uses witness as default initializer for uninitialized constrained variables - Validates witnesses via synthetic procedures - Injects constraints into quantifier bodies (forall → implies, exists → and) - Resolves all constrained type references to base types - Handles capture avoidance in identifier substitution Core's call elimination handles caller-side argument asserts and return value assumes automatically via requires/ensures. Grammar: constrained type syntax Parser: parseConstrainedType + topLevelConstrainedType Test: T09_ConstrainedTypes — 25 test procedures
1 parent 198670b commit e2ff8b6

File tree

5 files changed

+463
-1
lines changed

5 files changed

+463
-1
lines changed
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
/-
2+
Copyright Strata Contributors
3+
4+
SPDX-License-Identifier: Apache-2.0 OR MIT
5+
-/
6+
7+
import Strata.Languages.Laurel.Laurel
8+
import Strata.Languages.Laurel.Resolution
9+
10+
/-!
11+
# Constrained Type Elimination
12+
13+
A Laurel-to-Laurel pass that eliminates constrained types by:
14+
1. Adding `requires` for constrained-typed inputs (Core handles caller asserts and body assumes)
15+
2. Adding `ensures` for constrained-typed outputs (Core handles body checks and caller assumes)
16+
- For `isFunctional` procedures, clears the flag since the Laurel translator does not yet
17+
support function postconditions. TODO: restore `isFunctional` once supported.
18+
3. Inserting `assert` for local variable init and reassignment of constrained-typed variables
19+
4. Using the witness as default initializer for uninitialized constrained-typed variables
20+
5. Adding a synthetic witness-validation procedure per constrained type
21+
6. Injecting constraints into quantifier bodies (`forall` → `implies`, `exists` → `and`)
22+
7. Resolving all constrained type references to their base types
23+
-/
24+
25+
namespace Strata.Laurel
26+
27+
open Strata
28+
29+
abbrev ConstrainedTypeMap := Std.HashMap String ConstrainedType
30+
/-- Map from variable name to its original constrained type -/
31+
abbrev PredVarMap := Std.HashMap String HighType
32+
33+
def buildConstrainedTypeMap (types : List TypeDefinition) : ConstrainedTypeMap :=
34+
types.foldl (init := {}) fun m td =>
35+
match td with | .Constrained ct => m.insert ct.name.text ct | _ => m
36+
37+
partial def resolveBaseType (ptMap : ConstrainedTypeMap) (ty : HighType) : HighType :=
38+
match ty with
39+
| .UserDefined name => match ptMap.get? name.text with
40+
| some ct => resolveBaseType ptMap ct.base.val | none => ty
41+
| .Applied ctor args =>
42+
.Applied ctor (args.map fun a => ⟨resolveBaseType ptMap a.val, a.md⟩)
43+
| _ => ty
44+
45+
def resolveType (ptMap : ConstrainedTypeMap) (ty : HighTypeMd) : HighTypeMd :=
46+
⟨resolveBaseType ptMap ty.val, ty.md⟩
47+
48+
def isConstrainedType (ptMap : ConstrainedTypeMap) (ty : HighType) : Bool :=
49+
match ty with | .UserDefined name => ptMap.contains name.text | _ => false
50+
51+
/-- All predicates for a type transitively (e.g. posnat → [x > 0, x >= 0]) -/
52+
partial def getAllConstraints (ptMap : ConstrainedTypeMap) (ty : HighType)
53+
: List (Identifier × StmtExprMd) :=
54+
match ty with
55+
| .UserDefined name => match ptMap.get? name.text with
56+
| some ct => (ct.valueName, ct.constraint) :: getAllConstraints ptMap ct.base.val
57+
| none => []
58+
| _ => []
59+
60+
/-- Substitute `Identifier old` with `Identifier new` in a constraint expression -/
61+
partial def substId (old new : Identifier) : StmtExprMd → StmtExprMd
62+
| ⟨.Identifier n, md⟩ => ⟨if n == old then .Identifier new else .Identifier n, md⟩
63+
| ⟨.PrimitiveOp op args, md⟩ =>
64+
⟨.PrimitiveOp op (args.map fun a => substId old new a), md⟩
65+
| ⟨.StaticCall c args, md⟩ =>
66+
⟨.StaticCall c (args.map fun a => substId old new a), md⟩
67+
| ⟨.IfThenElse c t (some el), md⟩ =>
68+
⟨.IfThenElse (substId old new c) (substId old new t) (some (substId old new el)), md⟩
69+
| ⟨.IfThenElse c t none, md⟩ =>
70+
⟨.IfThenElse (substId old new c) (substId old new t) none, md⟩
71+
| ⟨.Block ss sep, md⟩ =>
72+
⟨.Block (ss.map fun s => substId old new s) sep, md⟩
73+
| ⟨.Forall param body, md⟩ =>
74+
if param.name == old then ⟨.Forall param body, md⟩
75+
else if param.name == new then
76+
let fresh : Identifier := mkId (param.name.text ++ "$")
77+
⟨.Forall { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩
78+
else ⟨.Forall param (substId old new body), md⟩
79+
| ⟨.Exists param body, md⟩ =>
80+
if param.name == old then ⟨.Exists param body, md⟩
81+
else if param.name == new then
82+
let fresh : Identifier := mkId (param.name.text ++ "$")
83+
⟨.Exists { param with name := fresh } (substId old new (substId param.name fresh body)), md⟩
84+
else ⟨.Exists param (substId old new body), md⟩
85+
| e => e
86+
87+
def mkAsserts (ptMap : ConstrainedTypeMap) (ty : HighType) (varName : Identifier)
88+
(md : Imperative.MetaData Core.Expression) : List StmtExprMd :=
89+
(getAllConstraints ptMap ty).map fun (valueName, pred) =>
90+
⟨.Assert (substId valueName varName pred), md⟩
91+
92+
private def wrap (stmts : List StmtExprMd) (md : Imperative.MetaData Core.Expression)
93+
: StmtExprMd :=
94+
match stmts with | [s] => s | ss => ⟨.Block ss none, md⟩
95+
96+
/-- Inject constraints into a quantifier body for a constrained type -/
97+
private def injectQuantifierConstraint (ptMap : ConstrainedTypeMap) (ty : HighType)
98+
(varName : Identifier) (body : StmtExprMd) (isForall : Bool) : StmtExprMd :=
99+
let constraints := getAllConstraints ptMap ty
100+
match constraints with
101+
| [] => body
102+
| _ =>
103+
let preds := constraints.map fun (vn, pred) => substId vn varName pred
104+
let conj := preds.tail.foldl (init := preds.head!) fun acc p =>
105+
⟨.PrimitiveOp .And [acc, p], body.md⟩
106+
if isForall then ⟨.PrimitiveOp .Implies [conj, body], body.md⟩
107+
else ⟨.PrimitiveOp .And [conj, body], body.md⟩
108+
109+
/-- Resolve constrained types in all type positions of an expression -/
110+
def resolveExpr (ptMap : ConstrainedTypeMap) : StmtExprMd → StmtExprMd
111+
| ⟨.LocalVariable n ty (some init), md⟩ =>
112+
⟨.LocalVariable n (resolveType ptMap ty) (some (resolveExpr ptMap init)), md⟩
113+
| ⟨.LocalVariable n ty none, md⟩ =>
114+
⟨.LocalVariable n (resolveType ptMap ty) none, md⟩
115+
| ⟨.Forall param body, md⟩ =>
116+
let body' := resolveExpr ptMap body
117+
let param' := { param with type := resolveType ptMap param.type }
118+
⟨.Forall param' (injectQuantifierConstraint ptMap param.type.val param.name body' true), md⟩
119+
| ⟨.Exists param body, md⟩ =>
120+
let body' := resolveExpr ptMap body
121+
let param' := { param with type := resolveType ptMap param.type }
122+
⟨.Exists param' (injectQuantifierConstraint ptMap param.type.val param.name body' false), md⟩
123+
| ⟨.AsType t ty, md⟩ => ⟨.AsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩
124+
| ⟨.IsType t ty, md⟩ => ⟨.IsType (resolveExpr ptMap t) (resolveType ptMap ty), md⟩
125+
| ⟨.PrimitiveOp op args, md⟩ =>
126+
⟨.PrimitiveOp op (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩
127+
| ⟨.StaticCall c args, md⟩ =>
128+
⟨.StaticCall c (args.attach.map fun ⟨a, _⟩ => resolveExpr ptMap a), md⟩
129+
| ⟨.Block ss sep, md⟩ =>
130+
⟨.Block (ss.attach.map fun ⟨s, _⟩ => resolveExpr ptMap s) sep, md⟩
131+
| ⟨.IfThenElse c t (some el), md⟩ =>
132+
⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) (some (resolveExpr ptMap el)), md⟩
133+
| ⟨.IfThenElse c t none, md⟩ =>
134+
⟨.IfThenElse (resolveExpr ptMap c) (resolveExpr ptMap t) none, md⟩
135+
| ⟨.While c inv dec body, md⟩ =>
136+
⟨.While (resolveExpr ptMap c) (inv.attach.map fun ⟨i, _⟩ => resolveExpr ptMap i)
137+
dec (resolveExpr ptMap body), md⟩
138+
| ⟨.Assign ts v, md⟩ =>
139+
⟨.Assign (ts.attach.map fun ⟨t, _⟩ => resolveExpr ptMap t) (resolveExpr ptMap v), md⟩
140+
| ⟨.Return (some v), md⟩ => ⟨.Return (some (resolveExpr ptMap v)), md⟩
141+
| ⟨.Return none, md⟩ => ⟨.Return none, md⟩
142+
| ⟨.Assert c, md⟩ => ⟨.Assert (resolveExpr ptMap c), md⟩
143+
| ⟨.Assume c, md⟩ => ⟨.Assume (resolveExpr ptMap c), md⟩
144+
| e => e
145+
termination_by e => sizeOf e
146+
decreasing_by all_goals (have := WithMetadata.sizeOf_val_lt ‹_›; term_by_mem)
147+
148+
/-- Insert asserts for constrained-typed variable init and reassignment -/
149+
abbrev ElimM := StateM PredVarMap
150+
151+
def elimStmt (ptMap : ConstrainedTypeMap)
152+
(stmt : StmtExprMd) : ElimM (List StmtExprMd) := do
153+
let md := stmt.md
154+
match _h : stmt.val with
155+
| .LocalVariable name ty init =>
156+
let isPred := isConstrainedType ptMap ty.val
157+
if isPred then modify fun pv => pv.insert name.text ty.val
158+
let asserts := if isPred then mkAsserts ptMap ty.val name md else []
159+
-- Use witness as default initializer for uninitialized constrained variables
160+
let init' := match init with
161+
| none => match ty.val with
162+
| .UserDefined n => (ptMap.get? n.text).map (·.witness)
163+
| _ => none
164+
| some _ => init
165+
pure ([⟨.LocalVariable name ty init', md⟩] ++ asserts)
166+
167+
-- Single-target only; multi-target assignments are not supported by the Laurel grammar
168+
| .Assign [target] _ => match target.val with
169+
| .Identifier name => do
170+
match (← get).get? name.text with
171+
| some ty => pure ([stmt] ++ mkAsserts ptMap ty name md)
172+
| none => pure [stmt]
173+
| _ => pure [stmt]
174+
175+
| .Block stmts sep =>
176+
let stmtss ← stmts.mapM (elimStmt ptMap)
177+
pure [⟨.Block stmtss.flatten sep, md⟩]
178+
179+
| .IfThenElse cond thenBr (some elseBr) =>
180+
let thenSs ← elimStmt ptMap thenBr
181+
let elseSs ← elimStmt ptMap elseBr
182+
pure [⟨.IfThenElse cond (wrap thenSs md) (some (wrap elseSs md)), md⟩]
183+
| .IfThenElse cond thenBr none =>
184+
let thenSs ← elimStmt ptMap thenBr
185+
pure [⟨.IfThenElse cond (wrap thenSs md) none, md⟩]
186+
187+
| .While cond inv dec body =>
188+
let bodySs ← elimStmt ptMap body
189+
pure [⟨.While cond inv dec (wrap bodySs md), md⟩]
190+
191+
| _ => pure [stmt]
192+
termination_by sizeOf stmt
193+
decreasing_by
194+
all_goals simp_wf
195+
all_goals (try have := WithMetadata.sizeOf_val_lt stmt)
196+
all_goals (try term_by_mem)
197+
all_goals omega
198+
199+
def elimProc (ptMap : ConstrainedTypeMap) (proc : Procedure) : Procedure :=
200+
-- Add requires for constrained-typed inputs
201+
let inputRequires := proc.inputs.flatMap fun p =>
202+
(getAllConstraints ptMap p.type.val).map fun (vn, pred) =>
203+
⟨(substId vn p.name pred).val, p.type.md⟩
204+
-- Add ensures for constrained-typed outputs
205+
let outputEnsures := proc.outputs.flatMap fun p =>
206+
(getAllConstraints ptMap p.type.val).map fun (vn, pred) =>
207+
⟨(substId vn p.name pred).val, p.type.md⟩
208+
-- Transform body: insert asserts for local variable init/reassignment
209+
let initVars : PredVarMap := proc.inputs.foldl (init := {}) fun s p =>
210+
if isConstrainedType ptMap p.type.val then s.insert p.name.text p.type.val else s
211+
let body' := match proc.body with
212+
| .Transparent bodyExpr =>
213+
let (stmts, _) := (elimStmt ptMap bodyExpr).run initVars
214+
let body := wrap stmts bodyExpr.md
215+
if outputEnsures.isEmpty then .Transparent body
216+
else
217+
-- Wrap expression body in a Return so it translates correctly as a procedure
218+
let retBody := if proc.isFunctional then ⟨.Return (some body), bodyExpr.md⟩ else body
219+
.Opaque outputEnsures (some retBody) []
220+
| .Opaque postconds impl modif =>
221+
let impl' := impl.map fun b => wrap ((elimStmt ptMap b).run initVars).1 b.md
222+
.Opaque (postconds ++ outputEnsures) impl' modif
223+
| .Abstract postconds => .Abstract (postconds ++ outputEnsures)
224+
| .External => .External
225+
-- Resolve all constrained types to base types
226+
let resolve := resolveExpr ptMap
227+
let resolveBody : Body → Body := fun body => match body with
228+
| .Transparent b => .Transparent (resolve b)
229+
| .Opaque ps impl modif => .Opaque (ps.map resolve) (impl.map resolve) (modif.map resolve)
230+
| .Abstract ps => .Abstract (ps.map resolve)
231+
| .External => .External
232+
{ proc with
233+
body := resolveBody body'
234+
-- TODO: restore isFunctional once function postconditions are supported
235+
isFunctional := if outputEnsures.isEmpty then proc.isFunctional else false
236+
inputs := proc.inputs.map fun p => { p with type := resolveType ptMap p.type }
237+
outputs := proc.outputs.map fun p => { p with type := resolveType ptMap p.type }
238+
preconditions := (proc.preconditions ++ inputRequires).map resolve }
239+
240+
/-- Create a synthetic procedure that asserts the witness satisfies all constraints -/
241+
private def mkWitnessProc (ptMap : ConstrainedTypeMap) (ct : ConstrainedType) : Procedure :=
242+
let md := ct.constraint.md
243+
let witnessId : Identifier := mkId "$witness"
244+
let witnessInit : StmtExprMd :=
245+
⟨.LocalVariable witnessId (resolveType ptMap ct.base) (some ct.witness), md⟩
246+
let asserts := (getAllConstraints ptMap (.UserDefined ct.name)).map fun (vn, pred) =>
247+
⟨.Assert (substId vn witnessId pred), md⟩
248+
{ name := mkId s!"$witness_{ct.name.text}"
249+
inputs := []
250+
outputs := []
251+
body := .Transparent ⟨.Block ([witnessInit] ++ asserts) none, md⟩
252+
preconditions := []
253+
isFunctional := false
254+
determinism := .deterministic none
255+
decreases := none
256+
md := md }
257+
258+
/-- Eliminate constrained types from a Laurel program.
259+
The `witness` field is used as the default initializer for uninitialized
260+
constrained-typed variables, and is validated via synthetic procedures. -/
261+
def constrainedTypeElim (_model : SemanticModel) (program : Program) : Program :=
262+
let ptMap := buildConstrainedTypeMap program.types
263+
if ptMap.isEmpty then program else
264+
let witnessProcedures := program.types.filterMap fun
265+
| .Constrained ct => some (mkWitnessProc ptMap ct)
266+
| _ => none
267+
{ program with
268+
staticProcedures := program.staticProcedures.map (elimProc ptMap) ++ witnessProcedures
269+
types := program.types.filter fun | .Constrained _ => false | _ => true }
270+
271+
end Strata.Laurel

Strata/Languages/Laurel/Grammar/ConcreteToAbstractTreeTranslator.lean

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,20 @@ def parseDatatype (arg : Arg) : TransM TypeDefinition := do
495495
| _, _ =>
496496
TransM.error s!"parseDatatype expects datatype, got {repr op.name}"
497497

498+
def parseConstrainedType (arg : Arg) : TransM ConstrainedType := do
499+
let .op op := arg
500+
| TransM.error s!"parseConstrainedType expects operation"
501+
match op.name, op.args with
502+
| q`Laurel.constrainedType, #[nameArg, valueNameArg, baseArg, constraintArg, witnessArg] =>
503+
let name ← translateIdent nameArg
504+
let valueName ← translateIdent valueNameArg
505+
let base ← translateHighType baseArg
506+
let constraint ← translateStmtExpr constraintArg
507+
let witness ← translateStmtExpr witnessArg
508+
return { name, base, valueName, constraint, witness }
509+
| _, _ =>
510+
TransM.error s!"parseConstrainedType expects constrainedType, got {repr op.name}"
511+
498512
def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinition) := do
499513
let .op op := arg
500514
| TransM.error s!"parseTopLevel expects operation"
@@ -509,8 +523,11 @@ def parseTopLevel (arg : Arg) : TransM (Option Procedure × Option TypeDefinitio
509523
| q`Laurel.topLevelDatatype, #[datatypeArg] =>
510524
let typeDef ← parseDatatype datatypeArg
511525
return (none, some typeDef)
526+
| q`Laurel.topLevelConstrainedType, #[ctArg] =>
527+
let ct ← parseConstrainedType ctArg
528+
return (none, some (.Constrained ct))
512529
| _, _ =>
513-
TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, or topLevelDatatype, got {repr op.name}"
530+
TransM.error s!"parseTopLevel expects topLevelProcedure, topLevelComposite, topLevelDatatype, or topLevelConstrainedType, got {repr op.name}"
514531

515532
/--
516533
Translate concrete Laurel syntax into abstract Laurel syntax

Strata/Languages/Laurel/Grammar/LaurelGrammar.st

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,10 @@ op topLevelComposite(composite: Composite): TopLevel => composite;
166166
op topLevelProcedure(procedure: Procedure): TopLevel => procedure;
167167
op topLevelDatatype(datatype: Datatype): TopLevel => datatype;
168168

169+
category ConstrainedType;
170+
op constrainedType (name: Ident, valueName: Ident, base: LaurelType,
171+
constraint: StmtExpr, witness: StmtExpr): ConstrainedType
172+
=> "constrained " name " = " valueName ": " base " where " constraint:0 " witness " witness:0;
173+
op topLevelConstrainedType(ct: ConstrainedType): TopLevel => ct;
174+
169175
op program (items: Seq TopLevel): Command => items;

Strata/Languages/Laurel/LaurelToCoreTranslator.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import Strata.DL.Imperative.Stmt
2020
import Strata.DL.Imperative.MetaData
2121
import Strata.DL.Lambda.LExpr
2222
import Strata.Languages.Laurel.LaurelFormat
23+
import Strata.Languages.Laurel.ConstrainedTypeElim
2324
import Strata.Util.Tactics
2425

2526
open Core (VCResult VCResults VerifyOptions)
@@ -583,6 +584,11 @@ def translate (program : Program) : Except (Array DiagnosticModel) (Core.Program
583584
let (program, model) := (result.program, result.model)
584585
_resolutionDiags := _resolutionDiags ++ result.errors
585586

587+
let program := constrainedTypeElim model program
588+
let result := resolve program (some model)
589+
let (program, model) := (result.program, result.model)
590+
_resolutionDiags := _resolutionDiags ++ result.errors
591+
586592
-- Procedures marked isFunctional are translated to Core functions; all others become Core procedures.
587593
-- External procedures are completely ignored (not translated to Core).
588594
let nonExternal := program.staticProcedures.filter (fun p => !p.body.isExternal)

0 commit comments

Comments
 (0)