diff --git a/pyk/src/pyk/__main__.py b/pyk/src/pyk/__main__.py index a7b8d6143ea..8664ccbeee6 100644 --- a/pyk/src/pyk/__main__.py +++ b/pyk/src/pyk/__main__.py @@ -36,8 +36,9 @@ from .kore.syntax import Pattern, kore_term from .ktool.kompile import Kompile, KompileBackend from .ktool.kprint import KPrint -from .ktool.kprove import KProve, ProveRpc +from .ktool.kprove import KProve from .ktool.krun import KRun +from .ktool.prove_rpc import ProveRpc from .prelude.k import GENERATED_TOP_CELL from .prelude.ml import is_top, mlAnd, mlOr from .proof.reachability import APRFailureInfo, APRProof diff --git a/pyk/src/pyk/ktool/claim_index.py b/pyk/src/pyk/ktool/claim_index.py new file mode 100644 index 00000000000..a33294f46b2 --- /dev/null +++ b/pyk/src/pyk/ktool/claim_index.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from functools import partial +from graphlib import TopologicalSorter +from typing import TYPE_CHECKING + +from ..kast import Atts +from ..kast.outer import KClaim +from ..utils import FrozenDict, unique + +if TYPE_CHECKING: + from collections.abc import Container, Iterable, Iterator + + from ..kast.outer import KFlatModule, KFlatModuleList + + +@dataclass(frozen=True) +class ClaimIndex(Mapping[str, KClaim]): + claims: FrozenDict[str, KClaim] + main_module_name: str | None + + def __init__( + self, + claims: Mapping[str, KClaim], + main_module_name: str | None = None, + ): + self._validate(claims) + object.__setattr__(self, 'claims', FrozenDict(claims)) + object.__setattr__(self, 'main_module_name', main_module_name) + + @staticmethod + def from_module_list(module_list: KFlatModuleList) -> ClaimIndex: + module_list = ClaimIndex._resolve_depends(module_list) + return ClaimIndex( + claims={claim.label: claim for module in module_list.modules for claim in module.claims}, + main_module_name=module_list.main_module, + ) + + @staticmethod + def _validate(claims: Mapping[str, KClaim]) -> None: + for label, claim in claims.items(): + if claim.label != label: + raise ValueError(f'Claim label mismatch, expected: {label}, found: {claim.label}') + + for depend in claim.dependencies: + if depend not in claims: + raise ValueError(f'Invalid dependency label: {depend}') + + @staticmethod + def _resolve_depends(module_list: KFlatModuleList) -> KFlatModuleList: + """Resolve each depends value relative to the module the claim belongs to. + + Example: + ``` + module THIS-MODULE + claim ... [depends(foo,OTHER-MODULE.bar)] + endmodule + ``` + + becomes + + ``` + module THIS-MODULE + claim ... [depends(THIS-MODULE.foo,OTHER-MODULE.bar)] + endmodule + ``` + """ + labels = {claim.label for module in module_list.modules for claim in module.claims} + + def resolve_claim_depends(module_name: str, claim: KClaim) -> KClaim: + depends = claim.dependencies + if not depends: + return claim + + resolve = partial(ClaimIndex._resolve_claim_label, labels, module_name) + resolved = [resolve(label) for label in depends] + return claim.let(att=claim.att.update([Atts.DEPENDS(','.join(resolved))])) + + modules: list[KFlatModule] = [] + for module in module_list.modules: + resolve_depends = partial(resolve_claim_depends, module.name) + module = module.map_sentences(resolve_depends, of_type=KClaim) + modules.append(module) + + return module_list.let(modules=modules) + + @staticmethod + def _resolve_claim_label(labels: Container[str], module_name: str | None, label: str) -> str: + """Resolve `label` to a valid label in `labels`, or raise. + + If a `label` is not found and `module_name` is set, the label is tried after qualifying. + """ + if label in labels: + return label + + if module_name is not None: + qualified = f'{module_name}.{label}' + if qualified in labels: + return qualified + + raise ValueError(f'Claim label not found: {label}') + + def __iter__(self) -> Iterator[str]: + return iter(self.claims) + + def __len__(self) -> int: + return len(self.claims) + + def __getitem__(self, label: str) -> KClaim: + try: + label = self.resolve(label) + except ValueError: + raise KeyError(f'Claim not found: {label}') from None + return self.claims[label] + + def resolve(self, label: str) -> str: + return self._resolve_claim_label(self.claims, self.main_module_name, label) + + def resolve_all(self, labels: Iterable[str]) -> list[str]: + return [self.resolve(label) for label in unique(labels)] + + def labels( + self, + *, + include: Iterable[str] | None = None, + exclude: Iterable[str] | None = None, + with_depends: bool = True, + ordered: bool = False, + ) -> list[str]: + """Return a list of labels from the index. + + Args: + include: Labels to include in the result. If `None`, all labels are included. + exclude: Labels to exclude from the result. If `None`, no labels are excluded. + Takes precedence over `include`. + with_depends: If `True`, the result is transitively closed w.r.t. the dependency relation. + Labels in `exclude` are pruned, and their dependencies are not considered on the given path. + ordered: If `True`, the result is topologically sorted w.r.t. the dependency relation. + + Returns: + A list of labels from the index. + + Raises: + ValueError: If an item in `include` or `exclude` cannot be resolved to a valid label. + """ + include = self.resolve_all(include) if include is not None else self.claims + exclude = self.resolve_all(exclude) if exclude is not None else [] + + labels: list[str] + + if with_depends: + labels = self._close_dependencies(labels=include, prune=exclude) + else: + labels = [label for label in include if label not in set(exclude)] + + if ordered: + return self._sort_topologically(labels) + + return labels + + def _close_dependencies(self, labels: Iterable[str], prune: Iterable[str]) -> list[str]: + res: list[str] = [] + + pending = list(labels) + done = set(prune) + + while pending: + label = pending.pop(0) # BFS + + if label in done: + continue + + res.append(label) + pending += self.claims[label].dependencies + done.add(label) + + return res + + def _sort_topologically(self, labels: list[str]) -> list[str]: + label_set = set(labels) + graph = { + label: [dep for dep in claim.dependencies if dep in label_set] + for label, claim in self.claims.items() + if label in labels + } + return list(TopologicalSorter(graph).static_order()) diff --git a/pyk/src/pyk/ktool/kprove.py b/pyk/src/pyk/ktool/kprove.py index 195bee76536..0c2c945f3e2 100644 --- a/pyk/src/pyk/ktool/kprove.py +++ b/pyk/src/pyk/ktool/kprove.py @@ -4,12 +4,8 @@ import logging import os import re -from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import dataclass from enum import Enum -from functools import cached_property, partial -from graphlib import TopologicalSorter from itertools import chain from pathlib import Path from subprocess import CalledProcessError @@ -17,27 +13,25 @@ from ..cli.utils import check_dir_path, check_file_path from ..cterm import CTerm -from ..kast import Atts, kast_term +from ..kast import kast_term from ..kast.inner import KInner -from ..kast.manip import extract_lhs, flatten_label -from ..kast.outer import KApply, KClaim, KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire +from ..kast.manip import flatten_label +from ..kast.outer import KDefinition, KFlatModule, KFlatModuleList, KImport, KRequire from ..kore.rpc import KoreExecLogFormat from ..prelude.ml import is_top -from ..proof import APRProof, APRProver, EqualityProof, ImpliesProver -from ..utils import FrozenDict, gen_file_timestamp, run_process, unique +from ..utils import gen_file_timestamp, run_process from . import TypeInferenceMode +from .claim_index import ClaimIndex from .kprint import KPrint if TYPE_CHECKING: - from collections.abc import Callable, Container, Iterable, Iterator + from collections.abc import Callable, Iterable, Iterator, Mapping from subprocess import CompletedProcess - from typing import ContextManager, Final + from typing import Final - from ..cli.pyk import ProveOptions - from ..kast.outer import KRule, KRuleLike + from ..kast.outer import KClaim, KRule, KRuleLike from ..kast.pretty import SymbolTable from ..kcfg import KCFGExplore - from ..proof import Proof, Prover from ..utils import BugReport _LOGGER: Final = logging.getLogger(__name__) @@ -401,219 +395,3 @@ def _get_rule_line(_line: str) -> tuple[str, bool, int] | None: axioms.pop(-1) return axioms - - -class ProveRpc: - _kprove: KProve - _explore_context: Callable[[], ContextManager[KCFGExplore]] - - def __init__( - self, - kprove: KProve, - explore_context: Callable[[], ContextManager[KCFGExplore]], - ): - self._kprove = kprove - self._explore_context = explore_context - - def prove_rpc(self, options: ProveOptions) -> list[Proof]: - all_claims = self._kprove.get_claims( - options.spec_file, - spec_module_name=options.spec_module, - claim_labels=options.claim_labels, - exclude_claim_labels=options.exclude_claim_labels, - type_inference_mode=options.type_inference_mode, - ) - - if all_claims is None: - raise ValueError(f'No claims found in file: {options.spec_file}') - - return [ - self._prove_claim_rpc( - claim, - max_depth=options.max_depth, - save_directory=options.save_directory, - max_iterations=options.max_iterations, - ) - for claim in all_claims - ] - - def _prove_claim_rpc( - self, - claim: KClaim, - max_depth: int | None = None, - save_directory: Path | None = None, - max_iterations: int | None = None, - ) -> Proof: - definition = self._kprove.definition - - proof: Proof - prover: Prover - lhs_top = extract_lhs(claim.body) - is_functional_claim = type(lhs_top) is KApply and definition.symbols[lhs_top.label.name] in definition.functions - - if is_functional_claim: - proof = EqualityProof.from_claim(claim, definition, proof_dir=save_directory) - if save_directory is not None and EqualityProof.proof_data_exists(proof.id, save_directory): - _LOGGER.info(f'Reloading from disk {proof.id}: {save_directory}') - proof = EqualityProof.read_proof_data(save_directory, proof.id) - - else: - proof = APRProof.from_claim(definition, claim, {}, proof_dir=save_directory) - if save_directory is not None and APRProof.proof_data_exists(proof.id, save_directory): - _LOGGER.info(f'Reloading from disk {proof.id}: {save_directory}') - proof = APRProof.read_proof_data(save_directory, proof.id) - - if not proof.passed and (max_iterations is None or max_iterations > 0): - with self._explore_context() as kcfg_explore: - if is_functional_claim: - assert type(proof) is EqualityProof - prover = ImpliesProver(proof, kcfg_explore) - else: - assert type(proof) is APRProof - prover = APRProver(kcfg_explore, execute_depth=max_depth) - prover.advance_proof(proof, max_iterations=max_iterations) - - if proof.passed: - _LOGGER.info(f'Proof passed: {proof.id}') - elif proof.failed: - _LOGGER.info(f'Proof failed: {proof.id}') - else: - _LOGGER.info(f'Proof pending: {proof.id}') - return proof - - -@dataclass(frozen=True) -class ClaimIndex(Mapping[str, KClaim]): - claims: FrozenDict[str, KClaim] - main_module_name: str | None - - def __init__( - self, - claims: Mapping[str, KClaim], - main_module_name: str | None = None, - ): - self._validate(claims) - object.__setattr__(self, 'claims', FrozenDict(claims)) - object.__setattr__(self, 'main_module_name', main_module_name) - - @staticmethod - def from_module_list(module_list: KFlatModuleList) -> ClaimIndex: - module_list = ClaimIndex._resolve_depends(module_list) - return ClaimIndex( - claims={claim.label: claim for module in module_list.modules for claim in module.claims}, - main_module_name=module_list.main_module, - ) - - @staticmethod - def _validate(claims: Mapping[str, KClaim]) -> None: - for label, claim in claims.items(): - if claim.label != label: - raise ValueError(f'Claim label mismatch, expected: {label}, found: {claim.label}') - - for depend in claim.dependencies: - if depend not in claims: - raise ValueError(f'Invalid dependency label: {depend}') - - @staticmethod - def _resolve_depends(module_list: KFlatModuleList) -> KFlatModuleList: - """Resolve each depends value relative to the module the claim belongs to. - - Example: - ``` - module THIS-MODULE - claim ... [depends(foo,OTHER-MODULE.bar)] - endmodule - ``` - - becomes - - ``` - module THIS-MODULE - claim ... [depends(THIS-MODULE.foo,OTHER-MODULE.bar)] - endmodule - ``` - """ - labels = {claim.label for module in module_list.modules for claim in module.claims} - - def resolve_claim_depends(module_name: str, claim: KClaim) -> KClaim: - depends = claim.dependencies - if not depends: - return claim - - resolve = partial(ClaimIndex._resolve_claim_label, labels, module_name) - resolved = [resolve(label) for label in depends] - return claim.let(att=claim.att.update([Atts.DEPENDS(','.join(resolved))])) - - modules: list[KFlatModule] = [] - for module in module_list.modules: - resolve_depends = partial(resolve_claim_depends, module.name) - module = module.map_sentences(resolve_depends, of_type=KClaim) - modules.append(module) - - return module_list.let(modules=modules) - - @staticmethod - def _resolve_claim_label(labels: Container[str], module_name: str | None, label: str) -> str: - """Resolve `label` to a valid label in `labels`, or raise. - - If a `label` is not found and `module_name` is set, the label is tried after qualifying. - """ - if label in labels: - return label - - if module_name is not None: - qualified = f'{module_name}.{label}' - if qualified in labels: - return qualified - - raise ValueError(f'Claim label not found: {label}') - - def __iter__(self) -> Iterator[str]: - return iter(self.claims) - - def __len__(self) -> int: - return len(self.claims) - - def __getitem__(self, label: str) -> KClaim: - try: - label = self.resolve(label) - except ValueError: - raise KeyError(f'Claim not found: {label}') from None - return self.claims[label] - - @cached_property - def topological(self) -> tuple[str, ...]: - graph = {label: claim.dependencies for label, claim in self.claims.items()} - return tuple(TopologicalSorter(graph).static_order()) - - def resolve(self, label: str) -> str: - return self._resolve_claim_label(self.claims, self.main_module_name, label) - - def resolve_all(self, labels: Iterable[str]) -> list[str]: - return [self.resolve(label) for label in unique(labels)] - - def labels( - self, - *, - include: Iterable[str] | None = None, - exclude: Iterable[str] | None = None, - with_depends: bool = True, - ) -> list[str]: - res: list[str] = [] - - pending = self.resolve_all(include) if include is not None else list(self.claims) - done = set(self.resolve_all(exclude)) if exclude is not None else set() - - while pending: - label = pending.pop(0) # BFS - - if label in done: - continue - - res.append(label) - done.add(label) - - if with_depends: - pending += self.claims[label].dependencies - - return res diff --git a/pyk/src/pyk/ktool/prove_rpc.py b/pyk/src/pyk/ktool/prove_rpc.py new file mode 100644 index 00000000000..250381d3f54 --- /dev/null +++ b/pyk/src/pyk/ktool/prove_rpc.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from ..kast.manip import extract_lhs +from ..kast.outer import KApply +from ..proof import APRProof, APRProver, EqualityProof, ImpliesProver + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + from typing import ContextManager, Final + + from ..cli.pyk import ProveOptions + from ..kast.outer import KClaim + from ..kcfg import KCFGExplore + from ..proof import Proof, Prover + from .kprove import KProve + +_LOGGER: Final = logging.getLogger(__name__) + + +class ProveRpc: + _kprove: KProve + _explore_context: Callable[[], ContextManager[KCFGExplore]] + + def __init__( + self, + kprove: KProve, + explore_context: Callable[[], ContextManager[KCFGExplore]], + ): + self._kprove = kprove + self._explore_context = explore_context + + def prove_rpc(self, options: ProveOptions) -> list[Proof]: + all_claims = self._kprove.get_claims( + options.spec_file, + spec_module_name=options.spec_module, + claim_labels=options.claim_labels, + exclude_claim_labels=options.exclude_claim_labels, + type_inference_mode=options.type_inference_mode, + ) + + if all_claims is None: + raise ValueError(f'No claims found in file: {options.spec_file}') + + return [ + self._prove_claim_rpc( + claim, + max_depth=options.max_depth, + save_directory=options.save_directory, + max_iterations=options.max_iterations, + ) + for claim in all_claims + ] + + def _prove_claim_rpc( + self, + claim: KClaim, + max_depth: int | None = None, + save_directory: Path | None = None, + max_iterations: int | None = None, + ) -> Proof: + definition = self._kprove.definition + + proof: Proof + prover: Prover + lhs_top = extract_lhs(claim.body) + is_functional_claim = type(lhs_top) is KApply and definition.symbols[lhs_top.label.name] in definition.functions + + if is_functional_claim: + proof = EqualityProof.from_claim(claim, definition, proof_dir=save_directory) + if save_directory is not None and EqualityProof.proof_data_exists(proof.id, save_directory): + _LOGGER.info(f'Reloading from disk {proof.id}: {save_directory}') + proof = EqualityProof.read_proof_data(save_directory, proof.id) + + else: + proof = APRProof.from_claim(definition, claim, {}, proof_dir=save_directory) + if save_directory is not None and APRProof.proof_data_exists(proof.id, save_directory): + _LOGGER.info(f'Reloading from disk {proof.id}: {save_directory}') + proof = APRProof.read_proof_data(save_directory, proof.id) + + if not proof.passed and (max_iterations is None or max_iterations > 0): + with self._explore_context() as kcfg_explore: + if is_functional_claim: + assert type(proof) is EqualityProof + prover = ImpliesProver(proof, kcfg_explore) + else: + assert type(proof) is APRProof + prover = APRProver(kcfg_explore, execute_depth=max_depth) + prover.advance_proof(proof, max_iterations=max_iterations) + + if proof.passed: + _LOGGER.info(f'Proof passed: {proof.id}') + elif proof.failed: + _LOGGER.info(f'Proof failed: {proof.id}') + else: + _LOGGER.info(f'Proof pending: {proof.id}') + return proof diff --git a/pyk/src/pyk/proof/reachability.py b/pyk/src/pyk/proof/reachability.py index a079190441e..e0ece7406f9 100644 --- a/pyk/src/pyk/proof/reachability.py +++ b/pyk/src/pyk/proof/reachability.py @@ -1,6 +1,5 @@ from __future__ import annotations -import graphlib import json import logging import re @@ -10,13 +9,13 @@ from pyk.kore.rpc import LogEntry from ..cterm.cterm import remove_useless_constraints -from ..kast.att import AttEntry, Atts from ..kast.inner import KInner, Subst from ..kast.manip import flatten_label, free_vars, ml_pred_to_bool from ..kast.outer import KFlatModule, KImport, KRule from ..kcfg import KCFG, KCFGStore from ..kcfg.exploration import KCFGExploration from ..konvert import kflatmodule_to_kore +from ..ktool.claim_index import ClaimIndex from ..prelude.ml import mlAnd, mlTop from ..utils import FrozenDict, ensure_dir_path, hash_str, shorten_hashes, single from .implies import ProofSummary, Prover, RefutationProof @@ -423,74 +422,28 @@ def from_spec_modules( logs: dict[int, tuple[LogEntry, ...]], proof_dir: Path | None = None, spec_labels: Iterable[str] | None = None, - **kwargs: Any, ) -> list[APRProof]: - claims_by_label = {claim.label: claim for module in spec_modules.modules for claim in module.claims} - if spec_labels is None: - spec_labels = list(claims_by_label.keys()) - _spec_labels = [] - for spec_label in spec_labels: - if spec_label in claims_by_label: - _spec_labels.append(spec_label) - elif f'{spec_modules.main_module}.{spec_label}' in claims_by_label: - _spec_labels.append(f'{spec_modules.main_module}.{spec_label}') + claim_index = ClaimIndex.from_module_list(spec_modules) + spec_labels = claim_index.labels(include=spec_labels, with_depends=True, ordered=True) + + res: list[APRProof] = [] + + for label in spec_labels: + if proof_dir is not None and Proof.proof_data_exists(label, proof_dir): + apr_proof = APRProof.read_proof_data(proof_dir, label) else: - raise ValueError( - f'Could not find specification label: {spec_label} or {spec_modules.main_module}.{spec_label}' + _LOGGER.info(f'Building APRProof for claim: {label}') + claim = claim_index[label] + apr_proof = APRProof.from_claim( + defn, + claim, + logs=logs, + proof_dir=proof_dir, ) - spec_labels = _spec_labels - - claims_graph: dict[str, list[str]] = {} - unfound_dependencies = [] - for module in spec_modules.modules: - for claim in module.claims: - claims_graph[claim.label] = [] - for dependency in claim.dependencies: - if dependency in claims_by_label: - claims_graph[claim.label].append(dependency) - elif f'{module.name}.{dependency}' in claims_by_label: - claims_graph[claim.label].append(f'{module.name}.{dependency}') - else: - unfound_dependencies.append((claim.label, module.name, dependency)) - if unfound_dependencies: - unfound_dependency_list = [ - f'Could not find dependency for claim {label}: {dependency} or {module_name}.{dependency}' - for label, module_name, dependency in unfound_dependencies - ] - unfound_dependency_message = '\n - ' + '\n - '.join(unfound_dependency_list) - raise ValueError(f'Could not find dependencies:{unfound_dependency_message}') - - claims_subgraph: dict[str, list[str]] = {} - remaining_claims = spec_labels - while len(remaining_claims) > 0: - claim_label = remaining_claims.pop() - claims_subgraph[claim_label] = claims_graph[claim_label] - remaining_claims.extend(claims_graph[claim_label]) - - topological_sorter = graphlib.TopologicalSorter(claims_subgraph) - topological_sorter.prepare() - apr_proofs: list[APRProof] = [] - while topological_sorter.is_active(): - for claim_label in topological_sorter.get_ready(): - if proof_dir is not None and Proof.proof_data_exists(claim_label, proof_dir): - apr_proof = APRProof.read_proof_data(proof_dir, claim_label) - else: - _LOGGER.info(f'Building APRProof for claim: {claim_label}') - claim = claims_by_label[claim_label] - if len(claims_graph[claim_label]) > 0: - claim_att = claim.att.update([AttEntry(Atts.DEPENDS, ','.join(claims_graph[claim_label]))]) - claim = claim.let_att(claim_att) - apr_proof = APRProof.from_claim( - defn, - claim, - logs=logs, - proof_dir=proof_dir, - ) - apr_proof.write_proof_data() - apr_proofs.append(apr_proof) - topological_sorter.done(claim_label) - - return apr_proofs + apr_proof.write_proof_data() + res.append(apr_proof) + + return res def path_constraints(self, final_node_id: NodeIdLike) -> KInner: path = self.shortest_path_to(final_node_id) diff --git a/pyk/src/tests/integration/ktool/test_imp.py b/pyk/src/tests/integration/ktool/test_imp.py index 347c12621c2..66b45faa559 100644 --- a/pyk/src/tests/integration/ktool/test_imp.py +++ b/pyk/src/tests/integration/ktool/test_imp.py @@ -10,7 +10,7 @@ from pyk.cli.pyk import ProveOptions from pyk.kast.inner import KApply, KSequence, KVariable from pyk.kcfg.semantics import KCFGSemantics -from pyk.ktool.kprove import ProveRpc +from pyk.ktool.prove_rpc import ProveRpc from pyk.proof import ProofStatus from pyk.testing import KCFGExploreTest, KProveTest from pyk.utils import single