Skip to content

Commit 87603c4

Browse files
Use Unit and Experiment aliases in experiment profiles
1 parent 8eca98a commit 87603c4

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

pioreactor/actions/leader/experiment_profile.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pioreactor.experiment_profiles import profile_struct as struct
2020
from pioreactor.logging import create_logger
2121
from pioreactor.logging import CustomLogger
22+
from pioreactor import types as pt
2223
from pioreactor.mureq import HTTPException
2324
from pioreactor.pubsub import get_from
2425
from pioreactor.pubsub import patch_into
@@ -135,7 +136,7 @@ def check_syntax_of_bool_expression(bool_expression: BoolExpression) -> bool:
135136
return check_syntax(bool_expression)
136137

137138

138-
def check_if_job_running(unit: str, job: str) -> bool:
139+
def check_if_job_running(unit: pt.Unit, job: str) -> bool:
139140
if is_testing_env():
140141
return True
141142
try:
@@ -197,8 +198,8 @@ def get_simple_priority(action: struct.Action):
197198

198199

199200
def wrapped_execute_action(
200-
unit: str,
201-
experiment: str,
201+
unit: pt.Unit,
202+
experiment: pt.Experiment,
202203
global_env: Env,
203204
job_name: str,
204205
logger: CustomLogger,
@@ -322,7 +323,7 @@ def combined_function() -> None:
322323

323324

324325
def common_wrapped_execute_action(
325-
experiment: str,
326+
experiment: pt.Experiment,
326327
job_name: str,
327328
global_env: Env,
328329
logger: CustomLogger,
@@ -353,8 +354,8 @@ def common_wrapped_execute_action(
353354

354355

355356
def when(
356-
unit: str,
357-
experiment: str,
357+
unit: pt.Unit,
358+
experiment: pt.Experiment,
358359
parent_job: long_running_managed_lifecycle,
359360
job_name: str,
360361
dry_run: bool,
@@ -431,8 +432,8 @@ def _callable() -> None:
431432

432433

433434
def repeat(
434-
unit: str,
435-
experiment: str,
435+
unit: pt.Unit,
436+
experiment: pt.Experiment,
436437
parent_job: long_running_managed_lifecycle,
437438
job_name: str,
438439
dry_run: bool,
@@ -521,8 +522,8 @@ def _callable() -> None:
521522

522523

523524
def log(
524-
unit: str,
525-
experiment: str,
525+
unit: pt.Unit,
526+
experiment: pt.Experiment,
526527
parent_job: long_running_managed_lifecycle,
527528
job_name: str,
528529
dry_run: bool,
@@ -565,8 +566,8 @@ def _callable() -> None:
565566

566567

567568
def start_job(
568-
unit: str,
569-
experiment: str,
569+
unit: pt.Unit,
570+
experiment: pt.Experiment,
570571
parent_job: long_running_managed_lifecycle,
571572
job_name: str,
572573
dry_run: bool,
@@ -627,8 +628,8 @@ def _callable() -> None:
627628

628629

629630
def pause_job(
630-
unit: str,
631-
experiment: str,
631+
unit: pt.Unit,
632+
experiment: pt.Experiment,
632633
parent_job: long_running_managed_lifecycle,
633634
job_name: str,
634635
dry_run: bool,
@@ -670,8 +671,8 @@ def _callable() -> None:
670671

671672

672673
def resume_job(
673-
unit: str,
674-
experiment: str,
674+
unit: pt.Unit,
675+
experiment: pt.Experiment,
675676
parent_job: long_running_managed_lifecycle,
676677
job_name: str,
677678
dry_run: bool,
@@ -714,8 +715,8 @@ def _callable() -> None:
714715

715716

716717
def stop_job(
717-
unit: str,
718-
experiment: str,
718+
unit: pt.Unit,
719+
experiment: pt.Experiment,
719720
parent_job: long_running_managed_lifecycle,
720721
job_name: str,
721722
dry_run: bool,
@@ -754,8 +755,8 @@ def _callable() -> None:
754755

755756

756757
def update_job(
757-
unit: str,
758-
experiment: str,
758+
unit: pt.Unit,
759+
experiment: pt.Experiment,
759760
parent_job: long_running_managed_lifecycle,
760761
job_name: str,
761762
dry_run: bool,
@@ -855,7 +856,7 @@ def load_and_verify_profile(profile_filename: str) -> struct.Profile:
855856
return profile
856857

857858

858-
def push_labels_to_ui(experiment, labels_map: dict[str, str]) -> None:
859+
def push_labels_to_ui(experiment: pt.Experiment, labels_map: dict[str, str]) -> None:
859860
try:
860861
for unit_name, label in labels_map.items():
861862
patch_into_leader(
@@ -912,7 +913,7 @@ def check_plugins(required_plugins: list[struct.Plugin]) -> None:
912913
raise ImportError(f"Missing plugins: {not_installed}")
913914

914915

915-
def execute_experiment_profile(profile_filename: str, experiment: str, dry_run: bool = False) -> None:
916+
def execute_experiment_profile(profile_filename: str, experiment: pt.Experiment, dry_run: bool = False) -> None:
916917
unit = get_unit_name()
917918
action_name = "experiment_profile"
918919
with long_running_managed_lifecycle(unit, experiment, action_name) as mananged_job:
@@ -1065,7 +1066,7 @@ def click_experiment_profile():
10651066
@click.argument("filename", type=click.Path())
10661067
@click.argument("experiment", type=str)
10671068
@click.option("--dry-run", is_flag=True, help="Don't actually execute, just print to screen")
1068-
def click_execute_experiment_profile(filename: str, experiment: str, dry_run: bool) -> None:
1069+
def click_execute_experiment_profile(filename: str, experiment: pt.Experiment, dry_run: bool) -> None:
10691070
"""
10701071
(leader only) Run an experiment profile.
10711072
"""

pioreactor/experiment_profiles/profile_struct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from msgspec import field
77
from msgspec import Struct
8+
from pioreactor import types as pt
89

910

1011
bool_expression = str | bool
@@ -102,7 +103,7 @@ class Job(Struct, forbid_unknown_fields=True):
102103
# logging?
103104

104105

105-
PioreactorUnitName = str
106+
PioreactorUnitName = pt.Unit
106107
PioreactorLabel = str
107108
JobName = str
108109
Jobs = dict[JobName, Job]

0 commit comments

Comments
 (0)