Skip to content

Commit c6b650d

Browse files
authored
RUST-2204 Enforce size limits on outgoing messages (#1369)
1 parent 5077ae5 commit c6b650d

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

src/cmap/conn.rs

+17-10
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ impl Connection {
222222

223223
self.command_executing = true;
224224

225+
let max_message_size = self.max_message_size_bytes();
225226
#[cfg(any(
226227
feature = "zstd-compression",
227228
feature = "zlib-compression",
@@ -230,30 +231,30 @@ impl Connection {
230231
let write_result = match self.compressor {
231232
Some(ref compressor) if message.should_compress => {
232233
message
233-
.write_op_compressed_to(&mut self.stream, compressor)
234+
.write_op_compressed_to(&mut self.stream, compressor, max_message_size)
235+
.await
236+
}
237+
_ => {
238+
message
239+
.write_op_msg_to(&mut self.stream, max_message_size)
234240
.await
235241
}
236-
_ => message.write_op_msg_to(&mut self.stream).await,
237242
};
238243
#[cfg(all(
239244
not(feature = "zstd-compression"),
240245
not(feature = "zlib-compression"),
241246
not(feature = "snappy-compression")
242247
))]
243-
let write_result = message.write_op_msg_to(&mut self.stream).await;
248+
let write_result = message
249+
.write_op_msg_to(&mut self.stream, max_message_size)
250+
.await;
244251

245252
if let Err(ref err) = write_result {
246253
self.error = Some(err.clone());
247254
}
248255
write_result?;
249256

250-
let response_message_result = Message::read_from(
251-
&mut self.stream,
252-
self.stream_description
253-
.as_ref()
254-
.map(|d| d.max_message_size_bytes),
255-
)
256-
.await;
257+
let response_message_result = Message::read_from(&mut self.stream, max_message_size).await;
257258
self.command_executing = false;
258259
if let Err(ref err) = response_message_result {
259260
self.error = Some(err.clone());
@@ -306,6 +307,12 @@ impl Connection {
306307
pub(crate) fn is_streaming(&self) -> bool {
307308
self.more_to_come
308309
}
310+
311+
fn max_message_size_bytes(&self) -> Option<i32> {
312+
self.stream_description
313+
.as_ref()
314+
.map(|d| d.max_message_size_bytes)
315+
}
309316
}
310317

311318
/// A handle to a pinned connection - the connection itself can be retrieved or returned to the

src/cmap/conn/wire/message.rs

+20
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ impl Message {
274274
pub(crate) async fn write_op_msg_to<T: AsyncWrite + Send + Unpin>(
275275
&self,
276276
mut writer: T,
277+
max_message_size_bytes: Option<i32>,
277278
) -> Result<()> {
278279
let sections = self.get_sections_bytes()?;
279280

@@ -286,6 +287,15 @@ impl Message {
286287
.map(std::mem::size_of_val)
287288
.unwrap_or(0);
288289

290+
let max_len =
291+
Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?;
292+
if total_length > max_len {
293+
return Err(ErrorKind::InvalidArgument {
294+
message: format!("Message length {} over maximum {}", total_length, max_len),
295+
}
296+
.into());
297+
}
298+
289299
let header = Header {
290300
length: total_length.try_into()?,
291301
request_id: self.request_id.unwrap_or_else(next_request_id),
@@ -316,6 +326,7 @@ impl Message {
316326
&self,
317327
mut writer: T,
318328
compressor: &Compressor,
329+
max_message_size_bytes: Option<i32>,
319330
) -> Result<()> {
320331
let flag_bytes = &self.flags.bits().to_le_bytes();
321332
let section_bytes = self.get_sections_bytes()?;
@@ -329,6 +340,15 @@ impl Message {
329340
+ std::mem::size_of::<u8>()
330341
+ compressed_bytes.len();
331342

343+
let max_len =
344+
Checked::try_from(max_message_size_bytes.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE_BYTES))?;
345+
if total_length > max_len {
346+
return Err(ErrorKind::InvalidArgument {
347+
message: format!("Message length {} over maximum {}", total_length, max_len),
348+
}
349+
.into());
350+
}
351+
332352
let header = Header {
333353
length: total_length.try_into()?,
334354
request_id: self.request_id.unwrap_or_else(next_request_id),

0 commit comments

Comments
 (0)