diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index cd3a876520..3fe4f402d8 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -25,9 +25,15 @@ async fn prepare( sql: &str, parameters: &[PgTypeInfo], metadata: Option>, + persistent: bool, ) -> Result<(StatementId, Arc), Error> { - let id = conn.inner.next_statement_id; - conn.inner.next_statement_id = id.next(); + let id = if persistent { + let id = conn.inner.next_statement_id; + conn.inner.next_statement_id = id.next(); + id + } else { + StatementId::UNNAMED + }; // build a list of type OIDs to send to the database in the PARSE command // we have not yet started the query sequence, so we are *safe* to cleanly make @@ -163,8 +169,7 @@ impl PgConnection { &mut self, sql: &str, parameters: &[PgTypeInfo], - // should we store the result of this prepare to the cache - store_to_cache: bool, + persistent: bool, // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, @@ -173,9 +178,9 @@ impl PgConnection { return Ok((*statement).clone()); } - let statement = prepare(self, sql, parameters, metadata).await?; + let statement = prepare(self, sql, parameters, metadata, persistent).await?; - if store_to_cache && self.inner.cache_statement.is_enabled() { + if persistent && self.inner.cache_statement.is_enabled() { if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) { self.inner.stream.write_msg(Close::Statement(id))?; self.write_sync(); diff --git a/sqlx-postgres/src/message/parse.rs b/sqlx-postgres/src/message/parse.rs index 75300c4815..62f57a1cc4 100644 --- a/sqlx-postgres/src/message/parse.rs +++ b/sqlx-postgres/src/message/parse.rs @@ -77,3 +77,19 @@ fn test_encode_parse() { assert_eq!(buf, EXPECTED); } + +#[test] +fn test_encode_parse_unnamed_statement() { + const EXPECTED: &[u8] = b"P\0\0\0\x15\0SELECT $1\0\0\x01\0\0\0\x19"; + + let mut buf = Vec::new(); + let m = Parse { + statement: StatementId::UNNAMED, + query: "SELECT $1", + param_types: &[Oid(25)], + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index fc7108bf4f..f0d453a9a3 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -817,6 +817,27 @@ async fn it_closes_statement_from_cache_issue_470() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_closes_statements_when_not_persistent_issue_3850() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let _row = sqlx::query("SELECT $1 AS val") + .bind(Oid(1)) + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let row = sqlx::query("SELECT count(*) AS num_prepared_statements FROM pg_prepared_statements") + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let n: i64 = row.get("num_prepared_statements"); + assert_eq!(0, n, "no prepared statements should be open"); + + Ok(()) +} + #[sqlx_macros::test] async fn it_sets_application_name() -> anyhow::Result<()> { sqlx_test::setup_if_needed();