Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring the branching mechanism #4248

Merged
merged 9 commits into from
Apr 24, 2024
52 changes: 46 additions & 6 deletions pyk/regression-new/kprove-haskell/sum-spec.k.out
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
APRProof: 3b7070ca4603e18d26d032c0fe64dd71ffadc2c63dae9d57315e8e471a371bd3
status: ProofStatus.PASSED
admitted: False
nodes: 5
nodes: 7
pending: 0
failing: 0
vacuous: 0
Expand All @@ -12,7 +12,7 @@ APRProof: 3b7070ca4603e18d26d032c0fe64dd71ffadc2c63dae9d57315e8e471a371bd3
execution time: 0s
Subproofs: 0

┌─ 1 (root, init)
┌─ 1 (root, split, init)
│ <generatedTop>
│ <k>
│ addCounter ( N:Int )
Expand All @@ -30,12 +30,32 @@ Subproofs: 0
│ </generatedTop>
│ #And { true #Equals N:Int >=Int 0 }
┃ (1 step)
┣━━┓
┃ (branch)
┣━━┓ constraint: N:Int >Int 0
┃ │
┃ ├─ 3
┃ │ <generatedTop>
┃ │ <k>
┃ │ addCounter ( N:Int )
┃ │ ~> K_CELL_fc656f08:K
┃ │ </k>
┃ │ <counter>
┃ │ C:Int
┃ │ </counter>
┃ │ <sum>
┃ │ S:Int
┃ │ </sum>
┃ │ <generatedCounter>
┃ │ GENERATEDCOUNTER_CELL_949ec677:Int
┃ │ </generatedCounter>
┃ │ </generatedTop>
┃ │ #And { true #Equals N:Int >=Int 0 }
┃ │ #And { true #Equals N:Int >Int 0 }
┃ │
┃ │ (1 step)
┃ ├─ 5
┃ │ <generatedTop>
┃ │ <k>
┃ │ addCounter ( N:Int +Int -1 )
┃ │ ~> K_CELL_fc656f08:K
┃ │ </k>
Expand All @@ -53,7 +73,7 @@ Subproofs: 0
┃ │ #And { true #Equals N:Int >=Int 0 }
┃ │
┃ │ (1 step)
┃ ├─ 5
┃ ├─ 7
┃ │ <generatedTop>
┃ │ <k>
┃ │ K_CELL_fc656f08:K
Expand Down Expand Up @@ -92,11 +112,31 @@ Subproofs: 0
┃ #And { true #Equals N:Int >=Int 0 }
┃ #And { true #Equals ?S:Int ==Int S:Int +Int N:Int *Int C:Int +Int N:Int -Int 1 *Int N:Int /Int 2 }
┗━━┓
┗━━┓ constraint: N:Int ==K 0
├─ 4
│ <generatedTop>
│ <k>
│ addCounter ( N:Int )
│ ~> K_CELL_fc656f08:K
│ </k>
│ <counter>
│ C:Int
│ </counter>
│ <sum>
│ S:Int
│ </sum>
│ <generatedCounter>
│ GENERATEDCOUNTER_CELL_949ec677:Int
│ </generatedCounter>
│ </generatedTop>
│ #And { true #Equals N:Int >=Int 0 }
│ #And { N:Int #Equals 0 }
│ (1 step)
├─ 6
│ <generatedTop>
│ <k>
│ K_CELL_fc656f08:K
│ </k>
│ <counter>
Expand Down
17 changes: 14 additions & 3 deletions pyk/src/pyk/cterm/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@
_LOGGER: Final = logging.getLogger(__name__)


class NextState(NamedTuple):
state: CTerm
condition: KInner | None


class CTermExecute(NamedTuple):
state: CTerm
next_states: tuple[CTerm, ...]
next_states: tuple[NextState, ...]
depth: int
vacuous: bool
logs: tuple[LogEntry, ...]
Expand Down Expand Up @@ -120,9 +125,15 @@ def execute(

state = CTerm.from_kast(self.kore_to_kast(response.state.kore))
resp_next_states = response.next_states or ()
next_states = tuple(CTerm.from_kast(self.kore_to_kast(ns.kore)) for ns in resp_next_states)
next_states = tuple(
NextState(
CTerm.from_kast(self.kore_to_kast(ns.kore)),
self.kore_to_kast(ns.rule_predicate) if ns.rule_predicate is not None else None,
)
for ns in resp_next_states
)

assert all(not cterm.is_bottom for cterm in next_states)
assert all(not cterm.is_bottom for cterm, _ in next_states)
PetarMax marked this conversation as resolved.
Show resolved Hide resolved
assert len(next_states) != 1 or response.reason is StopReason.CUT_POINT_RULE

return CTermExecute(
Expand Down
77 changes: 23 additions & 54 deletions pyk/src/pyk/kcfg/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import cached_property
from typing import TYPE_CHECKING

from ..cterm import CTerm
from ..kast.inner import KApply, KVariable
from ..kast.manip import (
flatten_label,
Expand All @@ -16,18 +15,16 @@
)
from ..kast.pretty import PrettyPrinter
from ..kore.rpc import RewriteSuccess
from ..prelude.kbool import notBool
from ..prelude.kint import leInt, ltInt
from ..prelude.ml import is_top, mlAnd, mlEqualsFalse, mlEqualsTrue, mlNot
from ..utils import shorten_hashes, single
from ..prelude.ml import is_top
from ..utils import not_none, shorten_hashes, single
from .kcfg import KCFG, Abstract, Branch, NDBranch, Step, Stuck, Vacuous
from .semantics import DefaultSemantics

if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Final

from ..cterm import CTermSymbolic
from ..cterm import CTerm, CTermSymbolic
from ..kast import KInner
from ..kcfg.exploration import KCFGExploration
from ..kore.rpc import LogEntry
Expand Down Expand Up @@ -225,19 +222,7 @@ def extract_rule_labels(_logs: tuple[LogEntry, ...]) -> list[str]:
log(f'abstraction node: {node_id}')
return Abstract(abstract_cterm)

_branches = self.kcfg_semantics.extract_branches(_cterm)
branches = []
for constraint in _branches:
kast = mlAnd(list(_cterm.constraints) + [constraint])
kast, _ = self.cterm_symbolic.kast_simplify(kast)
if not CTerm._is_bottom(kast):
branches.append(constraint)
if len(branches) > 1:
constraint_strs = [self.pretty_print(bc) for bc in branches]
log(f'{len(branches)} branches using heuristics: {node_id} -> {constraint_strs}')
return Branch(branches, heuristic=True)

cterm, next_cterms, depth, vacuous, next_node_logs = self.cterm_symbolic.execute(
cterm, next_states, depth, vacuous, next_node_logs = self.cterm_symbolic.execute(
_cterm,
depth=execute_depth,
cut_point_rules=cut_point_rules,
Expand All @@ -251,49 +236,33 @@ def extract_rule_labels(_logs: tuple[LogEntry, ...]) -> list[str]:
return Step(cterm, depth, next_node_logs, extract_rule_labels(next_node_logs))

# Stuck or vacuous
if not next_cterms:
if not next_states:
if vacuous:
log(f'vacuous node: {node_id}', warning=True)
return Vacuous()
log(f'stuck node: {node_id}')
return Stuck()

# Cut rule
if len(next_cterms) == 1:
if len(next_states) == 1:
log(f'cut-rule basic block at depth {depth}: {node_id}')
return Step(next_cterms[0], 1, next_node_logs, extract_rule_labels(next_node_logs), cut=True)
return Step(
next_states[0].state,
1,
next_node_logs,
extract_rule_labels(next_node_logs),
cut=True,
)

# Branch
assert len(next_cterms) > 1
branches = [mlAnd(c for c in s.constraints if c not in cterm.constraints) for s in next_cterms]
branch_and = mlAnd(branches)
branch_patterns = [
mlAnd([mlEqualsTrue(KVariable('B')), mlEqualsTrue(notBool(KVariable('B')))]),
mlAnd([mlEqualsTrue(notBool(KVariable('B'))), mlEqualsTrue(KVariable('B'))]),
mlAnd([mlEqualsTrue(KVariable('B')), mlEqualsFalse(KVariable('B'))]),
mlAnd([mlEqualsFalse(KVariable('B')), mlEqualsTrue(KVariable('B'))]),
mlAnd([mlNot(KVariable('B')), KVariable('B')]),
mlAnd([KVariable('B'), mlNot(KVariable('B'))]),
mlAnd(
[
mlEqualsTrue(ltInt(KVariable('I1'), KVariable('I2'))),
mlEqualsTrue(leInt(KVariable('I2'), KVariable('I1'))),
]
),
mlAnd(
[
mlEqualsTrue(leInt(KVariable('I1'), KVariable('I2'))),
mlEqualsTrue(ltInt(KVariable('I2'), KVariable('I1'))),
]
),
]

# Split on branch patterns
if any(branch_pattern.match(branch_and) for branch_pattern in branch_patterns):
constraint_strs = [self.pretty_print(bc) for bc in branches]
log(f'{len(branches)} branches using heuristics: {node_id} -> {constraint_strs}')
assert len(next_states) > 1
if all(branch_constraint for _, branch_constraint in next_states):
branches = [not_none(rule_predicate) for _, rule_predicate in next_states]
constraint_strs = [self.pretty_print(ml_pred_to_bool(bc)) for bc in branches]
log(f'{len(branches)} branches: {node_id} -> {constraint_strs}')
return Branch(branches)

# NDBranch on successor nodes
log(f'{len(next_cterms)} non-deterministic branches: {node_id}')
return NDBranch(next_cterms, next_node_logs, extract_rule_labels(next_node_logs))
else:
# NDBranch
log(f'{len(next_states)} non-deterministic branches: {node_id}')
next_cterms = [cterm for cterm, _ in next_states]
return NDBranch(next_cterms, next_node_logs, extract_rule_labels(next_node_logs))
7 changes: 0 additions & 7 deletions pyk/src/pyk/kcfg/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@

if TYPE_CHECKING:
from ..cterm import CTerm
from ..kast.inner import KInner


class KCFGSemantics(ABC):
@abstractmethod
def is_terminal(self, c: CTerm) -> bool: ...

@abstractmethod
def extract_branches(self, c: CTerm) -> list[KInner]: ...

@abstractmethod
def abstract_node(self, c: CTerm) -> CTerm: ...

Expand All @@ -26,9 +22,6 @@ class DefaultSemantics(KCFGSemantics):
def is_terminal(self, c: CTerm) -> bool:
return False

def extract_branches(self, c: CTerm) -> list[KInner]:
return []

def abstract_node(self, c: CTerm) -> CTerm:
return c

Expand Down
4 changes: 2 additions & 2 deletions pyk/src/tests/integration/cterm/test_multiple_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_execute(
'a ( X:KItem ) ~> .K',
] == split_next_k

step_1_res = cterm_symbolic.execute(split_next_terms[0], depth=1)
step_1_res = cterm_symbolic.execute(split_next_terms[0].state, depth=1)
step_1_k = kprint.pretty_print(step_1_res.state.cell('K_CELL'))
assert 'c ~> .K' == step_1_k

step_2_res = cterm_symbolic.execute(split_next_terms[1], depth=1)
step_2_res = cterm_symbolic.execute(split_next_terms[1].state, depth=1)
step_2_k = kprint.pretty_print(step_2_res.state.cell('K_CELL'))
assert 'c ~> .K' == step_2_k
2 changes: 1 addition & 1 deletion pyk/src/tests/integration/cterm/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_execute(
kprint.pretty_print(s.cell('K_CELL')),
kprint.pretty_print(s.cell('STATE_CELL')),
)
for s in exec_res.next_states
for s, _ in exec_res.next_states
]

# Then
Expand Down
18 changes: 0 additions & 18 deletions pyk/src/tests/integration/ktool/test_imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from pyk.cli.pyk import ProveOptions
from pyk.kast.inner import KApply, KSequence, KVariable
from pyk.kcfg.semantics import KCFGSemantics
from pyk.prelude.kbool import BOOL, notBool
from pyk.prelude.ml import mlEqualsTrue
from pyk.proof import ProofStatus
from pyk.testing import KProveTest
from pyk.utils import single
Expand All @@ -22,7 +20,6 @@
from typing import Final

from pyk.cterm import CTerm
from pyk.kast.inner import KInner
from pyk.kast.outer import KDefinition
from pyk.ktool.kprove import KProve

Expand All @@ -47,21 +44,6 @@ def is_terminal(self, c: CTerm) -> bool:
return True
return False

def extract_branches(self, c: CTerm) -> list[KInner]:
if self.definition is None:
raise ValueError('IMP branch extraction requires a non-None definition')

k_cell = c.cell('K_CELL')
if type(k_cell) is KSequence and len(k_cell) > 0:
k_cell = k_cell[0]
if type(k_cell) is KApply and k_cell.label.name == 'if(_)_else_':
condition = k_cell.args[0]
if (type(condition) is KVariable and condition.sort == BOOL) or (
type(condition) is KApply and self.definition.return_sort(condition.label) == BOOL
):
return [mlEqualsTrue(condition), mlEqualsTrue(notBool(condition))]
return []

def abstract_node(self, c: CTerm) -> CTerm:
return c

Expand Down
17 changes: 1 addition & 16 deletions pyk/src/tests/integration/proof/test_goto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

import pytest

from pyk.kast.inner import KApply, KSequence, KVariable
from pyk.kast.inner import KApply, KSequence
from pyk.kcfg.semantics import KCFGSemantics
from pyk.kcfg.show import KCFGShow
from pyk.prelude.ml import mlEqualsTrue
from pyk.prelude.utils import token
from pyk.proof import APRProof, APRProver, ProofStatus
from pyk.proof.show import APRProofNodePrinter
from pyk.testing import KCFGExploreTest, KProveTest
Expand All @@ -23,7 +21,6 @@
from typing import Final

from pyk.cterm import CTerm
from pyk.kast.inner import KInner
from pyk.kast.outer import KDefinition
from pyk.kcfg import KCFGExplore
from pyk.ktool.kprove import KProve
Expand All @@ -35,18 +32,6 @@ class GotoSemantics(KCFGSemantics):
def is_terminal(self, c: CTerm) -> bool:
return False

def extract_branches(self, c: CTerm) -> list[KInner]:
k_cell_pattern = KSequence([KApply('jumpi', [KVariable('JD')])])
stack_cell_pattern = KApply('ws', [KVariable('S'), KVariable('SS')])
k_cell_match = k_cell_pattern.match(c.cell('K_CELL'))
stack_cell_match = stack_cell_pattern.match(c.cell('STACK_CELL'))
if k_cell_match is not None and stack_cell_match is not None:
return [
mlEqualsTrue(KApply('_==Int_', [token(0), stack_cell_match['S']])),
mlEqualsTrue(KApply('_=/=Int_', [token(0), stack_cell_match['S']])),
]
return []

def abstract_node(self, c: CTerm) -> CTerm:
return c

Expand Down
Loading
Loading