From b4c28c3804f5a85a869bfe0e841b42c1c68a2d0b Mon Sep 17 00:00:00 2001 From: Sakibul Islam Date: Sat, 12 Apr 2025 19:30:20 +0100 Subject: [PATCH] Add mergeable queue - Implement mergeable queue loop - Add to queue on PR edit base change - Add to queue on push to branch --- ...031aa4a035fc7f74cf773309e8d2079d48aba.json | 125 +++++++ ...e7d7614ffac5358b73ff039220ffed15c4fce.json | 15 + src/bin/bors.rs | 12 +- src/bors/handlers/mod.rs | 7 +- src/bors/handlers/pr_events.rs | 69 +++- src/bors/mergeable_queue.rs | 320 ++++++++++++++++++ src/bors/mod.rs | 1 + src/database/client.rs | 21 +- src/database/mod.rs | 2 +- src/database/operations.rs | 62 ++++ src/github/server.rs | 61 +++- src/lib.rs | 2 +- src/tests/mocks/bors.rs | 22 +- src/tests/mocks/pull_request.rs | 1 - 14 files changed, 694 insertions(+), 26 deletions(-) create mode 100644 .sqlx/query-2d6d97641c85f3ac05f952644ef031aa4a035fc7f74cf773309e8d2079d48aba.json create mode 100644 .sqlx/query-3cc3661fb1e1ff945e415f80637e7d7614ffac5358b73ff039220ffed15c4fce.json create mode 100644 src/bors/mergeable_queue.rs diff --git a/.sqlx/query-2d6d97641c85f3ac05f952644ef031aa4a035fc7f74cf773309e8d2079d48aba.json b/.sqlx/query-2d6d97641c85f3ac05f952644ef031aa4a035fc7f74cf773309e8d2079d48aba.json new file mode 100644 index 00000000..12e8aa3f --- /dev/null +++ b/.sqlx/query-2d6d97641c85f3ac05f952644ef031aa4a035fc7f74cf773309e8d2079d48aba.json @@ -0,0 +1,125 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n pr.id,\n pr.repository as \"repository: GithubRepoName\",\n pr.number as \"number!: i64\",\n (\n pr.approved_by,\n pr.approved_sha\n ) AS \"approval_status!: ApprovalStatus\",\n pr.status as \"pr_status: PullRequestStatus\",\n pr.priority,\n pr.rollup as \"rollup: RollupMode\",\n pr.delegated_permission as \"delegated_permission: DelegatedPermission\",\n pr.base_branch,\n pr.mergeable_state as \"mergeable_state: MergeableState\",\n pr.created_at as \"created_at: DateTime\",\n build AS \"try_build: BuildModel\"\n FROM pull_request as pr\n LEFT JOIN build ON pr.build_id = build.id\n WHERE pr.repository = $1\n AND pr.base_branch = $2\n AND pr.status IN ('open', 'draft')\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "repository: GithubRepoName", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "number!: i64", + "type_info": "Int8" + }, + { + "ordinal": 3, + "name": "approval_status!: ApprovalStatus", + "type_info": "Record" + }, + { + "ordinal": 4, + "name": "pr_status: PullRequestStatus", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "priority", + "type_info": "Int4" + }, + { + "ordinal": 6, + "name": "rollup: RollupMode", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "delegated_permission: DelegatedPermission", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "base_branch", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "mergeable_state: MergeableState", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "created_at: DateTime", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, + "name": "try_build: BuildModel", + "type_info": { + "Custom": { + "name": "build", + "kind": { + "Composite": [ + [ + "id", + "Int4" + ], + [ + "repository", + "Text" + ], + [ + "branch", + "Text" + ], + [ + "commit_sha", + "Text" + ], + [ + "status", + "Text" + ], + [ + "parent", + "Text" + ], + [ + "created_at", + "Timestamptz" + ] + ] + } + } + } + } + ], + "parameters": { + "Left": [ + "Text", + "Text" + ] + }, + "nullable": [ + false, + false, + false, + null, + false, + true, + true, + true, + false, + false, + false, + null + ] + }, + "hash": "2d6d97641c85f3ac05f952644ef031aa4a035fc7f74cf773309e8d2079d48aba" +} diff --git a/.sqlx/query-3cc3661fb1e1ff945e415f80637e7d7614ffac5358b73ff039220ffed15c4fce.json b/.sqlx/query-3cc3661fb1e1ff945e415f80637e7d7614ffac5358b73ff039220ffed15c4fce.json new file mode 100644 index 00000000..2343d258 --- /dev/null +++ b/.sqlx/query-3cc3661fb1e1ff945e415f80637e7d7614ffac5358b73ff039220ffed15c4fce.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE pull_request SET mergeable_state = $1 WHERE id = $2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Int4" + ] + }, + "nullable": [] + }, + "hash": "3cc3661fb1e1ff945e415f80637e7d7614ffac5358b73ff039220ffed15c4fce" +} diff --git a/src/bin/bors.rs b/src/bin/bors.rs index 7811675b..920536a9 100644 --- a/src/bin/bors.rs +++ b/src/bin/bors.rs @@ -6,8 +6,9 @@ use std::time::Duration; use anyhow::Context; use bors::{ - BorsContext, BorsGlobalEvent, CommandParser, PgDbClient, ServerState, TeamApiClient, - WebhookSecret, create_app, create_bors_process, create_github_client, load_repositories, + BorsContext, BorsGlobalEvent, BorsProcess, CommandParser, PgDbClient, ServerState, + TeamApiClient, WebhookSecret, create_app, create_bors_process, create_github_client, + load_repositories, }; use clap::Parser; use sqlx::postgres::PgConnectOptions; @@ -109,7 +110,12 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { } let ctx = BorsContext::new(CommandParser::new(opts.cmd_prefix), Arc::new(db), repos); - let (repository_tx, global_tx, bors_process) = create_bors_process(ctx, client, team_api); + let BorsProcess { + repository_tx, + global_tx, + bors_process, + .. + } = create_bors_process(ctx, client, team_api); let refresh_tx = global_tx.clone(); let refresh_process = async move { diff --git a/src/bors/handlers/mod.rs b/src/bors/handlers/mod.rs index 40891bda..539ee8f3 100644 --- a/src/bors/handlers/mod.rs +++ b/src/bors/handlers/mod.rs @@ -31,6 +31,8 @@ use tracing::Instrument; #[cfg(test)] use crate::tests::util::TestSyncMarker; +use super::mergeable_queue::MergeableQueueSender; + mod help; mod info; mod labels; @@ -48,6 +50,7 @@ pub static WAIT_FOR_WORKFLOW_STARTED: TestSyncMarker = TestSyncMarker::new(); pub async fn handle_bors_repository_event( event: BorsRepositoryEvent, ctx: Arc, + mergeable_queue_tx: MergeableQueueSender, ) -> anyhow::Result<()> { let db = Arc::clone(&ctx.db); let Some(repo) = ctx @@ -129,7 +132,7 @@ pub async fn handle_bors_repository_event( let span = tracing::info_span!("Pull request edited", repo = payload.repository.to_string()); - handle_pull_request_edited(repo, db, payload) + handle_pull_request_edited(repo, db, mergeable_queue_tx, payload) .instrument(span.clone()) .await?; } @@ -199,7 +202,7 @@ pub async fn handle_bors_repository_event( let span = tracing::info_span!("Pushed to branch", repo = payload.repository.to_string()); - handle_push_to_branch(repo, db, payload) + handle_push_to_branch(repo, db, mergeable_queue_tx, payload) .instrument(span.clone()) .await?; } diff --git a/src/bors/handlers/pr_events.rs b/src/bors/handlers/pr_events.rs index c9e9d0e3..122b3fe6 100644 --- a/src/bors/handlers/pr_events.rs +++ b/src/bors/handlers/pr_events.rs @@ -5,6 +5,7 @@ use crate::bors::event::{ PushToBranch, }; use crate::bors::handlers::labels::handle_label_trigger; +use crate::bors::mergeable_queue::MergeableQueueSender; use crate::bors::{Comment, PullRequestStatus, RepositoryState}; use crate::database::MergeableState; use crate::github::{CommitSha, LabelTrigger, PullRequestNumber}; @@ -13,6 +14,7 @@ use std::sync::Arc; pub(super) async fn handle_pull_request_edited( repo_state: Arc, db: Arc, + mergeable_queue: MergeableQueueSender, payload: PullRequestEdited, ) -> anyhow::Result<()> { let pr = &payload.pull_request; @@ -32,6 +34,8 @@ pub(super) async fn handle_pull_request_edited( return Ok(()); }; + mergeable_queue.enqueue(pr_model.repository.clone(), pr_number); + if !pr_model.is_approved() { return Ok(()); } @@ -154,6 +158,7 @@ pub(super) async fn handle_pull_request_ready_for_review( pub(super) async fn handle_push_to_branch( repo_state: Arc, db: Arc, + mergeable_queue: MergeableQueueSender, payload: PushToBranch, ) -> anyhow::Result<()> { let rows = db @@ -163,8 +168,18 @@ pub(super) async fn handle_push_to_branch( MergeableState::Unknown, ) .await?; + let affected_prs = db + .get_nonclosed_pull_requests_by_base_branch(repo_state.repository(), &payload.branch) + .await?; - tracing::info!("Updated mergeable_state to `unknown` for {} PR(s)", rows); + tracing::info!( + "Adding {} PR(s) to the mergeable queue due to base branch change", + rows + ); + + for pr in affected_prs { + mergeable_queue.enqueue(repo_state.repository().clone(), pr.number); + } Ok(()) } @@ -207,7 +222,7 @@ mod tests { use crate::bors::PullRequestStatus; use crate::tests::mocks::default_pr_number; use crate::{ - database::MergeableState, + database::{MergeableState, OctocrabMergeableState}, tests::mocks::{User, default_branch_name, default_repo_name, run_test}, }; @@ -344,7 +359,7 @@ mod tests { run_test(pool.clone(), |mut tester| async { tester .edit_pr(default_repo_name(), default_pr_number(), |pr| { - pr.mergeable_state = octocrab::models::pulls::MergeableState::Dirty; + pr.mergeable_state = OctocrabMergeableState::Dirty; }) .await?; tester @@ -425,6 +440,7 @@ mod tests { .await; } + #[tracing_test::traced_test] #[sqlx::test] async fn open_and_merge_pr(pool: sqlx::PgPool) { run_test(pool, |mut tester| async { @@ -444,4 +460,51 @@ mod tests { }) .await; } + + #[sqlx::test] + async fn mergeable_queue_processes_pr_base_change(pool: sqlx::PgPool) { + run_test(pool, |mut tester| async { + let branch = tester.create_branch("beta").clone(); + tester + .edit_pr(default_repo_name(), default_pr_number(), |pr| { + pr.base_branch = branch; + pr.mergeable_state = OctocrabMergeableState::Unknown; + }) + .await?; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::Unknown) + .await?; + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Dirty; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::HasConflicts) + .await?; + Ok(tester) + }) + .await; + } + + #[sqlx::test] + async fn enqueue_prs_on_push_to_branch(pool: sqlx::PgPool) { + run_test(pool, |mut tester| async { + tester.open_pr(default_repo_name(), false).await?; + tester.push_to_branch(default_branch_name()).await?; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::Unknown) + .await?; + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Dirty; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::HasConflicts) + .await?; + Ok(tester) + }) + .await; + } } diff --git a/src/bors/mergeable_queue.rs b/src/bors/mergeable_queue.rs new file mode 100644 index 00000000..c43be8fb --- /dev/null +++ b/src/bors/mergeable_queue.rs @@ -0,0 +1,320 @@ +use crate::database::OctocrabMergeableState; +use crate::github::{GithubRepoName, PullRequestNumber}; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; +use tokio::sync::Notify; +use tokio::time::timeout; + +use super::BorsContext; + +/// Delay before processing a mergeable queue item for the first time. +const BASE_DELAY: Duration = Duration::from_millis(500); +/// Exponential backoff delay multiplier. +const BACKOFF_MULTIPLIER: f64 = 2.0; +/// Max number of mergeable check retries before giving up. +const MAX_RETRIES: u32 = 5; + +#[derive(Debug, Clone)] +pub struct QueuedPullRequest { + pub pr_number: PullRequestNumber, + pub repo: GithubRepoName, +} + +impl std::fmt::Display for QueuedPullRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/#{}", self.repo, self.pr_number) + } +} + +#[derive(Debug, Clone)] +pub struct MergeableQueueItem { + pub pull_request: QueuedPullRequest, + pub attempt: u32, +} + +#[derive(Debug, Clone)] +enum QueueMessage { + Item(MergeableQueueItem), + Shutdown, +} + +struct Item { + /// When to process item (None = immediate). + /// Reversed to create min-heap for expirations. + expiration: Reverse>, + inner: QueueMessage, +} + +impl PartialEq for Item { + fn eq(&self, other: &Self) -> bool { + self.expiration == other.expiration + } +} + +impl Eq for Item {} + +impl PartialOrd for Item { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Item { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.expiration.cmp(&other.expiration) + } +} + +struct SharedInner { + queue: Mutex>, + notify: Notify, +} + +#[derive(Clone)] +pub struct MergeableQueueSender { + inner: Arc, +} + +pub struct MergeableQueueReceiver { + inner: Arc, +} + +pub fn create_mergeable_queue() -> (MergeableQueueSender, MergeableQueueReceiver) { + let shared = Arc::new(SharedInner { + queue: Mutex::new(BinaryHeap::new()), + notify: Notify::new(), + }); + + ( + MergeableQueueSender { + inner: shared.clone(), + }, + MergeableQueueReceiver { inner: shared }, + ) +} + +impl MergeableQueueSender { + pub fn shutdown(&self) { + let mut queue = self.inner.queue.lock().unwrap(); + + // Send shutdown message + queue.push(Item { + expiration: Reverse(None), + inner: QueueMessage::Shutdown, + }); + + // and wake receiver for immediate processing. + self.inner.notify.notify_one(); + } + + pub fn enqueue(&self, repo: GithubRepoName, pr_number: PullRequestNumber) { + let expiration = Some(Instant::now() + BASE_DELAY); + + self.insert_item( + MergeableQueueItem { + pull_request: QueuedPullRequest { pr_number, repo }, + attempt: 1, + }, + expiration, + ); + } + + #[allow(dead_code)] + pub fn enqueue_now(&self, pr_number: PullRequestNumber, repo: GithubRepoName) { + self.insert_item( + MergeableQueueItem { + pull_request: QueuedPullRequest { pr_number, repo }, + attempt: 1, + }, + None, + ); + } + + pub fn enqueue_retry(&self, queue_item: MergeableQueueItem) { + let next_attempt = queue_item.attempt + 1; + let delay = calculate_exponential_backoff(BASE_DELAY, BACKOFF_MULTIPLIER, next_attempt); + let expiration = Some(Instant::now() + delay); + + self.insert_item( + MergeableQueueItem { + pull_request: queue_item.pull_request, + attempt: next_attempt, + }, + expiration, + ); + } + + fn insert_item(&self, item: MergeableQueueItem, expiration: Option) { + let mut queue = self.inner.queue.lock().unwrap(); + + // Notify when: + // 1. The current item expires sooner than the head of the queue + let has_earlier_expiration = + queue + .peek() + .is_some_and(|head| match (expiration, head.expiration) { + (Some(new_exp), Reverse(Some(head_exp))) => new_exp < head_exp, + _ => false, + }); + // 2. The queue was empty before insertion (reader might be waiting) + let should_notify = queue.is_empty() || expiration.is_none() || has_earlier_expiration; + + queue.push(Item { + expiration: Reverse(expiration), + inner: QueueMessage::Item(item), + }); + + if should_notify { + self.inner.notify.notify_one(); + } + } +} + +impl MergeableQueueReceiver { + fn peek_inner(&self) -> Result> { + let now = Instant::now(); + let mut queue = self.inner.queue.lock().unwrap(); + + match queue.peek() { + // Immediate item, ready for processing. + Some(Item { + expiration: Reverse(None), + .. + }) => { + let item = queue.pop().unwrap().inner; + Ok(item) + } + // Expiration has passed, ready for processing. + Some(Item { + expiration: Reverse(Some(expiration)), + .. + }) if *expiration <= now => { + let item = queue.pop().unwrap().inner; + Ok(item) + } + // Scheduled for future, wait until it's ready. + Some(Item { + expiration: Reverse(Some(expiration)), + .. + }) => { + let wait_time = *expiration - now; + Err(Some(wait_time)) + } + // Empty queue, wait for an item to be added. + None => Err(None), + } + } + + pub async fn dequeue(&self) -> Option<(MergeableQueueItem, MergeableQueueSender)> { + loop { + match self.peek_inner() { + // Item is ready. + Ok(QueueMessage::Item(item)) => { + break Some(( + item, + MergeableQueueSender { + inner: self.inner.clone(), + }, + )); + } + Ok(QueueMessage::Shutdown) => { + break None; + } + // Item exists but not ready, wait until then or until notified of a higher priority item. + Err(Some(duration)) => { + let _ = timeout(duration, self.inner.notify.notified()).await; + } + // Queue is empty, wait until notified of a new item. + Err(None) => { + self.inner.notify.notified().await; + } + } + } + } +} + +fn calculate_exponential_backoff( + base_delay: Duration, + backoff_multiplier: f64, + attempt: u32, +) -> Duration { + let multiplier = backoff_multiplier.powi(attempt as i32 - 1); + let timeout = (base_delay.as_millis() as f64 * multiplier) as u64; + Duration::from_millis(timeout) +} + +pub async fn handle_mergeable_queue_item( + ctx: Arc, + mq_tx: MergeableQueueSender, + mq_item: MergeableQueueItem, +) -> anyhow::Result<()> { + let MergeableQueueItem { + pull_request, + attempt, + .. + } = mq_item.clone(); + + if attempt >= MAX_RETRIES { + tracing::warn!( + "Exceeded max mergeable state checks for PR: {}", + pull_request + ); + return Ok(()); + } + + let pr_model = match ctx + .db + .get_pull_request(&pull_request.repo, pull_request.pr_number) + .await? + { + Some(model) => model, + None => { + tracing::error!("PR not found in database: {}", pull_request); + return Ok(()); + } + }; + + let repo_state = match ctx.repositories.read() { + Ok(guard) => match guard.get(&pull_request.repo) { + Some(state) => state.clone(), + None => { + return Err(anyhow::anyhow!( + "Repository not found: {}", + pull_request.repo + )); + } + }, + Err(err) => { + return Err(anyhow::anyhow!( + "Failed to acquire read lock on repositories: {}", + err + )); + } + }; + + let fetched_pr = repo_state + .client + .get_pull_request(pull_request.pr_number) + .await?; + let new_mergeable_state = fetched_pr.mergeable_state; + + if new_mergeable_state == OctocrabMergeableState::Unknown { + tracing::info!( + "Retrying mergeable state check for PR: {pull_request} ({attempt}/{MAX_RETRIES})", + ); + + mq_tx.enqueue_retry(mq_item); + + return Ok(()); + } + + ctx.db + .update_pr_mergeable_state(&pr_model, new_mergeable_state.clone().into()) + .await?; + + tracing::debug!("PR {pull_request} `mergeable_state` updated to `{new_mergeable_state:?}`"); + + Ok(()) +} diff --git a/src/bors/mod.rs b/src/bors/mod.rs index ccec07dc..88689472 100644 --- a/src/bors/mod.rs +++ b/src/bors/mod.rs @@ -22,6 +22,7 @@ pub mod comment; mod context; pub mod event; mod handlers; +pub mod mergeable_queue; #[derive(Clone, Debug, PartialEq, Eq)] pub enum CheckSuiteStatus { diff --git a/src/database/client.rs b/src/database/client.rs index e7dfd9a3..8957d431 100644 --- a/src/database/client.rs +++ b/src/database/client.rs @@ -10,11 +10,12 @@ use crate::github::{CommitSha, GithubRepoName}; use super::operations::{ approve_pull_request, create_build, create_pull_request, create_workflow, - delegate_pull_request, find_build, find_pr_by_build, get_pull_request, get_repository, + delegate_pull_request, find_build, find_pr_by_build, + get_nonclosed_pull_requests_by_base_branch, get_pull_request, get_repository, get_running_builds, get_workflow_urls_for_build, get_workflows_for_build, set_pr_priority, set_pr_rollup, set_pr_status, unapprove_pull_request, undelegate_pull_request, update_build_status, update_mergeable_states_by_base_branch, update_pr_build_id, - update_workflow_status, upsert_pull_request, upsert_repository, + update_pr_mergeable_state, update_workflow_status, upsert_pull_request, upsert_repository, }; use super::{ApprovalInfo, DelegatedPermission, MergeableState, RunId}; @@ -59,6 +60,14 @@ impl PgDbClient { undelegate_pull_request(&self.pool, pr.id).await } + pub async fn update_pr_mergeable_state( + &self, + pr: &PullRequestModel, + mergeable_state: MergeableState, + ) -> anyhow::Result<()> { + update_pr_mergeable_state(&self.pool, pr.id, mergeable_state).await + } + pub async fn update_mergeable_states_by_base_branch( &self, repo: &GithubRepoName, @@ -104,6 +113,14 @@ impl PgDbClient { Ok(pr) } + pub async fn get_nonclosed_pull_requests_by_base_branch( + &self, + repo: &GithubRepoName, + base_branch: &str, + ) -> anyhow::Result> { + get_nonclosed_pull_requests_by_base_branch(&self.pool, repo, base_branch).await + } + pub async fn create_pull_request( &self, repo: &GithubRepoName, diff --git a/src/database/mod.rs b/src/database/mod.rs index 3917f4cf..87649ac6 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -10,7 +10,7 @@ use crate::{ }; use chrono::{DateTime, Utc}; pub use client::PgDbClient; -use octocrab::models::pulls::MergeableState as OctocrabMergeableState; +pub use octocrab::models::pulls::MergeableState as OctocrabMergeableState; use sqlx::error::BoxDynError; use sqlx::{Database, Postgres}; diff --git a/src/database/operations.rs b/src/database/operations.rs index fef1a043..793c7cff 100644 --- a/src/database/operations.rs +++ b/src/database/operations.rs @@ -163,6 +163,68 @@ pub(crate) async fn upsert_pull_request( .await } +/// Uses inclusion rather than negation, which would cause a full table scan, +/// to leverage the index from PR #246 (https://github.com/rust-lang/bors/pull/246). +pub(crate) async fn get_nonclosed_pull_requests_by_base_branch( + executor: impl PgExecutor<'_>, + repo: &GithubRepoName, + base_branch: &str, +) -> anyhow::Result> { + measure_db_query("get_pull_requests_by_base_branch", || async { + let records = sqlx::query_as!( + PullRequestModel, + r#" + SELECT + pr.id, + pr.repository as "repository: GithubRepoName", + pr.number as "number!: i64", + ( + pr.approved_by, + pr.approved_sha + ) AS "approval_status!: ApprovalStatus", + pr.status as "pr_status: PullRequestStatus", + pr.priority, + pr.rollup as "rollup: RollupMode", + pr.delegated_permission as "delegated_permission: DelegatedPermission", + pr.base_branch, + pr.mergeable_state as "mergeable_state: MergeableState", + pr.created_at as "created_at: DateTime", + build AS "try_build: BuildModel" + FROM pull_request as pr + LEFT JOIN build ON pr.build_id = build.id + WHERE pr.repository = $1 + AND pr.base_branch = $2 + AND pr.status IN ('open', 'draft') + "#, + repo as &GithubRepoName, + base_branch + ) + .fetch_all(executor) + .await?; + + Ok(records) + }) + .await +} + +pub(crate) async fn update_pr_mergeable_state( + executor: impl PgExecutor<'_>, + pr_id: i32, + mergeable_state: MergeableState, +) -> anyhow::Result<()> { + measure_db_query("update_pr_mergeable_state", || async { + sqlx::query!( + "UPDATE pull_request SET mergeable_state = $1 WHERE id = $2", + mergeable_state as _, + pr_id + ) + .execute(executor) + .await?; + Ok(()) + }) + .await +} + pub(crate) async fn update_mergeable_states_by_base_branch( executor: impl PgExecutor<'_>, repo: &GithubRepoName, diff --git a/src/github/server.rs b/src/github/server.rs index d38e808f..c87398a7 100644 --- a/src/github/server.rs +++ b/src/github/server.rs @@ -1,4 +1,8 @@ use crate::bors::event::BorsEvent; +use crate::bors::mergeable_queue::{ + MergeableQueueReceiver, MergeableQueueSender, create_mergeable_queue, + handle_mergeable_queue_item, +}; use crate::bors::{BorsContext, handle_bors_global_event, handle_bors_repository_event}; use crate::github::webhook::GitHubWebhook; use crate::github::webhook::WebhookSecret; @@ -12,6 +16,7 @@ use axum::response::IntoResponse; use axum::routing::{get, post}; use octocrab::Octocrab; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use tokio::sync::mpsc; use tower::limit::ConcurrencyLimitLayer; @@ -79,19 +84,25 @@ pub async fn github_webhook_handler( } } +pub struct BorsProcess { + pub repository_tx: mpsc::Sender, + pub global_tx: mpsc::Sender, + pub mergeable_queue_tx: MergeableQueueSender, + pub bors_process: Pin + Send>>, +} + /// Creates a future with a Bors process that continuously receives webhook events and reacts to /// them. pub fn create_bors_process( ctx: BorsContext, gh_client: Octocrab, team_api: TeamApiClient, -) -> ( - mpsc::Sender, - mpsc::Sender, - impl Future, -) { +) -> BorsProcess { let (repository_tx, repository_rx) = mpsc::channel::(1024); let (global_tx, global_rx) = mpsc::channel::(1024); + let (mergeable_queue_tx, mergeable_queue_rx) = create_mergeable_queue(); + + let mq_tx = mergeable_queue_tx.clone(); let service = async move { let ctx = Arc::new(ctx); @@ -103,8 +114,9 @@ pub fn create_bors_process( #[cfg(test)] { tokio::join!( - consume_repository_events(ctx.clone(), repository_rx), - consume_global_events(ctx.clone(), global_rx, gh_client, team_api) + consume_repository_events(ctx.clone(), repository_rx, mergeable_queue_tx), + consume_global_events(ctx.clone(), global_rx, gh_client, team_api), + consume_mergeable_queue(ctx, mergeable_queue_rx) ); } // In real execution, the bot runs forever. If there is something that finishes @@ -113,28 +125,39 @@ pub fn create_bors_process( #[cfg(not(test))] { tokio::select! { - _ = consume_repository_events(ctx.clone(), repository_rx) => { + _ = consume_repository_events(ctx.clone(), repository_rx, mergeable_queue_tx) => { tracing::error!("Repository event handling process has ended"); } _ = consume_global_events(ctx.clone(), global_rx, gh_client, team_api) => { tracing::error!("Global event handling process has ended"); } + _ = consume_mergeable_queue(ctx, mergeable_queue_rx) => { + tracing::error!("Mergeable queue handling process has ended") + } } } }; - (repository_tx, global_tx, service) + + BorsProcess { + repository_tx, + global_tx, + mergeable_queue_tx: mq_tx, + bors_process: Box::pin(service), + } } async fn consume_repository_events( ctx: Arc, mut repository_rx: mpsc::Receiver, + mergeable_queue_tx: MergeableQueueSender, ) { while let Some(event) = repository_rx.recv().await { let ctx = ctx.clone(); + let mergeable_queue_tx = mergeable_queue_tx.clone(); let span = tracing::info_span!("RepositoryEvent"); tracing::debug!("Received repository event: {event:#?}"); - if let Err(error) = handle_bors_repository_event(event, ctx) + if let Err(error) = handle_bors_repository_event(event, ctx, mergeable_queue_tx) .instrument(span.clone()) .await { @@ -163,6 +186,24 @@ async fn consume_global_events( } } +async fn consume_mergeable_queue( + ctx: Arc, + mergeable_queue_rx: MergeableQueueReceiver, +) { + while let Some((mq_item, mq_tx)) = mergeable_queue_rx.dequeue().await { + let ctx = ctx.clone(); + + let span = tracing::info_span!("MergeableQueue"); + tracing::debug!("Received mergeable queue item: {}", mq_item.pull_request); + if let Err(error) = handle_mergeable_queue_item(ctx, mq_tx, mq_item) + .instrument(span.clone()) + .await + { + handle_root_error(span, error); + } + } +} + #[allow(unused_variables)] fn handle_root_error(span: Span, error: Error) { // In tests, we want to panic on all errors. diff --git a/src/lib.rs b/src/lib.rs index 69316d45..c1161924 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ pub use github::{ WebhookSecret, api::create_github_client, api::load_repositories, - server::{ServerState, create_app, create_bors_process}, + server::{BorsProcess, ServerState, create_app, create_bors_process}, }; pub use permissions::TeamApiClient; diff --git a/src/tests/mocks/bors.rs b/src/tests/mocks/bors.rs index f1029c08..1c79563b 100644 --- a/src/tests/mocks/bors.rs +++ b/src/tests/mocks/bors.rs @@ -12,9 +12,11 @@ use tokio::sync::mpsc::Sender; use tokio::task::JoinHandle; use tower::Service; +use crate::bors::mergeable_queue::MergeableQueueSender; use crate::bors::{RollupMode, WAIT_FOR_REFRESH}; use crate::database::{BuildStatus, DelegatedPermission, PullRequestModel}; use crate::github::api::load_repositories; +use crate::github::server::BorsProcess; use crate::github::{GithubRepoName, PullRequestNumber}; use crate::tests::mocks::comment::{Comment, GitHubIssueCommentEventPayload}; use crate::tests::mocks::workflow::{ @@ -30,7 +32,9 @@ use crate::{ create_app, create_bors_process, }; -use super::pull_request::{GitHubPullRequestEventPayload, PullRequestChangeEvent}; +use super::pull_request::{ + GitHubPullRequestEventPayload, GitHubPushEventPayload, PullRequestChangeEvent, +}; use super::repository::PullRequest; pub struct BorsBuilder { @@ -95,6 +99,7 @@ pub struct BorsTester { http_mock: ExternalHttpMock, github: GitHubState, db: Arc, + mergeable_queue_tx: MergeableQueueSender, // Sender for bors global events global_tx: Sender, } @@ -115,8 +120,12 @@ impl BorsTester { let ctx = BorsContext::new(CommandParser::new("@bors".to_string()), db.clone(), repos); - let (repository_tx, global_tx, bors_process) = - create_bors_process(ctx, mock.github_client(), mock.team_api_client()); + let BorsProcess { + repository_tx, + global_tx, + mergeable_queue_tx, + bors_process, + } = create_bors_process(ctx, mock.github_client(), mock.team_api_client()); let state = ServerState::new( repository_tx, @@ -131,6 +140,7 @@ impl BorsTester { http_mock: mock, github, db, + mergeable_queue_tx, global_tx, }, bors, @@ -232,6 +242,11 @@ impl BorsTester { .clone() } + pub async fn push_to_branch(&mut self, branch: &str) -> anyhow::Result<()> { + self.send_webhook("push", GitHubPushEventPayload::new(branch)) + .await + } + pub fn try_branch(&self) -> Branch { self.get_branch("automation/bors/try") } @@ -617,6 +632,7 @@ impl BorsTester { // Make sure that the event channel senders are closed drop(self.app); drop(self.global_tx); + self.mergeable_queue_tx.shutdown(); // Wait until all events are handled in the bors service bors.await.unwrap(); // Flush any local queues diff --git a/src/tests/mocks/pull_request.rs b/src/tests/mocks/pull_request.rs index 2bd36f44..dd6040bc 100644 --- a/src/tests/mocks/pull_request.rs +++ b/src/tests/mocks/pull_request.rs @@ -291,7 +291,6 @@ pub struct GitHubPushEventPayload { } impl GitHubPushEventPayload { - #[allow(unused)] pub fn new(branch_name: &str) -> Self { GitHubPushEventPayload { repository: default_repo_name().into(),