diff --git a/crates/wasi-http/src/p3/request.rs b/crates/wasi-http/src/p3/request.rs index 4f921258f1af..6d3792f2e77a 100644 --- a/crates/wasi-http/src/p3/request.rs +++ b/crates/wasi-http/src/p3/request.rs @@ -1,13 +1,18 @@ +use crate::get_content_length; use crate::p3::bindings::http::types::ErrorCode; -use crate::p3::body::Body; +use crate::p3::body::{Body, GuestBody}; +use crate::p3::{WasiHttpCtxView, WasiHttpView}; use bytes::Bytes; use core::time::Duration; +use http::header::HOST; use http::uri::{Authority, PathAndQuery, Scheme}; -use http::{HeaderMap, Method}; +use http::{HeaderMap, HeaderValue, Method, Uri}; use http_body_util::BodyExt as _; use http_body_util::combinators::UnsyncBoxBody; use std::sync::Arc; use tokio::sync::oneshot; +use tracing::debug; +use wasmtime::AsContextMut; /// The concrete type behind a `wasi:http/types.request-options` resource. #[derive(Copy, Clone, Debug, Default)] @@ -119,6 +124,114 @@ impl Request { body.map_err(Into::into).boxed_unsync(), ) } + + /// Convert this [`Request`] into an [`http::Request>`]. + /// + /// The specified future `fut` can be used to communicate a request processing + /// error, if any, back to the caller (e.g., if this request was constructed + /// through `wasi:http/types.request#new`). + pub fn into_http( + self, + store: impl AsContextMut, + fut: impl Future> + Send + 'static, + ) -> wasmtime::Result>> { + self.into_http_with_getter(store, fut, T::http) + } + + /// Like [`Self::into_http`], but uses a custom getter for obtaining the [`WasiHttpCtxView`]. + pub fn into_http_with_getter( + self, + mut store: impl AsContextMut, + fut: impl Future> + Send + 'static, + getter: fn(&mut T) -> WasiHttpCtxView<'_>, + ) -> wasmtime::Result>> { + let Request { + method, + scheme, + authority, + path_with_query, + headers, + options: _, + body, + } = self; + let content_length = match get_content_length(&headers) { + Ok(content_length) => content_length, + Err(err) => { + body.drop(&mut store); + return Err(ErrorCode::InternalError(Some(format!("{err:#}"))).into()); + } + }; + // This match must appear before any potential errors handled with '?' + // (or errors have to explicitly be addressed and drop the body, as above), + // as otherwise the Body::Guest resources will not be cleaned up when dropped. + // see: https://github.com/bytecodealliance/wasmtime/pull/11440#discussion_r2326139381 + // for additional context. + let body = match body { + Body::Guest { + contents_rx, + trailers_rx, + result_tx, + } => GuestBody::new( + &mut store, + contents_rx, + trailers_rx, + result_tx, + fut, + content_length, + ErrorCode::HttpRequestBodySize, + getter, + ) + .boxed_unsync(), + Body::Host { body, result_tx } => { + _ = result_tx.send(Box::new(fut)); + body + } + }; + let mut headers = Arc::unwrap_or_clone(headers); + let mut store_ctx = store.as_context_mut(); + let WasiHttpCtxView { ctx, table: _ } = getter(store_ctx.data_mut()); + if ctx.set_host_header() { + let host = if let Some(authority) = authority.as_ref() { + HeaderValue::try_from(authority.as_str()) + .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))? + } else { + HeaderValue::from_static("") + }; + headers.insert(HOST, host); + } + let scheme = match scheme { + None => ctx.default_scheme().ok_or(ErrorCode::HttpProtocolError)?, + Some(scheme) if ctx.is_supported_scheme(&scheme) => scheme, + Some(..) => return Err(ErrorCode::HttpProtocolError.into()), + }; + let mut uri = Uri::builder().scheme(scheme); + if let Some(authority) = authority { + uri = uri.authority(authority) + }; + if let Some(path_with_query) = path_with_query { + uri = uri.path_and_query(path_with_query) + }; + let uri = uri.build().map_err(|err| { + debug!(?err, "failed to build request URI"); + ErrorCode::HttpRequestUriInvalid + })?; + let mut req = http::Request::builder(); + if let Some(headers_mut) = req.headers_mut() { + *headers_mut = headers; + } else { + return Err(ErrorCode::InternalError(Some( + "failed to get mutable headers from request builder".to_string(), + )) + .into()); + } + let req = req + .method(method) + .uri(uri) + .body(body) + .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?; + let (req, body) = req.into_parts(); + Ok(http::Request::from_parts(req, body)) + } } /// The default implementation of how an outgoing request is sent. @@ -348,3 +461,123 @@ pub async fn default_send_request( conn.await.map_err(ErrorCode::from_hyper_response_error) })) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::p3::WasiHttpCtx; + use anyhow::Result; + use http_body_util::{BodyExt, Empty, Full}; + use std::future::Future; + use std::str::FromStr; + use std::task::{Context, Waker}; + use wasmtime::{Engine, Store}; + use wasmtime_wasi::{ResourceTable, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView}; + + struct TestHttpCtx; + struct TestCtx { + table: ResourceTable, + wasi: WasiCtx, + http: TestHttpCtx, + } + + impl TestCtx { + fn new() -> Self { + Self { + table: ResourceTable::default(), + wasi: WasiCtxBuilder::new().build(), + http: TestHttpCtx, + } + } + } + + impl WasiView for TestCtx { + fn ctx(&mut self) -> WasiCtxView<'_> { + WasiCtxView { + ctx: &mut self.wasi, + table: &mut self.table, + } + } + } + + impl WasiHttpCtx for TestHttpCtx {} + + impl WasiHttpView for TestCtx { + fn http(&mut self) -> WasiHttpCtxView<'_> { + WasiHttpCtxView { + ctx: &mut self.http, + table: &mut self.table, + } + } + } + + #[tokio::test] + async fn test_request_into_http_schemes() -> Result<()> { + let schemes = vec![Some(Scheme::HTTP), Some(Scheme::HTTPS), None]; + let engine = Engine::default(); + + for scheme in schemes { + let (req, fut) = Request::new( + Method::POST, + scheme.clone(), + Some(Authority::from_static("example.com")), + Some(PathAndQuery::from_static("/path?query=1")), + HeaderMap::new(), + None, + Full::new(Bytes::from_static(b"body")) + .map_err(|x| match x {}) + .boxed_unsync(), + ); + let mut store = Store::new(&engine, TestCtx::new()); + let http_req = req.into_http(&mut store, async { Ok(()) }).unwrap(); + assert_eq!(http_req.method(), Method::POST); + let expected_scheme = scheme.unwrap_or(Scheme::HTTPS); // default scheme + assert_eq!( + http_req.uri(), + &http::Uri::from_str(&format!( + "{}://example.com/path?query=1", + expected_scheme.as_str() + )) + .unwrap() + ); + let body_bytes = http_req.into_body().collect().await?; + assert_eq!(*body_bytes.to_bytes(), *b"body"); + let mut cx = Context::from_waker(Waker::noop()); + let mut fut = Box::pin(fut); + let result = fut.as_mut().poll(&mut cx); + assert!(matches!(result, futures::task::Poll::Ready(Ok(())))); + } + + Ok(()) + } + + #[tokio::test] + async fn test_request_into_http_uri_error() -> Result<()> { + let (req, fut) = Request::new( + Method::GET, + Some(Scheme::HTTP), + Some(Authority::from_static("example.com")), + None, // <-- should fail, must be Some(_) when authority is set + HeaderMap::new(), + None, + Empty::new().map_err(|x| match x {}).boxed_unsync(), + ); + let mut store = Store::new(&Engine::default(), TestCtx::new()); + let result = req.into_http(&mut store, async { + Err(ErrorCode::InternalError(Some("uh oh".to_string()))) + }); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err().downcast_ref::(), + Some(ErrorCode::HttpRequestUriInvalid) + )); + let mut cx = Context::from_waker(Waker::noop()); + let result = Box::pin(fut).as_mut().poll(&mut cx); + assert!(matches!( + result, + futures::task::Poll::Ready(Err(ErrorCode::InternalError(Some(_)))) + )); + + Ok(()) + } +}