diff --git a/src/_canary/testspec.py b/src/_canary/testspec.py index 86ab2ab6..8bf05fca 100644 --- a/src/_canary/testspec.py +++ b/src/_canary/testspec.py @@ -60,7 +60,9 @@ def unmasked(cls) -> "Mask": @dataclasses.dataclass class BaseSpec(Generic[T]): file_root: Path + """The search path contiaining the generated spec; typically the some version-control root""" file_path: Path + """The path to the test file relative to `file_root`""" family: str = "" stdout: str = "canary-out.txt" stderr: str | None = None # combine stdout/stderr by default @@ -131,6 +133,7 @@ def from_dict(cls: Type[T], d: dict, lookup: dict[str, T]) -> T: @cached_property def file(self) -> Path: + """Path to the test specification file""" return self.file_root / self.file_path @property @@ -482,30 +485,18 @@ def _generate_dependency_patterns( class _GlobalSpecCache: - _file_hash: dict[Path, str] = {} - _repo_root: dict[Path, Path] = {} - _lock = threading.Lock() + """Simple cache for storing re-used data feeding the spec ID""" - @classmethod - def file_hash(cls, path: Path) -> str: - path = path.absolute() - try: - return cls._file_hash[path] - except KeyError: - pass - h = hashlib.blake2b(digest_size=16) - h.update(path.read_bytes()) - digest = h.hexdigest() - with cls._lock: - return cls._file_hash.setdefault(path, digest) + _key: dict[Path, Path] = {} + """Maps the input file path to a key index (absolute path)""" + + _file_hash: dict[Path, bytes] = {} + _repo_root: dict[Path, bytes] = {} + _rel_repo: dict[Path, bytes] = {} + _lock = threading.Lock() @classmethod - def repo_root(cls, path: Path) -> Path: - path = path.absolute() - try: - return cls._repo_root[path] - except KeyError: - pass + def _compute_repo_root(cls, path: Path) -> Path: d = path.parent while d.parent != d: if (d / ".git").exists() or (d / ".repo").exists(): @@ -514,8 +505,41 @@ def repo_root(cls, path: Path) -> Path: d = d.parent else: root = path.parent + return root + + @classmethod + def populate_cache(cls, path: Path) -> Path: + try: + return cls._key[path] + except KeyError: + pass + + key = path.absolute() + h = hashlib.blake2b(digest_size=16) + h.update(key.read_bytes()) + root = cls._compute_repo_root(key) + rel = key.relative_to(root) + with cls._lock: - return cls._repo_root.setdefault(path, root) + cls._repo_root[key] = str(root).encode() + cls._rel_repo[key] = str(rel).encode() + cls._file_hash[key] = h.hexdigest().encode() + return cls._key.setdefault(path, key) + + @classmethod + def file_hash(cls, path: Path) -> bytes: + key = cls.populate_cache(path) + return cls._file_hash[key] + + @classmethod + def repo_root(cls, path: Path) -> bytes: + key = cls.populate_cache(path) + return cls._repo_root[key] + + @classmethod + def rel_repo(cls, path: Path) -> bytes: + key = cls.populate_cache(path) + return cls._rel_repo[key] def build_spec_id(spec: BaseSpec) -> str: @@ -526,8 +550,9 @@ def build_spec_id(spec: BaseSpec) -> str: if parameters: for p in sorted(parameters): hasher.update(f"{p}={stringify(parameters[p], float_fmt='%.16e')}".encode()) - hasher.update(_GlobalSpecCache.file_hash(spec.file).encode()) - hasher.update(str(_GlobalSpecCache.repo_root(spec.file)).encode()) + hasher.update(_GlobalSpecCache.file_hash(spec.file)) + hasher.update(_GlobalSpecCache.repo_root(spec.file)) + hasher.update(_GlobalSpecCache.rel_repo(spec.file)) return hasher.hexdigest() diff --git a/src/_canary/workspace.py b/src/_canary/workspace.py index b26ba1d9..1943f06c 100644 --- a/src/_canary/workspace.py +++ b/src/_canary/workspace.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import Any from typing import Callable +from typing import Iterable from typing import Sequence from typing import TypeVar @@ -1169,14 +1170,14 @@ def delete_selection(self, tag: str) -> bool: self.connection.execute("DELETE FROM selections WHERE tag = ?", (tag,)) return True - def get_updownstream_ids(self, seeds: list[str] | None = None) -> tuple[set[str], set[str]]: + def get_updownstream_ids(self, seeds: Iterable[str] | None = None) -> tuple[set[str], set[str]]: if seeds is None: return set(), set() downstream = self.get_downstream_ids(seeds) - upstream = self.get_upstream_ids(list(downstream.union(seeds))) + upstream = self.get_upstream_ids(downstream.union(seeds)) return upstream, downstream - def get_downstream_ids(self, seeds: list[str]) -> set[str]: + def get_downstream_ids(self, seeds: Iterable[str]) -> set[str]: """Return dependencies in instantiation order.""" if not seeds: return set() @@ -1198,7 +1199,7 @@ def get_downstream_ids(self, seeds: list[str]) -> set[str]: rows = self.connection.execute(query, tuple(seeds)).fetchall() return {r[0] for r in rows} - def get_upstream_ids(self, seeds: list[str]) -> set[str]: + def get_upstream_ids(self, seeds: Iterable[str]) -> set[str]: """Return dependents in reverse instantiation order.""" if not seeds: return set()