Skip to content

Commit 76b5c16

Browse files
committed
feat: Implement before_connect callback to modify connect options.
Allows the user to see and maybe modify the connect options before each attempt to connect to a database. May be used in a number of ways, e.g.: - adding jitter to connection lifetime - validating/setting a per-connection password - using a custom server discovery process
1 parent 028084b commit 76b5c16

File tree

2 files changed

+99
-9
lines changed

2 files changed

+99
-9
lines changed

sqlx-core/src/pool/inner.rs

+36-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crossbeam_queue::ArrayQueue;
88

99
use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser};
1010

11+
use std::borrow::Cow;
1112
use std::cmp;
1213
use std::future::Future;
1314
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
@@ -324,6 +325,36 @@ impl<DB: Database> PoolInner<DB> {
324325
Ok(acquired)
325326
}
326327

328+
/// Attempts to get connect options, possibly modify using before_connect, then connect.
329+
///
330+
/// Wrapping this code in a timeout allows the total time taken for these steps to
331+
/// be bounded by the connection deadline.
332+
async fn get_connect_options_and_connect(
333+
self: &Arc<Self>,
334+
num_attempts: u32,
335+
) -> Result<DB::Connection, Error> {
336+
// clone the connect options arc so it can be used without holding the RwLockReadGuard
337+
// across an async await point
338+
let connect_options_arc = self
339+
.connect_options
340+
.read()
341+
.expect("write-lock holder panicked")
342+
.clone();
343+
344+
let connect_options = if let Some(callback) = &self.options.before_connect {
345+
callback(connect_options_arc.as_ref(), num_attempts)
346+
.await
347+
.map_err(|error| {
348+
tracing::error!(%error, "error returned from before_connect");
349+
error
350+
})?
351+
} else {
352+
Cow::Borrowed(connect_options_arc.as_ref())
353+
};
354+
355+
connect_options.connect().await
356+
}
357+
327358
pub(super) async fn connect(
328359
self: &Arc<Self>,
329360
deadline: Instant,
@@ -335,21 +366,17 @@ impl<DB: Database> PoolInner<DB> {
335366

336367
let mut backoff = Duration::from_millis(10);
337368
let max_backoff = deadline_as_timeout(deadline)? / 5;
369+
let mut num_attempts: u32 = 0;
338370

339371
loop {
340372
let timeout = deadline_as_timeout(deadline)?;
341-
342-
// clone the connect options arc so it can be used without holding the RwLockReadGuard
343-
// across an async await point
344-
let connect_options = self
345-
.connect_options
346-
.read()
347-
.expect("write-lock holder panicked")
348-
.clone();
373+
num_attempts += 1;
349374

350375
// result here is `Result<Result<C, Error>, TimeoutError>`
351376
// if this block does not return, sleep for the backoff timeout and try again
352-
match crate::rt::timeout(timeout, connect_options.connect()).await {
377+
match crate::rt::timeout(timeout, self.get_connect_options_and_connect(num_attempts))
378+
.await
379+
{
353380
// successfully established connection
354381
Ok(Ok(mut raw)) => {
355382
// See comment on `PoolOptions::after_connect`

sqlx-core/src/pool/options.rs

+63
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::pool::inner::PoolInner;
55
use crate::pool::Pool;
66
use futures_core::future::BoxFuture;
77
use log::LevelFilter;
8+
use std::borrow::Cow;
89
use std::fmt::{self, Debug, Formatter};
910
use std::sync::Arc;
1011
use std::time::{Duration, Instant};
@@ -44,6 +45,18 @@ use std::time::{Duration, Instant};
4445
/// the perspectives of both API designer and consumer.
4546
pub struct PoolOptions<DB: Database> {
4647
pub(crate) test_before_acquire: bool,
48+
pub(crate) before_connect: Option<
49+
Arc<
50+
dyn Fn(
51+
&<DB::Connection as Connection>::Options,
52+
u32,
53+
)
54+
-> BoxFuture<'_, Result<Cow<'_, <DB::Connection as Connection>::Options>, Error>>
55+
+ 'static
56+
+ Send
57+
+ Sync,
58+
>,
59+
>,
4760
pub(crate) after_connect: Option<
4861
Arc<
4962
dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>>
@@ -94,6 +107,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
94107
fn clone(&self) -> Self {
95108
PoolOptions {
96109
test_before_acquire: self.test_before_acquire,
110+
before_connect: self.before_connect.clone(),
97111
after_connect: self.after_connect.clone(),
98112
before_acquire: self.before_acquire.clone(),
99113
after_release: self.after_release.clone(),
@@ -143,6 +157,7 @@ impl<DB: Database> PoolOptions<DB> {
143157
pub fn new() -> Self {
144158
Self {
145159
// User-specifiable routines
160+
before_connect: None,
146161
after_connect: None,
147162
before_acquire: None,
148163
after_release: None,
@@ -339,6 +354,54 @@ impl<DB: Database> PoolOptions<DB> {
339354
self
340355
}
341356

357+
/// Perform an asynchronous action before connecting to the database.
358+
///
359+
/// This operation is performed on every attempt to connect, including retries. The
360+
/// current `ConnectOptions` is passed, and this may be passed unchanged, or modified
361+
/// after cloning. The current connection attempt is passed as the second parameter
362+
/// (starting at 1).
363+
///
364+
/// If the operation returns with an error, then the connection attempt fails without
365+
/// attempting further retries. The operation therefore may need to implement error
366+
/// handling and/or value caching to avoid failing the connection attempt.
367+
///
368+
/// # Example: Per-Request Authentication
369+
/// This callback may be used to modify values in the database's `ConnectOptions`, before
370+
/// connecting to the database.
371+
///
372+
/// This example is written for PostgreSQL but can likely be adapted to other databases.
373+
///
374+
/// ```no_run
375+
/// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
376+
/// use std::borrow::Cow;
377+
/// use sqlx::Executor;
378+
/// use sqlx::postgres::PgPoolOptions;
379+
///
380+
/// let pool = PgPoolOptions::new()
381+
/// .before_connect(move |opts, _num_attempts| Box::pin(async move {
382+
/// Ok(Cow::Owned(opts.clone().password("abc")))
383+
/// }))
384+
/// .connect("postgres:// …").await?;
385+
/// # Ok(())
386+
/// # }
387+
/// ```
388+
///
389+
/// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self].
390+
pub fn before_connect<F>(mut self, callback: F) -> Self
391+
where
392+
for<'c> F: Fn(
393+
&'c <DB::Connection as Connection>::Options,
394+
u32,
395+
)
396+
-> BoxFuture<'c, crate::Result<Cow<'c, <DB::Connection as Connection>::Options>>>
397+
+ 'static
398+
+ Send
399+
+ Sync,
400+
{
401+
self.before_connect = Some(Arc::new(callback));
402+
self
403+
}
404+
342405
/// Perform an asynchronous action after connecting to the database.
343406
///
344407
/// If the operation returns with an error then the error is logged, the connection is closed

0 commit comments

Comments
 (0)