Skip to content

Commit 357ae12

Browse files
Bordapre-commit-ci[bot]Copilot
authored
cli: debug optional args (#420)
* tests: debug optional args * as_positional=False * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <[email protected]>
1 parent 757ea65 commit 357ae12

File tree

3 files changed

+78
-8
lines changed

3 files changed

+78
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
### Fixed
1818

19+
- CLI: fix accepting keyword arguments ([#420](https://github.com/Lightning-AI/utilities/pull/420))
1920
- Scripts: fix CLI parsing ([#419](https://github.com/Lightning-AI/utilities/pull/419))
2021

2122

src/lightning_utilities/cli/__main__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,18 @@ def main() -> None:
2121
from jsonargparse import auto_cli, set_parsing_settings
2222

2323
set_parsing_settings(parse_optionals_as_positionals=True)
24-
auto_cli({
25-
"requirements": {
26-
"_help": "Manage requirements files.",
27-
"prune-pkgs": prune_packages_in_requirements,
28-
"set-oldest": replace_oldest_version,
29-
"replace-pkg": replace_package_in_requirements,
24+
auto_cli(
25+
{
26+
"requirements": {
27+
"_help": "Manage requirements files.",
28+
"prune-pkgs": prune_packages_in_requirements,
29+
"set-oldest": replace_oldest_version,
30+
"replace-pkg": replace_package_in_requirements,
31+
},
32+
"version": _get_version,
3033
},
31-
"version": _get_version,
32-
})
34+
as_positional=False,
35+
)
3336

3437

3538
if __name__ == "__main__":
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import subprocess
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
7+
def test_version():
8+
"""Prints the help message for the requirements commands."""
9+
return_code = subprocess.call(["python", "-mlightning_utilities.cli", "version"]) # noqa: S607
10+
assert return_code == 0
11+
12+
13+
@pytest.mark.parametrize("args", ["positional", "optional"])
14+
class TestRequirements:
15+
"""Test requirements commands."""
16+
17+
BASE_CMD = ("python", "-m", "lightning_utilities.cli", "requirements")
18+
REQUIREMENTS_SAMPLE = """
19+
# This is sample requirements file
20+
# with multi line comments
21+
22+
torchvision >=0.13.0, <0.16.0 # sample # comment
23+
gym[classic,control] >=0.17.0, <0.27.0
24+
ipython[all] <8.15.0 # strict
25+
torchmetrics >=0.10.0, <1.3.0
26+
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
27+
"""
28+
29+
def _create_requirements_file(self, local_path: Path, filename: str = "testing-cli-requirements.txt"):
30+
"""Create a sample requirements file."""
31+
req_file = local_path / filename
32+
with open(req_file, "w", encoding="utf8") as fopen:
33+
fopen.write(self.REQUIREMENTS_SAMPLE)
34+
return str(req_file)
35+
36+
def _build_command(self, subcommand: str, cli_params: tuple, arg_style: str):
37+
"""Build the command for the CLI."""
38+
if arg_style == "positional":
39+
return list(self.BASE_CMD) + [subcommand] + [value for _, value in cli_params]
40+
if arg_style == "optional":
41+
return list(self.BASE_CMD) + [subcommand] + [f"--{key}={value}" for key, value in cli_params]
42+
raise ValueError(f"Unknown test configuration: {arg_style}")
43+
44+
def test_requirements_prune_pkgs(self, args, tmp_path):
45+
"""Prune packages from requirements files."""
46+
req_file = self._create_requirements_file(tmp_path)
47+
cli_params = (("packages", "ipython"), ("req_files", req_file))
48+
cmd = self._build_command("prune-pkgs", cli_params, args)
49+
return_code = subprocess.call(cmd) # noqa: S603
50+
assert return_code == 0
51+
52+
def test_requirements_set_oldest(self, args, tmp_path):
53+
"""Set the oldest version of packages in requirement files."""
54+
req_file = self._create_requirements_file(tmp_path, "requirements.txt")
55+
cli_params = (("req_files", req_file),)
56+
cmd = self._build_command("set-oldest", cli_params, args)
57+
return_code = subprocess.call(cmd) # noqa: S603
58+
assert return_code == 0
59+
60+
def test_requirements_replace_pkg(self, args, tmp_path):
61+
"""Replace a package in requirements files."""
62+
req_file = self._create_requirements_file(tmp_path, "requirements.txt")
63+
cli_params = (("old_package", "torchvision"), ("new_package", "torchtext"), ("req_files", req_file))
64+
cmd = self._build_command("replace-pkg", cli_params, args)
65+
return_code = subprocess.call(cmd) # noqa: S603
66+
assert return_code == 0

0 commit comments

Comments
 (0)