Skip to content

Commit

Permalink
Inline app in Assoc (#4487)
Browse files Browse the repository at this point in the history
Related:
* #4466

This allows type-safe traversal of `Assoc` using `top_down` and
`bottom_up`.

---------

Co-authored-by: gtrepta <[email protected]>
  • Loading branch information
tothtamas28 and gtrepta authored Jul 2, 2024
1 parent 25bbf54 commit 144e5db
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 199 deletions.
10 changes: 10 additions & 0 deletions pyk/src/pyk/kllvm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
ML_SYMBOLS,
AliasDecl,
App,
Assoc,
Axiom,
Claim,
Definition,
EVar,
Import,
LeftAssoc,
MLPattern,
Module,
RightAssoc,
SortApp,
SortDecl,
SortVar,
Expand Down Expand Up @@ -110,6 +113,8 @@ def pattern_to_llvm(pattern: Pattern) -> kllvm.Pattern:
return kllvm.VariablePattern(name, sort_to_llvm(sort))
case App(symbol, sorts, args):
return _composite_pattern(symbol, sorts, args)
case Assoc():
return _composite_pattern(pattern.kore_symbol(), [], [pattern.app])
case MLPattern():
return _composite_pattern(pattern.symbol(), pattern.sorts, pattern.ctor_patterns)
case _:
Expand Down Expand Up @@ -206,6 +211,11 @@ def llvm_to_pattern(pattern: kllvm.Pattern) -> Pattern:
symbol, sorts, patterns = _unpack_composite_pattern(pattern)
if symbol in ML_SYMBOLS:
return MLPattern.of(symbol, sorts, patterns)
elif symbol in [r'\left-assoc', r'\right-assoc']:
(app,) = patterns
assert isinstance(app, App)
assoc = LeftAssoc if symbol == r'\left-assoc' else RightAssoc
return assoc(app.symbol, app.sorts, app.args)
else:
return App(symbol, sorts, patterns)
case _:
Expand Down
5 changes: 1 addition & 4 deletions pyk/src/pyk/kore/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from .syntax import And, Assoc, EVar, MLQuant, Top
from .syntax import And, EVar, MLQuant, Top

if TYPE_CHECKING:
from collections.abc import Collection
Expand All @@ -28,9 +28,6 @@ def collect(pattern: Pattern, bound_vars: set[str]) -> None:
else:
occurrences[pattern.name] = [pattern]

elif isinstance(pattern, Assoc):
collect(pattern.app, bound_vars)

elif isinstance(pattern, MLQuant):
new_bound_vars = {pattern.var.name}.union(bound_vars)
collect(pattern.pattern, new_bound_vars)
Expand Down
30 changes: 15 additions & 15 deletions pyk/src/pyk/kore/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,36 @@ def match_dv(pattern: Pattern, sort: Sort | None = None) -> DV:
return dv


def match_symbol(app: App, symbol: str) -> None:
if app.symbol != symbol:
raise ValueError(f'Expected symbol {symbol}, found: {app.symbol}')
def match_symbol(actual: str, expected: str) -> None:
if actual != expected:
raise ValueError(f'Expected symbol {expected}, found: {actual}')


def match_app(pattern: Pattern, symbol: str | None = None) -> App:
app = check_type(pattern, App)
if symbol is not None:
match_symbol(app, symbol)
match_symbol(app.symbol, symbol)
return app


def match_inj(pattern: Pattern) -> App:
return match_app(pattern, 'inj')


def match_left_assoc(pattern: Pattern) -> LeftAssoc:
return check_type(pattern, LeftAssoc)
def match_left_assoc(pattern: Pattern, symbol: str | None = None) -> LeftAssoc:
assoc = check_type(pattern, LeftAssoc)
if symbol is not None:
match_symbol(assoc.symbol, symbol)
return assoc


def match_list(pattern: Pattern) -> tuple[Pattern, ...]:
if type(pattern) is App:
match_app(pattern, "Lbl'Stop'List")
return ()

assoc = match_left_assoc(pattern)
cons = match_app(assoc.app, "Lbl'Unds'List'Unds'")
items = (match_app(arg, 'LblListItem') for arg in cons.args)
assoc = match_left_assoc(pattern, "Lbl'Unds'List'Unds'")
items = (match_app(arg, 'LblListItem') for arg in assoc.args)
elems = (item.args[0] for item in items)
return tuple(elems)

Expand All @@ -62,9 +64,8 @@ def match_set(pattern: Pattern) -> tuple[Pattern, ...]:
match_app(pattern, "Lbl'Stop'Set")
return ()

assoc = match_left_assoc(pattern)
cons = match_app(assoc.app, "Lbl'Unds'Set'Unds'")
items = (match_app(arg, 'LblSetItem') for arg in cons.args)
assoc = match_left_assoc(pattern, "Lbl'Unds'Set'Unds'")
items = (match_app(arg, 'LblSetItem') for arg in assoc.args)
elems = (item.args[0] for item in items)
return tuple(elems)

Expand All @@ -79,9 +80,8 @@ def match_map(pattern: Pattern, *, cell: str | None = None) -> tuple[tuple[Patte
match_app(pattern, stop_symbol)
return ()

assoc = match_left_assoc(pattern)
cons = match_app(assoc.app, cons_symbol)
items = (match_app(arg, item_symbol) for arg in cons.args)
assoc = match_left_assoc(pattern, cons_symbol)
items = (match_app(arg, item_symbol) for arg in assoc.args)
entries = ((item.args[0], item.args[1]) for item in items)
return tuple(entries)

Expand Down
2 changes: 1 addition & 1 deletion pyk/src/pyk/kore/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _assoc(self, token_type: TokenType, cls: type[AS]) -> AS:
self._match(TokenType.LPAREN)
app = self.app()
self._match(TokenType.RPAREN)
return cls(app) # type: ignore
return cls(app.symbol, app.sorts, app.args) # type: ignore

def left_assoc(self) -> LeftAssoc:
return self._assoc(TokenType.ML_LEFT_ASSOC, LeftAssoc)
Expand Down
28 changes: 12 additions & 16 deletions pyk/src/pyk/kore/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,10 @@ def kseq(kitems: Iterable[Pattern], *, dotvar: EVar | None = None) -> Pattern:
if len(args) == 1:
return tail

app = App(KSEQ, (), args)

if len(args) == 2:
return app
return App(KSEQ, (), args)

return RightAssoc(app)
return RightAssoc(KSEQ, (), args)


def k_config_var(var: str) -> DV:
Expand Down Expand Up @@ -224,7 +222,7 @@ def top_cell_initializer(config: Mapping[str, Pattern]) -> App:
def list_pattern(*args: Pattern) -> Pattern:
if not args:
return STOP_LIST
return LeftAssoc(App(LBL_LIST, args=(App(LBL_LIST_ITEM, args=(arg,)) for arg in args)))
return LeftAssoc(LBL_LIST, args=(App(LBL_LIST_ITEM, args=(arg,)) for arg in args))


STOP_SET: Final = App("Lbl'Stop'Set")
Expand All @@ -235,7 +233,7 @@ def list_pattern(*args: Pattern) -> Pattern:
def set_pattern(*args: Pattern) -> Pattern:
if not args:
return STOP_SET
return LeftAssoc(App(LBL_SET, args=(App(LBL_SET_ITEM, args=(arg,)) for arg in args)))
return LeftAssoc(LBL_SET, args=(App(LBL_SET_ITEM, args=(arg,)) for arg in args))


STOP_MAP: Final = App("Lbl'Stop'Map")
Expand All @@ -249,7 +247,7 @@ def map_pattern(*args: tuple[Pattern, Pattern], cell: str | None = None) -> Patt

cons_symbol = SymbolId(f"Lbl'Unds'{cell}Map'Unds'") if cell else LBL_MAP
item_symbol = SymbolId(f'Lbl{cell}MapItem') if cell else LBL_MAP_ITEM
return LeftAssoc(App(cons_symbol, args=(App(item_symbol, args=arg) for arg in args)))
return LeftAssoc(cons_symbol, args=(App(item_symbol, args=arg) for arg in args))


STOP_RANGEMAP: Final = App("Lbl'Stop'RangeMap")
Expand All @@ -263,10 +261,8 @@ def rangemap_pattern(*args: tuple[tuple[Pattern, Pattern], Pattern]) -> Pattern:
return STOP_RANGEMAP

return LeftAssoc(
App(
LBL_RANGEMAP,
args=(App(LBL_RANGEMAP_ITEM, args=(App(LBL_RANGEMAP_RANGE, args=arg[0]), arg[1])) for arg in args),
)
LBL_RANGEMAP,
args=(App(LBL_RANGEMAP_ITEM, args=(App(LBL_RANGEMAP_RANGE, args=arg[0]), arg[1])) for arg in args),
)


Expand Down Expand Up @@ -306,7 +302,7 @@ def json_object(pattern: Pattern) -> App:


def jsons(patterns: Iterable[Pattern]) -> RightAssoc:
return RightAssoc(App(LBL_JSONS, (), chain(patterns, (STOP_JSONS,))))
return RightAssoc(LBL_JSONS, (), chain(patterns, (STOP_JSONS,)))


def json_key(key: str) -> App:
Expand Down Expand Up @@ -370,21 +366,21 @@ def kore_to_json(pattern: Pattern) -> Any:
def _iter_json_list(app: App) -> Iterator[Pattern]:
from . import match as km

km.match_symbol(app, LBL_JSON_LIST.value)
km.match_app(app, LBL_JSON_LIST.value)
curr = km.match_app(app.args[0])
while curr.symbol != STOP_JSONS.symbol:
km.match_symbol(curr, LBL_JSONS.value)
km.match_app(curr, LBL_JSONS.value)
yield curr.args[0]
curr = km.match_app(curr.args[1])


def _iter_json_object(app: App) -> Iterator[tuple[str, Pattern]]:
from . import match as km

km.match_symbol(app, LBL_JSON_OBJECT.value)
km.match_app(app, LBL_JSON_OBJECT.value)
curr = km.match_app(app.args[0])
while curr.symbol != STOP_JSONS.symbol:
km.match_symbol(curr, LBL_JSONS.value)
km.match_app(curr, LBL_JSONS.value)
entry = km.match_app(curr.args[0], LBL_JSON_ENTRY.value)
key = km.kore_str(km.inj(entry.args[0]))
value = entry.args[1]
Expand Down
Loading

0 comments on commit 144e5db

Please sign in to comment.