-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsweep_cli.py
111 lines (90 loc) · 2.86 KB
/
sweep_cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import itertools
import json
import math
import os
import shlex
import subprocess
from pathlib import Path
from typing import Any, Literal
import ray
from jsonargparse import CLI
from ray import train, tune
ray.init(_temp_dir=str(Path.home() / ".cache" / "ray"))
def run_cli(config, debug: bool = True, command: str = "fit", devices: int = 1):
os.chdir(os.environ["TUNE_ORIG_WORKING_DIR"])
argv = ["./run", command]
ckpt_path = config.pop("ckpt_path", None)
if ckpt_path is not None:
config_path = Path(ckpt_path).parents[1] / "config.yaml"
argv.extend(["--config", str(config_path)])
argv.extend(["--ckpt_path", ckpt_path])
config.pop("config", None)
config.pop("data_config", None)
else:
for cfg in ["config", "data_config"]:
if cfg in config:
argv.extend(["--config", config.pop(cfg)])
argv.extend(
itertools.chain(
*[
[f"--{k}", v if isinstance(v, str) else json.dumps(v)]
for k, v in config.items()
]
)
)
argv.extend(["--trainer.devices", str(devices)])
if debug:
argv.extend(["--config", "configs/presets/tester.yaml"])
print(shlex.join(argv))
subprocess.check_output(argv)
def sweep(
command: Literal["fit", "validate", "test"],
debug: bool = False,
gpus_per_trial: int | float = 1,
*,
ckpt_paths: list[str | None] | None = None,
configs: list[str] | None = None,
data_configs: list[str | None] | None = None,
override_kwargs: dict[str, Any] | None = None,
):
param_space = {
**({"ckpt_path": tune.grid_search(ckpt_paths)} if ckpt_paths else {}),
**({"config": tune.grid_search(configs)} if configs else {}),
**({"data_config": tune.grid_search(data_configs)} if data_configs else {}),
**(
{
k: tune.grid_search(v) if isinstance(v, list) else tune.grid_search([v])
for k, v in override_kwargs.items()
}
if override_kwargs
else {}
),
}
tune_config = tune.TuneConfig()
run_config = train.RunConfig(
log_to_file=True,
storage_path=Path("./results/ray").resolve(),
)
trainable = tune.with_parameters(
run_cli,
debug=debug,
command=command,
devices=math.ceil(gpus_per_trial),
)
tuner = tune.Tuner(
tune.with_resources(trainable, resources={"gpu": gpus_per_trial}),
param_space=param_space,
tune_config=tune_config,
run_config=run_config,
)
tuner.fit()
def fit(*args, **kwargs):
sweep("fit", *args, **kwargs)
def validate(*args, **kwargs):
sweep("validate", *args, **kwargs)
def test(*args, **kwargs):
sweep("test", *args, **kwargs)
def sweep_cli():
CLI([fit, validate, test])
if __name__ == "__main__":
sweep_cli()