diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index b3301aaa4..e26303ad1 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -25,6 +25,7 @@ tonic = { version = "0.13.0", path = "../tonic", default-features = false } tower-service = "0.3" tower-layer = "0.3" tracing = "0.1" +axum = { version = "0.8.1", optional = true } [dev-dependencies] tokio = { version = "1", features = ["macros", "rt"] } diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index a60a724ba..31bf3c1ba 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -6,13 +6,15 @@ use std::task::{ready, Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; use pin_project::pin_project; use tonic::metadata::GRPC_CONTENT_TYPE; -use tonic::{body::Body, server::NamedService}; +use tonic::server::NamedService; use tower_service::Service; use tracing::{debug, trace}; use crate::call::content_types::is_grpc_web; use crate::call::{Encoding, GrpcWebCall}; +use bytes::Bytes; + /// Service implementing the grpc-web protocol. #[derive(Debug, Clone)] pub struct GrpcWebService { @@ -45,9 +47,9 @@ impl GrpcWebService { impl Service> for GrpcWebService where - S: Service, Response = Response>, - B: http_body::Body + Send + 'static, - B::Error: Into + fmt::Display, + S: Service, Response = Response>, + B: http_body::Body + BoxedBody + Send + 'static, + B::Error: Into + std::error::Error + fmt::Display + Send + Sync, { type Response = S::Response; type Error = S::Error; @@ -100,7 +102,7 @@ where debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE)); ResponseFuture { case: Case::Other { - future: self.inner.call(req.map(Body::new)), + future: self.inner.call(req.map(B::new)), }, } } @@ -152,11 +154,13 @@ impl Case { } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - F: Future, E>>, + F: Future, E>>, + A: BoxedBody + 'static, + ::Error: std::error::Error + Send + Sync + 'static, { - type Output = Result, E>; + type Output = Result, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -169,7 +173,7 @@ where } CaseProj::Other { future } => future.poll(cx), CaseProj::ImmediateResponse { res } => { - let res = Response::from_parts(res.take().unwrap(), Body::empty()); + let res = Response::from_parts(res.take().unwrap(), A::empty()); Poll::Ready(Ok(res)) } } @@ -203,9 +207,9 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request +fn coerce_request(mut req: Request, encoding: Encoding) -> Request where - B: http_body::Body + Send + 'static, + B: http_body::Body + BoxedBody + Send + 'static, B::Error: Into + fmt::Display, { req.headers_mut().remove(header::CONTENT_LENGTH); @@ -221,17 +225,15 @@ where HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| Body::new(GrpcWebCall::request(b, encoding))) + req.map(|b| B::new(GrpcWebCall::request(b, encoding))) } -fn coerce_response(res: Response, encoding: Encoding) -> Response +fn coerce_response(res: Response, encoding: Encoding) -> Response where - B: http_body::Body + Send + 'static, + B: http_body::Body + BoxedBody + Send + 'static, B::Error: Into + fmt::Display, { - let mut res = res - .map(|b| GrpcWebCall::response(b, encoding)) - .map(Body::new); + let mut res = res.map(|b| GrpcWebCall::response(b, encoding)).map(B::new); res.headers_mut().insert( header::CONTENT_TYPE, @@ -241,6 +243,45 @@ where res } +/// Alias for a type-erased error type. +type BoxError = Box; + +trait BoxedBody: http_body::Body + Send { + fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into; + + fn empty() -> Self; +} + +impl BoxedBody for tonic::body::Body { + fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into, + { + Self::new(body) + } + + fn empty() -> Self { + Self::empty() + } +} +#[cfg(feature = "axum")] +impl BoxedBody for axum::body::Body { + fn new(body: B) -> Self + where + B: http_body::Body + Send + 'static, + B::Error: Into, + { + Self::new(body) + } + + fn empty() -> Self { + Self::empty() + } +} #[cfg(test)] mod tests { use super::*;