From 541f4ec6f76218add3d45a9c40a216bca156901f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C3=B3th?= Date: Thu, 23 Jan 2025 16:38:28 +0000 Subject: [PATCH] Special case `inj` application generation --- pyk/src/pyk/k2lean4/k2lean4.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/pyk/src/pyk/k2lean4/k2lean4.py b/pyk/src/pyk/k2lean4/k2lean4.py index f25cc3dc82..73dcd69654 100644 --- a/pyk/src/pyk/k2lean4/k2lean4.py +++ b/pyk/src/pyk/k2lean4/k2lean4.py @@ -40,7 +40,7 @@ from ..kore.internal import KoreDefn from ..kore.rule import RewriteRule - from ..kore.syntax import Pattern, SymbolDecl + from ..kore.syntax import Pattern, Sort, SymbolDecl from .model import Binder, Command, Declaration, FieldVal @@ -385,8 +385,8 @@ def _transform_pattern(self, pattern: Pattern) -> Term: return self._transform_evar(name) case DV(SortApp(sort), String(value)): return self._transform_dv(sort, value) - case App(symbol, _, args): - return self._transform_app(symbol, args) + case App(symbol, sorts, args): + return self._transform_app(symbol, sorts, args) case _: raise ValueError(f'Unsupported pattern: {pattern.text}') @@ -440,7 +440,10 @@ def encode(c: str) -> str: encoded = ''.join(encode(c) for c in value) return Term(f'"{encoded}"') - def _transform_app(self, symbol: str, args: Iterable[Pattern]) -> Term: + def _transform_app(self, symbol: str, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term: + if symbol == 'inj': + return self._transform_inj_app(sorts, args) + if symbol in self.structure_symbols: fields = self.structures[self.structure_symbols[symbol]] return self._transform_structure_app(fields, args) @@ -449,6 +452,22 @@ def _transform_app(self, symbol: str, args: Iterable[Pattern]) -> Term: sort = decl.sort.name if isinstance(decl.sort, SortApp) else None return self._transform_basic_app(sort, symbol, args) + def _transform_arg(self, pattern: Pattern) -> Term: + term = self._transform_pattern(pattern) + if not isinstance(pattern, App): + return term + return Term(f'({term})') + + def _transform_inj_app(self, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term: + _from_sort, _to_sort = sorts + assert isinstance(_from_sort, SortApp) + assert isinstance(_to_sort, SortApp) + from_str = _from_sort.name + to_str = _to_sort.name + (arg,) = args + term = self._transform_arg(arg) + return Term(f'(@inj {from_str} {to_str}) {term}') + def _transform_structure_app(self, fields: Iterable[Field], args: Iterable[Pattern]) -> Term: fields_str = ', '.join( f'{field.name} := {self._transform_pattern(arg)}' for field, arg in zip(fields, args, strict=True) @@ -467,12 +486,7 @@ def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pat ident = self._symbol_ident(symbol) chunks.append(ident) - chunks.extend( - f'({term})' if isinstance(arg, App) and arg.args else str(term) - for arg in args - if (term := self._transform_pattern(arg)) - ) - + chunks.extend(str(self._transform_arg(arg)) for arg in args) return Term(' '.join(chunks))