diff --git a/src/bors/handlers/info.rs b/src/bors/handlers/info.rs index eaceeea..ad30a54 100644 --- a/src/bors/handlers/info.rs +++ b/src/bors/handlers/info.rs @@ -12,7 +12,7 @@ pub(super) async fn command_info( ) -> anyhow::Result<()> { // Geting PR info from database let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo.client.repository(), pr.number, &pr.base.name, diff --git a/src/bors/handlers/mod.rs b/src/bors/handlers/mod.rs index e58984e..d13ccf8 100644 --- a/src/bors/handlers/mod.rs +++ b/src/bors/handlers/mod.rs @@ -140,7 +140,7 @@ pub async fn handle_bors_repository_event( let span = tracing::info_span!("Pull request pushed", repo = payload.repository.to_string()); - handle_push_to_pull_request(repo, db, payload) + handle_push_to_pull_request(repo, db, mergeable_queue_tx, payload) .instrument(span.clone()) .await?; } @@ -148,7 +148,7 @@ pub async fn handle_bors_repository_event( let span = tracing::info_span!("Pull request opened", repo = payload.repository.to_string()); - handle_pull_request_opened(repo, db, payload) + handle_pull_request_opened(repo, db, mergeable_queue_tx, payload) .instrument(span.clone()) .await?; } @@ -174,7 +174,7 @@ pub async fn handle_bors_repository_event( repo = payload.repository.to_string() ); - handle_pull_request_reopened(repo, db, payload) + handle_pull_request_reopened(repo, db, mergeable_queue_tx, payload) .instrument(span.clone()) .await?; } @@ -503,7 +503,7 @@ async fn has_permission( } let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, diff --git a/src/bors/handlers/pr_events.rs b/src/bors/handlers/pr_events.rs index 6a49c1e..e5b5b8c 100644 --- a/src/bors/handlers/pr_events.rs +++ b/src/bors/handlers/pr_events.rs @@ -20,7 +20,7 @@ pub(super) async fn handle_pull_request_edited( let pr = &payload.pull_request; let pr_number = pr.number; let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr_number, &pr.base.name, @@ -48,12 +48,13 @@ pub(super) async fn handle_pull_request_edited( pub(super) async fn handle_push_to_pull_request( repo_state: Arc, db: Arc, + mergeable_queue: MergeableQueueSender, payload: PullRequestPushed, ) -> anyhow::Result<()> { let pr = &payload.pull_request; let pr_number = pr.number; let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr_number, &pr.base.name, @@ -62,6 +63,8 @@ pub(super) async fn handle_push_to_pull_request( ) .await?; + mergeable_queue.enqueue(repo_state.repository().clone(), pr_number); + if !pr_model.is_approved() { return Ok(()); } @@ -74,6 +77,7 @@ pub(super) async fn handle_push_to_pull_request( pub(super) async fn handle_pull_request_opened( repo_state: Arc, db: Arc, + mergeable_queue: MergeableQueueSender, payload: PullRequestOpened, ) -> anyhow::Result<()> { let pr_status = if payload.draft { @@ -87,7 +91,11 @@ pub(super) async fn handle_pull_request_opened( &payload.pull_request.base.name, pr_status, ) - .await + .await?; + + mergeable_queue.enqueue(repo_state.repository().clone(), payload.pull_request.number); + + Ok(()) } pub(super) async fn handle_pull_request_closed( @@ -119,14 +127,23 @@ pub(super) async fn handle_pull_request_merged( pub(super) async fn handle_pull_request_reopened( repo_state: Arc, db: Arc, + mergeable_queue: MergeableQueueSender, payload: PullRequestReopened, ) -> anyhow::Result<()> { - db.set_pr_status( + let pr = &payload.pull_request; + let pr_number = pr.number; + db.upsert_pull_request( repo_state.repository(), - payload.pull_request.number, - PullRequestStatus::Open, + pr_number, + &pr.base.name, + pr.mergeable_state.clone().into(), + &pr.status, ) - .await + .await?; + + mergeable_queue.enqueue(repo_state.repository().clone(), pr_number); + + Ok(()) } pub(super) async fn handle_pull_request_converted_to_draft( @@ -506,4 +523,73 @@ mod tests { }) .await; } + + #[sqlx::test] + async fn enqueue_prs_on_pr_opened(pool: sqlx::PgPool) { + run_test(pool, |mut tester| async { + tester.open_pr(default_repo_name(), false).await?; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::Unknown) + .await?; + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Dirty; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::HasConflicts) + .await?; + Ok(tester) + }) + .await; + } + + #[sqlx::test] + async fn enqueue_prs_on_pr_reopened(pool: sqlx::PgPool) { + run_test(pool, |mut tester| async { + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Unknown; + tester + .reopen_pr(default_repo_name(), default_pr_number()) + .await?; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::Unknown) + .await?; + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Dirty; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::HasConflicts) + .await?; + Ok(tester) + }) + .await; + } + + #[sqlx::test] + async fn enqueue_prs_on_push_to_pr(pool: sqlx::PgPool) { + run_test(pool, |mut tester| async { + tester + .push_to_pr(default_repo_name(), default_pr_number()) + .await?; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::Unknown) + .await?; + tester + .default_repo() + .lock() + .get_pr_mut(default_pr_number()) + .mergeable_state = OctocrabMergeableState::Dirty; + tester + .wait_for_default_pr(|pr| pr.mergeable_state == MergeableState::HasConflicts) + .await?; + Ok(tester) + }) + .await; + } } diff --git a/src/bors/handlers/review.rs b/src/bors/handlers/review.rs index 1c30550..5ef7dd7 100644 --- a/src/bors/handlers/review.rs +++ b/src/bors/handlers/review.rs @@ -41,7 +41,7 @@ pub(super) async fn command_approve( sha: pr.head.sha.to_string(), }; let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, @@ -70,7 +70,7 @@ pub(super) async fn command_unapprove( return Ok(()); }; let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, @@ -98,7 +98,7 @@ pub(super) async fn command_set_priority( return Ok(()); }; let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, @@ -129,7 +129,7 @@ pub(super) async fn command_delegate( } let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, @@ -155,7 +155,7 @@ pub(super) async fn command_undelegate( return Ok(()); } let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, @@ -181,7 +181,7 @@ pub(super) async fn command_set_rollup( return Ok(()); } let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo_state.repository(), pr.number, &pr.base.name, diff --git a/src/bors/handlers/trybuild.rs b/src/bors/handlers/trybuild.rs index 4d010bf..aaa794d 100644 --- a/src/bors/handlers/trybuild.rs +++ b/src/bors/handlers/trybuild.rs @@ -56,7 +56,7 @@ pub(super) async fn command_try_build( // Create pr model based on CI repo, so we can retrieve the pr later when // the CI repo emits events let pr_model = db - .get_or_create_pull_request( + .upsert_pull_request( repo.client.repository(), pr.number, &pr.base.name, @@ -208,7 +208,7 @@ pub(super) async fn command_try_cancel( let pr_number: PullRequestNumber = pr.number; let pr = db - .get_or_create_pull_request( + .upsert_pull_request( repo.client.repository(), pr_number, &pr.base.name, diff --git a/src/database/client.rs b/src/database/client.rs index e1f64cb..f2358d3 100644 --- a/src/database/client.rs +++ b/src/database/client.rs @@ -94,7 +94,7 @@ impl PgDbClient { get_pull_request(&self.pool, repo, pr_number).await } - pub async fn get_or_create_pull_request( + pub async fn upsert_pull_request( &self, repo: &GithubRepoName, pr_number: PullRequestNumber, diff --git a/src/tests/mocks/bors.rs b/src/tests/mocks/bors.rs index 1c79563..22c699f 100644 --- a/src/tests/mocks/bors.rs +++ b/src/tests/mocks/bors.rs @@ -14,7 +14,7 @@ use tower::Service; use crate::bors::mergeable_queue::MergeableQueueSender; use crate::bors::{RollupMode, WAIT_FOR_REFRESH}; -use crate::database::{BuildStatus, DelegatedPermission, PullRequestModel}; +use crate::database::{BuildStatus, DelegatedPermission, OctocrabMergeableState, PullRequestModel}; use crate::github::api::load_repositories; use crate::github::server::BorsProcess; use crate::github::{GithubRepoName, PullRequestNumber}; @@ -510,6 +510,7 @@ impl BorsTester { .get_mut(&pr_number) .expect("PR must be initialized before pushing to it"); pr.head_sha = format!("pr-{pr_number}-commit-{counter}"); + pr.mergeable_state = OctocrabMergeableState::Unknown; pr.clone() };