Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 95 additions & 36 deletions Strata/Languages/Python/Specs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ public abbrev EventType := String
def importEvent : EventType := "import"

/--
Set of event types to log to stderr. Test code can modify this to
enable/disable logging.
-/
initialize stdoutEventsRef : IO.Ref (Std.HashSet EventType) ← IO.mkRef {}

/--
Log message for event type if enabled in `stdoutEventsRef`.
Log message for event type if enabled in the given event set.
Output format: `[event]: message`
-/
def logEvent (event : EventType) (message : String) : BaseIO Unit := do
let events ← stdoutEventsRef.get
def baseLogEvent (events : Std.HashSet EventType)
(event : EventType) (message : String) : BaseIO Unit := do
if event ∈ events then
let _ ← IO.eprintln s!"[{event}]: {message}" |>.toBaseIO
pure ()
Expand Down Expand Up @@ -148,6 +142,7 @@ inductive SpecValue
| typingRequired
| typingNotRequired
| typingUnpack
| reCompile
| requiredType (type : SpecType)
| notRequiredType (type : SpecType)
| typeValue (type : SpecType)
Expand Down Expand Up @@ -206,6 +201,7 @@ def preludeSig :=
.mk .typingRequired .typingRequired,
.mk .typingNotRequired .typingNotRequired,
.mk .typingUnpack .typingUnpack,
.mk .reCompile .reCompile,
]

inductive ClassRef where
Expand All @@ -216,6 +212,10 @@ inductive ClassRef where
abbrev ModuleReader := ModuleName → EIO String System.FilePath

structure PySpecContext where
/-- Events to log -/
eventSet : Std.HashSet String
/-- Top-level definitions to skip. -/
skipNames : Std.HashSet PythonIdent := {}
/-- Command to run for Python -/
pythonCmd : String
/-- Path to Python dialect. -/
Expand Down Expand Up @@ -256,6 +256,23 @@ structure PySpecState where

abbrev PySpecM := ReaderT PySpecContext (StateT PySpecState BaseIO)

def logEvent (event : EventType) (message : String) : PySpecM Unit := do
baseLogEvent (←read).eventSet event message

/-- Check whether a decorator list contains `@overload`. -/
private def hasOverloadDecorator
(decorators : Array (Strata.Python.expr Strata.SourceRange)) : Bool :=
decorators.any fun d =>
match d with
| .Name _ ⟨_, "overload"⟩ _ => true
| _ => false

/-- Should we skip the given top-level name? -/
def shouldSkip (name : String) : PySpecM Bool := do
let ctx ← read
let mod := ctx.pythonFile.fileStem.getD ""
return ctx.skipNames.contains { pythonModule := mod, name }

def specErrorAt (file : System.FilePath) (loc : SourceRange) (message : String) : PySpecM Unit := do
let e : SpecError := { file, loc, message }
modify fun s => { s with errors := s.errors.push e }
Expand Down Expand Up @@ -448,7 +465,8 @@ def translateConstant (value : constant SourceRange) : PySpecM SpecValue := do
specError value.ann s!"Could not interpret constant {value}"
return default

def translateSubscript (paramLoc : SourceRange) (paramType : String) (sargs : SpecValue) : PySpecM SpecValue := do
def translateSubscript (paramLoc : SourceRange) (paramType : String)
(sargs : SpecValue) : PySpecM SpecValue := do
match ← getNameValue? paramType with
| none =>
specError paramLoc s!"Unknown parameterized type {paramType}."
Expand Down Expand Up @@ -757,12 +775,12 @@ def transAssertExpr (e : expr SourceRange)
| .GtE _ =>
match comparators.val[0] with
| .Constant _ (.ConPos _ n) _ =>
return .lenGe subj n.val
return .intGe (.len subj) (.intLit n.val)
| _ => pure ()
| .LtE _ =>
match comparators.val[0] with
| .Constant _ (.ConPos _ n) _ =>
return .lenLe subj n.val
return .intLe (.len subj) (.intLit n.val)
| _ => pure ()
| _ => pure ()
| none => pure ()
Expand All @@ -778,16 +796,16 @@ def transAssertExpr (e : expr SourceRange)
| .GtE _ =>
match comparators.val[0] with
| .Constant _ (.ConPos _ n) _ =>
return .valueGe subj (Int.ofNat n.val)
return .intGe subj (.intLit (Int.ofNat n.val))
| .Constant _ (.ConNeg _ n) _ =>
return .valueGe subj (Int.negOfNat n.val)
return .intGe subj (.intLit (Int.negOfNat n.val))
| _ => pure ()
| .LtE _ =>
match comparators.val[0] with
| .Constant _ (.ConPos _ n) _ =>
return .valueLe subj (Int.ofNat n.val)
return .intLe subj (.intLit (Int.ofNat n.val))
| .Constant _ (.ConNeg _ n) _ =>
return .valueLe subj (Int.negOfNat n.val)
return .intLe subj (.intLit (Int.negOfNat n.val))
| _ => pure ()
| _ => pure ()
| none => pure ()
Expand Down Expand Up @@ -945,7 +963,14 @@ def pySpecFunctionArgs (fnLoc : SourceRange)
match returns with
| none => pure <| .ident fnLoc .typingAny
| some tp => pySpecType tp
let as ← collectAssertions argDecls returnType <| body.forM blockStmt
let as ← collectAssertions argDecls returnType <|
if overload then
-- Overload stubs should have `...` as their only body statement.
unless body.size = 1 &&
body[0]! matches .Expr _ (.Constant _ (.ConEllipsis _) _) do
specWarning fnLoc "overload body is not `...`"
else
body.forM blockStmt

return {
loc := fnLoc
Expand Down Expand Up @@ -1039,7 +1064,8 @@ def checkLevel (loc : SourceRange) (level : Option (int SourceRange)) : PySpecM
| none =>
specError loc s!"Missing import level."

def translateImportFrom (mod : String) (types : Std.HashMap String SpecValue) (names : Array (alias SourceRange)) : PySpecM Unit := do
def translateImportFrom (mod : String) (types : Std.HashMap String SpecValue)
(names : Array (alias SourceRange)) : PySpecM Unit := do
-- Check if module is a builtin (in prelude) - if so, don't generate extern declarations
let isBuiltinModule := preludeSig.rank.contains mod
for a in names do
Expand Down Expand Up @@ -1126,15 +1152,14 @@ partial def resolveModule (loc : SourceRange) (modName : String) :
| .error _ => false
-- If Strata is newer use it.
if useStrata then
logEvent importEvent s!"Importing {modName} from PySpec file"
match ← readDDM strataFile |>.toBaseIO with
| .ok sigs =>
logEvent importEvent s!"Imported {modName} from PySpec file"
return signatureValueMap modName sigs
| .error msg =>
specError loc s!"Could not load Strata file: {msg}"
return default

logEvent importEvent s!"Importing {modName} from Python"
let pythonCmd := (←read).pythonCmd
let dialectFile := (←read).dialectFile
let commands ←
Expand All @@ -1144,13 +1169,18 @@ partial def resolveModule (loc : SourceRange) (modName : String) :
specError loc msg
return default
let errors := (←get).errors
let warnings := (←get).warnings
let errorCount := errors.size
modify fun s => { s with errors := #[] }
modify fun s => { s with errors := #[], warnings := #[] }
let ctx := { (←read) with pythonFile := pythonFile }
let (sigs, t) ← translateModuleAux commands |>.run ctx |>.run { errors := errors }
modify fun s => { s with errors := t.errors }
let initState : PySpecState := { errors, warnings }
let (sigs, t) ← translateModuleAux commands |>.run ctx |>.run initState
let newWarnings := t.warnings.size - warnings.size
modify fun s => { s with errors := t.errors, warnings := t.warnings }
if t.errors.size > errorCount then
return default
let warnMsg := if newWarnings > 0 then s!" ({newWarnings} warning(s))" else ""
logEvent importEvent s!"Imported {modName}{warnMsg}"

if let .error msg ← IO.FS.createDirAll strataDir |>.toBaseIO then
specError loc s!"Could not create directory {strataDir}: {msg}"
Expand All @@ -1173,7 +1203,7 @@ partial def resolveModuleCached (loc : SourceRange) (modName : String)
modify fun s => { s with typeSigs := s.typeSigs.insert modName r }
return r

partial def translate (body : Array (Strata.Python.stmt Strata.SourceRange)) : PySpecM Unit := do
partial def translate (body : Array (stmt Strata.SourceRange)) : PySpecM Unit := do
for stmt in body do
match stmt with
| .Assign loc ⟨_, targets⟩ value _typeAnn =>
Expand Down Expand Up @@ -1208,6 +1238,9 @@ partial def translate (body : Array (Strata.Python.stmt Strata.SourceRange)) : P
⟨_returnsLoc, returns⟩
⟨_typeCommentLoc, typeComment⟩
⟨_typeParamsLoc, typeParams⟩ =>
if hasOverloadDecorator decorators = false ∧ (←shouldSkip funName) then
logEvent "skip" s!"Skipping function {funName}"
continue
assert! _bodyLoc.isNone
-- Flag indicating this is an overload
assert! _decoratorsLoc.isNone
Expand All @@ -1228,7 +1261,21 @@ partial def translate (body : Array (Strata.Python.stmt Strata.SourceRange)) : P
| specError loc s!"Local imports not supported"; continue
if let some types ← resolveModuleCached loc mod then
translateImportFrom mod types names
else
-- Module resolution failed; register imported names as opaque extern
-- types so that downstream references don't produce unknown-identifier
-- errors.
for a in names do
let name := a.name
let asname := a.asname.getD name
let source : PythonIdent := { pythonModule := mod, name := name }
let tpv : SpecValue := .typeValue (.ident loc source)
setNameValue asname tpv
pushSignature (.externTypeDecl asname source)
| .ClassDef loc ⟨_classNameLoc, className⟩ bases keywords stmts decorators typeParams =>
if ←shouldSkip className then
logEvent "skip" s!"Skipping class {className}"
continue
assert! _classNameLoc.isNone
assert! keywords.val.size = 0
assert! decorators.val.size = 0
Expand Down Expand Up @@ -1275,29 +1322,37 @@ end
/-- Maps file paths to their FileMap for error location reporting. -/
public abbrev FileMaps := Std.HashMap System.FilePath Lean.FileMap

def FileMaps.ppSourceRange (fmm : Strata.Python.Specs.FileMaps) (path : System.FilePath) (loc : SourceRange) : String :=
namespace FileMaps

def ppSourceRange (fmm : FileMaps) (path : System.FilePath) (loc : SourceRange) : String :=
match fmm[path]? with
| none =>
panic! "Invalid path {file}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR, but noting a panic here that'd be a good candidate to convert to an exception.

| some fm =>
loc.format path fm

end FileMaps

/-- Translates Python AST statements to PySpec signatures with dependency resolution. -/
def translateModule
(dialectFile searchPath strataDir pythonFile : System.FilePath)
(fileMap : Lean.FileMap)
(body : Array (Strata.Python.stmt Strata.SourceRange))
(pythonCmd : String := "python") :
BaseIO (FileMaps × Array Signature × Array SpecError) := do
(pythonCmd : String := "python")
(events : Std.HashSet EventType := {})
(skipNames : Std.HashSet PythonIdent := {}) :
BaseIO (FileMaps × Array Signature × Array SpecError × Array SpecError) := do
let fmm : FileMaps := {}
let fmm := fmm.insert pythonFile fileMap
let fileMapsRef : IO.Ref FileMaps ← IO.mkRef fmm
let ctx : PySpecContext := {
eventSet := events
skipNames := skipNames
pythonCmd := pythonCmd
dialectFile := dialectFile.toString
moduleReader := fun (mod : ModuleName) => do
let pythonPath ← mod.findInPath searchPath
logEvent "findFile" s!"Found {mod} as {pythonPath} in {searchPath}"
baseLogEvent events "findFile" s!"Found {mod} as {pythonPath} in {searchPath}"
match ← IO.FS.readFile pythonPath |>.toBaseIO with
| .ok contents =>
let fm := Lean.FileMap.ofString contents
Expand All @@ -1310,16 +1365,16 @@ def translateModule
}
let (res, s) ← translateModuleAux body |>.run ctx |>.run {}
let fmm ← fileMapsRef.get
for w in s.warnings do
let _ ← IO.eprintln s!"warning: {fmm.ppSourceRange w.file w.loc}: {w.message}" |>.toBaseIO
pure (fmm, res, s.errors)
pure (fmm, res, s.errors, s.warnings)

/-- Translates a Python source file to PySpec signatures. Main entry point for translation. -/
public def translateFile
(dialectFile strataDir pythonFile : System.FilePath)
(pythonCmd : String := "python")
(searchPath : Option System.FilePath := none) :
EIO String (Array Signature) := do
(searchPath : Option System.FilePath := none)
(events : Std.HashSet EventType := {})
(skipNames : Std.HashSet PythonIdent := {}) :
EIO String (Array Signature × Array String) := do
let searchPath ←
match searchPath with
| some p => pure p
Expand All @@ -1340,9 +1395,11 @@ public def translateFile
match ← pythonToStrata (pythonCmd := pythonCmd) dialectFile pythonFile |>.toBaseIO with
| .ok r => pure r
| .error msg => throw msg
let (fmm, sigs, errors) ←
Strata.Python.Specs.translateModule
let (fmm, sigs, errors, warnings) ←
translateModule
(pythonCmd := pythonCmd)
(events := events)
(skipNames := skipNames)
(dialectFile := dialectFile)
(searchPath := searchPath)
(strataDir := strataDir)
Expand All @@ -1354,6 +1411,8 @@ public def translateFile
let msg := errors.foldl (init := msg) fun msg e =>
s!"{msg}{fmm.ppSourceRange pythonFile e.loc}: {e.message}\n"
throw msg
pure sigs
let warningMsgs := warnings.map fun w =>
s!"{fmm.ppSourceRange w.file w.loc}: {w.message}"
pure (sigs, warningMsgs)

end Strata.Python.Specs
Loading
Loading