Skip to content
This repository has been archived by the owner on Apr 25, 2024. It is now read-only.

Commit

Permalink
Add separate get proof steps and execute proof step phase
Browse files Browse the repository at this point in the history
  • Loading branch information
nwatson22 committed Apr 5, 2024
1 parent 7ee5e70 commit 6c7eb22
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 44 deletions.
34 changes: 22 additions & 12 deletions src/pyk/proof/implies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..prelude.kbool import BOOL, FALSE, TRUE
from ..prelude.ml import is_bottom, is_top, mlAnd, mlEquals, mlEqualsFalse, mlEqualsTrue
from ..utils import ensure_dir_path
from .proof import FailureInfo, Proof, ProofStatus, ProofSummary, Prover, StepResult
from .proof import FailureInfo, Proof, ProofStatus, ProofStep, ProofSummary, Prover, StepResult

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand All @@ -26,6 +26,11 @@
_LOGGER: Final = logging.getLogger(__name__)


@dataclass
class ImpliesProofStep(ProofStep):
proof: ImpliesProof


class ImpliesProof(Proof):
antecedent: KInner
consequent: KInner
Expand Down Expand Up @@ -55,6 +60,9 @@ def __init__(
self.simplified_consequent = simplified_consequent
self.csubst = csubst

def get_steps(self) -> Iterable[ProofStep]:
return [ImpliesProofStep(self)]

def commit(self, result: StepResult) -> None:
proof_type = type(self).__name__
if isinstance(result, ImpliesProofResult):
Expand Down Expand Up @@ -404,29 +412,31 @@ def __init__(self, proof: ImpliesProof, kcfg_explore: KCFGExplore):
super().__init__(kcfg_explore)
self.proof = proof

def step_proof(self) -> Iterable[StepResult]:
proof_type = type(self.proof).__name__
_LOGGER.info(f'Attempting {proof_type} {self.proof.id}')
def step_proof(self, step: ProofStep) -> Iterable[StepResult]:
assert isinstance(step, ImpliesProofStep)

proof_type = type(step.proof).__name__
_LOGGER.info(f'Attempting {proof_type} {step.proof.id}')

if self.proof.status is not ProofStatus.PENDING:
_LOGGER.info(f'{proof_type} finished {self.proof.id}: {self.proof.status}')
if step.proof.status is not ProofStatus.PENDING:
_LOGGER.info(f'{proof_type} finished {step.proof.id}: {step.proof.status}')
return []

# to prove the equality, we check the implication of the form `constraints #Implies LHS #Equals RHS`, i.e.
# "LHS equals RHS under these constraints"
simplified_antecedent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(self.proof.antecedent)
simplified_consequent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(self.proof.consequent)
simplified_antecedent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(step.proof.antecedent)
simplified_consequent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(step.proof.consequent)
_LOGGER.info(f'Simplified antecedent: {self.kcfg_explore.pretty_print(simplified_antecedent)}')
_LOGGER.info(f'Simplified consequent: {self.kcfg_explore.pretty_print(simplified_consequent)}')

csubst: CSubst | None = None

if is_bottom(simplified_antecedent):
_LOGGER.warning(f'Antecedent of implication (proof constraints) simplifies to #Bottom {self.proof.id}')
_LOGGER.warning(f'Antecedent of implication (proof constraints) simplifies to #Bottom {step.proof.id}')
csubst = CSubst(Subst({}), ())

elif is_top(simplified_consequent):
_LOGGER.warning(f'Consequent of implication (proof equality) simplifies to #Top {self.proof.id}')
_LOGGER.warning(f'Consequent of implication (proof equality) simplifies to #Top {step.proof.id}')
csubst = CSubst(Subst({}), ())

else:
Expand All @@ -435,13 +445,13 @@ def step_proof(self) -> Iterable[StepResult]:
_result = self.kcfg_explore.cterm_symbolic.implies(
antecedent=CTerm(config=dummy_config, constraints=[simplified_antecedent]),
consequent=CTerm(config=dummy_config, constraints=[simplified_consequent]),
bind_universally=self.proof.bind_universally,
bind_universally=step.proof.bind_universally,
)
result = _result.csubst
if result is not None:
csubst = result

_LOGGER.info(f'{proof_type} finished {self.proof.id}: {self.proof.status}')
_LOGGER.info(f'{proof_type} finished {step.proof.id}: {step.proof.status}')
return [
ImpliesProofResult(
csubst=csubst, simplified_antecedent=simplified_antecedent, simplified_consequent=simplified_consequent
Expand Down
34 changes: 21 additions & 13 deletions src/pyk/proof/proof.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def lines(self) -> list[str]:
subproofs_summaries = [subproof.summary for subproof in self.subproofs]
return CompositeSummary([BaseSummary(self.id, self.status), *subproofs_summaries])

@abstractmethod
def get_steps(self) -> Iterable[ProofStep]: ...


class ProofSummary(ABC):
id: str
Expand Down Expand Up @@ -296,6 +299,9 @@ def lines(self) -> list[str]:
return [line for lines in (summary.lines for summary in self.summaries) for line in lines]


class ProofStep: ...


class StepResult: ...


Expand All @@ -313,21 +319,23 @@ def __init__(self, kcfg_explore: KCFGExplore):
def failure_info(self) -> FailureInfo: ...

@abstractmethod
def step_proof(self) -> Iterable[StepResult]: ...
def step_proof(self, step: ProofStep) -> Iterable[StepResult]: ...

def advance_proof(self, max_iterations: int | None = None, fail_fast: bool = False) -> None:
iterations = 0
while self.proof.can_progress:
if fail_fast and self.proof.failed:
_LOGGER.warning(f'Terminating proof early because fail_fast is set: {self.proof.id}')
self.proof.failure_info = self.failure_info()
return
if max_iterations is not None and max_iterations <= iterations:
return
iterations += 1
results = self.step_proof()
for result in results:
self.proof.commit(result)
self.proof.write_proof_data()
while True:
steps = self.proof.get_steps()
for step in steps:
if fail_fast and self.proof.failed:
_LOGGER.warning(f'Terminating proof early because fail_fast is set: {self.proof.id}')
self.proof.failure_info = self.failure_info()
return
if max_iterations is not None and max_iterations <= iterations:
return
iterations += 1
results = self.step_proof(step)
for result in results:
self.proof.commit(result)
self.proof.write_proof_data()
if self.proof.failed:
self.proof.failure_info = self.failure_info()
47 changes: 28 additions & 19 deletions src/pyk/proof/reachability.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ..prelude.ml import mlAnd, mlTop
from ..utils import FrozenDict, ensure_dir_path, hash_str, shorten_hashes, single
from .implies import ProofSummary, Prover, RefutationProof
from .proof import CompositeSummary, FailureInfo, Proof, ProofStatus, StepResult
from .proof import CompositeSummary, FailureInfo, Proof, ProofStatus, ProofStep, StepResult

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -59,6 +59,12 @@ class APRProofTerminalResult(APRProofResult): ...
class APRProofBoundedResult(APRProofResult): ...


@dataclass
class APRProofStep(ProofStep):
proof: APRProof
node_id: int


class APRProof(Proof, KCFGExploration):
"""APRProof and APRProver implement all-path reachability logic,
as introduced by A. Stefanescu and others in their paper 'All-Path Reachability Logic':
Expand Down Expand Up @@ -132,6 +138,9 @@ def __init__(
assert type(subproof) is RefutationProof
self.node_refutations[node_id] = subproof

def get_steps(self) -> Iterable[APRProofStep]:
return [APRProofStep(self, node.id) for node in self.pending]

def commit(self, result: StepResult) -> None:
if isinstance(result, APRProofExtendResult):
self.kcfg.extend(result.extend_result, self.kcfg.node(result.node_id), logs=self.logs)
Expand Down Expand Up @@ -720,37 +729,37 @@ def _check_subsume(self, node: KCFG.Node) -> CSubst | None:
_LOGGER.info(f'Subsumed into target node {self.proof.id}: {shorten_hashes((node.id, self.proof.target))}')
return csubst

def step_proof(self) -> Iterable[StepResult]:
if not self.proof.pending:
return []
curr_node = self.proof.pending[0]
def step_proof(self, step: ProofStep) -> Iterable[StepResult]:
assert isinstance(step, APRProofStep)

curr_node = step.proof.kcfg.node(step.node_id)

if self.proof.bmc_depth is not None and curr_node.id not in self._checked_for_bounded:
_LOGGER.info(f'Checking bmc depth for node {self.proof.id}: {curr_node.id}')
if step.proof.bmc_depth is not None and curr_node.id not in self._checked_for_bounded:
_LOGGER.info(f'Checking bmc depth for node {step.proof.id}: {curr_node.id}')
self._checked_for_bounded.add(curr_node.id)

prior_loops = []
for succ in reversed(self.proof.shortest_path_to(curr_node.id)):
for succ in reversed(step.proof.shortest_path_to(curr_node.id)):
if self.kcfg_explore.kcfg_semantics.same_loop(succ.source.cterm, curr_node.cterm):
if succ.source.id in self.proof.prior_loops_cache:
if self.proof.kcfg.zero_depth_between(succ.source.id, curr_node.id):
prior_loops = self.proof.prior_loops_cache[succ.source.id]
if succ.source.id in step.proof.prior_loops_cache:
if step.proof.kcfg.zero_depth_between(succ.source.id, curr_node.id):
prior_loops = step.proof.prior_loops_cache[succ.source.id]
else:
prior_loops = self.proof.prior_loops_cache[succ.source.id] + [succ.source.id]
prior_loops = step.proof.prior_loops_cache[succ.source.id] + [succ.source.id]
break
else:
self.proof.prior_loops_cache[succ.source.id] = []
step.proof.prior_loops_cache[succ.source.id] = []

self.proof.prior_loops_cache[curr_node.id] = prior_loops
step.proof.prior_loops_cache[curr_node.id] = prior_loops

_LOGGER.info(f'Prior loop heads for node {self.proof.id}: {(curr_node.id, prior_loops)}')
if len(prior_loops) > self.proof.bmc_depth:
_LOGGER.warning(f'Bounded node {self.proof.id}: {curr_node.id} at bmc depth {self.proof.bmc_depth}')
_LOGGER.info(f'Prior loop heads for node {step.proof.id}: {(curr_node.id, prior_loops)}')
if len(prior_loops) > step.proof.bmc_depth:
_LOGGER.warning(f'Bounded node {step.proof.id}: {curr_node.id} at bmc depth {step.proof.bmc_depth}')
return [APRProofBoundedResult(curr_node.id)]

# Terminal checks for current node and target node
is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(curr_node.cterm)
target_is_terminal = self.proof.is_terminal(self.proof.target)
target_is_terminal = step.proof.is_terminal(step.proof.target)

terminal_result = [APRProofTerminalResult(node_id=curr_node.id)] if is_terminal else []

Expand All @@ -768,7 +777,7 @@ def step_proof(self) -> Iterable[StepResult]:

module_name = self.circularities_module_name if self.nonzero_depth(curr_node) else self.dependencies_module_name

self.kcfg_explore.check_extendable(self.proof, curr_node)
self.kcfg_explore.check_extendable(step.proof, curr_node)
extend_result = self.kcfg_explore.extend_cterm(
curr_node.cterm,
execute_depth=self.execute_depth,
Expand Down

0 comments on commit 6c7eb22

Please sign in to comment.