Skip to content

Commit

Permalink
Special case inj application generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tothtamas28 committed Jan 24, 2025
1 parent 8aaa29a commit 2cdc209
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions pyk/src/pyk/k2lean4/k2lean4.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from ..kore.internal import KoreDefn
from ..kore.rule import RewriteRule
from ..kore.syntax import Pattern, SymbolDecl
from ..kore.syntax import Pattern, Sort, SymbolDecl
from .model import Binder, Command, Declaration, FieldVal


Expand Down Expand Up @@ -385,8 +385,8 @@ def _transform_pattern(self, pattern: Pattern) -> Term:
return self._transform_evar(name)
case DV(SortApp(sort), String(value)):
return self._transform_dv(sort, value)
case App(symbol, _, args):
return self._transform_app(symbol, args)
case App(symbol, sorts, args):
return self._transform_app(symbol, sorts, args)
case _:
raise ValueError(f'Unsupported pattern: {pattern.text}')

Expand Down Expand Up @@ -440,7 +440,10 @@ def encode(c: str) -> str:
encoded = ''.join(encode(c) for c in value)
return Term(f'"{encoded}"')

def _transform_app(self, symbol: str, args: Iterable[Pattern]) -> Term:
def _transform_app(self, symbol: str, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term:
if symbol == 'inj':
return self._transform_inj_app(sorts, args)

if symbol in self.structure_symbols:
fields = self.structures[self.structure_symbols[symbol]]
return self._transform_structure_app(fields, args)
Expand All @@ -449,6 +452,22 @@ def _transform_app(self, symbol: str, args: Iterable[Pattern]) -> Term:
sort = decl.sort.name if isinstance(decl.sort, SortApp) else None
return self._transform_basic_app(sort, symbol, args)

def _transform_arg(self, pattern: Pattern) -> Term:
term = self._transform_pattern(pattern)
if not isinstance(pattern, App):
return term
return Term(f'({term})')

def _transform_inj_app(self, sorts: tuple[Sort, ...], args: tuple[Pattern, ...]) -> Term:
_from_sort, _to_sort = sorts
assert isinstance(_from_sort, SortApp)
assert isinstance(_to_sort, SortApp)
from_str = _from_sort.name
to_str = _to_sort.name
(arg,) = args
term = self._transform_arg(arg)
return Term(f'(@inj {from_str} {to_str}) {term}')

def _transform_structure_app(self, fields: Iterable[Field], args: Iterable[Pattern]) -> Term:
fields_str = ', '.join(
f'{field.name} := {self._transform_pattern(arg)}' for field, arg in zip(fields, args, strict=True)
Expand All @@ -467,12 +486,7 @@ def _transform_basic_app(self, sort: str | None, symbol: str, args: Iterable[Pat
ident = self._symbol_ident(symbol)

chunks.append(ident)
chunks.extend(
f'({term})' if isinstance(arg, App) and arg.args else str(term)
for arg in args
if (term := self._transform_pattern(arg))
)

chunks.extend(str(self._transform_arg(arg)) for arg in args)
return Term(' '.join(chunks))


Expand Down

0 comments on commit 2cdc209

Please sign in to comment.