From 46e63721f847451a123c448e06ab779349ba69cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20T=C3=B3th?= Date: Mon, 3 Jun 2024 16:02:29 +0200 Subject: [PATCH] Pretty print brackets in `Formatter` (#4388) Implements a simplified version of the bracketing algorithm from `kprint`: * If the symbol is a constant or appears between two terminals, do not add brackets * Otherwise, if associativity or priority rules forbid the term from appearing as a direct child of its parent, add brackets * Otherwise, do not add brackets. --- pyk/src/pyk/kast/att.py | 1 + pyk/src/pyk/kast/formatter.py | 89 ++++++++++++++++++- pyk/src/pyk/kast/outer.py | 44 ++++++++- pyk/src/pyk/konvert/_module_to_kore.py | 49 ++++------ .../tests/integration/kast/test_formatter.py | 56 ++++++++++-- 5 files changed, 196 insertions(+), 43 deletions(-) diff --git a/pyk/src/pyk/kast/att.py b/pyk/src/pyk/kast/att.py index d84038a3ab7..b9f4f4bc2cb 100644 --- a/pyk/src/pyk/kast/att.py +++ b/pyk/src/pyk/kast/att.py @@ -291,6 +291,7 @@ class Atts: ANYWHERE: Final = AttKey('anywhere', type=_NONE) ASSOC: Final = AttKey('assoc', type=_NONE) BRACKET: Final = AttKey('bracket', type=_NONE) + BRACKET_LABEL: Final = AttKey('bracketLabel', type=_ANY) CIRCULARITY: Final = AttKey('circularity', type=_NONE) CELL: Final = AttKey('cell', type=_NONE) CELL_COLLECTION: Final = AttKey('cellCollection', type=_NONE) diff --git a/pyk/src/pyk/kast/formatter.py b/pyk/src/pyk/kast/formatter.py index e960c406025..64cb4636b2b 100644 --- a/pyk/src/pyk/kast/formatter.py +++ b/pyk/src/pyk/kast/formatter.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from . import KInner + from .inner import KSort from .outer import KDefinition, KProduction @@ -16,15 +17,19 @@ class Formatter: definition: KDefinition _indent: int + _brackets: bool - def __init__(self, definition: KDefinition, *, indent: int = 0): + def __init__(self, definition: KDefinition, *, indent: int = 0, brackets: bool = True): self.definition = definition self._indent = indent + self._brackets = brackets def __call__(self, term: KInner) -> str: return self.format(term) def format(self, term: KInner) -> str: + if self._brackets: + term = add_brackets(self.definition, term) return ''.join(self._format(term)) def _format(self, term: KInner) -> list[str]: @@ -47,7 +52,7 @@ def _format_ksequence(self, ksequence: KSequence) -> list[str]: return [chunk for chunks in intersperse(items, [' ~> ']) for chunk in chunks] def _format_kapply(self, kapply: KApply) -> list[str]: - production = self.definition.symbols[kapply.label.name] + production = self.definition.syntax_symbols[kapply.label.name] formatt = production.att.get(Atts.FORMAT, production.default_format) return [ chunk @@ -104,3 +109,83 @@ def _interpret_index(self, index: int, production: KProduction, kapply: KApply) raise ValueError(f'Invalid format index escape to regex terminal: {index}: {production}') case _: raise AssertionError() + + +def add_brackets(definition: KDefinition, term: KInner) -> KInner: + if not isinstance(term, KApply): + return term + prod = definition.symbols[term.label.name] + + args: list[KInner] = [] + + arg_index = -1 + for index, item in enumerate(prod.items): + if not isinstance(item, KNonTerminal): + continue + + arg_index += 1 + arg = term.args[arg_index] + arg = add_brackets(definition, arg) + arg = _with_bracket(definition, term, arg, item.sort, index) + args.append(arg) + + return term.let(args=args) + + +def _with_bracket(definition: KDefinition, parent: KApply, term: KInner, bracket_sort: KSort, index: int) -> KInner: + if not _requires_bracket(definition, parent, term, index): + return term + + bracket_prod = definition.brackets.get(bracket_sort) + if not bracket_prod: + return term + + bracket_label = bracket_prod.att[Atts.BRACKET_LABEL]['name'] + return KApply(bracket_label, term) + + +def _requires_bracket(definition: KDefinition, parent: KApply, term: KInner, index: int) -> bool: + if isinstance(term, (KToken, KVariable, KSequence)): + return False + + assert isinstance(term, KApply) + + if len(term.args) == 1: + return False + + if _between_terminals(definition, parent, index): + return False + + if _associativity_wrong(definition, parent, term, index): + return True + + if _priority_wrong(definition, parent, term): + return True + + return False + + +def _between_terminals(definition: KDefinition, parent: KApply, index: int) -> bool: + prod = definition.symbols[parent.label.name] + if index in [0, len(prod.items) - 1]: + return False + return all(isinstance(prod.items[index + offset], KTerminal) for offset in [-1, 1]) + + +def _associativity_wrong(definition: KDefinition, parent: KApply, term: KApply, index: int) -> bool: + """A left (right) associative symbol cannot appear as the rightmost (leftmost) child of a symbol with equal priority.""" + parent_label = parent.label.name + term_label = term.label.name + prod = definition.symbols[parent_label] + if index == 0 and term_label in definition.right_assocs.get(parent_label, ()): + return True + if index == len(prod.items) - 1 and term_label in definition.left_assocs.get(parent_label, ()): + return True + return False + + +def _priority_wrong(definition: KDefinition, parent: KApply, term: KApply) -> bool: + """A symbol with a lesser priority cannot appear as the child of a symbol with greater priority.""" + parent_label = parent.label.name + term_label = term.label.name + return term_label in definition.priorities.get(parent_label, ()) diff --git a/pyk/src/pyk/kast/outer.py b/pyk/src/pyk/kast/outer.py index 435ee5436c9..0f9438e0495 100644 --- a/pyk/src/pyk/kast/outer.py +++ b/pyk/src/pyk/kast/outer.py @@ -9,7 +9,7 @@ from dataclasses import InitVar # noqa: TC003 from dataclasses import dataclass from enum import Enum -from functools import cached_property +from functools import cached_property, reduce from itertools import pairwise, product from typing import TYPE_CHECKING, final, overload @@ -1200,6 +1200,18 @@ def subsorts(self, sort: KSort) -> frozenset[KSort]: """Return all subsorts of a given `KSort` by inspecting the definition.""" return self.subsort_table.get(sort, frozenset()) + @cached_property + def brackets(self) -> FrozenDict[KSort, KProduction]: + brackets: dict[KSort, KProduction] = {} + for prod in self.productions: + if Atts.BRACKET in prod.att: + assert not prod.klabel + sort = prod.sort + if sort in brackets: + raise ValueError(f'Multiple bracket productions for sort: {sort.name}') + brackets[sort] = prod + return FrozenDict(brackets) + @cached_property def symbols(self) -> FrozenDict[str, KProduction]: symbols: dict[str, KProduction] = {} @@ -1216,6 +1228,13 @@ def symbols(self) -> FrozenDict[str, KProduction]: symbols[symbol] = prod return FrozenDict(symbols) + @cached_property + def syntax_symbols(self) -> FrozenDict[str, KProduction]: + brackets: dict[str, KProduction] = { + prod.att[Atts.BRACKET_LABEL]['name']: prod for _, prod in self.brackets.items() + } + return FrozenDict({**self.symbols, **brackets}) + @cached_property def overloads(self) -> FrozenDict[str, frozenset[str]]: """Return a mapping from symbols to the sets of symbols that overload them.""" @@ -1272,6 +1291,29 @@ def priorities(self) -> FrozenDict[str, frozenset[str]]: ) return POSet(relation).image + @cached_property + def left_assocs(self) -> FrozenDict[str, frozenset[str]]: + return FrozenDict({key: frozenset(value) for key, value in self._assocs(KAssoc.LEFT).items()}) + + @cached_property + def right_assocs(self) -> FrozenDict[str, frozenset[str]]: + return FrozenDict({key: frozenset(value) for key, value in self._assocs(KAssoc.RIGHT).items()}) + + def _assocs(self, assoc: KAssoc) -> dict[str, set[str]]: + sents = ( + sent + for module in self.modules + for sent in module.sentences + if isinstance(sent, KSyntaxAssociativity) and sent.assoc in (assoc, KAssoc.NON_ASSOC) + ) + pairs = (pair for sent in sents for pair in product(sent.tags, sent.tags)) + + def insert(dct: dict[str, set[str]], *, key: str, value: str) -> dict[str, set[str]]: + dct.setdefault(key, set()).add(value) + return dct + + return reduce(lambda res, pair: insert(res, key=pair[0], value=pair[1]), pairs, {}) + def sort(self, kast: KInner) -> KSort | None: """Computes the sort of a given term using best-effort simple sorting algorithm, returns `None` on algorithm failure.""" match kast: diff --git a/pyk/src/pyk/konvert/_module_to_kore.py b/pyk/src/pyk/konvert/_module_to_kore.py index 7ea5670a5de..2fb1519425b 100644 --- a/pyk/src/pyk/konvert/_module_to_kore.py +++ b/pyk/src/pyk/konvert/_module_to_kore.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import reduce -from itertools import product, repeat +from itertools import repeat from pathlib import Path from typing import ClassVar # noqa: TC003 from typing import TYPE_CHECKING, NamedTuple, final @@ -13,17 +13,7 @@ from ..kast.att import Format from ..kast.inner import KApply, KRewrite, KSort from ..kast.manip import extract_lhs, extract_rhs -from ..kast.outer import ( - KAssoc, - KDefinition, - KNonTerminal, - KProduction, - KRegexTerminal, - KRule, - KSyntaxAssociativity, - KSyntaxSort, - KTerminal, -) +from ..kast.outer import KDefinition, KNonTerminal, KProduction, KRegexTerminal, KRule, KSyntaxSort, KTerminal from ..kore.prelude import inj from ..kore.syntax import ( And, @@ -1280,10 +1270,11 @@ def update(production: KProduction) -> KProduction: @dataclass -class AddAssocAtts(SingleModulePass): - def _transform_module(self, module: KFlatModule) -> KFlatModule: - left_assocs = self._assocs(module, KAssoc.LEFT) - right_assocs = self._assocs(module, KAssoc.RIGHT) +class AddAssocAtts(KompilerPass): + def execute(self, definition: KDefinition) -> KDefinition: + if len(definition.modules) > 1: + raise ValueError('Expected a single module') + module = definition.modules[0] def update(production: KProduction) -> KProduction: if not production.klabel: @@ -1292,23 +1283,13 @@ def update(production: KProduction) -> KProduction: if Atts.FORMAT not in production.att: return production - left = tuple(KApply(tag).to_dict() for tag in sorted(left_assocs.get(production.klabel.name, []))) - right = tuple(KApply(tag).to_dict() for tag in sorted(right_assocs.get(production.klabel.name, []))) + left = tuple( + KApply(tag).to_dict() for tag in sorted(definition.left_assocs.get(production.klabel.name, [])) + ) + right = tuple( + KApply(tag).to_dict() for tag in sorted(definition.right_assocs.get(production.klabel.name, [])) + ) return production.let(att=production.att.update([Atts.LEFT(left), Atts.RIGHT(right)])) - return module.map_sentences(update, of_type=KProduction) - - @staticmethod - def _assocs(module: KFlatModule, assoc: KAssoc) -> dict[str, set[str]]: - sents = ( - sent - for sent in module.sentences - if isinstance(sent, KSyntaxAssociativity) and sent.assoc in (assoc, KAssoc.NON_ASSOC) - ) - pairs = (pair for sent in sents for pair in product(sent.tags, sent.tags)) - - def insert(dct: dict[str, set[str]], *, key: str, value: str) -> dict[str, set[str]]: - dct.setdefault(key, set()).add(value) - return dct - - return reduce(lambda res, pair: insert(res, key=pair[0], value=pair[1]), pairs, {}) + module = module.map_sentences(update, of_type=KProduction) + return KDefinition(module.name, (module,)) diff --git a/pyk/src/tests/integration/kast/test_formatter.py b/pyk/src/tests/integration/kast/test_formatter.py index 90358684ee7..74faf290abb 100644 --- a/pyk/src/tests/integration/kast/test_formatter.py +++ b/pyk/src/tests/integration/kast/test_formatter.py @@ -18,10 +18,16 @@ from pyk.kast.outer import KDefinition +class FormatterTest(KompiledTest): + @pytest.fixture + def formatter(self, definition: KDefinition) -> Formatter: + return Formatter(definition) + + x, y = (KToken(name, KSort('Id')) for name in ['x', 'y']) -TEST_DATA = ( +FORMATTER_TEST_DATA = ( ( KApply('', KSequence()), """ @@ -89,14 +95,52 @@ ) -class TestFormatter(KompiledTest): +class TestFormatter(FormatterTest): KOMPILE_MAIN_FILE = K_FILES / 'imp.k' - @pytest.fixture - def formatter(self, definition: KDefinition) -> Formatter: - return Formatter(definition) + @pytest.mark.parametrize('term,output', FORMATTER_TEST_DATA, ids=count()) + def test(self, formatter: Formatter, term: KInner, output: str) -> None: + # Given + expected = textwrap.dedent(output).strip() + + # When + actual = formatter(term) + + # Then + assert actual == expected + + +BRACKETS_TEST_DATA = ( + (token(1), '1'), + (KApply('_+_', token(1), token(2)), '1 + 2'), + (KApply('_+_', KApply('_+_', token(1), token(2)), token(3)), '1 + 2 + 3'), + (KApply('_+_', KApply('_-_', token(1), token(2)), token(3)), '1 - 2 + 3'), + (KApply('_+_', token(1), KApply('_+_', token(2), token(3))), '1 + ( 2 + 3 )'), + (KApply('_+_', token(1), KApply('_*_', token(2), token(3))), '1 + 2 * 3'), + (KApply('_+_', KApply('_*_', token(1), token(2)), token(3)), '1 * 2 + 3'), + (KApply('_*_', token(1), KApply('_+_', token(2), token(3))), '1 * ( 2 + 3 )'), + (KApply('_*_', KApply('_+_', token(1), token(2)), token(3)), '( 1 + 2 ) * 3'), + (KApply('sgn(_)', KApply('_+_', token(1), token(2))), 'sgn ( 1 + 2 )'), +) + + +class TestBrackets(FormatterTest): + KOMPILE_DEFINITION = """ + module BRACKETS + imports INT-SYNTAX + syntax Exp ::= Int + | sgn ( Exp ) [symbol(sgn(_))] + > Exp "*" Exp [symbol(_*_), left] + | Exp "/" Exp [symbol(_/_), left] + > Exp "+" Exp [symbol(_+_), left] + | Exp "-" Exp [symbol(_-_), left] + > "(" Exp ")" [bracket] + endmodule + """ + KOMPILE_MAIN_MODULE = 'BRACKETS' + KOMPILE_ARGS = {'syntax_module': KOMPILE_MAIN_MODULE} - @pytest.mark.parametrize('term,output', TEST_DATA, ids=count()) + @pytest.mark.parametrize('term,output', BRACKETS_TEST_DATA, ids=count()) def test(self, formatter: Formatter, term: KInner, output: str) -> None: # Given expected = textwrap.dedent(output).strip()