diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py index 3992d39e1e..56c5da967d 100644 --- a/pyk/src/pyk/k2lean4/k2lean4.py +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -22,6 +22,8 @@ Mutual, Signature, SimpleFieldVal, + StructCtor, + Structure, StructVal, Term, ) @@ -69,7 +71,7 @@ def _sort_block(self, sorts: list[str]) -> Command | None: def _transform_sort(self, sort: str) -> Declaration: def is_inductive(sort: str) -> bool: decl = self.defn.sorts[sort] - return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key + return not decl.hooked and 'hasDomainValues' not in decl.attrs_by_key and not self._is_cell(sort) def is_collection(sort: str) -> bool: return sort in self.defn.collections @@ -77,6 +79,9 @@ def is_collection(sort: str) -> bool: if is_inductive(sort): return self._inductive(sort) + if self._is_cell(sort): + return self._cell(sort) + if is_collection(sort): return self._collection(sort) @@ -109,7 +114,33 @@ def _symbol_ident(symbol: str) -> str: symbol = f'«{symbol}»' return symbol - def _collection(self, sort: str) -> Inductive: + @staticmethod + def _is_cell(sort: str) -> bool: + return sort.endswith('Cell') + + def _cell(self, sort: str) -> Structure: + (cell_ctor,) = self.defn.constructors[sort] + decl = self.defn.symbols[cell_ctor] + param_sorts = _param_sorts(decl) + + param_names: list[str] + + if all(self._is_cell(sort) for sort in param_sorts): + param_names = [] + for param_sort in param_sorts: + assert param_sort.startswith('Sort') + assert param_sort.endswith('Cell') + name = param_sort[4:-4] + name = name[0].lower() + name[1:] + param_names.append(name) + else: + assert len(param_sorts) == 1 + param_names = ['val'] + + fields = tuple(ExplBinder((name,), Term(sort)) for name, sort in zip(param_names, param_sorts, strict=True)) + return Structure(sort, Signature((), Term('Type')), ctor=StructCtor(fields)) + + def _collection(self, sort: str) -> Structure: coll = self.defn.collections[sort] elem = self.defn.symbols[coll.element] sorts = _param_sorts(elem) @@ -124,8 +155,8 @@ def _collection(self, sort: str) -> Inductive: case CollectionKind.MAP: key, value = sorts val = Term(f'List ({key} × {value})') - ctor = Ctor('mk', Signature((ExplBinder(('coll',), val),), Term(sort))) - return Inductive(sort, Signature((), Term('Type')), ctors=(ctor,)) + field = ExplBinder(('coll',), val) + return Structure(sort, Signature((), Term('Type')), ctor=StructCtor((field,))) def inj_module(self) -> Module: return Module(commands=self._inj_commands()) diff --git a/pyk/src/pyk/k2lean4/model.py b/pyk/src/pyk/k2lean4/model.py index 39e763e801..6ca0ad23fe 100644 --- a/pyk/src/pyk/k2lean4/model.py +++ b/pyk/src/pyk/k2lean4/model.py @@ -260,6 +260,101 @@ def __str__(self) -> str: return f'| {patterns} => {self.rhs}' +@final +@dataclass(frozen=True) +class Structure(Declaration): + ident: DeclId + signature: Signature | None + extends: tuple[Term, ...] + ctor: StructCtor | None + deriving: tuple[str, ...] + modifiers: Modifiers | None + + def __init__( + self, + ident: str | DeclId, + signature: Signature | None = None, + extends: Iterable[Term] | None = None, + ctor: StructCtor | None = None, + deriving: Iterable[str] | None = None, + modifiers: Modifiers | None = None, + ): + ident = DeclId(ident) if isinstance(ident, str) else ident + extends = tuple(extends) if extends is not None else () + deriving = tuple(deriving) if deriving is not None else () + object.__setattr__(self, 'ident', ident) + object.__setattr__(self, 'signature', signature) + object.__setattr__(self, 'extends', extends) + object.__setattr__(self, 'ctor', ctor) + object.__setattr__(self, 'deriving', deriving) + object.__setattr__(self, 'modifiers', modifiers) + + def __str__(self) -> str: + lines = [] + + modifiers = f'{self.modifiers} ' if self.modifiers else '' + binders = ( + ' '.join(str(binder) for binder in self.signature.binders) + if self.signature and self.signature.binders + else '' + ) + binders = f' {binders}' if binders else '' + extends = ', '.join(str(extend) for extend in self.extends) + extends = f' extends {extends}' if extends else '' + ty = f' : {self.signature.ty}' if self.signature and self.signature.ty else '' + where = ' where' if self.ctor else '' + lines.append(f'{modifiers}structure {self.ident}{binders}{extends}{ty}{where}') + + if self.deriving: + lines.append(f' deriving {self.deriving}') + + if self.ctor: + lines.extend(f' {line}' for line in str(self.ctor).splitlines()) + + return '\n'.join(lines) + + +@final +@dataclass(frozen=True) +class StructCtor: + fields: tuple[Binder, ...] # TODO implement StructField + ident: StructIdent | None + + def __init__( + self, + fields: Iterable[Binder], + ident: str | StructIdent | None = None, + ): + fields = tuple(fields) + ident = StructIdent(ident) if isinstance(ident, str) else ident + object.__setattr__(self, 'fields', fields) + object.__setattr__(self, 'ident', ident) + + def __str__(self) -> str: + lines = [] + if self.ident: + lines.append(f'{self.ident} ::') + for field in self.fields: + if isinstance(field, ExplBinder) and len(field.idents) == 1: + (ident,) = field.idents + ty = '' if field.ty is None else f' : {field.ty}' + lines.append(f'{ident}{ty}') + else: + lines.append(str(field)) + return '\n'.join(lines) + + +@final +@dataclass(frozen=True) +class StructIdent: + ident: str + modifiers: Modifiers | None = None + + def __str__(self) -> str: + modifiers = f'{self.modifiers} ' if self.modifiers else '' + return f'{modifiers}{ self.ident}' + + @final @dataclass(frozen=True) class DeclId: