Skip to content

Commit 6fa0458

Browse files
authored
fix(Postgres) chunk pg_copy data (#3703)
* fix(postgres) chunk pg_copy data * fix: cleanup after review
1 parent 74da542 commit 6fa0458

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

sqlx-postgres/src/copy.rs

+15-7
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ impl PgPoolCopyExt for Pool<Postgres> {
129129
}
130130
}
131131

132+
// (1 GiB - 1) - 1 - length prefix (4 bytes)
133+
pub const PG_COPY_MAX_DATA_LEN: usize = 0x3fffffff - 1 - 4;
134+
132135
/// A connection in streaming `COPY FROM STDIN` mode.
133136
///
134137
/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
@@ -186,15 +189,20 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
186189

187190
/// Send a chunk of `COPY` data.
188191
///
192+
/// The data is sent in chunks if it exceeds the maximum length of a `CopyData` message (1 GiB - 6
193+
/// bytes) and may be partially sent if this call is cancelled.
194+
///
189195
/// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
190196
pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
191-
self.conn
192-
.as_deref_mut()
193-
.expect("send_data: conn taken")
194-
.inner
195-
.stream
196-
.send(CopyData(data))
197-
.await?;
197+
for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
198+
self.conn
199+
.as_deref_mut()
200+
.expect("send_data: conn taken")
201+
.inner
202+
.stream
203+
.send(CopyData(chunk))
204+
.await?;
205+
}
198206

199207
Ok(self)
200208
}

sqlx-postgres/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ mod value;
3434
#[doc(hidden)]
3535
pub mod any;
3636

37+
#[doc(hidden)]
38+
pub use copy::PG_COPY_MAX_DATA_LEN;
39+
3740
#[cfg(feature = "migrate")]
3841
mod migrate;
3942

tests/postgres/postgres.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use futures::{Stream, StreamExt, TryStreamExt};
33
use sqlx::postgres::types::Oid;
44
use sqlx::postgres::{
55
PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener,
6-
PgPoolOptions, PgRow, PgSeverity, Postgres,
6+
PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN,
77
};
88
use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo};
99
use sqlx_core::{bytes::Bytes, error::BoxDynError};
@@ -2042,3 +2042,23 @@ async fn test_issue_3052() {
20422042
"expected encode error, got {too_large_error:?}",
20432043
);
20442044
}
2045+
2046+
#[sqlx_macros::test]
2047+
async fn test_pg_copy_chunked() -> anyhow::Result<()> {
2048+
let mut conn = new::<Postgres>().await?;
2049+
2050+
let mut row = "1".repeat(PG_COPY_MAX_DATA_LEN / 10 - 1);
2051+
row.push_str("\n");
2052+
2053+
// creates a payload with COPY_MAX_DATA_LEN + 1 as size
2054+
let mut payload = row.repeat(10);
2055+
payload.push_str("12345678\n");
2056+
2057+
assert_eq!(payload.len(), PG_COPY_MAX_DATA_LEN + 1);
2058+
2059+
let mut copy = conn.copy_in_raw("COPY products(name) FROM STDIN").await?;
2060+
2061+
assert!(copy.send(payload.as_bytes()).await.is_ok());
2062+
assert!(copy.finish().await.is_ok());
2063+
Ok(())
2064+
}

0 commit comments

Comments
 (0)