Skip to content

Commit 72a1d91

Browse files
authored
Validate kernel Python dependencies (#182)
* Validate kernel Python dependencies * Fix nit: deps will never be None
1 parent a8bfd18 commit 72a1d91

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

.github/workflows/test.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ jobs:
5454
run: |
5555
uv run pytest tests
5656
57+
- name: Run dependency test with dependency installed
58+
run: |
59+
uv pip install nvidia-cutlass-dsl
60+
uv run pytest tests/test_deps.py
61+
5762
- name: Run staging tests
5863
env:
5964
HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}

src/kernels/deps.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import importlib.util
2+
from typing import List, Set
3+
4+
allowed_dependencies: Set[str] = {
5+
"einops",
6+
"nvidia-cutlass-dsl",
7+
}
8+
9+
10+
def validate_dependencies(dependencies: List[str]):
11+
"""
12+
Validate a list of dependencies to ensure they are installed.
13+
14+
Args:
15+
dependencies (`List[str]`): A list of dependency strings.
16+
"""
17+
for dependency in dependencies:
18+
if dependency not in allowed_dependencies:
19+
allowed = ", ".join(sorted(allowed_dependencies))
20+
raise ValueError(
21+
f"Invalid dependency: {dependency}, allowed dependencies: {allowed}"
22+
)
23+
24+
if importlib.util.find_spec(dependency.replace("-", "_")) is None:
25+
raise ImportError(
26+
f"Kernel requires dependency `{dependency}`. Please install with: pip install {dependency}"
27+
)

src/kernels/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from kernels._system import glibc_version
2020
from kernels._versions import select_revision_or_version
2121
from kernels.lockfile import KernelLock, VariantLock
22+
from kernels.deps import validate_dependencies
2223

2324
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
2425

@@ -89,6 +90,13 @@ def universal_build_variant() -> str:
8990

9091

9192
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
93+
metadata_path = variant_path / "metadata.json"
94+
if metadata_path.exists():
95+
with open(metadata_path, "r") as f:
96+
metadata = json.load(f)
97+
deps = metadata.get("python-depends", [])
98+
validate_dependencies(deps)
99+
92100
file_path = variant_path / "__init__.py"
93101
if not file_path.exists():
94102
file_path = variant_path / module_name / "__init__.py"

tests/test_deps.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from importlib.util import find_spec
2+
3+
import pytest
4+
5+
from kernels import get_kernel
6+
7+
8+
def test_python_deps():
9+
must_raise = find_spec("nvidia_cutlass_dsl") is None
10+
if must_raise:
11+
with pytest.raises(
12+
ImportError, match=r"Kernel requires dependency `nvidia-cutlass-dsl`"
13+
):
14+
get_kernel("kernels-test/python-dep")
15+
else:
16+
get_kernel("kernels-test/python-dep")
17+
18+
19+
def test_illegal_dep():
20+
with pytest.raises(ValueError, match=r"Invalid dependency: kepler-22b"):
21+
get_kernel("kernels-test/python-invalid-dep")

0 commit comments

Comments
 (0)