From 9aa3752440efeb7457f7081d6d2570dd66f7d7d5 Mon Sep 17 00:00:00 2001 From: Petar Maksimovic Date: Tue, 9 Apr 2024 13:05:56 +0100 Subject: [PATCH] refactoring the branching mechanism --- src/pyk/cterm/symbolic.py | 18 +++- src/pyk/kcfg/explore.py | 89 ++++++++----------- .../cterm/test_multiple_definitions.py | 4 +- src/tests/integration/cterm/test_simple.py | 2 +- src/tests/integration/proof/test_imp.py | 2 +- 5 files changed, 54 insertions(+), 61 deletions(-) diff --git a/src/pyk/cterm/symbolic.py b/src/pyk/cterm/symbolic.py index 4565fdc46..25a727cb1 100644 --- a/src/pyk/cterm/symbolic.py +++ b/src/pyk/cterm/symbolic.py @@ -43,7 +43,7 @@ class CTermExecute(NamedTuple): state: CTerm - next_states: tuple[CTerm, ...] + next_states: tuple[tuple[CTerm, KInner | None], ...] depth: int vacuous: bool logs: tuple[LogEntry, ...] @@ -120,11 +120,23 @@ def execute( state = CTerm.from_kast(self.kore_to_kast(response.state.kore)) resp_next_states = response.next_states or () - next_states = tuple(CTerm.from_kast(self.kore_to_kast(ns.kore)) for ns in resp_next_states) + next_states = tuple( + ( + CTerm.from_kast(self.kore_to_kast(ns.kore)), + self.kore_to_kast(ns.rule_predicate) if ns.rule_predicate is not None else None, + ) + for ns in resp_next_states + ) - assert all(not cterm.is_bottom for cterm in next_states) + assert all(not cterm.is_bottom for cterm, _ in next_states) assert len(next_states) != 1 or response.reason is StopReason.CUT_POINT_RULE + for resp_next_state in resp_next_states: + if resp_next_state.rule_predicate is not None: + print(f'Rule predicate: {self.kore_to_kast(resp_next_state.rule_predicate) }') + else: + print('Rule predicate: None') + return CTermExecute( state=state, next_states=next_states, diff --git a/src/pyk/kcfg/explore.py b/src/pyk/kcfg/explore.py index 1df1a6c56..3b06ec198 100644 --- a/src/pyk/kcfg/explore.py +++ b/src/pyk/kcfg/explore.py @@ -4,7 +4,6 @@ from functools import cached_property from typing import TYPE_CHECKING -from ..cterm import CTerm from ..kast.inner import KApply, KVariable from ..kast.manip import ( flatten_label, @@ -16,10 +15,8 @@ ) from ..kast.pretty import PrettyPrinter from ..kore.rpc import RewriteSuccess -from ..prelude.kbool import notBool -from ..prelude.kint import leInt, ltInt -from ..prelude.ml import is_top, mlAnd, mlEqualsFalse, mlEqualsTrue, mlNot -from ..utils import shorten_hashes, single +from ..prelude.ml import is_top +from ..utils import not_none, shorten_hashes, single from .kcfg import KCFG, Abstract, Branch, NDBranch, Step, Stuck, Vacuous from .semantics import DefaultSemantics @@ -27,7 +24,7 @@ from collections.abc import Iterable from typing import Final - from ..cterm import CTermSymbolic + from ..cterm import CTerm, CTermSymbolic from ..kast import KInner from ..kcfg.exploration import KCFGExploration from ..kore.rpc import LogEntry @@ -225,19 +222,19 @@ def extract_rule_labels(_logs: tuple[LogEntry, ...]) -> list[str]: log(f'abstraction node: {node_id}') return Abstract(abstract_cterm) - _branches = self.kcfg_semantics.extract_branches(_cterm) - branches = [] - for constraint in _branches: - kast = mlAnd(list(_cterm.constraints) + [constraint]) - kast, _ = self.cterm_symbolic.kast_simplify(kast) - if not CTerm._is_bottom(kast): - branches.append(constraint) - if len(branches) > 1: - constraint_strs = [self.pretty_print(bc) for bc in branches] - log(f'{len(branches)} branches using heuristics: {node_id} -> {constraint_strs}') - return Branch(branches, heuristic=True) - - cterm, next_cterms, depth, vacuous, next_node_logs = self.cterm_symbolic.execute( + # _branches = self.kcfg_semantics.extract_branches(_cterm) + # branches = [] + # for constraint in _branches: + # kast = mlAnd(list(_cterm.constraints) + [constraint]) + # kast, _ = self.cterm_symbolic.kast_simplify(kast) + # if not CTerm._is_bottom(kast): + # branches.append(constraint) + # if len(branches) > 1: + # constraint_strs = [self.pretty_print(bc) for bc in branches] + # log(f'{len(branches)} branches using heuristics: {node_id} -> {constraint_strs}') + # return Branch(branches, heuristic=True) + + cterm, next_cterms_with_branch_constraints, depth, vacuous, next_node_logs = self.cterm_symbolic.execute( _cterm, depth=execute_depth, cut_point_rules=cut_point_rules, @@ -251,7 +248,7 @@ def extract_rule_labels(_logs: tuple[LogEntry, ...]) -> list[str]: return Step(cterm, depth, next_node_logs, extract_rule_labels(next_node_logs)) # Stuck or vacuous - if not next_cterms: + if not next_cterms_with_branch_constraints: if vacuous: log(f'vacuous node: {node_id}', warning=True) return Vacuous() @@ -259,41 +256,25 @@ def extract_rule_labels(_logs: tuple[LogEntry, ...]) -> list[str]: return Stuck() # Cut rule - if len(next_cterms) == 1: + if len(next_cterms_with_branch_constraints) == 1: log(f'cut-rule basic block at depth {depth}: {node_id}') - return Step(next_cterms[0], 1, next_node_logs, extract_rule_labels(next_node_logs), cut=True) + return Step( + next_cterms_with_branch_constraints[0][0], + 1, + next_node_logs, + extract_rule_labels(next_node_logs), + cut=True, + ) # Branch - assert len(next_cterms) > 1 - branches = [mlAnd(c for c in s.constraints if c not in cterm.constraints) for s in next_cterms] - branch_and = mlAnd(branches) - branch_patterns = [ - mlAnd([mlEqualsTrue(KVariable('B')), mlEqualsTrue(notBool(KVariable('B')))]), - mlAnd([mlEqualsTrue(notBool(KVariable('B'))), mlEqualsTrue(KVariable('B'))]), - mlAnd([mlEqualsTrue(KVariable('B')), mlEqualsFalse(KVariable('B'))]), - mlAnd([mlEqualsFalse(KVariable('B')), mlEqualsTrue(KVariable('B'))]), - mlAnd([mlNot(KVariable('B')), KVariable('B')]), - mlAnd([KVariable('B'), mlNot(KVariable('B'))]), - mlAnd( - [ - mlEqualsTrue(ltInt(KVariable('I1'), KVariable('I2'))), - mlEqualsTrue(leInt(KVariable('I2'), KVariable('I1'))), - ] - ), - mlAnd( - [ - mlEqualsTrue(leInt(KVariable('I1'), KVariable('I2'))), - mlEqualsTrue(ltInt(KVariable('I2'), KVariable('I1'))), - ] - ), - ] - - # Split on branch patterns - if any(branch_pattern.match(branch_and) for branch_pattern in branch_patterns): - constraint_strs = [self.pretty_print(bc) for bc in branches] - log(f'{len(branches)} branches using heuristics: {node_id} -> {constraint_strs}') + assert len(next_cterms_with_branch_constraints) > 1 + if all(branch_constraint for _, branch_constraint in next_cterms_with_branch_constraints): + branches = [not_none(rule_predicate) for _, rule_predicate in next_cterms_with_branch_constraints] + constraint_strs = [self.pretty_print(ml_pred_to_bool(bc)) for bc in branches] + log(f'{len(branches)} branches: {node_id} -> {constraint_strs}') return Branch(branches) - - # NDBranch on successor nodes - log(f'{len(next_cterms)} non-deterministic branches: {node_id}') - return NDBranch(next_cterms, next_node_logs, extract_rule_labels(next_node_logs)) + else: + # NDBranch + log(f'{len(next_cterms_with_branch_constraints)} non-deterministic branches: {node_id}') + next_cterms = [cterm for cterm, _ in next_cterms_with_branch_constraints] + return NDBranch(next_cterms, next_node_logs, extract_rule_labels(next_node_logs)) diff --git a/src/tests/integration/cterm/test_multiple_definitions.py b/src/tests/integration/cterm/test_multiple_definitions.py index f554ae81a..58dc8e0fc 100644 --- a/src/tests/integration/cterm/test_multiple_definitions.py +++ b/src/tests/integration/cterm/test_multiple_definitions.py @@ -81,10 +81,10 @@ def test_execute( 'a ( X:KItem ) ~> .K', ] == split_next_k - step_1_res = cterm_symbolic.execute(split_next_terms[0], depth=1) + step_1_res = cterm_symbolic.execute(split_next_terms[0][0], depth=1) step_1_k = kprint.pretty_print(step_1_res.state.cell('K_CELL')) assert 'c ~> .K' == step_1_k - step_2_res = cterm_symbolic.execute(split_next_terms[1], depth=1) + step_2_res = cterm_symbolic.execute(split_next_terms[1][0], depth=1) step_2_k = kprint.pretty_print(step_2_res.state.cell('K_CELL')) assert 'c ~> .K' == step_2_k diff --git a/src/tests/integration/cterm/test_simple.py b/src/tests/integration/cterm/test_simple.py index e1dff0b3d..e01d9ca37 100644 --- a/src/tests/integration/cterm/test_simple.py +++ b/src/tests/integration/cterm/test_simple.py @@ -87,7 +87,7 @@ def test_execute( kprint.pretty_print(s.cell('K_CELL')), kprint.pretty_print(s.cell('STATE_CELL')), ) - for s in exec_res.next_states + for s, _ in exec_res.next_states ] # Then diff --git a/src/tests/integration/proof/test_imp.py b/src/tests/integration/proof/test_imp.py index ce123d93b..4d0556795 100644 --- a/src/tests/integration/proof/test_imp.py +++ b/src/tests/integration/proof/test_imp.py @@ -786,7 +786,7 @@ def test_execute( kcfg_explore.pretty_print(s.cell('K_CELL')), kcfg_explore.pretty_print(s.cell('STATE_CELL')), ) - for s in exec_res.next_states + for s, _ in exec_res.next_states ] # Then