diff --git a/dropshot/tests/test_streaming.rs b/dropshot/tests/test_streaming.rs index d931cc46d..8616a13a3 100644 --- a/dropshot/tests/test_streaming.rs +++ b/dropshot/tests/test_streaming.rs @@ -2,11 +2,15 @@ //! Test cases for streaming requests. -use dropshot::{endpoint, ApiDescription, HttpError, RequestContext}; +use std::convert::Infallible; + +use bytes::Bytes; +use dropshot::{ + endpoint, ApiDescription, HttpError, RequestContext, StreamingBody, +}; +use futures::TryStreamExt; use http::{Method, Response, StatusCode}; use hyper::{body::HttpBody, Body}; -use hyper_staticfile::FileBytesStream; -use tokio::io::{AsyncSeekExt, AsyncWriteExt}; extern crate slog; @@ -15,6 +19,7 @@ pub mod common; fn api() -> ApiDescription { let mut api = ApiDescription::new(); api.register(api_streaming).unwrap(); + api.register(api_client_streaming).unwrap(); api.register(api_not_streaming).unwrap(); api } @@ -22,6 +27,15 @@ fn api() -> ApiDescription { const BUF_SIZE: usize = 8192; const BUF_COUNT: usize = 128; +fn make_chunked_body(buf_count: usize) -> Body { + let bytes = Bytes::from(vec![0; BUF_SIZE]); + // This is cheap -- just a bunch of Arc clones. + let bufs = vec![bytes; buf_count]; + Body::wrap_stream(futures::stream::iter( + bufs.into_iter().map(|bytes| Ok::<_, Infallible>(bytes)), + )) +} + #[endpoint { method = GET, path = "/streaming", @@ -29,27 +43,32 @@ const BUF_COUNT: usize = 128; async fn api_streaming( _rqctx: RequestContext, ) -> Result, HttpError> { - let mut file = tempfile::tempfile() - .map_err(|_| { - HttpError::for_bad_request( - None, - "Cannot create tempfile".to_string(), - ) - }) - .map(|f| tokio::fs::File::from_std(f))?; - - // Fill the file with some arbitrary contents. - let mut buf = [0; BUF_SIZE]; - for i in 0..BUF_COUNT { - file.write_all(&buf).await.unwrap(); - buf.fill((i & 255) as u8); - } - file.seek(std::io::SeekFrom::Start(0)).await.unwrap(); + Ok(Response::builder() + .status(StatusCode::OK) + .body(make_chunked_body(BUF_COUNT))?) +} + +#[endpoint { + method = PUT, + path = "/client_streaming", + // 8192 * 128 = 1_048_576 + request_body_max_bytes = 1_048_576, +}] +async fn api_client_streaming( + rqctx: RequestContext, + body: StreamingBody, +) -> Result, HttpError> { + check_has_transfer_encoding(rqctx.request.headers(), Some("chunked")); + + let nbytes = body + .into_stream() + .try_fold(0, |acc, v| futures::future::ok(acc + v.len())) + .await?; + assert_eq!(nbytes, BUF_SIZE * BUF_COUNT); - let file_stream = FileBytesStream::new(file); Ok(Response::builder() .status(StatusCode::OK) - .body(file_stream.into_body())?) + .body(make_chunked_body(BUF_COUNT))?) } #[endpoint { @@ -65,15 +84,18 @@ async fn api_not_streaming( } fn check_has_transfer_encoding( - response: &Response, + headers: &http::HeaderMap, expected_value: Option<&str>, ) { - let transfer_encoding_header = response.headers().get("transfer-encoding"); + let transfer_encoding_header = headers.get("transfer-encoding"); match expected_value { Some(expected_value) => { assert_eq!( expected_value, - transfer_encoding_header.expect("expected value") + transfer_encoding_header.unwrap_or_else(|| panic!( + "expected transfer-encoding to be {}, found None", + expected_value + )) ); } None => { @@ -88,29 +110,64 @@ async fn test_streaming_server_streaming_client() { let testctx = common::test_setup("streaming_server_streaming_client", api); let client = &testctx.client_testctx; - let mut response = client + async fn check_chunked_response(mut response: Response) { + check_has_transfer_encoding(response.headers(), Some("chunked")); + + let mut chunk_count = 0; + let mut byte_count = 0; + while let Some(chunk) = response.body_mut().data().await { + let chunk = + chunk.expect("Should have received chunk without error"); + byte_count += chunk.len(); + chunk_count += 1; + } + + assert!( + chunk_count >= 2, + "Expected 2+ chunks for streaming, saw: {}", + chunk_count + ); + assert_eq!( + BUF_SIZE * BUF_COUNT, + byte_count, + "Mismatch of sent vs received byte count" + ); + } + + // Success case: GET without body. + let response = client .make_request_no_body(Method::GET, "/streaming", StatusCode::OK) .await .expect("Expected GET request to succeed"); - check_has_transfer_encoding(&response, Some("chunked")); - - let mut chunk_count = 0; - let mut byte_count = 0; - while let Some(chunk) = response.body_mut().data().await { - let chunk = chunk.expect("Should have received chunk without error"); - byte_count += chunk.len(); - chunk_count += 1; - } + check_chunked_response(response).await; - assert!( - chunk_count >= 2, - "Expected 2+ chunks for streaming, saw: {}", - chunk_count - ); + // Success case: PUT with streaming body. + let body = make_chunked_body(BUF_COUNT); + let response = client + .make_request_with_body( + Method::PUT, + "/client_streaming", + body, + StatusCode::OK, + ) + .await + .expect("Expected PUT request to succeed"); + check_chunked_response(response).await; + + // Error case: streaming body that's slightly too large. + let body = make_chunked_body(BUF_COUNT + 1); + let error = client + .make_request_with_body( + Method::PUT, + "/client_streaming", + body, + StatusCode::BAD_REQUEST, + ) + .await + .unwrap_err(); assert_eq!( - BUF_SIZE * BUF_COUNT, - byte_count, - "Mismatch of sent vs received byte count" + error.message, + "request body exceeded maximum size of 1048576 bytes" ); testctx.teardown().await; @@ -126,7 +183,7 @@ async fn test_streaming_server_buffered_client() { .make_request_no_body(Method::GET, "/streaming", StatusCode::OK) .await .expect("Expected GET request to succeed"); - check_has_transfer_encoding(&response, Some("chunked")); + check_has_transfer_encoding(response.headers(), Some("chunked")); let body_bytes = hyper::body::to_bytes(response.body_mut()) .await @@ -153,6 +210,6 @@ async fn test_non_streaming_servers_do_not_use_transfer_encoding() { .make_request_no_body(Method::GET, "/not-streaming", StatusCode::OK) .await .expect("Expected GET request to succeed"); - check_has_transfer_encoding(&response, None); + check_has_transfer_encoding(response.headers(), None); testctx.teardown().await; }