diff --git a/src/async_impl/response.rs b/src/async_impl/response.rs index 4c0d52727..24620f032 100644 --- a/src/async_impl/response.rs +++ b/src/async_impl/response.rs @@ -1,9 +1,11 @@ use std::fmt; use std::net::SocketAddr; use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use bytes::Bytes; +use http_body::{Frame, SizeHint}; use http_body_util::BodyExt; use hyper::{HeaderMap, StatusCode, Version}; use hyper_util::client::legacy::connect::HttpInfo; @@ -11,6 +13,7 @@ use hyper_util::client::legacy::connect::HttpInfo; use serde::de::DeserializeOwned; #[cfg(feature = "json")] use serde_json; +use tokio::sync::oneshot::{self, Receiver}; use tokio::time::Sleep; use url::Url; @@ -31,6 +34,7 @@ pub struct Response { // Boxed to save space (11 words to 1 word), and it's not accessed // frequently internally. url: Box, + trailers_rx: Receiver, } impl Response { @@ -42,16 +46,20 @@ impl Response { read_timeout: Option, ) -> Response { let (mut parts, body) = res.into_parts(); + let (body, trailers_rx) = extract_trailers_from_body(body); + let decoder = Decoder::detect( &mut parts.headers, super::body::response(body, total_timeout, read_timeout), accepts, ); + let res = hyper::Response::from_parts(parts, decoder); Response { res, url: Box::new(url), + trailers_rx, } } @@ -424,6 +432,22 @@ impl Response { } } + /// Get the response trailers if available. + /// + /// Trailers are additional headers sent after the response body in HTTP/1.1 chunked + /// encoding or HTTP/2 responses. They are typically used for metadata that can only + /// be determined after processing the entire response body. + #[inline] + pub async fn trailers(&mut self) -> crate::Result> { + match self.trailers_rx.try_recv() { + Ok(trailers) => Ok(Some(trailers)), + Err(err) => match err { + oneshot::error::TryRecvError::Empty => Ok(None), + oneshot::error::TryRecvError::Closed => Err(crate::error::body(err)), + }, + } + } + // private // The Response's body is an implementation detail. @@ -462,6 +486,9 @@ impl> From> for Response { let (mut parts, body) = r.into_parts(); let body: crate::async_impl::body::Body = body.into(); + + let (body, trailers_rx) = extract_trailers_from_body(body); + let decoder = Decoder::detect( &mut parts.headers, ResponseBody::new(body.map_err(Into::into)), @@ -476,6 +503,7 @@ impl> From> for Response { Response { res, url: Box::new(url), + trailers_rx, } } } @@ -490,6 +518,90 @@ impl From for http::Response { } } +pin_project_lite::pin_project! { + /// A body wrapper that extracts HTTP trailers while preserving the original size hint. + /// + /// This wrapper monitors HTTP frames for trailers and sends them through a oneshot + /// channel when found, while maintaining the original body's size hint and other + /// characteristics to ensure `content_length()` continues to work correctly. + /// + /// HTTP trailers are additional headers sent after the response body in chunked + /// encoding (HTTP/1.1) or HTTP/2 responses. They are useful for metadata that + /// can only be determined after processing the entire response body, such as + /// checksums or final status information. + pub struct TrailerExtractingBody { + #[pin] + inner: B, + trailers_tx: Option>, + } +} + +impl TrailerExtractingBody { + fn new(body: B, trailers_tx: oneshot::Sender) -> Self + where + B: http_body::Body, + { + Self { + inner: body, + trailers_tx: Some(trailers_tx), + } + } +} + +impl http_body::Body for TrailerExtractingBody +where + B: http_body::Body, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.project(); + + match std::task::ready!(this.inner.poll_frame(cx)) { + Some(Ok(mut frame)) => { + if let Some(trailers) = frame.trailers_mut() { + if let Some(tx) = this.trailers_tx.take() { + let _ = tx.send(std::mem::take(trailers)); + } + } + Poll::Ready(Some(Ok(frame))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + + #[inline] + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} + +fn extract_trailers_from_body( + body: B, +) -> ( + http_body_util::combinators::BoxBody, + oneshot::Receiver, +) +where + B: http_body::Body + Send + Sync + 'static, +{ + let (trailers_tx, trailers_rx) = oneshot::channel(); + let wrapper = TrailerExtractingBody::new(body, trailers_tx); + let boxed_body = http_body_util::BodyExt::boxed(wrapper); + + (boxed_body, trailers_rx) +} + #[cfg(test)] mod tests { use super::Response; diff --git a/tests/client.rs b/tests/client.rs index 307892c98..272482e5e 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -600,3 +600,95 @@ async fn error_has_url() { let err = reqwest::get(u).await.unwrap_err(); assert_eq!(err.url().map(AsRef::as_ref), Some(u), "{err:?}"); } + +#[tokio::test] +async fn response_trailers() { + use tokio::io::AsyncWriteExt; + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + // Send HTTP response with chunked encoding and trailers + client_socket + .write_all(b"HTTP/1.1 200 OK\r\n") + .await + .expect("write status line"); + + client_socket + .write_all(b"Transfer-Encoding: chunked\r\n") + .await + .expect("write transfer-encoding header"); + + client_socket + .write_all(b"Trailer: X-Custom-Trailer, X-Checksum\r\n") + .await + .expect("write trailer header"); + + client_socket + .write_all(b"\r\n") + .await + .expect("write header end"); + + // Send chunked body + client_socket + .write_all(b"5\r\nHello\r\n") + .await + .expect("write chunk 1"); + + client_socket + .write_all(b"6\r\nWorld!\r\n") + .await + .expect("write chunk 2"); + + // Send end chunk + client_socket + .write_all(b"0\r\n") + .await + .expect("write end chunk"); + + // Send trailers + client_socket + .write_all(b"X-Custom-Trailer: custom-value\r\n") + .await + .expect("write custom trailer"); + + client_socket + .write_all(b"X-Checksum: abc123\r\n") + .await + .expect("write checksum trailer"); + + // End of trailers + client_socket + .write_all(b"\r\n") + .await + .expect("write trailers end"); + }) + }); + + let client = Client::new(); + + let mut res = client + .get(&format!("http://{}/trailers", server.addr())) + .send() + .await + .expect("Failed to get response"); + + assert_eq!(res.status(), reqwest::StatusCode::OK); + + // Read the body using chunk() to preserve response ownership + let mut body_content = Vec::new(); + + while let Some(chunk) = res.chunk().await.expect("Failed to read chunk") { + body_content.extend_from_slice(&chunk); + } + + let body = String::from_utf8(body_content).expect("Invalid UTF-8"); + assert_eq!(body, "HelloWorld!"); + + // Now we can check trailers since the response body has been fully consumed + if let Some(trailers) = res.trailers().await.expect("Failed to get trailers") { + assert_eq!(trailers.get("X-Custom-Trailer").unwrap(), "custom-value"); + assert_eq!(trailers.get("X-Checksum").unwrap(), "abc123"); + } else { + panic!("Expected trailers but got None"); + } +}