Skip to content

Commit 53c6bdf

Browse files
authored
[autorevert] correctly fetch and build the gaps in the signal (#7248)
1. Fixed commits-without-jobs issue - Problem: Commits with no workflow jobs (e.g., periodic workflow) were excluded from signal extraction - Solution: - Added fetch_commits_in_time_range() to query push table directly - Modified job query to filter by explicit list of head_shas instead of JOIN - Changed ORDER BY to use sha dimension first (preserves grouping, actual order doesn't matter as internally extractors now iterate over the list of commits passed explicitly) 2. Added mandatory timestamp field to SignalCommit - Changes: - SignalCommit.__init__(head_sha, timestamp, events) - timestamp is now mandatory - Signal extraction populates timestamps from push table - HUD state logger uses commit timestamp instead of computing from event times - Updated 36 test constructor calls ### Testing Before: [2025-09-29T19-29-47.670686-00-00.html](https://github.com/user-attachments/files/22606856/2025-09-29T19-29-47.670686-00-00.html) After: [2025-09-29T21-38-10.190584-00-00.html](https://github.com/user-attachments/files/22606859/2025-09-29T21-38-10.190584-00-00.html)
1 parent 5398e1a commit 53c6bdf

File tree

7 files changed

+229
-53
lines changed

7 files changed

+229
-53
lines changed

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/run_state_logger.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,17 @@ def _build_state_json(
3434

3535
# Collect commit order (newest → older) across signals
3636
commits: List[str] = []
37+
commit_times: Dict[str, str] = {}
3738
seen = set()
3839
for s in signals:
3940
for c in s.commits:
4041
if c.head_sha not in seen:
4142
seen.add(c.head_sha)
4243
commits.append(c.head_sha)
44+
commit_times[c.head_sha] = c.timestamp.isoformat()
4345

44-
# Compute minimal started_at per commit (for timestamp context)
45-
commit_times: Dict[str, str] = {}
46-
for sha in commits:
47-
tmin_iso: str | None = None
48-
for s in signals:
49-
# find commit in this signal
50-
sc = next((cc for cc in s.commits if cc.head_sha == sha), None)
51-
if not sc or not sc.events:
52-
continue
53-
# events are sorted oldest first
54-
t = sc.events[0].started_at
55-
ts_iso = t.isoformat()
56-
if tmin_iso is None or ts_iso < tmin_iso:
57-
tmin_iso = ts_iso
58-
if tmin_iso is not None:
59-
commit_times[sha] = tmin_iso
46+
# sorting commits by their timestamp
47+
commits.sort(key=lambda sha: commit_times[sha], reverse=True)
6048

6149
# Build columns with outcomes, notes, and per-commit events
6250
cols = []

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ def is_failure(self) -> bool:
100100
class SignalCommit:
101101
"""All events for a single commit, ordered oldest → newest by start time."""
102102

103-
def __init__(self, head_sha: str, events: List[SignalEvent]):
103+
def __init__(self, head_sha: str, timestamp: datetime, events: List[SignalEvent]):
104104
self.head_sha = head_sha
105+
self.timestamp = timestamp
105106
# enforce events ordered by time, then by wf_run_id (oldest first)
106107
self.events = (
107108
sorted(events, key=lambda e: (e.started_at, e.wf_run_id)) if events else []

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,18 @@ def _fmt_event_name(
6464
# -----------------------------
6565
def extract(self) -> List[Signal]:
6666
"""Extract Signals for configured workflows within the lookback window."""
67+
# Fetch commits first to ensure we include commits without jobs
68+
commits = self._datasource.fetch_commits_in_time_range(
69+
repo_full_name=self.repo_full_name,
70+
lookback_hours=self.lookback_hours,
71+
)
72+
73+
# Fetch jobs for these commits
6774
jobs = self._datasource.fetch_jobs_for_workflows(
6875
repo_full_name=self.repo_full_name,
6976
workflows=self.workflows,
7077
lookback_hours=self.lookback_hours,
78+
head_shas=[sha for sha, _ in commits],
7179
)
7280

7381
# Select jobs to participate in test-track details fetch
@@ -76,8 +84,8 @@ def extract(self) -> List[Signal]:
7684
test_track_job_ids, failed_job_ids=failed_job_ids
7785
)
7886

79-
test_signals = self._build_test_signals(jobs, test_rows)
80-
job_signals = self._build_non_test_signals(jobs)
87+
test_signals = self._build_test_signals(jobs, test_rows, commits)
88+
job_signals = self._build_non_test_signals(jobs, commits)
8189
# Deduplicate events within commits across all signals as a final step
8290
# GitHub-specific behavior like "rerun failed" can reuse job instances for reruns.
8391
# When that happens, the jobs have identical timestamps by DIFFERENT job ids.
@@ -101,7 +109,11 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
101109
continue
102110
filtered.append(e)
103111
prev_key = key
104-
new_commits.append(SignalCommit(head_sha=c.head_sha, events=filtered))
112+
new_commits.append(
113+
SignalCommit(
114+
head_sha=c.head_sha, timestamp=c.timestamp, events=filtered
115+
)
116+
)
105117
deduped.append(
106118
Signal(key=s.key, workflow_name=s.workflow_name, commits=new_commits)
107119
)
@@ -145,6 +157,7 @@ def _build_test_signals(
145157
self,
146158
jobs: List[JobRow],
147159
test_rows: List[TestRow],
160+
commits: List[Tuple[Sha, datetime]],
148161
) -> List[Signal]:
149162
"""Build per-test Signals across commits, scoped to job base.
150163
@@ -155,9 +168,15 @@ def _build_test_signals(
155168
- If test_run_s3 rows exist → FAILURE if any failing/errored else SUCCESS
156169
- Else if group pending → PENDING
157170
- Else → no event (missing)
171+
172+
Args:
173+
jobs: List of job rows from the datasource
174+
test_rows: List of test rows from the datasource
175+
commits: Ordered list of (sha, timestamp) tuples (newest → older)
158176
"""
159177

160178
jobs_by_id = {j.job_id: j for j in jobs}
179+
commit_timestamps = dict(commits)
161180

162181
index_by_commit_job_base_wf_run_attempt: JobAggIndex[
163182
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
@@ -172,11 +191,6 @@ def _build_test_signals(
172191
),
173192
)
174193

175-
# Preserve newest → older commit order from the datasource
176-
commit_shas = index_by_commit_job_base_wf_run_attempt.unique_values(
177-
lambda j: j.head_sha
178-
)
179-
180194
run_ids_attempts = index_by_commit_job_base_wf_run_attempt.group_map_values_by(
181195
key_fn=lambda j: (j.head_sha, j.workflow_name, j.base_name),
182196
value_fn=lambda j: (j.wf_run_id, j.run_attempt),
@@ -225,7 +239,7 @@ def _build_test_signals(
225239
)
226240

227241
# y-axis: commits (newest → older)
228-
for commit_sha in commit_shas:
242+
for commit_sha, _ in commits:
229243
events: List[SignalEvent] = []
230244

231245
# x-axis: events for the signal
@@ -286,7 +300,13 @@ def _build_test_signals(
286300
has_any_events = True
287301

288302
# important to always include the commit, even if no events
289-
commit_objs.append(SignalCommit(head_sha=commit_sha, events=events))
303+
commit_objs.append(
304+
SignalCommit(
305+
head_sha=commit_sha,
306+
timestamp=commit_timestamps[commit_sha],
307+
events=events,
308+
)
309+
)
290310

291311
if has_any_events:
292312
signals.append(
@@ -295,9 +315,19 @@ def _build_test_signals(
295315

296316
return signals
297317

298-
def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
299-
# Build Signals keyed by normalized job base name per workflow.
300-
# Aggregate across shards within (wf_run_id, run_attempt) using JobAggIndex.
318+
def _build_non_test_signals(
319+
self, jobs: List[JobRow], commits: List[Tuple[Sha, datetime]]
320+
) -> List[Signal]:
321+
"""Build Signals keyed by normalized job base name per workflow.
322+
323+
Aggregate across shards within (wf_run_id, run_attempt) using JobAggIndex.
324+
325+
Args:
326+
jobs: List of job rows from the datasource
327+
commits: Ordered list of (sha, timestamp) tuples (newest → older)
328+
"""
329+
330+
commit_timestamps = dict(commits)
301331

302332
index = JobAggIndex.from_rows(
303333
jobs,
@@ -310,9 +340,6 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
310340
),
311341
)
312342

313-
# Preserve commit order as first-seen in the job rows (datasource orders newest→older).
314-
commit_shas = index.unique_values(lambda j: j.head_sha)
315-
316343
# Map (sha, workflow, base) -> [attempt_keys]
317344
groups_index = index.group_keys_by(
318345
key_fn=lambda j: (j.head_sha, j.workflow_name, j.base_name)
@@ -329,7 +356,7 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
329356
# Track failure types across all attempts/commits for this base
330357
has_relevant_failures = False # at least one non-test failure observed
331358

332-
for sha in commit_shas:
359+
for sha, _ in commits:
333360
attempt_keys: List[
334361
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
335362
] = groups_index.get((sha, wf_name, base_name), [])
@@ -374,7 +401,11 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
374401
)
375402

376403
# important to always include the commit, even if no events
377-
commit_objs.append(SignalCommit(head_sha=sha, events=events))
404+
commit_objs.append(
405+
SignalCommit(
406+
head_sha=sha, timestamp=commit_timestamps[sha], events=events
407+
)
408+
)
378409

379410
# Emit job signal when failures were present and failures were NOT exclusively test-caused
380411
if has_relevant_failures:

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal_extraction_datasource.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,65 @@ class SignalExtractionDatasource:
2424
Encapsulates ClickHouse queries used by the signal extraction layer.
2525
"""
2626

27+
def fetch_commits_in_time_range(
28+
self, *, repo_full_name: str, lookback_hours: int
29+
) -> List[tuple[Sha, datetime]]:
30+
"""
31+
Fetch all commits pushed to main within the lookback window.
32+
Returns list of (sha, timestamp) tuples ordered newest → older.
33+
"""
34+
lookback_time = datetime.now() - timedelta(hours=lookback_hours)
35+
36+
query = """
37+
SELECT head_commit.id AS sha, max(head_commit.timestamp) AS ts
38+
FROM default.push
39+
WHERE head_commit.timestamp >= {lookback_time:DateTime}
40+
AND ref = 'refs/heads/main'
41+
AND dynamoKey like {repo:String}
42+
GROUP BY sha
43+
ORDER BY ts DESC
44+
"""
45+
46+
params = {
47+
"lookback_time": lookback_time,
48+
"repo": f"{repo_full_name}%",
49+
}
50+
51+
log = logging.getLogger(__name__)
52+
log.info(
53+
"[extract] Fetching commits in time range: repo=%s lookback=%sh",
54+
repo_full_name,
55+
lookback_hours,
56+
)
57+
t0 = time.perf_counter()
58+
for attempt in RetryWithBackoff():
59+
with attempt:
60+
res = CHCliFactory().client.query(query, parameters=params)
61+
commits = [(Sha(row[0]), row[1]) for row in res.result_rows]
62+
dt = time.perf_counter() - t0
63+
log.info("[extract] Commits fetched: %d commits in %.2fs", len(commits), dt)
64+
return commits
65+
2766
def fetch_jobs_for_workflows(
28-
self, *, repo_full_name: str, workflows: Iterable[str], lookback_hours: int
67+
self,
68+
*,
69+
repo_full_name: str,
70+
workflows: Iterable[str],
71+
lookback_hours: int,
72+
head_shas: List[Sha],
2973
) -> List[JobRow]:
3074
"""
31-
Fetch recent workflow job rows for the given workflows within the lookback window.
75+
Fetch workflow job rows for the given head_shas and workflows.
3276
33-
Returns rows ordered by push timestamp desc, then by workflow run/job identity.
77+
Returns rows ordered by head_sha (following the order of head_shas), then by started_at ASC.
3478
"""
3579
lookback_time = datetime.now() - timedelta(hours=lookback_hours)
3680

3781
workflow_filter = ""
3882
params: Dict[str, Any] = {
3983
"lookback_time": lookback_time,
4084
"repo": repo_full_name,
85+
"head_shas": [str(s) for s in head_shas],
4186
}
4287
workflow_list = list(workflows)
4388
if workflow_list:
@@ -55,13 +100,6 @@ def fetch_jobs_for_workflows(
55100
# the extractor and downstream logic rely on the KG-adjusted value so
56101
# that pending jobs can also be recognized as failures-in-progress.
57102
query = f"""
58-
WITH push_dedup AS (
59-
SELECT head_commit.id AS sha, max(head_commit.timestamp) AS ts
60-
FROM default.push
61-
WHERE head_commit.timestamp >= {{lookback_time:DateTime}}
62-
AND ref = 'refs/heads/main'
63-
GROUP BY sha
64-
)
65103
SELECT
66104
wf.head_sha,
67105
wf.workflow_name,
@@ -76,19 +114,20 @@ def fetch_jobs_for_workflows(
76114
wf.created_at,
77115
tupleElement(wf.torchci_classification_kg,'rule') AS rule
78116
FROM default.workflow_job AS wf FINAL
79-
INNER JOIN push_dedup p ON wf.head_sha = p.sha
80117
WHERE wf.repository_full_name = {{repo:String}}
118+
AND wf.head_sha IN {{head_shas:Array(String)}}
81119
AND wf.created_at >= {{lookback_time:DateTime}}
82120
AND (wf.name NOT LIKE '%mem_leak_check%' AND wf.name NOT LIKE '%rerun_disabled_tests%')
83121
{workflow_filter}
84-
ORDER BY p.ts DESC, wf.started_at ASC, wf.head_sha, wf.run_id, wf.run_attempt, wf.name
122+
ORDER BY wf.head_sha, wf.started_at ASC, wf.run_id, wf.run_attempt, wf.name
85123
"""
86124

87125
log = logging.getLogger(__name__)
88126
log.info(
89-
"[extract] Fetching jobs: repo=%s workflows=%s lookback=%sh",
127+
"[extract] Fetching jobs: repo=%s workflows=%s commits=%d lookback=%sh",
90128
repo_full_name,
91129
",".join(workflow_list) if workflow_list else "<all>",
130+
len(head_shas),
92131
lookback_hours,
93132
)
94133
t0 = time.perf_counter()

0 commit comments

Comments
 (0)