Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Strata/DL/Lambda/IntBoolFactory.lean
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,14 @@ def intSafeModFunc : WFLFunc T :=
binaryOp (InValTy := Int) "Int.SafeMod" (· % ·) (· != 0)
(preconditions := [yNeZeroPrecond])

def intSafeDivTFunc : WFLFunc T :=
binaryOp (InValTy := Int) "Int.SafeDivT" Int.tdiv (· != 0)
(preconditions := [yNeZeroPrecond])

def intSafeModTFunc : WFLFunc T :=
binaryOp (InValTy := Int) "Int.SafeModT" Int.tmod (· != 0)
(preconditions := [yNeZeroPrecond])

end

def IntBoolFactory [Inhabited T.mono.base.Metadata] : @Factory T := (#[
Expand All @@ -304,7 +312,9 @@ def IntBoolFactory [Inhabited T.mono.base.Metadata] : @Factory T := (#[
intModFunc,
intSafeModFunc,
intDivTFunc,
intSafeDivTFunc,
intModTFunc,
intSafeModTFunc,
intNegFunc,
intLtFunc,
intLeFunc,
Expand Down
4 changes: 4 additions & 0 deletions Strata/Languages/Core/DDMTransform/ASTtoCST.lean
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ def handleBinaryOps {M} [Inhabited M] (name : String)
| "Int.SafeDiv" => pure (.safediv_expr default ty arg1 arg2)
| "Int.Mod" => pure (.mod_expr default ty arg1 arg2)
| "Int.SafeMod" => pure (.safemod_expr default ty arg1 arg2)
| "Int.DivT" => pure (.divt_expr default ty arg1 arg2)
| "Int.SafeDivT" => pure (.safedivt_expr default ty arg1 arg2)
| "Int.ModT" => pure (.modt_expr default ty arg1 arg2)
| "Int.SafeModT" => pure (.safemodt_expr default ty arg1 arg2)
| "Int.Le" | "Real.Le" => pure (.le default ty arg1 arg2)
| "Int.Lt" | "Real.Lt" => pure (.lt default ty arg1 arg2)
| "Int.Ge" | "Real.Ge" => pure (.ge default ty arg1 arg2)
Expand Down
4 changes: 4 additions & 0 deletions Strata/Languages/Core/DDMTransform/Grammar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ fn div_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " div "
fn mod_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " mod " b;
fn safediv_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " / " b;
fn safemod_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " % " b;
fn divt_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " divt " b;
fn modt_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " modt " b;
fn safedivt_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " /t " b;
fn safemodt_expr (tp : Type, a : tp, b : tp) : tp => @[prec(30), leftassoc] a " %t " b;

fn bvnot (tp : Type, a : tp) : tp => "~" a;
fn bvand (tp : Type, a : tp, b : tp) : tp => @[prec(20), leftassoc] a " & " b;
Expand Down
8 changes: 8 additions & 0 deletions Strata/Languages/Core/DDMTransform/Translate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ def translateFn (ty? : Option LMonoTy) (q : QualifiedIdent) : TransM Core.Expres
| .some .int, q`Core.mod_expr => return Core.intModOp
| .some .int, q`Core.safediv_expr => return Core.intSafeDivOp
| .some .int, q`Core.safemod_expr => return Core.intSafeModOp
| .some .int, q`Core.divt_expr => return Core.intDivTOp
| .some .int, q`Core.modt_expr => return Core.intModTOp
| .some .int, q`Core.safedivt_expr => return Core.intSafeDivTOp
| .some .int, q`Core.safemodt_expr => return Core.intSafeModTOp
| .some .int, q`Core.neg_expr => return Core.intNegOp

| .some .real, q`Core.le => return Core.realLeOp
Expand Down Expand Up @@ -888,6 +892,10 @@ partial def translateExpr (p : Program) (bindings : TransBindings) (arg : Arg) :
| q`Core.safediv_expr
| q`Core.mod_expr
| q`Core.safemod_expr
| q`Core.divt_expr
| q`Core.modt_expr
| q`Core.safedivt_expr
| q`Core.safemodt_expr
| q`Core.bvand
| q`Core.bvor
| q`Core.bvxor
Expand Down
4 changes: 4 additions & 0 deletions Strata/Languages/Core/Factory.lean
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def WFFactory : Lambda.WFLFactory CoreLParams :=
intModFunc (T := CoreLParams),
intSafeModFunc (T := CoreLParams),
intDivTFunc (T := CoreLParams),
intSafeDivTFunc (T := CoreLParams),
intModTFunc (T := CoreLParams),
intSafeModTFunc (T := CoreLParams),
intNegFunc (T := CoreLParams),

intLtFunc (T := CoreLParams),
Expand Down Expand Up @@ -426,7 +428,9 @@ def intSafeDivOp : Expression.Expr := (@intSafeDivFunc CoreLParams _ _).opExpr
def intModOp : Expression.Expr := (@intModFunc CoreLParams _).opExpr
def intSafeModOp : Expression.Expr := (@intSafeModFunc CoreLParams _ _).opExpr
def intDivTOp : Expression.Expr := (@intDivTFunc CoreLParams _).opExpr
def intSafeDivTOp : Expression.Expr := (@intSafeDivTFunc CoreLParams _ _).opExpr
def intModTOp : Expression.Expr := (@intModTFunc CoreLParams _).opExpr
def intSafeModTOp : Expression.Expr := (@intSafeModTFunc CoreLParams _ _).opExpr
def intNegOp : Expression.Expr := (@intNegFunc CoreLParams _).opExpr
def intLtOp : Expression.Expr := (@intLtFunc CoreLParams _).opExpr
def intLeOp : Expression.Expr := (@intLeFunc CoreLParams _).opExpr
Expand Down
4 changes: 2 additions & 2 deletions Strata/Languages/Core/SMTEncoder.lean
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte
| "Int.Mod" => .ok (.app Op.mod, .int , ctx)
| "Int.SafeMod" => .ok (.app Op.mod, .int , ctx)
-- Truncating division: tdiv(a,b) = let q = ediv(abs(a), abs(b)) in ite(a*b >= 0, q, -q)
| "Int.DivT" =>
| "Int.DivT" | "Int.SafeDivT" =>
let divTApp := fun (args : List Term) (retTy : TermType) =>
match args with
| [a, b] =>
Expand All @@ -398,7 +398,7 @@ partial def toSMTOp (E : Env) (fn : CoreIdent) (fnty : LMonoTy) (ctx : SMT.Conte
.ok (divTApp, .int, ctx)
-- Truncating modulo: tmod(a,b) = a - b * tdiv(a,b)
-- tdiv(a,b) = let q = ediv(abs(a), abs(b)) in ite(a*b >= 0, q, -q)
| "Int.ModT" =>
| "Int.ModT" | "Int.SafeModT" =>
let modTApp := fun (args : List Term) (retTy : TermType) =>
match args with
| [a, b] =>
Expand Down
10 changes: 5 additions & 5 deletions Strata/Languages/Laurel/LaurelToCoreTranslator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Strata.Languages.Laurel.LaurelFormat
import Strata.Util.Tactics

open Core (VCResult VCResults VerifyOptions)
open Core (intAddOp intSubOp intMulOp intDivOp intModOp intDivTOp intModTOp intNegOp intLtOp intLeOp intGtOp intGeOp boolAndOp boolOrOp boolNotOp boolImpliesOp strConcatOp)
open Core (intAddOp intSubOp intMulOp intSafeDivOp intSafeModOp intSafeDivTOp intSafeModTOp intNegOp intLtOp intLeOp intGtOp intGeOp boolAndOp boolOrOp boolNotOp boolImpliesOp strConcatOp)

namespace Strata.Laurel

Expand Down Expand Up @@ -161,10 +161,10 @@ def translateExpr (env : TypeEnv) (expr : StmtExprMd)
| .Add => return binOp intAddOp
| .Sub => return binOp intSubOp
| .Mul => return binOp intMulOp
| .Div => return binOp intDivOp
| .Mod => return binOp intModOp
| .DivT => return binOp intDivTOp
| .ModT => return binOp intModTOp
| .Div => return binOp intSafeDivOp
| .Mod => return binOp intSafeModOp
| .DivT => return binOp intSafeDivTOp
| .ModT => return binOp intSafeModTOp
| .Lt => return binOp intLtOp
| .Leq => return binOp intLeOp
| .Gt => return binOp intGtOp
Expand Down
33 changes: 18 additions & 15 deletions Strata/Languages/Python/PythonToLaurel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,13 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let target := name.val
let valueExpr ← translateExpr ctx value
let targetExpr := mkStmtExprMd (StmtExpr.Identifier target)
let assignStmt := mkStmtExprMd (StmtExpr.Assign [targetExpr] valueExpr)
let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] valueExpr) md
return (ctx, assignStmt)
| .Attribute _ obj attr _ =>
-- Field assignment: obj.field = expr or self.field = expr
let valueExpr ← translateExpr ctx value
let targetExpr ← translateExpr ctx targets.val[0]! -- This will handle self.field via translateExpr
let assignStmt := mkStmtExprMd (StmtExpr.Assign [targetExpr] valueExpr)
let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [targetExpr] valueExpr) md
return (ctx, assignStmt)
| _ => throw (.unsupportedConstruct "Only simple variable or field assignment supported" (toString (repr s)))

Expand All @@ -501,7 +501,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let fieldAccess := mkStmtExprMd (StmtExpr.FieldSelect
(mkStmtExprMd (StmtExpr.Identifier "self"))
attr.val)
let assignStmt := mkStmtExprMd (StmtExpr.Assign [fieldAccess] valueExpr)
let assignStmt := mkStmtExprMdWithLoc (StmtExpr.Assign [fieldAccess] valueExpr) md
return (ctx, assignStmt)
else
throw (.unsupportedConstruct "Only self.field assignments supported in methods" (toString (repr s)))
Expand Down Expand Up @@ -535,7 +535,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let translatedArgs ← args.val.toList.mapM (translateExpr ctx)

let newExpr := mkStmtExprMd (StmtExpr.New funcName)
let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType (some newExpr))
let declStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable varName varType (some newExpr)) md

let initCall := mkStmtExprMd (StmtExpr.InstanceCall
(mkStmtExprMd (StmtExpr.Identifier varName))
Expand All @@ -547,16 +547,18 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
else
-- Regular call, not a constructor
let initVal ← translateCall ctx f args.val.toList
let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType (some initVal))
let initVal := { initVal with md := md }
let declStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable varName varType (some initVal)) md
return (newCtx, declStmt)
| some initExpr => do
-- Regular annotated assignment with initializer
let initVal ← translateExpr newCtx initExpr
let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType (some initVal))
let initVal := { initVal with md := md }
let declStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable varName varType (some initVal)) md
return (newCtx, declStmt)
| none =>
-- Declaration without initializer
let declStmt := mkStmtExprMd (StmtExpr.LocalVariable varName varType none)
let declStmt := mkStmtExprMdWithLoc (StmtExpr.LocalVariable varName varType none) md
return (newCtx, declStmt)

-- If statement
Expand All @@ -583,13 +585,13 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
else do
let (_, elseStmts) ← translateStmtList bodyCtx orelse.val.toList
.ok (some (mkStmtExprMd (StmtExpr.Block elseStmts none)))
let ifStmt := mkStmtExprMd (StmtExpr.IfThenElse finalCondExpr bodyBlock elseBlock)
let ifStmt := mkStmtExprMdWithLoc (StmtExpr.IfThenElse finalCondExpr bodyBlock elseBlock) md

-- Wrap in block if we hoisted condition
let result := if condStmts.isEmpty then
ifStmt
else
mkStmtExprMd (StmtExpr.Block (condStmts ++ [ifStmt]) none)
mkStmtExprMdWithLoc (StmtExpr.Block (condStmts ++ [ifStmt]) none) md

return (bodyCtx, result)

Expand All @@ -610,13 +612,13 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang

let (loopCtx, bodyStmts) ← translateStmtList condCtx body.val.toList
let bodyBlock := mkStmtExprMd (StmtExpr.Block bodyStmts none)
let whileStmt := mkStmtExprMd (StmtExpr.While finalCondExpr [] none bodyBlock)
let whileStmt := mkStmtExprMdWithLoc (StmtExpr.While finalCondExpr [] none bodyBlock) md

-- Wrap in block if we hoisted condition
let result := if condStmts.isEmpty then
whileStmt
else
mkStmtExprMd (StmtExpr.Block (condStmts ++ [whileStmt]) none)
mkStmtExprMdWithLoc (StmtExpr.Block (condStmts ++ [whileStmt]) none) md

return (loopCtx, result)

Expand All @@ -627,7 +629,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let e ← translateExpr ctx expr
.ok (some e)
| none => .ok none
let retStmt := mkStmtExprMd (StmtExpr.Return retVal)
let retStmt := mkStmtExprMdWithLoc (StmtExpr.Return retVal) md
return (ctx, retStmt)

-- Assert statement
Expand All @@ -650,13 +652,14 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let result := if condStmts.isEmpty then
assertStmt
else
mkStmtExprMd (StmtExpr.Block (condStmts ++ [assertStmt]) none)
mkStmtExprMdWithLoc (StmtExpr.Block (condStmts ++ [assertStmt]) none) md

return (condCtx, result)

-- Expression statement (e.g., function call)
| .Expr _ value => do
let expr ← translateExpr ctx value
let expr := { expr with md := md }
return (ctx, expr)

| .Import _ _ | .ImportFrom _ _ _ _ => return (ctx, mkStmtExprMd .Hole)
Expand Down Expand Up @@ -690,7 +693,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
let handlerBlock := mkStmtExprMd (StmtExpr.Block handlerStmts (some handlerLabel))

-- Wrap in try block
let tryBlock := mkStmtExprMd (StmtExpr.Block (bodyStmtsWithChecks ++ [handlerBlock]) (some tryLabel))
let tryBlock := mkStmtExprMdWithLoc (StmtExpr.Block (bodyStmtsWithChecks ++ [handlerBlock]) (some tryLabel)) md
return (bodyCtx, tryBlock)

| .Raise _ _ _ => return (ctx, mkStmtExprMd .Hole)
Expand All @@ -717,7 +720,7 @@ partial def translateStmt (ctx : TranslationContext) (s : Python.stmt SourceRang
-- Create: { target = havoc; body_statements }
-- This abstracts: execute body once with arbitrary target value
let targetDecl := mkStmtExprMd (StmtExpr.LocalVariable targetName targetType (some (mkStmtExprMd .Hole)))
let loopBlock := mkStmtExprMd (StmtExpr.Block ([targetDecl] ++ bodyStmts) none)
let loopBlock := mkStmtExprMdWithLoc (StmtExpr.Block ([targetDecl] ++ bodyStmts) none) md

return (finalCtx, loopBlock)

Expand Down
4 changes: 2 additions & 2 deletions StrataTest/DL/Lambda/TestGenTests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Lambda.LTy.forAll [] (Lambda.LMonoTy.tcons "arrow" [Lambda.LMonoTy.tcons "bool"
in context
{ types := [[]], aliases := [] }
in factory
#[Int.Add, Int.Sub, Int.Mul, Int.Div, Int.SafeDiv, Int.Mod, Int.SafeMod, Int.DivT, Int.ModT, Int.Neg, Int.Lt, Int.Le, Int.Gt, Int.Ge, Bool.And, Bool.Or, Bool.Implies, Bool.Equiv, Bool.Not]
#[Int.Add, Int.Sub, Int.Mul, Int.Div, Int.SafeDiv, Int.Mod, Int.SafeMod, Int.DivT, Int.SafeDivT, Int.ModT, Int.SafeModT, Int.Neg, Int.Lt, Int.Le, Int.Gt, Int.Ge, Bool.And, Bool.Or, Bool.Implies, Bool.Equiv, Bool.Not]
-/
#guard_msgs in
#eval Strata.Util.withStdGenSeed 0 do
Expand All @@ -167,7 +167,7 @@ Lambda.LTy.forAll [] (Lambda.LMonoTy.tcons "arrow" [Lambda.LMonoTy.tcons "bool"
in context
{ types := [[]], aliases := [] }
in factory
#[Int.Add, Int.Sub, Int.Mul, Int.Div, Int.SafeDiv, Int.Mod, Int.SafeMod, Int.DivT, Int.ModT, Int.Neg, Int.Lt, Int.Le, Int.Gt, Int.Ge, Bool.And, Bool.Or, Bool.Implies, Bool.Equiv, Bool.Not]
#[Int.Add, Int.Sub, Int.Mul, Int.Div, Int.SafeDiv, Int.Mod, Int.SafeMod, Int.DivT, Int.SafeDivT, Int.ModT, Int.SafeModT, Int.Neg, Int.Lt, Int.Le, Int.Gt, Int.Ge, Bool.And, Bool.Or, Bool.Implies, Bool.Equiv, Bool.Not]
-/
#guard_msgs(info, drop error) in
#eval Strata.Util.withStdGenSeed 0 do
Expand Down
4 changes: 4 additions & 0 deletions StrataTest/Languages/Core/ProcedureEvalTests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ func Int.Mod : ((x : int) (y : int)) → int;
func Int.SafeMod : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.DivT : ((x : int) (y : int)) → int;
func Int.SafeDivT : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.ModT : ((x : int) (y : int)) → int;
func Int.SafeModT : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.Neg : ((x : int)) → int;
func Int.Lt : ((x : int) (y : int)) → bool;
func Int.Le : ((x : int) (y : int)) → bool;
Expand Down
4 changes: 4 additions & 0 deletions StrataTest/Languages/Core/ProgramTypeTests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ info: ok: [(type Foo (a0 : Type, a1 : Type);
func Int.SafeMod : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.DivT : ((x : int) (y : int)) → int;
func Int.SafeDivT : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.ModT : ((x : int) (y : int)) → int;
func Int.SafeModT : ((x : int) (y : int)) → int
requires ((~Bool.Not : (arrow bool bool)) ((y : int) == #0));
func Int.Neg : ((x : int)) → int;
func Int.Lt : ((x : int) (y : int)) → bool;
func Int.Le : ((x : int) (y : int)) → bool;
Expand Down
54 changes: 54 additions & 0 deletions StrataTest/Languages/Laurel/DivisionByZeroCheckTest.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/-
Copyright Strata Contributors

SPDX-License-Identifier: Apache-2.0 OR MIT
-/

import StrataTest.Util.TestDiagnostics
import StrataTest.Languages.Laurel.TestExamples

open StrataTest.Util

namespace Strata.Laurel

/-! ## End-to-end test: safe division (no errors) and unsafe division (error)

Division and modulo in Laurel translate to Core's safe operators, which have
built-in preconditions (divisor ≠ 0). The PrecondElim transform automatically
generates verification conditions for these preconditions.
-/

def e2eProgram := r"
procedure safeDivision() {
var x: int := 10;
var y: int := 2;
var z: int := x / y;
assert z == 5;
}

procedure unsafeDivision(x: int) {
var z: int := 10 / x;
// ^^^^^^ error: assertion does not hold
}

function pureDiv(x: int, y: int): int
requires y != 0
{
x / y
}

procedure callPureDivSafe() {
var z: int := pureDiv(10, 2);
assert z == 5;
}

procedure callPureDivUnsafe(x: int) {
var z: int := pureDiv(10, x);
// ^^^^^^^^^^^^^^ error: assertion does not hold
}
"

#guard_msgs(drop info, error) in
#eval testInputWithOffset "DivByZeroE2E" e2eProgram 22 processLaurelFile

end Laurel
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ assert(102): ✅ pass (at line 7, col 4)
assert(226): ✅ pass (at line 12, col 4)
assert(345): ✅ pass (at line 16, col 4)
assert(458): ✅ pass (at line 20, col 4)
init_calls_Int.SafeDiv_0: ✅ pass (at line 23, col 4)
assert(567): ✅ pass (at line 24, col 4)
Loading