2121from .utils import RestartAction , RetryWithBackoff , RevertAction
2222
2323
24- DEFAULT_WORKFLOWS = ["Lint" , "trunk" , "pull" , "inductor" , "linux-aarch64" ]
25- DEFAULT_REPO_FULL_NAME = "pytorch/pytorch"
26- DEFAULT_HOURS = 16
27- DEFAULT_COMMENT_ISSUE_NUMBER = (
28- 163650 # https://github.com/pytorch/pytorch/issues/163650
29- )
3024# Special constant to indicate --hud-html was passed as a flag (without a value)
3125HUD_HTML_NO_VALUE_FLAG = object ()
3226
3327
28+ class DefaultConfig :
29+ def __init__ (self ):
30+ self .bisection_limit = (
31+ int (os .environ ["BISECTION_LIMIT" ])
32+ if "BISECTION_LIMIT" in os .environ
33+ else None
34+ )
35+ self .clickhouse_database = os .environ .get ("CLICKHOUSE_DATABASE" , "default" )
36+ self .clickhouse_host = os .environ .get ("CLICKHOUSE_HOST" , "localhost" )
37+ self .clickhouse_password = os .environ .get ("CLICKHOUSE_PASSWORD" , "" )
38+ self .clickhouse_port = int (os .environ .get ("CLICKHOUSE_PORT" , 8443 ))
39+ self .clickhouse_username = os .environ .get ("CLICKHOUSE_USERNAME" , "" )
40+ self .github_access_token = os .environ .get ("GITHUB_TOKEN" , "" )
41+ self .github_app_id = os .environ .get ("GITHUB_APP_ID" , "" )
42+ self .github_app_secret = os .environ .get ("GITHUB_APP_SECRET" , "" )
43+ self .github_installation_id = os .environ .get ("GITHUB_INSTALLATION_ID" , "" )
44+ self .hours = int (os .environ .get ("HOURS" , 16 ))
45+ self .log_level = os .environ .get ("LOG_LEVEL" , "INFO" )
46+ self .notify_issue_number = int (
47+ os .environ .get ("NOTIFY_ISSUE_NUMBER" , 163650 )
48+ ) # https://github.com/pytorch/pytorch/issues/163650
49+ self .repo_full_name = os .environ .get ("REPO_FULL_NAME" , "pytorch/pytorch" )
50+ self .restart_action = (
51+ RestartAction .from_str (os .environ ["RESTART_ACTION" ])
52+ if "RESTART_ACTION" in os .environ
53+ else None
54+ )
55+ self .revert_action = (
56+ RevertAction .from_str (os .environ ["REVERT_ACTION" ])
57+ if "REVERT_ACTION" in os .environ
58+ else None
59+ )
60+ self .secret_store_name = os .environ .get ("SECRET_STORE_NAME" , "" )
61+ self .workflows = os .environ .get (
62+ "WORKFLOWS" ,
63+ "," .join (["Lint" , "trunk" , "pull" , "inductor" , "linux-aarch64" ]),
64+ ).split ("," )
65+
66+ def to_autorevert_v2_params (
67+ self ,
68+ * ,
69+ default_restart_action : RestartAction ,
70+ default_revert_action : RevertAction ,
71+ dry_run : bool ,
72+ ) -> dict :
73+ """Convert the configuration to a dictionary."""
74+ return {
75+ "workflows" : self .workflows ,
76+ "repo_full_name" : self .repo_full_name ,
77+ "hours" : self .hours ,
78+ "notify_issue_number" : self .notify_issue_number ,
79+ "restart_action" : RestartAction .LOG
80+ if dry_run
81+ else (self .restart_action or default_restart_action ),
82+ "revert_action" : RevertAction .LOG
83+ if dry_run
84+ else (self .revert_action or default_revert_action ),
85+ "bisection_limit" : self .bisection_limit ,
86+ }
87+
88+
89+ def validate_actions_dry_run (
90+ opts : argparse .Namespace , default_config : DefaultConfig
91+ ) -> None :
92+ """Validate the actions to be taken in dry run mode."""
93+ if (
94+ default_config .restart_action is not None
95+ or default_config .revert_action is not None
96+ ) and opts .dry_run :
97+ logging .error (
98+ "Dry run mode: using dry-run flag with environment variables is not allowed."
99+ )
100+ raise ValueError (
101+ "Conflicting options: --dry-run with explicit actions via environment variables"
102+ )
103+ if (
104+ opts .subcommand == "autorevert-checker"
105+ and (opts .restart_action is not None or opts .revert_action is not None )
106+ and opts .dry_run
107+ ):
108+ logging .error (
109+ "Dry run mode: using dry-run flag with explicit actions is not allowed."
110+ )
111+ raise ValueError ("Conflicting options: --dry-run with explicit actions" )
112+
113+
34114def setup_logging (log_level : str ) -> None :
35115 """Set up logging configuration."""
36116 numeric_level = getattr (logging , log_level .upper (), None )
@@ -54,45 +134,41 @@ def setup_logging(log_level: str) -> None:
54134 handler .setLevel (numeric_level )
55135
56136
57- def get_opts () -> argparse .Namespace :
137+ def get_opts (default_config : DefaultConfig ) -> argparse .Namespace :
58138 parser = argparse .ArgumentParser ()
59139
60140 # General options and configurations
61141 parser .add_argument (
62142 "--log-level" ,
63- default = os . environ . get ( "LOG_LEVEL" , "INFO" ) ,
143+ default = default_config . log_level ,
64144 choices = ["NOTSET" , "DEBUG" , "INFO" , "WARNING" , "ERROR" , "CRITICAL" ],
65145 help = "Set the logging level for the application." ,
66146 )
67- parser .add_argument (
68- "--clickhouse-host" , default = os .environ .get ("CLICKHOUSE_HOST" , "" )
69- )
147+ parser .add_argument ("--clickhouse-host" , default = default_config .clickhouse_host )
70148 parser .add_argument (
71149 "--clickhouse-port" ,
72150 type = int ,
73- default = int ( os . environ . get ( "CLICKHOUSE_PORT" , "8443" )) ,
151+ default = default_config . clickhouse_port ,
74152 )
75153 parser .add_argument (
76- "--clickhouse-username" , default = os . environ . get ( "CLICKHOUSE_USERNAME" , "" )
154+ "--clickhouse-username" , default = default_config . clickhouse_username
77155 )
78156 parser .add_argument (
79- "--clickhouse-password" , default = os . environ . get ( "CLICKHOUSE_PASSWORD" , "" )
157+ "--clickhouse-password" , default = default_config . clickhouse_password
80158 )
81159 parser .add_argument (
82160 "--clickhouse-database" ,
83- default = os .environ .get ("CLICKHOUSE_DATABASE" , "default" ),
84- )
85- parser .add_argument (
86- "--github-access-token" , default = os .environ .get ("GITHUB_TOKEN" , "" )
161+ default = default_config .clickhouse_database ,
87162 )
88- parser .add_argument ("--github-app-id" , default = os .environ .get ("GITHUB_APP_ID" , "" ))
89163 parser .add_argument (
90- "--github-app-secret " , default = os . environ . get ( "GITHUB_APP_SECRET" , "" )
164+ "--github-access-token " , default = default_config . github_access_token
91165 )
166+ parser .add_argument ("--github-app-id" , default = default_config .github_app_id )
167+ parser .add_argument ("--github-app-secret" , default = default_config .github_app_secret )
92168 parser .add_argument (
93169 "--github-installation-id" ,
94170 type = int ,
95- default = int ( os . environ . get ( "GITHUB_INSTALLATION_ID" , "0" )) ,
171+ default = default_config . github_installation_id ,
96172 )
97173 parser .add_argument (
98174 "--dry-run" ,
@@ -102,7 +178,7 @@ def get_opts() -> argparse.Namespace:
102178 parser .add_argument (
103179 "--secret-store-name" ,
104180 action = "store" ,
105- default = os . environ . get ( "SECRET_STORE_NAME" , "" ) ,
181+ default = default_config . secret_store_name ,
106182 help = "Name of the secret in AWS Secrets Manager to fetch GitHub App secret from" ,
107183 )
108184
@@ -117,41 +193,38 @@ def get_opts() -> argparse.Namespace:
117193 workflow_parser .add_argument (
118194 "workflows" ,
119195 nargs = "+" ,
120- default = DEFAULT_WORKFLOWS ,
196+ default = default_config . workflows ,
121197 help = "Workflow name(s) to analyze - single name or comma/space separated"
122198 + ' list (e.g., "pull" or "pull,trunk,inductor")' ,
123199 )
124200 workflow_parser .add_argument (
125201 "--hours" ,
126202 type = int ,
127- default = DEFAULT_HOURS ,
128- help = f"Lookback window in hours (default: { DEFAULT_HOURS } )" ,
203+ default = default_config . hours ,
204+ help = f"Lookback window in hours (default: { default_config . hours } )" ,
129205 )
130206 workflow_parser .add_argument (
131207 "--repo-full-name" ,
132- default = os . environ . get ( "REPO_FULL_NAME" , DEFAULT_REPO_FULL_NAME ) ,
208+ default = default_config . repo_full_name ,
133209 help = "Full repo name to filter by (owner/repo)." ,
134210 )
135211 workflow_parser .add_argument (
136212 "--restart-action" ,
137213 type = RestartAction .from_str ,
138- default = RestartAction .from_str (
139- os .environ .get ("RESTART_ACTION" , RestartAction .RUN )
140- ),
214+ default = default_config .restart_action ,
141215 choices = list (RestartAction ),
142216 help = (
143- "Restart mode: skip (no logging), log (no side effects), or run (dispatch)."
217+ "Restart mode: skip (no logging), log (no side effects), or run (dispatch). Default is run. "
144218 ),
145219 )
146220 workflow_parser .add_argument (
147221 "--revert-action" ,
148222 type = RevertAction .from_str ,
149- default = RevertAction .from_str (
150- os .environ .get ("REVERT_ACTION" , RevertAction .LOG )
151- ),
223+ default = default_config .revert_action ,
152224 choices = list (RevertAction ),
153225 help = (
154- "Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or run-revert."
226+ "Revert mode: skip, log (no side effects), run-log (prod-style logging), run-notify, or "
227+ "run-revert. Default is log."
155228 ),
156229 )
157230 workflow_parser .add_argument (
@@ -166,22 +239,16 @@ def get_opts() -> argparse.Namespace:
166239 workflow_parser .add_argument (
167240 "--bisection-limit" ,
168241 type = int ,
169- default = (
170- int (os .environ ["BISECTION_LIMIT" ])
171- if os .environ .get ("BISECTION_LIMIT" , "" ).strip ()
172- else None
173- ),
242+ default = default_config .bisection_limit ,
174243 help = (
175244 "Max new pending jobs to schedule per signal to cover gaps (None = unlimited)."
176245 ),
177246 )
178247 workflow_parser .add_argument (
179248 "--notify-issue-number" ,
180249 type = int ,
181- default = int (
182- os .environ .get ("NOTIFY_ISSUE_NUMBER" , DEFAULT_COMMENT_ISSUE_NUMBER )
183- ),
184- help = f"Issue number to notify (default: { DEFAULT_COMMENT_ISSUE_NUMBER } )" ,
250+ default = default_config .notify_issue_number ,
251+ help = "Issue number to notify" ,
185252 )
186253
187254 # workflow-restart-checker subcommand
@@ -265,7 +332,8 @@ def get_secret_from_aws(secret_store_name: str) -> AWSSecretsFromStore:
265332
266333def main (* args , ** kwargs ) -> None :
267334 load_dotenv ()
268- opts = get_opts ()
335+ default_config = DefaultConfig ()
336+ opts = get_opts (default_config )
269337
270338 gh_app_secret = ""
271339 if opts .github_app_secret :
@@ -299,51 +367,39 @@ def main(*args, **kwargs) -> None:
299367 )
300368
301369 if opts .subcommand is None :
302- repo_name = os .environ .get ("REPO_FULL_NAME" , DEFAULT_REPO_FULL_NAME )
303-
304- if check_autorevert_disabled (repo_name ):
370+ if check_autorevert_disabled (default_config .repo_full_name ):
305371 logging .error (
306372 "Autorevert is disabled via circuit breaker (ci: disable-autorevert issue found). "
307373 "Exiting successfully."
308374 )
309375 return
310376
311- # Read env-driven defaults for the lambda path
312- _bis_env = os .environ .get ("BISECTION_LIMIT" , "" ).strip ()
313- _bis_limit = int (_bis_env ) if _bis_env else None
377+ validate_actions_dry_run (opts , default_config )
314378
315379 autorevert_v2 (
316- os .environ .get ("WORKFLOWS" , "," .join (DEFAULT_WORKFLOWS )).split ("," ),
317- hours = int (os .environ .get ("HOURS" , DEFAULT_HOURS )),
318- notify_issue_number = int (
319- os .environ .get ("NOTIFY_ISSUE_NUMBER" , DEFAULT_COMMENT_ISSUE_NUMBER )
320- ),
321- repo_full_name = repo_name ,
380+ ** default_config .to_autorevert_v2_params (
381+ default_restart_action = RestartAction .RUN ,
382+ default_revert_action = RevertAction .RUN_NOTIFY ,
383+ dry_run = opts .dry_run ,
384+ )
385+ )
386+ elif opts .subcommand == "autorevert-checker" :
387+ validate_actions_dry_run (opts , default_config )
388+ _ , _ , state_json = autorevert_v2 (
389+ opts .workflows ,
390+ hours = opts .hours ,
391+ notify_issue_number = opts .notify_issue_number ,
392+ repo_full_name = opts .repo_full_name ,
322393 restart_action = (
323394 RestartAction .LOG
324395 if opts .dry_run
325- else RestartAction .from_str (
326- os .environ .get ("RESTART_ACTION" , RestartAction .RUN )
327- )
396+ else (opts .restart_action or RestartAction .RUN )
328397 ),
329398 revert_action = (
330399 RevertAction .LOG
331400 if opts .dry_run
332- else RevertAction .from_str (
333- os .environ .get ("REVERT_ACTION" , RevertAction .RUN_NOTIFY )
334- )
401+ else (opts .revert_action or RevertAction .LOG )
335402 ),
336- bisection_limit = _bis_limit ,
337- )
338- elif opts .subcommand == "autorevert-checker" :
339- # New default behavior under the same subcommand
340- _ , _ , state_json = autorevert_v2 (
341- opts .workflows ,
342- hours = opts .hours ,
343- notify_issue_number = opts .notify_issue_number ,
344- repo_full_name = opts .repo_full_name ,
345- restart_action = (RestartAction .LOG if opts .dry_run else opts .restart_action ),
346- revert_action = (RevertAction .LOG if opts .dry_run else opts .revert_action ),
347403 bisection_limit = opts .bisection_limit ,
348404 )
349405 write_hud_html_from_cli (opts .hud_html , HUD_HTML_NO_VALUE_FLAG , state_json )
0 commit comments