Skip to content

Add PoolOptions::before_connect to execute a function before creating a new connection in the pool #3117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
MattDelac opened this issue Mar 14, 2024 · 3 comments · May be fixed by #3582
Open
Labels
enhancement New feature or request

Comments

@MattDelac
Copy link

MattDelac commented Mar 14, 2024

Is your feature request related to a problem? Please describe.
When connecting to RDS using IAM I would need Sqlx to call a function that provide a new connection string (or PoolOptions) before creating a connection

Looking at some workaround, none of them work in my case. For example using set_connection_options() does not solve my problem since the password of PgPoolOptions is going to be invalidated 15min after my app started

#[tokio::main]
async fn main() {
    let pool = get_pg_pool()
        .await
        .set_connect_options(get_pg_options())
        .await;

    let state = SharedState {
        pool,
        ...
    };

    let app = Router::new()...;
    ...
    axum::serve(listener, app).await.unwrap();
}

Describe the solution you'd like
Be able to execute a function that creates a new URL string (or PoolOptions object) right before creating a new connection in the pool
Example of how I could use it

#[tokio::main]
async fn main() {
    let pool = get_pg_pool()
        .await
        .before_connect(function_to_execute_right_before_creating_a_new_connection);
    ....

Describe alternatives you've considered
If I create all the connections at the beginning if any of them dies, it won't be possible to create a new one since the IAM auth password is valid only for 15min.

So basically the following does not solve my problem either

PgPoolOptions::new()
    .min_connections(max_connections)
    .max_connections(max_connections)
    ...

If any connection is reaped by max_lifetime or idle_timeout or explicitly closed, and it brings the connection count below this amount, a new connection will be opened to replace it.
https://docs.rs/sqlx/latest/sqlx/pool/struct.PoolOptions.html#method.min_connections

Another alternative coming from a conversation from Discord would be to create my own custom Database to override the connect method. This sounds like some work and pretty risky without much context about how Sqlx is built

Hmm, actually, I guess you could do it by having a custom Database implementation which wraps the underlying postgres one. Then that Database implementation would have a ConnectOptions::connect implementation which fetches a fresh password.

Additional context
This is because we use AWS RDS IAM for which the password is valid only for 15min. So if any connection dies after 15min then creating a new connection with the initial PgPoolOptions or connection string will simply fail

@MattDelac MattDelac added the enhancement New feature or request label Mar 14, 2024
@abonander abonander linked a pull request Oct 29, 2024 that will close this issue
@GeeWee
Copy link

GeeWee commented May 6, 2025

How did y'all end up solving this? I see we have #3562 and #3851 that both attempted to solve this in favour of some upcoming work that will make the whole thing more flexible - but what is everyone doing in the meantime?

@MattDelac
Copy link
Author

MattDelac commented May 6, 2025

How did y'all end up solving this? I see we have #3562 and #3851 that both attempted to solve this in favour of some upcoming work that will make the whole thing more flexible - but what is everyone doing in the meantime?

👋
High level

  1. Function to generate the RDS IAM password
  2. Get it and create he connection pool when you boot up your service
  3. Have a background process to renew the password every 15minutes

Note: This has been running in production for 14 months without any issue
Note 2: Bottom line is to use set_connect_options and pass a new PgConnectOptions object

Generate the RDS IAM password

// https://github.com/awslabs/aws-sdk-rust/issues/951#issuecomment-1838117702
// https://github.com/awslabs/aws-sdk-rust/issues/951#issuecomment-1961010056
async fn generate_rds_iam_token(host: &str, port: u16, user: &str) -> Result<String, KarmaError> {
    let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await;

    let credentials = config
        .credentials_provider()
        .expect("no credentials provider found")
        .provide_credentials()
        .await
        .expect("unable to load credentials");
    let identity = credentials.into();
    let region = config.region().unwrap().to_string();

    let mut signing_settings = SigningSettings::default();
    signing_settings.expires_in = Some(Duration::from_secs(900));
    signing_settings.signature_location = aws_sigv4::http_request::SignatureLocation::QueryParams;

    let signing_params = v4::SigningParams::builder()
        .identity(&identity)
        .region(&region)
        .name("rds-db")
        .time(SystemTime::now())
        .settings(signing_settings)
        .build()?;

    let url = format!("https://{host}:{port}/?Action=connect&DBUser={user}");

    let signable_request =
        SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))
            .expect("signable request");

    let (signing_instructions, _signature) =
        sign(signable_request, &signing_params.into())?.into_parts();

    let mut url = url::Url::parse(&url).unwrap();
    for (name, value) in signing_instructions.params() {
        url.query_pairs_mut().append_pair(name, value);
    }

    let response = url.to_string().split_off("https://".len());

    Ok(response)
}

PG Option pools

async fn get_pg_options(endpoint_type: EndpointType) -> PgConnectOptions {
    let user = env::var("POSTGRES_USER").expect("Environment variable POSTGRES_USER not found");
    let port = env::var("POSTGRES_PORT")
        .expect("Environment variable POSTGRES_PORT not found")
        .parse::<u16>()
        .expect("Environment variable POSTGRES_PORT is not a valid port number");
    let database = env::var("POSTGRES_DB").expect("Environment variable POSTGRES_DB not found");

    let host = match endpoint_type {
        EndpointType::Write => env::var("POSTGRES_HOST_WRITE")
            .expect("Environment variable POSTGRES_HOST_WRITE not found"),
        EndpointType::ReadOnly => env::var("POSTGRES_HOST_READ_ONLY")
            .expect("Environment variable POSTGRES_HOST_READ_ONLY not found"),
    };

    let password = match Environment::new() {
        Environment::Local => env::var("POSTGRES_USER").unwrap_or("postgres".to_string()),
        _ => generate_rds_iam_token(&host, port, &user).await.unwrap(),
    };
    let encoded_password = encode(&password).to_string();

    let url = format!("postgres://{user}:{encoded_password}@{host}:{port}/{database}");
    url.parse().unwrap()
}

Update the pool in the background

pub fn update_pool_password_in_background(state: SharedState) {
    tokio::spawn(async move {
        // Fetch new pool password every 14 minutes
        let mut interval = time::interval(time::Duration::from_secs(60 * 14));
        loop {
            interval.tick().await;
            let read_only_opts = ReadOnlyPgPool::get_pg_options().await;
            let write_opts = WritePgPool::get_pg_options().await;

            let state = state.write().unwrap();

            state
                .read_only_pool
                .clone()
                .inner()
                .set_connect_options(read_only_opts);
            state
                .write_pool
                .clone()
                .inner()
                .set_connect_options(write_opts);
        }
    });
}

@GeeWee
Copy link

GeeWee commented May 6, 2025

Thanks @MattDelac ! This is also what I ended up with as an overall strategy, so super happy to hear it's working in production!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants