Skip to content

Commit 799e8fb

Browse files
committed
Add test for ensuring exposed pyclasses default to frozen
1 parent c95e8b1 commit 799e8fb

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Ensure exposed pyclasses default to frozen."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
from dataclasses import dataclass
7+
from pathlib import Path
8+
from typing import Iterator
9+
10+
PYCLASS_RE = re.compile(r"#\[\s*pyclass\s*(?:\((?P<args>.*?)\))?\s*\]", re.DOTALL)
11+
ARG_STRING_RE = re.compile(r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"")
12+
STRUCT_NAME_RE = re.compile(r"\b(?:pub\s+)?(?:struct|enum)\s+(?P<name>[A-Za-z_][A-Za-z0-9_]*)")
13+
14+
15+
@dataclass
16+
class PyClass:
17+
module: str
18+
name: str
19+
frozen: bool
20+
source: Path
21+
22+
23+
def iter_pyclasses(root: Path) -> Iterator[PyClass]:
24+
for path in root.rglob("*.rs"):
25+
text = path.read_text(encoding="utf8")
26+
for match in PYCLASS_RE.finditer(text):
27+
args = match.group("args") or ""
28+
frozen = re.search(r"\bfrozen\b", args) is not None
29+
30+
module = None
31+
name = None
32+
for arg_match in ARG_STRING_RE.finditer(args):
33+
key = arg_match.group("key")
34+
value = arg_match.group("value")
35+
if key == "module":
36+
module = value
37+
elif key == "name":
38+
name = value
39+
40+
remainder = text[match.end() :]
41+
struct_match = STRUCT_NAME_RE.search(remainder)
42+
struct_name = struct_match.group("name") if struct_match else None
43+
44+
yield PyClass(
45+
module=module or "datafusion",
46+
name=name or struct_name or "<unknown>",
47+
frozen=frozen,
48+
source=path,
49+
)
50+
51+
52+
def test_pyclasses_are_frozen() -> None:
53+
allowlist = {
54+
# NOTE: Any new exceptions must include a justification comment in the Rust source
55+
# and, ideally, a follow-up issue to remove the exemption.
56+
("datafusion.common", "SqlTable"),
57+
("datafusion.common", "SqlView"),
58+
("datafusion.common", "DataTypeMap"),
59+
("datafusion.expr", "TryCast"),
60+
("datafusion.expr", "WriteOp"),
61+
}
62+
63+
unfrozen = [
64+
pyclass
65+
for pyclass in iter_pyclasses(Path("src"))
66+
if not pyclass.frozen and (pyclass.module, pyclass.name) not in allowlist
67+
]
68+
69+
assert not unfrozen, (
70+
"Found pyclasses missing `frozen`; add them to the allowlist only with a "
71+
"justification comment and follow-up plan:\n" +
72+
"\n".join(
73+
f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})"
74+
for pyclass in unfrozen
75+
)
76+
)

0 commit comments

Comments
 (0)