Skip to content

Commit 8a43a2f

Browse files
authored
[autorevert] Prevent revert requests for open PRs, fallback to notification (#7331)
There were some examples when autorevert was trying to revert already reverted PRs, e.g. pytorch/pytorch#164144 (comment) it's benign, but annoying, this should fix it. ---- ### Testing unit test + [this manual test](https://gist.github.com/izaitsevfb/b647c0d639ce26e566c11fe3586f1a8a)
1 parent 264eed5 commit 8a43a2f

File tree

2 files changed

+124
-133
lines changed

2 files changed

+124
-133
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_actions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,22 @@ def _comment_issue_pr_revert(
576576
and action_type == CommitPRSourceAction.MERGE
577577
)
578578

579+
# If the PR is still open, do not request a bot revert.
580+
# This covers cases where the commit belongs to an open PR
581+
# (not yet merged) or the PR has already been reverted and is open.
582+
# In such cases, fall back to posting a notification comment only.
583+
if should_do_revert_on_pr:
584+
pr_state = getattr(pr, "state", None)
585+
if pr_state == "open":
586+
logging.info(
587+
"[v2][action] (%s, %s) revert for sha %s: PR #%s is open, will just notify",
588+
ctx.revert_action,
589+
action_type,
590+
commit_sha[:8],
591+
pr.number,
592+
)
593+
should_do_revert_on_pr = False
594+
579595
if should_do_revert_on_pr:
580596
# check if label 'autorevert: disable' is on the `pr`
581597
labels = []

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_signal_actions.py

Lines changed: 108 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,72 @@
1515
# flake8: noqa
1616

1717

18+
# ------------------------------
19+
# Test helpers (to avoid duplication)
20+
# ------------------------------
21+
22+
23+
def setup_gh_mocks(
24+
mock_gh_factory, *, pr_number: int = 12345, state: str | None = None, labels=None
25+
):
26+
"""Set up common GitHub mocks and return (mock_pr, mock_issue).
27+
28+
- Configures GHClientFactory().client.get_repo().get_issue() to return a mock issue
29+
- Creates a mock PR object with provided number/state/labels
30+
"""
31+
if labels is None:
32+
labels = []
33+
34+
mock_pr = Mock()
35+
mock_pr.number = pr_number
36+
if state is not None:
37+
mock_pr.state = state
38+
mock_pr.labels = labels
39+
mock_pr.get_labels.return_value = labels
40+
41+
mock_issue = Mock()
42+
mock_repo = Mock()
43+
mock_repo.get_issue.return_value = mock_issue
44+
mock_client = Mock()
45+
mock_client.get_repo.return_value = mock_repo
46+
mock_gh_factory.return_value.client = mock_client
47+
48+
return mock_pr, mock_issue
49+
50+
51+
def make_ctx(*, revert_action, restart_action=RestartAction.SKIP):
52+
return RunContext(
53+
ts=datetime.now(timezone.utc),
54+
notify_issue_number=123456,
55+
repo_full_name="pytorch/pytorch",
56+
workflows=["trunk"],
57+
lookback_hours=24,
58+
revert_action=revert_action,
59+
restart_action=restart_action,
60+
)
61+
62+
63+
def make_source(
64+
*,
65+
workflow_name: str = "trunk",
66+
key: str = "test_signal",
67+
job_base_name: str | None = "linux-jammy / test",
68+
wf_run_id: int | None = 12345,
69+
job_id: int | None = 67890,
70+
):
71+
return SignalMetadata(
72+
workflow_name=workflow_name,
73+
key=key,
74+
job_base_name=job_base_name,
75+
wf_run_id=wf_run_id,
76+
job_id=job_id,
77+
)
78+
79+
80+
def set_find_pr_to_merge(proc: SignalActionProcessor, pr):
81+
proc._find_pr_by_sha = Mock(return_value=(CommitPRSourceAction.MERGE, pr))
82+
83+
1884
class FakeLogger:
1985
def __init__(self):
2086
self._recent = []
@@ -251,31 +317,10 @@ def test_comment_no_pr_found(self, mock_gh_factory):
251317
@patch("pytorch_auto_revert.signal_actions.GHClientFactory")
252318
def test_comment_with_job_and_hud_links(self, mock_gh_factory):
253319
"""Test that comment includes job link and HUD link when available."""
254-
# Mock PR and issue
255-
mock_pr = Mock()
256-
mock_pr.number = 12345
257-
mock_pr.get_labels.return_value = []
258-
259-
mock_issue = Mock()
260-
mock_repo = Mock()
261-
mock_repo.get_issue.return_value = mock_issue
262-
mock_client = Mock()
263-
mock_client.get_repo.return_value = mock_repo
264-
mock_gh_factory.return_value.client = mock_client
265-
266-
self.proc._find_pr_by_sha = Mock(
267-
return_value=(CommitPRSourceAction.MERGE, mock_pr)
268-
)
320+
mock_pr, mock_issue = setup_gh_mocks(mock_gh_factory, pr_number=12345)
321+
set_find_pr_to_merge(self.proc, mock_pr)
269322

270-
sources = [
271-
SignalMetadata(
272-
workflow_name="trunk",
273-
key="test_signal",
274-
job_base_name="linux-jammy / test",
275-
wf_run_id=12345,
276-
job_id=67890,
277-
)
278-
]
323+
sources = [make_source()]
279324

280325
result = self.proc._comment_issue_pr_revert("abc123", sources, self.ctx)
281326

@@ -294,31 +339,10 @@ def test_comment_with_job_and_hud_links(self, mock_gh_factory):
294339
@patch("pytorch_auto_revert.signal_actions.GHClientFactory")
295340
def test_comment_with_job_without_hud_links(self, mock_gh_factory):
296341
"""Test that comment includes job link but without HUD link."""
297-
# Mock PR and issue
298-
mock_pr = Mock()
299-
mock_pr.number = 12345
300-
mock_pr.get_labels.return_value = []
301-
302-
mock_issue = Mock()
303-
mock_repo = Mock()
304-
mock_repo.get_issue.return_value = mock_issue
305-
mock_client = Mock()
306-
mock_client.get_repo.return_value = mock_repo
307-
mock_gh_factory.return_value.client = mock_client
308-
309-
self.proc._find_pr_by_sha = Mock(
310-
return_value=(CommitPRSourceAction.MERGE, mock_pr)
311-
)
342+
mock_pr, mock_issue = setup_gh_mocks(mock_gh_factory, pr_number=12345)
343+
set_find_pr_to_merge(self.proc, mock_pr)
312344

313-
sources = [
314-
SignalMetadata(
315-
workflow_name="trunk",
316-
key="test_signal",
317-
job_base_name=None,
318-
wf_run_id=12345,
319-
job_id=67890,
320-
)
321-
]
345+
sources = [make_source(job_base_name=None)]
322346

323347
result = self.proc._comment_issue_pr_revert("abc123", sources, self.ctx)
324348

@@ -337,30 +361,10 @@ def test_comment_with_job_without_hud_links(self, mock_gh_factory):
337361
@patch("pytorch_auto_revert.signal_actions.GHClientFactory")
338362
def test_comment_without_job_info(self, mock_gh_factory):
339363
"""Test that comment works without job_id/wf_run_id."""
340-
mock_pr = Mock()
341-
mock_pr.number = 12345
342-
mock_pr.get_labels.return_value = []
343-
344-
mock_issue = Mock()
345-
mock_repo = Mock()
346-
mock_repo.get_issue.return_value = mock_issue
347-
mock_client = Mock()
348-
mock_client.get_repo.return_value = mock_repo
349-
mock_gh_factory.return_value.client = mock_client
350-
351-
self.proc._find_pr_by_sha = Mock(
352-
return_value=(CommitPRSourceAction.MERGE, mock_pr)
353-
)
364+
mock_pr, mock_issue = setup_gh_mocks(mock_gh_factory, pr_number=12345)
365+
set_find_pr_to_merge(self.proc, mock_pr)
354366

355-
sources = [
356-
SignalMetadata(
357-
workflow_name="trunk",
358-
key="test_signal",
359-
job_base_name="linux-jammy / test",
360-
wf_run_id=None,
361-
job_id=None,
362-
)
363-
]
367+
sources = [make_source(wf_run_id=None, job_id=None)]
364368

365369
result = self.proc._comment_issue_pr_revert("abc123", sources, self.ctx)
366370

@@ -381,43 +385,15 @@ def test_comment_autorevert_disabled(self, mock_gh_factory):
381385
"""Test that revert is not requested when autorevert is disabled."""
382386
mock_label = Mock()
383387
mock_label.name = "autorevert: disable"
384-
385-
mock_pr = Mock()
386-
mock_pr.number = 12345
387-
mock_pr.labels = [mock_label]
388-
mock_pr.get_labels.return_value = [mock_label]
389-
390-
mock_issue = Mock()
391-
mock_repo = Mock()
392-
mock_repo.get_issue.return_value = mock_issue
393-
mock_client = Mock()
394-
mock_client.get_repo.return_value = mock_repo
395-
mock_gh_factory.return_value.client = mock_client
396-
397-
self.proc._find_pr_by_sha = Mock(
398-
return_value=(CommitPRSourceAction.MERGE, mock_pr)
388+
mock_pr, mock_issue = setup_gh_mocks(
389+
mock_gh_factory, pr_number=12345, labels=[mock_label]
399390
)
391+
set_find_pr_to_merge(self.proc, mock_pr)
400392

401393
# Use RUN_REVERT to test the disable logic
402-
ctx = RunContext(
403-
ts=datetime.now(timezone.utc),
404-
notify_issue_number=123456,
405-
repo_full_name="pytorch/pytorch",
406-
workflows=["trunk"],
407-
lookback_hours=24,
408-
revert_action=RevertAction.RUN_REVERT,
409-
restart_action=RestartAction.SKIP,
410-
)
394+
ctx = make_ctx(revert_action=RevertAction.RUN_REVERT)
411395

412-
sources = [
413-
SignalMetadata(
414-
workflow_name="trunk",
415-
key="test_signal",
416-
job_base_name="linux-jammy / test",
417-
wf_run_id=12345,
418-
job_id=67890,
419-
)
420-
]
396+
sources = [make_source()]
421397

422398
result = self.proc._comment_issue_pr_revert("abc123", sources, ctx)
423399

@@ -430,40 +406,39 @@ def test_comment_autorevert_disabled(self, mock_gh_factory):
430406
# Issue notification should still be created
431407
mock_issue.create_comment.assert_called_once()
432408

409+
@patch("pytorch_auto_revert.signal_actions.GHClientFactory")
410+
def test_comment_pr_open_fallback(self, mock_gh_factory):
411+
"""When PR is open, do not request revert; just notify."""
412+
mock_pr, mock_issue = setup_gh_mocks(
413+
mock_gh_factory, pr_number=98765, state="open"
414+
)
415+
# Find PR by sha returns Merge action type, but PR is open; fallback to notify
416+
set_find_pr_to_merge(self.proc, mock_pr)
417+
418+
# Use RUN_REVERT to ensure the code path would try a revert absent the open-state check
419+
ctx = make_ctx(revert_action=RevertAction.RUN_REVERT)
420+
421+
sources = [make_source()]
422+
423+
result = self.proc._comment_issue_pr_revert("abc123", sources, ctx)
424+
425+
# Should not request pytorchbot revert when PR is open
426+
mock_pr.create_issue_comment.assert_not_called()
427+
# Should still post a notification comment
428+
mock_issue.create_comment.assert_called_once()
429+
# Return False because RUN_REVERT was requested but we fell back to notify-only
430+
self.assertFalse(result)
431+
433432
@patch("pytorch_auto_revert.signal_actions.GHClientFactory")
434433
def test_comment_multiple_workflows(self, mock_gh_factory):
435434
"""Test that comment groups signals by workflow."""
436-
mock_pr = Mock()
437-
mock_pr.number = 12345
438-
mock_pr.get_labels.return_value = []
439-
440-
mock_issue = Mock()
441-
mock_repo = Mock()
442-
mock_repo.get_issue.return_value = mock_issue
443-
mock_client = Mock()
444-
mock_client.get_repo.return_value = mock_repo
445-
mock_gh_factory.return_value.client = mock_client
446-
447-
self.proc._find_pr_by_sha = Mock(
448-
return_value=(CommitPRSourceAction.MERGE, mock_pr)
449-
)
435+
mock_pr, mock_issue = setup_gh_mocks(mock_gh_factory, pr_number=12345)
436+
set_find_pr_to_merge(self.proc, mock_pr)
450437

451438
sources = [
452-
SignalMetadata(
453-
workflow_name="trunk",
454-
key="test_signal_1",
455-
job_base_name="linux-jammy / test",
456-
wf_run_id=12345,
457-
job_id=67890,
458-
),
459-
SignalMetadata(
460-
workflow_name="trunk",
461-
key="test_signal_2",
462-
job_base_name="linux-jammy / test",
463-
wf_run_id=12345,
464-
job_id=67890,
465-
),
466-
SignalMetadata(
439+
make_source(key="test_signal_1"),
440+
make_source(key="test_signal_2"),
441+
make_source(
467442
workflow_name="inductor",
468443
key="test_inductor",
469444
job_base_name="linux-jammy / inductor",

0 commit comments

Comments
 (0)