Skip to content

Commit

Permalink
Pretty print brackets in Formatter (#4388)
Browse files Browse the repository at this point in the history
Implements a simplified version of the bracketing algorithm from
`kprint`:
* If the symbol is a constant or appears between two terminals, do not
add brackets
* Otherwise, if associativity or priority rules forbid the term from
appearing as a direct child of its parent, add brackets
* Otherwise, do not add brackets.
  • Loading branch information
tothtamas28 authored Jun 3, 2024
1 parent f0c93bc commit 46e6372
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 43 deletions.
1 change: 1 addition & 0 deletions pyk/src/pyk/kast/att.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class Atts:
ANYWHERE: Final = AttKey('anywhere', type=_NONE)
ASSOC: Final = AttKey('assoc', type=_NONE)
BRACKET: Final = AttKey('bracket', type=_NONE)
BRACKET_LABEL: Final = AttKey('bracketLabel', type=_ANY)
CIRCULARITY: Final = AttKey('circularity', type=_NONE)
CELL: Final = AttKey('cell', type=_NONE)
CELL_COLLECTION: Final = AttKey('cellCollection', type=_NONE)
Expand Down
89 changes: 87 additions & 2 deletions pyk/src/pyk/kast/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@

if TYPE_CHECKING:
from . import KInner
from .inner import KSort
from .outer import KDefinition, KProduction


class Formatter:
definition: KDefinition

_indent: int
_brackets: bool

def __init__(self, definition: KDefinition, *, indent: int = 0):
def __init__(self, definition: KDefinition, *, indent: int = 0, brackets: bool = True):
self.definition = definition
self._indent = indent
self._brackets = brackets

def __call__(self, term: KInner) -> str:
return self.format(term)

def format(self, term: KInner) -> str:
if self._brackets:
term = add_brackets(self.definition, term)
return ''.join(self._format(term))

def _format(self, term: KInner) -> list[str]:
Expand All @@ -47,7 +52,7 @@ def _format_ksequence(self, ksequence: KSequence) -> list[str]:
return [chunk for chunks in intersperse(items, [' ~> ']) for chunk in chunks]

def _format_kapply(self, kapply: KApply) -> list[str]:
production = self.definition.symbols[kapply.label.name]
production = self.definition.syntax_symbols[kapply.label.name]
formatt = production.att.get(Atts.FORMAT, production.default_format)
return [
chunk
Expand Down Expand Up @@ -104,3 +109,83 @@ def _interpret_index(self, index: int, production: KProduction, kapply: KApply)
raise ValueError(f'Invalid format index escape to regex terminal: {index}: {production}')
case _:
raise AssertionError()


def add_brackets(definition: KDefinition, term: KInner) -> KInner:
if not isinstance(term, KApply):
return term
prod = definition.symbols[term.label.name]

args: list[KInner] = []

arg_index = -1
for index, item in enumerate(prod.items):
if not isinstance(item, KNonTerminal):
continue

arg_index += 1
arg = term.args[arg_index]
arg = add_brackets(definition, arg)
arg = _with_bracket(definition, term, arg, item.sort, index)
args.append(arg)

return term.let(args=args)


def _with_bracket(definition: KDefinition, parent: KApply, term: KInner, bracket_sort: KSort, index: int) -> KInner:
if not _requires_bracket(definition, parent, term, index):
return term

bracket_prod = definition.brackets.get(bracket_sort)
if not bracket_prod:
return term

bracket_label = bracket_prod.att[Atts.BRACKET_LABEL]['name']
return KApply(bracket_label, term)


def _requires_bracket(definition: KDefinition, parent: KApply, term: KInner, index: int) -> bool:
if isinstance(term, (KToken, KVariable, KSequence)):
return False

assert isinstance(term, KApply)

if len(term.args) == 1:
return False

if _between_terminals(definition, parent, index):
return False

if _associativity_wrong(definition, parent, term, index):
return True

if _priority_wrong(definition, parent, term):
return True

return False


def _between_terminals(definition: KDefinition, parent: KApply, index: int) -> bool:
prod = definition.symbols[parent.label.name]
if index in [0, len(prod.items) - 1]:
return False
return all(isinstance(prod.items[index + offset], KTerminal) for offset in [-1, 1])


def _associativity_wrong(definition: KDefinition, parent: KApply, term: KApply, index: int) -> bool:
"""A left (right) associative symbol cannot appear as the rightmost (leftmost) child of a symbol with equal priority."""
parent_label = parent.label.name
term_label = term.label.name
prod = definition.symbols[parent_label]
if index == 0 and term_label in definition.right_assocs.get(parent_label, ()):
return True
if index == len(prod.items) - 1 and term_label in definition.left_assocs.get(parent_label, ()):
return True
return False


def _priority_wrong(definition: KDefinition, parent: KApply, term: KApply) -> bool:
"""A symbol with a lesser priority cannot appear as the child of a symbol with greater priority."""
parent_label = parent.label.name
term_label = term.label.name
return term_label in definition.priorities.get(parent_label, ())
44 changes: 43 additions & 1 deletion pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import InitVar # noqa: TC003
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from functools import cached_property, reduce
from itertools import pairwise, product
from typing import TYPE_CHECKING, final, overload

Expand Down Expand Up @@ -1200,6 +1200,18 @@ def subsorts(self, sort: KSort) -> frozenset[KSort]:
"""Return all subsorts of a given `KSort` by inspecting the definition."""
return self.subsort_table.get(sort, frozenset())

@cached_property
def brackets(self) -> FrozenDict[KSort, KProduction]:
brackets: dict[KSort, KProduction] = {}
for prod in self.productions:
if Atts.BRACKET in prod.att:
assert not prod.klabel
sort = prod.sort
if sort in brackets:
raise ValueError(f'Multiple bracket productions for sort: {sort.name}')
brackets[sort] = prod
return FrozenDict(brackets)

@cached_property
def symbols(self) -> FrozenDict[str, KProduction]:
symbols: dict[str, KProduction] = {}
Expand All @@ -1216,6 +1228,13 @@ def symbols(self) -> FrozenDict[str, KProduction]:
symbols[symbol] = prod
return FrozenDict(symbols)

@cached_property
def syntax_symbols(self) -> FrozenDict[str, KProduction]:
brackets: dict[str, KProduction] = {
prod.att[Atts.BRACKET_LABEL]['name']: prod for _, prod in self.brackets.items()
}
return FrozenDict({**self.symbols, **brackets})

@cached_property
def overloads(self) -> FrozenDict[str, frozenset[str]]:
"""Return a mapping from symbols to the sets of symbols that overload them."""
Expand Down Expand Up @@ -1272,6 +1291,29 @@ def priorities(self) -> FrozenDict[str, frozenset[str]]:
)
return POSet(relation).image

@cached_property
def left_assocs(self) -> FrozenDict[str, frozenset[str]]:
return FrozenDict({key: frozenset(value) for key, value in self._assocs(KAssoc.LEFT).items()})

@cached_property
def right_assocs(self) -> FrozenDict[str, frozenset[str]]:
return FrozenDict({key: frozenset(value) for key, value in self._assocs(KAssoc.RIGHT).items()})

def _assocs(self, assoc: KAssoc) -> dict[str, set[str]]:
sents = (
sent
for module in self.modules
for sent in module.sentences
if isinstance(sent, KSyntaxAssociativity) and sent.assoc in (assoc, KAssoc.NON_ASSOC)
)
pairs = (pair for sent in sents for pair in product(sent.tags, sent.tags))

def insert(dct: dict[str, set[str]], *, key: str, value: str) -> dict[str, set[str]]:
dct.setdefault(key, set()).add(value)
return dct

return reduce(lambda res, pair: insert(res, key=pair[0], value=pair[1]), pairs, {})

def sort(self, kast: KInner) -> KSort | None:
"""Computes the sort of a given term using best-effort simple sorting algorithm, returns `None` on algorithm failure."""
match kast:
Expand Down
49 changes: 15 additions & 34 deletions pyk/src/pyk/konvert/_module_to_kore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce
from itertools import product, repeat
from itertools import repeat
from pathlib import Path
from typing import ClassVar # noqa: TC003
from typing import TYPE_CHECKING, NamedTuple, final
Expand All @@ -13,17 +13,7 @@
from ..kast.att import Format
from ..kast.inner import KApply, KRewrite, KSort
from ..kast.manip import extract_lhs, extract_rhs
from ..kast.outer import (
KAssoc,
KDefinition,
KNonTerminal,
KProduction,
KRegexTerminal,
KRule,
KSyntaxAssociativity,
KSyntaxSort,
KTerminal,
)
from ..kast.outer import KDefinition, KNonTerminal, KProduction, KRegexTerminal, KRule, KSyntaxSort, KTerminal
from ..kore.prelude import inj
from ..kore.syntax import (
And,
Expand Down Expand Up @@ -1280,10 +1270,11 @@ def update(production: KProduction) -> KProduction:


@dataclass
class AddAssocAtts(SingleModulePass):
def _transform_module(self, module: KFlatModule) -> KFlatModule:
left_assocs = self._assocs(module, KAssoc.LEFT)
right_assocs = self._assocs(module, KAssoc.RIGHT)
class AddAssocAtts(KompilerPass):
def execute(self, definition: KDefinition) -> KDefinition:
if len(definition.modules) > 1:
raise ValueError('Expected a single module')
module = definition.modules[0]

def update(production: KProduction) -> KProduction:
if not production.klabel:
Expand All @@ -1292,23 +1283,13 @@ def update(production: KProduction) -> KProduction:
if Atts.FORMAT not in production.att:
return production

left = tuple(KApply(tag).to_dict() for tag in sorted(left_assocs.get(production.klabel.name, [])))
right = tuple(KApply(tag).to_dict() for tag in sorted(right_assocs.get(production.klabel.name, [])))
left = tuple(
KApply(tag).to_dict() for tag in sorted(definition.left_assocs.get(production.klabel.name, []))
)
right = tuple(
KApply(tag).to_dict() for tag in sorted(definition.right_assocs.get(production.klabel.name, []))
)
return production.let(att=production.att.update([Atts.LEFT(left), Atts.RIGHT(right)]))

return module.map_sentences(update, of_type=KProduction)

@staticmethod
def _assocs(module: KFlatModule, assoc: KAssoc) -> dict[str, set[str]]:
sents = (
sent
for sent in module.sentences
if isinstance(sent, KSyntaxAssociativity) and sent.assoc in (assoc, KAssoc.NON_ASSOC)
)
pairs = (pair for sent in sents for pair in product(sent.tags, sent.tags))

def insert(dct: dict[str, set[str]], *, key: str, value: str) -> dict[str, set[str]]:
dct.setdefault(key, set()).add(value)
return dct

return reduce(lambda res, pair: insert(res, key=pair[0], value=pair[1]), pairs, {})
module = module.map_sentences(update, of_type=KProduction)
return KDefinition(module.name, (module,))
56 changes: 50 additions & 6 deletions pyk/src/tests/integration/kast/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
from pyk.kast.outer import KDefinition


class FormatterTest(KompiledTest):
@pytest.fixture
def formatter(self, definition: KDefinition) -> Formatter:
return Formatter(definition)


x, y = (KToken(name, KSort('Id')) for name in ['x', 'y'])


TEST_DATA = (
FORMATTER_TEST_DATA = (
(
KApply('<k>', KSequence()),
"""
Expand Down Expand Up @@ -89,14 +95,52 @@
)


class TestFormatter(KompiledTest):
class TestFormatter(FormatterTest):
KOMPILE_MAIN_FILE = K_FILES / 'imp.k'

@pytest.fixture
def formatter(self, definition: KDefinition) -> Formatter:
return Formatter(definition)
@pytest.mark.parametrize('term,output', FORMATTER_TEST_DATA, ids=count())
def test(self, formatter: Formatter, term: KInner, output: str) -> None:
# Given
expected = textwrap.dedent(output).strip()

# When
actual = formatter(term)

# Then
assert actual == expected


BRACKETS_TEST_DATA = (
(token(1), '1'),
(KApply('_+_', token(1), token(2)), '1 + 2'),
(KApply('_+_', KApply('_+_', token(1), token(2)), token(3)), '1 + 2 + 3'),
(KApply('_+_', KApply('_-_', token(1), token(2)), token(3)), '1 - 2 + 3'),
(KApply('_+_', token(1), KApply('_+_', token(2), token(3))), '1 + ( 2 + 3 )'),
(KApply('_+_', token(1), KApply('_*_', token(2), token(3))), '1 + 2 * 3'),
(KApply('_+_', KApply('_*_', token(1), token(2)), token(3)), '1 * 2 + 3'),
(KApply('_*_', token(1), KApply('_+_', token(2), token(3))), '1 * ( 2 + 3 )'),
(KApply('_*_', KApply('_+_', token(1), token(2)), token(3)), '( 1 + 2 ) * 3'),
(KApply('sgn(_)', KApply('_+_', token(1), token(2))), 'sgn ( 1 + 2 )'),
)


class TestBrackets(FormatterTest):
KOMPILE_DEFINITION = """
module BRACKETS
imports INT-SYNTAX
syntax Exp ::= Int
| sgn ( Exp ) [symbol(sgn(_))]
> Exp "*" Exp [symbol(_*_), left]
| Exp "/" Exp [symbol(_/_), left]
> Exp "+" Exp [symbol(_+_), left]
| Exp "-" Exp [symbol(_-_), left]
> "(" Exp ")" [bracket]
endmodule
"""
KOMPILE_MAIN_MODULE = 'BRACKETS'
KOMPILE_ARGS = {'syntax_module': KOMPILE_MAIN_MODULE}

@pytest.mark.parametrize('term,output', TEST_DATA, ids=count())
@pytest.mark.parametrize('term,output', BRACKETS_TEST_DATA, ids=count())
def test(self, formatter: Formatter, term: KInner, output: str) -> None:
# Given
expected = textwrap.dedent(output).strip()
Expand Down

0 comments on commit 46e6372

Please sign in to comment.