Skip to content

Commit

Permalink
Generate structure-s for cell sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
tothtamas28 committed Jan 16, 2025
1 parent a53995f commit 3025e82
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 1 deletion.
33 changes: 32 additions & 1 deletion pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
Mutual,
Signature,
SimpleFieldVal,
StructCtor,
Structure,
StructVal,
Term,
)
Expand Down Expand Up @@ -69,14 +71,17 @@ def _sort_block(self, sorts: list[str]) -> Command | None:
def _transform_sort(self, sort: str) -> Declaration:
def is_inductive(sort: str) -> bool:
decl = self.defn.sorts[sort]
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key
return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key and not self._is_cell(sort)

def is_collection(sort: str) -> bool:
return sort in self.defn.collections

if is_inductive(sort):
return self._inductive(sort)

if self._is_cell(sort):
return self._cell(sort)

if is_collection(sort):
return self._collection(sort)

Expand Down Expand Up @@ -109,6 +114,32 @@ def _symbol_ident(symbol: str) -> str:
symbol = f'«{symbol}»'
return symbol

@staticmethod
def _is_cell(sort: str) -> bool:
return sort.endswith('Cell')

def _cell(self, sort: str) -> Structure:
(cell_ctor,) = self.defn.constructors[sort]
decl = self.defn.symbols[cell_ctor]
param_sorts = _param_sorts(decl)

param_names: list[str]

if all(self._is_cell(sort) for sort in param_sorts):
param_names = []
for param_sort in param_sorts:
assert param_sort.startswith('Sort')
assert param_sort.endswith('Cell')
name = param_sort[4:-4]
name = name[0].lower() + name[1:]
param_names.append(name)
else:
assert len(param_sorts) == 1
param_names = ['val']

fields = tuple(ExplBinder((name,), Term(sort)) for name, sort in zip(param_names, param_sorts, strict=True))
return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(fields))

def _collection(self, sort: str) -> Inductive:
coll = self.defn.collections[sort]
elem = self.defn.symbols[coll.element]
Expand Down
95 changes: 95 additions & 0 deletions pyk/src/pyk/k2lean4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,101 @@ def __str__(self) -> str:
return f'| {patterns} => {self.rhs}'


@final
@dataclass(frozen=True)
class Structure(Declaration):
ident: DeclId
signature: Signature | None
extends: tuple[Term, ...]
ctor: StructCtor | None
deriving: tuple[str, ...]
modifiers: Modifiers | None

def __init__(
self,
ident: str | DeclId,
signature: Signature | None = None,
extends: Iterable[Term] | None = None,
ctor: StructCtor | None = None,
deriving: Iterable[str] | None = None,
modifiers: Modifiers | None = None,
):
ident = DeclId(ident) if isinstance(ident, str) else ident
extends = tuple(extends) if extends is not None else ()
deriving = tuple(deriving) if deriving is not None else ()
object.__setattr__(self, 'ident', ident)
object.__setattr__(self, 'signature', signature)
object.__setattr__(self, 'extends', extends)
object.__setattr__(self, 'ctor', ctor)
object.__setattr__(self, 'deriving', deriving)
object.__setattr__(self, 'modifiers', modifiers)

def __str__(self) -> str:
lines = []

modifiers = f'{self.modifiers} ' if self.modifiers else ''
binders = (
' '.join(str(binder) for binder in self.signature.binders)
if self.signature and self.signature.binders
else ''
)
binders = f' {binders}' if binders else ''
extends = ', '.join(str(extend) for extend in self.extends)
extends = f' extends {extends}' if extends else ''
ty = f' : {self.signature.ty}' if self.signature and self.signature.ty else ''
where = ' where' if self.ctor else ''
lines.append(f'{modifiers}structure {self.ident}{binders}{extends}{ty}{where}')

if self.deriving:
lines.append(f' deriving {self.deriving}')

if self.ctor:
lines.extend(f' {line}' for line in str(self.ctor).splitlines())

return '\n'.join(lines)


@final
@dataclass(frozen=True)
class StructCtor:
fields: tuple[Binder, ...] # TODO implement StructField
ident: StructIdent | None

def __init__(
self,
fields: Iterable[Binder],
ident: str | StructIdent | None = None,
):
fields = tuple(fields)
ident = StructIdent(ident) if isinstance(ident, str) else ident
object.__setattr__(self, 'fields', fields)
object.__setattr__(self, 'ident', ident)

def __str__(self) -> str:
lines = []
if self.ident:
lines.append(f'{self.ident} ::')
for field in self.fields:
if isinstance(field, ExplBinder) and len(field.idents) == 1:
(ident,) = field.idents
ty = '' if field.ty is None else f' : {field.ty}'
lines.append(f'{ident}{ty}')
else:
lines.append(str(field))
return '\n'.join(lines)


@final
@dataclass(frozen=True)
class StructIdent:
ident: str
modifiers: Modifiers | None = None

def __str__(self) -> str:
modifiers = f'{self.modifiers} ' if self.modifiers else ''
return f'{modifiers}{ self.ident}'


@final
@dataclass(frozen=True)
class DeclId:
Expand Down

0 comments on commit 3025e82

Please sign in to comment.