diff --git a/src/pyk/proof/implies.py b/src/pyk/proof/implies.py index 6932b421c..8a1c98071 100644 --- a/src/pyk/proof/implies.py +++ b/src/pyk/proof/implies.py @@ -12,7 +12,7 @@ from ..prelude.kbool import BOOL, FALSE, TRUE from ..prelude.ml import is_bottom, is_top, mlAnd, mlEquals, mlEqualsFalse, mlEqualsTrue from ..utils import ensure_dir_path -from .proof import FailureInfo, Proof, ProofStatus, ProofSummary, Prover, StepResult +from .proof import FailureInfo, Proof, ProofStatus, ProofStep, ProofSummary, Prover, StepResult if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -26,6 +26,11 @@ _LOGGER: Final = logging.getLogger(__name__) +@dataclass +class ImpliesProofStep(ProofStep): + proof: ImpliesProof + + class ImpliesProof(Proof): antecedent: KInner consequent: KInner @@ -55,6 +60,9 @@ def __init__( self.simplified_consequent = simplified_consequent self.csubst = csubst + def get_steps(self) -> Iterable[ProofStep]: + return [ImpliesProofStep(self)] + def commit(self, result: StepResult) -> None: proof_type = type(self).__name__ if isinstance(result, ImpliesProofResult): @@ -404,29 +412,31 @@ def __init__(self, proof: ImpliesProof, kcfg_explore: KCFGExplore): super().__init__(kcfg_explore) self.proof = proof - def step_proof(self) -> Iterable[StepResult]: - proof_type = type(self.proof).__name__ - _LOGGER.info(f'Attempting {proof_type} {self.proof.id}') + def step_proof(self, step: ProofStep) -> Iterable[StepResult]: + assert isinstance(step, ImpliesProofStep) + + proof_type = type(step.proof).__name__ + _LOGGER.info(f'Attempting {proof_type} {step.proof.id}') - if self.proof.status is not ProofStatus.PENDING: - _LOGGER.info(f'{proof_type} finished {self.proof.id}: {self.proof.status}') + if step.proof.status is not ProofStatus.PENDING: + _LOGGER.info(f'{proof_type} finished {step.proof.id}: {step.proof.status}') return [] # to prove the equality, we check the implication of the form `constraints #Implies LHS #Equals RHS`, i.e. # "LHS equals RHS under these constraints" - simplified_antecedent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(self.proof.antecedent) - simplified_consequent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(self.proof.consequent) + simplified_antecedent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(step.proof.antecedent) + simplified_consequent, _ = self.kcfg_explore.cterm_symbolic.kast_simplify(step.proof.consequent) _LOGGER.info(f'Simplified antecedent: {self.kcfg_explore.pretty_print(simplified_antecedent)}') _LOGGER.info(f'Simplified consequent: {self.kcfg_explore.pretty_print(simplified_consequent)}') csubst: CSubst | None = None if is_bottom(simplified_antecedent): - _LOGGER.warning(f'Antecedent of implication (proof constraints) simplifies to #Bottom {self.proof.id}') + _LOGGER.warning(f'Antecedent of implication (proof constraints) simplifies to #Bottom {step.proof.id}') csubst = CSubst(Subst({}), ()) elif is_top(simplified_consequent): - _LOGGER.warning(f'Consequent of implication (proof equality) simplifies to #Top {self.proof.id}') + _LOGGER.warning(f'Consequent of implication (proof equality) simplifies to #Top {step.proof.id}') csubst = CSubst(Subst({}), ()) else: @@ -435,13 +445,13 @@ def step_proof(self) -> Iterable[StepResult]: _result = self.kcfg_explore.cterm_symbolic.implies( antecedent=CTerm(config=dummy_config, constraints=[simplified_antecedent]), consequent=CTerm(config=dummy_config, constraints=[simplified_consequent]), - bind_universally=self.proof.bind_universally, + bind_universally=step.proof.bind_universally, ) result = _result.csubst if result is not None: csubst = result - _LOGGER.info(f'{proof_type} finished {self.proof.id}: {self.proof.status}') + _LOGGER.info(f'{proof_type} finished {step.proof.id}: {step.proof.status}') return [ ImpliesProofResult( csubst=csubst, simplified_antecedent=simplified_antecedent, simplified_consequent=simplified_consequent diff --git a/src/pyk/proof/proof.py b/src/pyk/proof/proof.py index d04f30005..a6a7b565e 100644 --- a/src/pyk/proof/proof.py +++ b/src/pyk/proof/proof.py @@ -268,6 +268,9 @@ def lines(self) -> list[str]: subproofs_summaries = [subproof.summary for subproof in self.subproofs] return CompositeSummary([BaseSummary(self.id, self.status), *subproofs_summaries]) + @abstractmethod + def get_steps(self) -> Iterable[ProofStep]: ... + class ProofSummary(ABC): id: str @@ -296,6 +299,9 @@ def lines(self) -> list[str]: return [line for lines in (summary.lines for summary in self.summaries) for line in lines] +class ProofStep: ... + + class StepResult: ... @@ -313,21 +319,23 @@ def __init__(self, kcfg_explore: KCFGExplore): def failure_info(self) -> FailureInfo: ... @abstractmethod - def step_proof(self) -> Iterable[StepResult]: ... + def step_proof(self, step: ProofStep) -> Iterable[StepResult]: ... def advance_proof(self, max_iterations: int | None = None, fail_fast: bool = False) -> None: iterations = 0 - while self.proof.can_progress: - if fail_fast and self.proof.failed: - _LOGGER.warning(f'Terminating proof early because fail_fast is set: {self.proof.id}') - self.proof.failure_info = self.failure_info() - return - if max_iterations is not None and max_iterations <= iterations: - return - iterations += 1 - results = self.step_proof() - for result in results: - self.proof.commit(result) - self.proof.write_proof_data() + while True: + steps = self.proof.get_steps() + for step in steps: + if fail_fast and self.proof.failed: + _LOGGER.warning(f'Terminating proof early because fail_fast is set: {self.proof.id}') + self.proof.failure_info = self.failure_info() + return + if max_iterations is not None and max_iterations <= iterations: + return + iterations += 1 + results = self.step_proof(step) + for result in results: + self.proof.commit(result) + self.proof.write_proof_data() if self.proof.failed: self.proof.failure_info = self.failure_info() diff --git a/src/pyk/proof/reachability.py b/src/pyk/proof/reachability.py index 332636597..bc3e33795 100644 --- a/src/pyk/proof/reachability.py +++ b/src/pyk/proof/reachability.py @@ -19,7 +19,7 @@ from ..prelude.ml import mlAnd, mlTop from ..utils import FrozenDict, ensure_dir_path, hash_str, shorten_hashes, single from .implies import ProofSummary, Prover, RefutationProof -from .proof import CompositeSummary, FailureInfo, Proof, ProofStatus, StepResult +from .proof import CompositeSummary, FailureInfo, Proof, ProofStatus, ProofStep, StepResult if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -59,6 +59,12 @@ class APRProofTerminalResult(APRProofResult): ... class APRProofBoundedResult(APRProofResult): ... +@dataclass +class APRProofStep(ProofStep): + proof: APRProof + node_id: int + + class APRProof(Proof, KCFGExploration): """APRProof and APRProver implement all-path reachability logic, as introduced by A. Stefanescu and others in their paper 'All-Path Reachability Logic': @@ -132,6 +138,9 @@ def __init__( assert type(subproof) is RefutationProof self.node_refutations[node_id] = subproof + def get_steps(self) -> Iterable[APRProofStep]: + return [APRProofStep(self, node.id) for node in self.pending] + def commit(self, result: StepResult) -> None: if isinstance(result, APRProofExtendResult): self.kcfg.extend(result.extend_result, self.kcfg.node(result.node_id), logs=self.logs) @@ -720,37 +729,37 @@ def _check_subsume(self, node: KCFG.Node) -> CSubst | None: _LOGGER.info(f'Subsumed into target node {self.proof.id}: {shorten_hashes((node.id, self.proof.target))}') return csubst - def step_proof(self) -> Iterable[StepResult]: - if not self.proof.pending: - return [] - curr_node = self.proof.pending[0] + def step_proof(self, step: ProofStep) -> Iterable[StepResult]: + assert isinstance(step, APRProofStep) + + curr_node = step.proof.kcfg.node(step.node_id) - if self.proof.bmc_depth is not None and curr_node.id not in self._checked_for_bounded: - _LOGGER.info(f'Checking bmc depth for node {self.proof.id}: {curr_node.id}') + if step.proof.bmc_depth is not None and curr_node.id not in self._checked_for_bounded: + _LOGGER.info(f'Checking bmc depth for node {step.proof.id}: {curr_node.id}') self._checked_for_bounded.add(curr_node.id) prior_loops = [] - for succ in reversed(self.proof.shortest_path_to(curr_node.id)): + for succ in reversed(step.proof.shortest_path_to(curr_node.id)): if self.kcfg_explore.kcfg_semantics.same_loop(succ.source.cterm, curr_node.cterm): - if succ.source.id in self.proof.prior_loops_cache: - if self.proof.kcfg.zero_depth_between(succ.source.id, curr_node.id): - prior_loops = self.proof.prior_loops_cache[succ.source.id] + if succ.source.id in step.proof.prior_loops_cache: + if step.proof.kcfg.zero_depth_between(succ.source.id, curr_node.id): + prior_loops = step.proof.prior_loops_cache[succ.source.id] else: - prior_loops = self.proof.prior_loops_cache[succ.source.id] + [succ.source.id] + prior_loops = step.proof.prior_loops_cache[succ.source.id] + [succ.source.id] break else: - self.proof.prior_loops_cache[succ.source.id] = [] + step.proof.prior_loops_cache[succ.source.id] = [] - self.proof.prior_loops_cache[curr_node.id] = prior_loops + step.proof.prior_loops_cache[curr_node.id] = prior_loops - _LOGGER.info(f'Prior loop heads for node {self.proof.id}: {(curr_node.id, prior_loops)}') - if len(prior_loops) > self.proof.bmc_depth: - _LOGGER.warning(f'Bounded node {self.proof.id}: {curr_node.id} at bmc depth {self.proof.bmc_depth}') + _LOGGER.info(f'Prior loop heads for node {step.proof.id}: {(curr_node.id, prior_loops)}') + if len(prior_loops) > step.proof.bmc_depth: + _LOGGER.warning(f'Bounded node {step.proof.id}: {curr_node.id} at bmc depth {step.proof.bmc_depth}') return [APRProofBoundedResult(curr_node.id)] # Terminal checks for current node and target node is_terminal = self.kcfg_explore.kcfg_semantics.is_terminal(curr_node.cterm) - target_is_terminal = self.proof.is_terminal(self.proof.target) + target_is_terminal = step.proof.is_terminal(step.proof.target) terminal_result = [APRProofTerminalResult(node_id=curr_node.id)] if is_terminal else [] @@ -768,7 +777,7 @@ def step_proof(self) -> Iterable[StepResult]: module_name = self.circularities_module_name if self.nonzero_depth(curr_node) else self.dependencies_module_name - self.kcfg_explore.check_extendable(self.proof, curr_node) + self.kcfg_explore.check_extendable(step.proof, curr_node) extend_result = self.kcfg_explore.extend_cterm( curr_node.cterm, execute_depth=self.execute_depth,