Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 49 additions & 24 deletions src/_canary/testspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def unmasked(cls) -> "Mask":
@dataclasses.dataclass
class BaseSpec(Generic[T]):
file_root: Path
"""The search path contiaining the generated spec; typically the some version-control root"""
file_path: Path
"""The path to the test file relative to `file_root`"""
family: str = ""
stdout: str = "canary-out.txt"
stderr: str | None = None # combine stdout/stderr by default
Expand Down Expand Up @@ -131,6 +133,7 @@ def from_dict(cls: Type[T], d: dict, lookup: dict[str, T]) -> T:

@cached_property
def file(self) -> Path:
"""Path to the test specification file"""
return self.file_root / self.file_path

@property
Expand Down Expand Up @@ -482,30 +485,18 @@ def _generate_dependency_patterns(


class _GlobalSpecCache:
_file_hash: dict[Path, str] = {}
_repo_root: dict[Path, Path] = {}
_lock = threading.Lock()
"""Simple cache for storing re-used data feeding the spec ID"""

@classmethod
def file_hash(cls, path: Path) -> str:
path = path.absolute()
try:
return cls._file_hash[path]
except KeyError:
pass
h = hashlib.blake2b(digest_size=16)
h.update(path.read_bytes())
digest = h.hexdigest()
with cls._lock:
return cls._file_hash.setdefault(path, digest)
_key: dict[Path, Path] = {}
"""Maps the input file path to a key index (absolute path)"""

_file_hash: dict[Path, bytes] = {}
_repo_root: dict[Path, bytes] = {}
_rel_repo: dict[Path, bytes] = {}
_lock = threading.Lock()

@classmethod
def repo_root(cls, path: Path) -> Path:
path = path.absolute()
try:
return cls._repo_root[path]
except KeyError:
pass
def _compute_repo_root(cls, path: Path) -> Path:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this method would return user paths and would result in unstable spec IDs for different users. Should we actually include this in the ID hash, or only use it to compute the relative paths? Thoughts @tjfulle ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a counter argument, the returned path is simply the parent if we don't find a VC directory in the tree. Generally speaking, constructing stable hashes from such a path would be very difficult or impossible so maybe this fall-back case should instead return /? This would make rel_repo be the full-path to the file and we can drop repo_root from the hash.

d = path.parent
while d.parent != d:
if (d / ".git").exists() or (d / ".repo").exists():
Expand All @@ -514,8 +505,41 @@ def repo_root(cls, path: Path) -> Path:
d = d.parent
else:
root = path.parent
return root

@classmethod
def populate_cache(cls, path: Path) -> Path:
try:
return cls._key[path]
except KeyError:
pass

key = path.absolute()
h = hashlib.blake2b(digest_size=16)
h.update(key.read_bytes())
root = cls._compute_repo_root(key)
rel = key.relative_to(root)

with cls._lock:
return cls._repo_root.setdefault(path, root)
cls._repo_root[key] = str(root).encode()
cls._rel_repo[key] = str(rel).encode()
cls._file_hash[key] = h.hexdigest().encode()
return cls._key.setdefault(path, key)

@classmethod
def file_hash(cls, path: Path) -> bytes:
key = cls.populate_cache(path)
return cls._file_hash[key]

@classmethod
def repo_root(cls, path: Path) -> bytes:
key = cls.populate_cache(path)
return cls._repo_root[key]

@classmethod
def rel_repo(cls, path: Path) -> bytes:
key = cls.populate_cache(path)
return cls._rel_repo[key]


def build_spec_id(spec: BaseSpec) -> str:
Expand All @@ -526,8 +550,9 @@ def build_spec_id(spec: BaseSpec) -> str:
if parameters:
for p in sorted(parameters):
hasher.update(f"{p}={stringify(parameters[p], float_fmt='%.16e')}".encode())
hasher.update(_GlobalSpecCache.file_hash(spec.file).encode())
hasher.update(str(_GlobalSpecCache.repo_root(spec.file)).encode())
hasher.update(_GlobalSpecCache.file_hash(spec.file))
hasher.update(_GlobalSpecCache.repo_root(spec.file))
hasher.update(_GlobalSpecCache.rel_repo(spec.file))
return hasher.hexdigest()


Expand Down
9 changes: 5 additions & 4 deletions src/_canary/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Sequence
from typing import TypeVar

Expand Down Expand Up @@ -1169,14 +1170,14 @@ def delete_selection(self, tag: str) -> bool:
self.connection.execute("DELETE FROM selections WHERE tag = ?", (tag,))
return True

def get_updownstream_ids(self, seeds: list[str] | None = None) -> tuple[set[str], set[str]]:
def get_updownstream_ids(self, seeds: Iterable[str] | None = None) -> tuple[set[str], set[str]]:
if seeds is None:
return set(), set()
downstream = self.get_downstream_ids(seeds)
upstream = self.get_upstream_ids(list(downstream.union(seeds)))
upstream = self.get_upstream_ids(downstream.union(seeds))
return upstream, downstream

def get_downstream_ids(self, seeds: list[str]) -> set[str]:
def get_downstream_ids(self, seeds: Iterable[str]) -> set[str]:
"""Return dependencies in instantiation order."""
if not seeds:
return set()
Expand All @@ -1198,7 +1199,7 @@ def get_downstream_ids(self, seeds: list[str]) -> set[str]:
rows = self.connection.execute(query, tuple(seeds)).fetchall()
return {r[0] for r in rows}

def get_upstream_ids(self, seeds: list[str]) -> set[str]:
def get_upstream_ids(self, seeds: Iterable[str]) -> set[str]:
"""Return dependents in reverse instantiation order."""
if not seeds:
return set()
Expand Down
Loading