Skip to content

Commit 1a4ed34

Browse files
committed
fix(cargo_provider): support workspace virtual manifests
1 parent b64c6f6 commit 1a4ed34

File tree

1 file changed

+87
-70
lines changed

1 file changed

+87
-70
lines changed
Lines changed: 87 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,122 @@
11
from __future__ import annotations
22

3-
import fnmatch
4-
import glob
3+
import fnmatch, glob
54
from pathlib import Path
6-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Iterable, Any, cast
76

87
from tomlkit import TOMLDocument, dumps, parse
98
from tomlkit.exceptions import NonExistentKey
10-
119
from commitizen.providers.base_provider import TomlProvider
1210

1311
if TYPE_CHECKING:
1412
from tomlkit.items import AoT
1513

1614

17-
class CargoProvider(TomlProvider):
18-
"""
19-
Cargo version management
15+
DictLike = dict[str, Any]
2016

21-
With support for `workspaces`
22-
"""
17+
18+
class CargoProvider(TomlProvider):
19+
"""Cargo version management for virtual workspace manifests + version.workspace=true members."""
2320

2421
filename = "Cargo.toml"
2522
lock_filename = "Cargo.lock"
2623

2724
@property
2825
def lock_file(self) -> Path:
29-
return Path() / self.lock_filename
26+
return Path(self.lock_filename)
3027

3128
def get(self, document: TOMLDocument) -> str:
32-
out = _try_get_workspace(document)["package"]["version"]
33-
if TYPE_CHECKING:
34-
assert isinstance(out, str)
35-
return out
29+
t = _root_version_table(document)
30+
v = t.get("version")
31+
if not isinstance(v, str):
32+
raise TypeError("expected root version to be a string")
33+
return v
3634

3735
def set(self, document: TOMLDocument, version: str) -> None:
38-
_try_get_workspace(document)["package"]["version"] = version
36+
_root_version_table(document)["version"] = version
3937

4038
def set_version(self, version: str) -> None:
4139
super().set_version(version)
4240
if self.lock_file.exists():
4341
self.set_lock_version(version)
4442

4543
def set_lock_version(self, version: str) -> None:
46-
cargo_toml_content = parse(self.file.read_text())
47-
cargo_lock_content = parse(self.lock_file.read_text())
48-
packages = cargo_lock_content["package"]
49-
44+
cargo_toml = parse(self.file.read_text())
45+
cargo_lock = parse(self.lock_file.read_text())
46+
packages = cargo_lock["package"]
5047
if TYPE_CHECKING:
5148
assert isinstance(packages, AoT)
5249

53-
try:
54-
cargo_package_name = cargo_toml_content["package"]["name"] # type: ignore[index]
55-
if TYPE_CHECKING:
56-
assert isinstance(cargo_package_name, str)
57-
for i, package in enumerate(packages):
58-
if package["name"] == cargo_package_name:
59-
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]
60-
break
61-
except NonExistentKey:
62-
workspace = cargo_toml_content.get("workspace", {})
63-
if TYPE_CHECKING:
64-
assert isinstance(workspace, dict)
65-
workspace_members = workspace.get("members", [])
66-
excluded_workspace_members = workspace.get("exclude", [])
67-
members_inheriting: list[str] = []
68-
69-
for member in workspace_members:
70-
for path in glob.glob(member, recursive=True):
71-
if any(
72-
fnmatch.fnmatch(path, pattern)
73-
for pattern in excluded_workspace_members
74-
):
75-
continue
76-
77-
cargo_file = Path(path) / "Cargo.toml"
78-
package_content = parse(cargo_file.read_text()).get("package", {})
79-
if TYPE_CHECKING:
80-
assert isinstance(package_content, dict)
81-
try:
82-
version_workspace = package_content["version"]["workspace"]
83-
if version_workspace is True:
84-
package_name = package_content["name"]
85-
if TYPE_CHECKING:
86-
assert isinstance(package_name, str)
87-
members_inheriting.append(package_name)
88-
except NonExistentKey:
89-
pass
90-
91-
for i, package in enumerate(packages):
92-
if package["name"] in members_inheriting:
93-
cargo_lock_content["package"][i]["version"] = version # type: ignore[index]
94-
95-
self.lock_file.write_text(dumps(cargo_lock_content))
96-
97-
98-
def _try_get_workspace(document: TOMLDocument) -> dict:
50+
root_pkg = _table_get(cargo_toml, "package")
51+
if root_pkg is not None:
52+
name = root_pkg.get("name")
53+
if isinstance(name, str):
54+
_lock_set_versions(packages, {name}, version)
55+
self.lock_file.write_text(dumps(cargo_lock))
56+
return
57+
58+
ws = _table_get(cargo_toml, "workspace") or {}
59+
members = cast(list[str], ws.get("members", []) or [])
60+
excludes = cast(list[str], ws.get("exclude", []) or [])
61+
inheriting = _workspace_inheriting_member_names(members, excludes)
62+
_lock_set_versions(packages, inheriting, version)
63+
self.lock_file.write_text(dumps(cargo_lock))
64+
65+
66+
def _table_get(doc: TOMLDocument, key: str) -> DictLike | None:
67+
"""Return a dict-like table for `key` if present, else None (type-safe for Pylance)."""
9968
try:
100-
workspace = document["workspace"]
101-
if TYPE_CHECKING:
102-
assert isinstance(workspace, dict)
103-
return workspace
69+
v = doc[key] # tomlkit returns Container/Table-like; typing is loose
10470
except NonExistentKey:
105-
return document
71+
return None
72+
return cast(DictLike, v) if hasattr(v, "get") else None
73+
74+
75+
def _root_version_table(doc: TOMLDocument) -> DictLike:
76+
"""Prefer [workspace.package]; fallback to [package]."""
77+
ws = _table_get(doc, "workspace")
78+
if ws is not None:
79+
pkg = ws.get("package")
80+
if hasattr(pkg, "get"):
81+
return cast(DictLike, pkg)
82+
pkg = _table_get(doc, "package")
83+
if pkg is None:
84+
raise NonExistentKey('expected either [workspace.package] or [package]')
85+
return pkg
86+
87+
88+
def _is_workspace_inherited_version(v: Any) -> bool:
89+
return hasattr(v, "get") and cast(DictLike, v).get("workspace") is True
90+
91+
92+
def _iter_member_dirs(members: Iterable[str], excludes: Iterable[str]) -> Iterable[Path]:
93+
for pat in members:
94+
for p in glob.glob(pat, recursive=True):
95+
if any(fnmatch.fnmatch(p, ex) for ex in excludes):
96+
continue
97+
yield Path(p)
98+
99+
100+
def _workspace_inheriting_member_names(members: Iterable[str], excludes: Iterable[str]) -> set[str]:
101+
out: set[str] = set()
102+
for d in _iter_member_dirs(members, excludes):
103+
cargo_file = d / "Cargo.toml"
104+
if not cargo_file.exists():
105+
continue
106+
pkg = parse(cargo_file.read_text()).get("package")
107+
if not hasattr(pkg, "get"):
108+
continue
109+
pkgd = cast(DictLike, pkg)
110+
if _is_workspace_inherited_version(pkgd.get("version")):
111+
name = pkgd.get("name")
112+
if isinstance(name, str):
113+
out.add(name)
114+
return out
115+
116+
117+
def _lock_set_versions(packages: Any, names: set[str], version: str) -> None:
118+
if not names:
119+
return
120+
for i, p in enumerate(packages):
121+
if getattr(p, "get", None) and p.get("name") in names:
122+
packages[i]["version"] = version # type: ignore[index]

0 commit comments

Comments
 (0)