|
| 1 | +import subprocess |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | + |
| 7 | +def test_version(): |
| 8 | + """Prints the help message for the requirements commands.""" |
| 9 | + return_code = subprocess.call(["python", "-mlightning_utilities.cli", "version"]) # noqa: S607 |
| 10 | + assert return_code == 0 |
| 11 | + |
| 12 | + |
| 13 | +@pytest.mark.parametrize("args", ["positional", "optional"]) |
| 14 | +class TestRequirements: |
| 15 | + """Test requirements commands.""" |
| 16 | + |
| 17 | + BASE_CMD = ("python", "-m", "lightning_utilities.cli", "requirements") |
| 18 | + REQUIREMENTS_SAMPLE = """ |
| 19 | +# This is sample requirements file |
| 20 | +# with multi line comments |
| 21 | +
|
| 22 | +torchvision >=0.13.0, <0.16.0 # sample # comment |
| 23 | +gym[classic,control] >=0.17.0, <0.27.0 |
| 24 | +ipython[all] <8.15.0 # strict |
| 25 | +torchmetrics >=0.10.0, <1.3.0 |
| 26 | +deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict |
| 27 | + """ |
| 28 | + |
| 29 | + def _create_requirements_file(self, local_path: Path, filename: str = "testing-cli-requirements.txt"): |
| 30 | + """Create a sample requirements file.""" |
| 31 | + req_file = local_path / filename |
| 32 | + with open(req_file, "w", encoding="utf8") as fopen: |
| 33 | + fopen.write(self.REQUIREMENTS_SAMPLE) |
| 34 | + return str(req_file) |
| 35 | + |
| 36 | + def _build_command(self, subcommand: str, cli_params: tuple, arg_style: str): |
| 37 | + """Build the command for the CLI.""" |
| 38 | + if arg_style == "positional": |
| 39 | + return list(self.BASE_CMD) + [subcommand] + [value for _, value in cli_params] |
| 40 | + if arg_style == "optional": |
| 41 | + return list(self.BASE_CMD) + [subcommand] + [f"--{key}={value}" for key, value in cli_params] |
| 42 | + raise ValueError(f"Unknown test configuration: {arg_style}") |
| 43 | + |
| 44 | + def test_requirements_prune_pkgs(self, args, tmp_path): |
| 45 | + """Prune packages from requirements files.""" |
| 46 | + req_file = self._create_requirements_file(tmp_path) |
| 47 | + cli_params = (("packages", "ipython"), ("req_files", req_file)) |
| 48 | + cmd = self._build_command("prune-pkgs", cli_params, args) |
| 49 | + return_code = subprocess.call(cmd) # noqa: S603 |
| 50 | + assert return_code == 0 |
| 51 | + |
| 52 | + def test_requirements_set_oldest(self, args, tmp_path): |
| 53 | + """Set the oldest version of packages in requirement files.""" |
| 54 | + req_file = self._create_requirements_file(tmp_path, "requirements.txt") |
| 55 | + cli_params = (("req_files", req_file),) |
| 56 | + cmd = self._build_command("set-oldest", cli_params, args) |
| 57 | + return_code = subprocess.call(cmd) # noqa: S603 |
| 58 | + assert return_code == 0 |
| 59 | + |
| 60 | + def test_requirements_replace_pkg(self, args, tmp_path): |
| 61 | + """Replace a package in requirements files.""" |
| 62 | + req_file = self._create_requirements_file(tmp_path, "requirements.txt") |
| 63 | + cli_params = (("old_package", "torchvision"), ("new_package", "torchtext"), ("req_files", req_file)) |
| 64 | + cmd = self._build_command("replace-pkg", cli_params, args) |
| 65 | + return_code = subprocess.call(cmd) # noqa: S603 |
| 66 | + assert return_code == 0 |
0 commit comments