1515# flake8: noqa
1616
1717
18+ # ------------------------------
19+ # Test helpers (to avoid duplication)
20+ # ------------------------------
21+
22+
23+ def setup_gh_mocks (
24+ mock_gh_factory , * , pr_number : int = 12345 , state : str | None = None , labels = None
25+ ):
26+ """Set up common GitHub mocks and return (mock_pr, mock_issue).
27+
28+ - Configures GHClientFactory().client.get_repo().get_issue() to return a mock issue
29+ - Creates a mock PR object with provided number/state/labels
30+ """
31+ if labels is None :
32+ labels = []
33+
34+ mock_pr = Mock ()
35+ mock_pr .number = pr_number
36+ if state is not None :
37+ mock_pr .state = state
38+ mock_pr .labels = labels
39+ mock_pr .get_labels .return_value = labels
40+
41+ mock_issue = Mock ()
42+ mock_repo = Mock ()
43+ mock_repo .get_issue .return_value = mock_issue
44+ mock_client = Mock ()
45+ mock_client .get_repo .return_value = mock_repo
46+ mock_gh_factory .return_value .client = mock_client
47+
48+ return mock_pr , mock_issue
49+
50+
51+ def make_ctx (* , revert_action , restart_action = RestartAction .SKIP ):
52+ return RunContext (
53+ ts = datetime .now (timezone .utc ),
54+ notify_issue_number = 123456 ,
55+ repo_full_name = "pytorch/pytorch" ,
56+ workflows = ["trunk" ],
57+ lookback_hours = 24 ,
58+ revert_action = revert_action ,
59+ restart_action = restart_action ,
60+ )
61+
62+
63+ def make_source (
64+ * ,
65+ workflow_name : str = "trunk" ,
66+ key : str = "test_signal" ,
67+ job_base_name : str | None = "linux-jammy / test" ,
68+ wf_run_id : int | None = 12345 ,
69+ job_id : int | None = 67890 ,
70+ ):
71+ return SignalMetadata (
72+ workflow_name = workflow_name ,
73+ key = key ,
74+ job_base_name = job_base_name ,
75+ wf_run_id = wf_run_id ,
76+ job_id = job_id ,
77+ )
78+
79+
80+ def set_find_pr_to_merge (proc : SignalActionProcessor , pr ):
81+ proc ._find_pr_by_sha = Mock (return_value = (CommitPRSourceAction .MERGE , pr ))
82+
83+
1884class FakeLogger :
1985 def __init__ (self ):
2086 self ._recent = []
@@ -251,31 +317,10 @@ def test_comment_no_pr_found(self, mock_gh_factory):
251317 @patch ("pytorch_auto_revert.signal_actions.GHClientFactory" )
252318 def test_comment_with_job_and_hud_links (self , mock_gh_factory ):
253319 """Test that comment includes job link and HUD link when available."""
254- # Mock PR and issue
255- mock_pr = Mock ()
256- mock_pr .number = 12345
257- mock_pr .get_labels .return_value = []
258-
259- mock_issue = Mock ()
260- mock_repo = Mock ()
261- mock_repo .get_issue .return_value = mock_issue
262- mock_client = Mock ()
263- mock_client .get_repo .return_value = mock_repo
264- mock_gh_factory .return_value .client = mock_client
265-
266- self .proc ._find_pr_by_sha = Mock (
267- return_value = (CommitPRSourceAction .MERGE , mock_pr )
268- )
320+ mock_pr , mock_issue = setup_gh_mocks (mock_gh_factory , pr_number = 12345 )
321+ set_find_pr_to_merge (self .proc , mock_pr )
269322
270- sources = [
271- SignalMetadata (
272- workflow_name = "trunk" ,
273- key = "test_signal" ,
274- job_base_name = "linux-jammy / test" ,
275- wf_run_id = 12345 ,
276- job_id = 67890 ,
277- )
278- ]
323+ sources = [make_source ()]
279324
280325 result = self .proc ._comment_issue_pr_revert ("abc123" , sources , self .ctx )
281326
@@ -294,31 +339,10 @@ def test_comment_with_job_and_hud_links(self, mock_gh_factory):
294339 @patch ("pytorch_auto_revert.signal_actions.GHClientFactory" )
295340 def test_comment_with_job_without_hud_links (self , mock_gh_factory ):
296341 """Test that comment includes job link but without HUD link."""
297- # Mock PR and issue
298- mock_pr = Mock ()
299- mock_pr .number = 12345
300- mock_pr .get_labels .return_value = []
301-
302- mock_issue = Mock ()
303- mock_repo = Mock ()
304- mock_repo .get_issue .return_value = mock_issue
305- mock_client = Mock ()
306- mock_client .get_repo .return_value = mock_repo
307- mock_gh_factory .return_value .client = mock_client
308-
309- self .proc ._find_pr_by_sha = Mock (
310- return_value = (CommitPRSourceAction .MERGE , mock_pr )
311- )
342+ mock_pr , mock_issue = setup_gh_mocks (mock_gh_factory , pr_number = 12345 )
343+ set_find_pr_to_merge (self .proc , mock_pr )
312344
313- sources = [
314- SignalMetadata (
315- workflow_name = "trunk" ,
316- key = "test_signal" ,
317- job_base_name = None ,
318- wf_run_id = 12345 ,
319- job_id = 67890 ,
320- )
321- ]
345+ sources = [make_source (job_base_name = None )]
322346
323347 result = self .proc ._comment_issue_pr_revert ("abc123" , sources , self .ctx )
324348
@@ -337,30 +361,10 @@ def test_comment_with_job_without_hud_links(self, mock_gh_factory):
337361 @patch ("pytorch_auto_revert.signal_actions.GHClientFactory" )
338362 def test_comment_without_job_info (self , mock_gh_factory ):
339363 """Test that comment works without job_id/wf_run_id."""
340- mock_pr = Mock ()
341- mock_pr .number = 12345
342- mock_pr .get_labels .return_value = []
343-
344- mock_issue = Mock ()
345- mock_repo = Mock ()
346- mock_repo .get_issue .return_value = mock_issue
347- mock_client = Mock ()
348- mock_client .get_repo .return_value = mock_repo
349- mock_gh_factory .return_value .client = mock_client
350-
351- self .proc ._find_pr_by_sha = Mock (
352- return_value = (CommitPRSourceAction .MERGE , mock_pr )
353- )
364+ mock_pr , mock_issue = setup_gh_mocks (mock_gh_factory , pr_number = 12345 )
365+ set_find_pr_to_merge (self .proc , mock_pr )
354366
355- sources = [
356- SignalMetadata (
357- workflow_name = "trunk" ,
358- key = "test_signal" ,
359- job_base_name = "linux-jammy / test" ,
360- wf_run_id = None ,
361- job_id = None ,
362- )
363- ]
367+ sources = [make_source (wf_run_id = None , job_id = None )]
364368
365369 result = self .proc ._comment_issue_pr_revert ("abc123" , sources , self .ctx )
366370
@@ -381,43 +385,15 @@ def test_comment_autorevert_disabled(self, mock_gh_factory):
381385 """Test that revert is not requested when autorevert is disabled."""
382386 mock_label = Mock ()
383387 mock_label .name = "autorevert: disable"
384-
385- mock_pr = Mock ()
386- mock_pr .number = 12345
387- mock_pr .labels = [mock_label ]
388- mock_pr .get_labels .return_value = [mock_label ]
389-
390- mock_issue = Mock ()
391- mock_repo = Mock ()
392- mock_repo .get_issue .return_value = mock_issue
393- mock_client = Mock ()
394- mock_client .get_repo .return_value = mock_repo
395- mock_gh_factory .return_value .client = mock_client
396-
397- self .proc ._find_pr_by_sha = Mock (
398- return_value = (CommitPRSourceAction .MERGE , mock_pr )
388+ mock_pr , mock_issue = setup_gh_mocks (
389+ mock_gh_factory , pr_number = 12345 , labels = [mock_label ]
399390 )
391+ set_find_pr_to_merge (self .proc , mock_pr )
400392
401393 # Use RUN_REVERT to test the disable logic
402- ctx = RunContext (
403- ts = datetime .now (timezone .utc ),
404- notify_issue_number = 123456 ,
405- repo_full_name = "pytorch/pytorch" ,
406- workflows = ["trunk" ],
407- lookback_hours = 24 ,
408- revert_action = RevertAction .RUN_REVERT ,
409- restart_action = RestartAction .SKIP ,
410- )
394+ ctx = make_ctx (revert_action = RevertAction .RUN_REVERT )
411395
412- sources = [
413- SignalMetadata (
414- workflow_name = "trunk" ,
415- key = "test_signal" ,
416- job_base_name = "linux-jammy / test" ,
417- wf_run_id = 12345 ,
418- job_id = 67890 ,
419- )
420- ]
396+ sources = [make_source ()]
421397
422398 result = self .proc ._comment_issue_pr_revert ("abc123" , sources , ctx )
423399
@@ -430,40 +406,39 @@ def test_comment_autorevert_disabled(self, mock_gh_factory):
430406 # Issue notification should still be created
431407 mock_issue .create_comment .assert_called_once ()
432408
409+ @patch ("pytorch_auto_revert.signal_actions.GHClientFactory" )
410+ def test_comment_pr_open_fallback (self , mock_gh_factory ):
411+ """When PR is open, do not request revert; just notify."""
412+ mock_pr , mock_issue = setup_gh_mocks (
413+ mock_gh_factory , pr_number = 98765 , state = "open"
414+ )
415+ # Find PR by sha returns Merge action type, but PR is open; fallback to notify
416+ set_find_pr_to_merge (self .proc , mock_pr )
417+
418+ # Use RUN_REVERT to ensure the code path would try a revert absent the open-state check
419+ ctx = make_ctx (revert_action = RevertAction .RUN_REVERT )
420+
421+ sources = [make_source ()]
422+
423+ result = self .proc ._comment_issue_pr_revert ("abc123" , sources , ctx )
424+
425+ # Should not request pytorchbot revert when PR is open
426+ mock_pr .create_issue_comment .assert_not_called ()
427+ # Should still post a notification comment
428+ mock_issue .create_comment .assert_called_once ()
429+ # Return False because RUN_REVERT was requested but we fell back to notify-only
430+ self .assertFalse (result )
431+
433432 @patch ("pytorch_auto_revert.signal_actions.GHClientFactory" )
434433 def test_comment_multiple_workflows (self , mock_gh_factory ):
435434 """Test that comment groups signals by workflow."""
436- mock_pr = Mock ()
437- mock_pr .number = 12345
438- mock_pr .get_labels .return_value = []
439-
440- mock_issue = Mock ()
441- mock_repo = Mock ()
442- mock_repo .get_issue .return_value = mock_issue
443- mock_client = Mock ()
444- mock_client .get_repo .return_value = mock_repo
445- mock_gh_factory .return_value .client = mock_client
446-
447- self .proc ._find_pr_by_sha = Mock (
448- return_value = (CommitPRSourceAction .MERGE , mock_pr )
449- )
435+ mock_pr , mock_issue = setup_gh_mocks (mock_gh_factory , pr_number = 12345 )
436+ set_find_pr_to_merge (self .proc , mock_pr )
450437
451438 sources = [
452- SignalMetadata (
453- workflow_name = "trunk" ,
454- key = "test_signal_1" ,
455- job_base_name = "linux-jammy / test" ,
456- wf_run_id = 12345 ,
457- job_id = 67890 ,
458- ),
459- SignalMetadata (
460- workflow_name = "trunk" ,
461- key = "test_signal_2" ,
462- job_base_name = "linux-jammy / test" ,
463- wf_run_id = 12345 ,
464- job_id = 67890 ,
465- ),
466- SignalMetadata (
439+ make_source (key = "test_signal_1" ),
440+ make_source (key = "test_signal_2" ),
441+ make_source (
467442 workflow_name = "inductor" ,
468443 key = "test_inductor" ,
469444 job_base_name = "linux-jammy / inductor" ,
0 commit comments