Skip to content

Commit 66592d5

Browse files
committed
Make the layer axum compatible
1 parent 88aed0e commit 66592d5

File tree

2 files changed

+62
-17
lines changed

2 files changed

+62
-17
lines changed

tonic-web/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ tower-service = "0.3"
2626
tower-layer = "0.3"
2727
tracing = "0.1"
2828

29+
[dependencies.axum]
30+
version = "0.8.1"
31+
optional = true
32+
2933
[dev-dependencies]
3034
tokio = { version = "1", features = ["macros", "rt"] }
3135
tower-http = { version = "0.6", features = ["cors"] }

tonic-web/src/service.rs

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ use std::task::{ready, Context, Poll};
66
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
77
use pin_project::pin_project;
88
use tonic::metadata::GRPC_CONTENT_TYPE;
9-
use tonic::{body::Body, server::NamedService};
9+
use tonic::server::NamedService;
1010
use tower_service::Service;
1111
use tracing::{debug, trace};
1212

1313
use crate::call::content_types::is_grpc_web;
1414
use crate::call::{Encoding, GrpcWebCall};
1515

16+
use bytes::Bytes;
17+
1618
/// Service implementing the grpc-web protocol.
1719
#[derive(Debug, Clone)]
1820
pub struct GrpcWebService<S> {
@@ -45,9 +47,9 @@ impl<S> GrpcWebService<S> {
4547

4648
impl<S, B> Service<Request<B>> for GrpcWebService<S>
4749
where
48-
S: Service<Request<Body>, Response = Response<Body>>,
49-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
50-
B::Error: Into<crate::BoxError> + fmt::Display,
50+
S: Service<Request<B>, Response = Response<B>>,
51+
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
52+
B::Error: Into<crate::BoxError> + std::error::Error + fmt::Display + Send + Sync,
5153
{
5254
type Response = S::Response;
5355
type Error = S::Error;
@@ -100,7 +102,7 @@ where
100102
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
101103
ResponseFuture {
102104
case: Case::Other {
103-
future: self.inner.call(req.map(Body::new)),
105+
future: self.inner.call(req.map(B::new)),
104106
},
105107
}
106108
}
@@ -152,11 +154,13 @@ impl<F> Case<F> {
152154
}
153155
}
154156

155-
impl<F, E> Future for ResponseFuture<F>
157+
impl<F, E, A> Future for ResponseFuture<F>
156158
where
157-
F: Future<Output = Result<Response<Body>, E>>,
159+
F: Future<Output = Result<Response<A>, E>>,
160+
A: BoxedBody + 'static,
161+
<A as http_body::Body>::Error: std::error::Error + Send + Sync + 'static,
158162
{
159-
type Output = Result<Response<Body>, E>;
163+
type Output = Result<Response<A>, E>;
160164

161165
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162166
let mut this = self.project();
@@ -169,7 +173,7 @@ where
169173
}
170174
CaseProj::Other { future } => future.poll(cx),
171175
CaseProj::ImmediateResponse { res } => {
172-
let res = Response::from_parts(res.take().unwrap(), Body::empty());
176+
let res = Response::from_parts(res.take().unwrap(), A::empty());
173177
Poll::Ready(Ok(res))
174178
}
175179
}
@@ -203,9 +207,9 @@ impl<'a> RequestKind<'a> {
203207
// Mutating request headers to conform to a gRPC request is not really
204208
// necessary for us at this point. We could remove most of these except
205209
// maybe for inserting `header::TE`, which tonic should check?
206-
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
210+
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<B>
207211
where
208-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
212+
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
209213
B::Error: Into<crate::BoxError> + fmt::Display,
210214
{
211215
req.headers_mut().remove(header::CONTENT_LENGTH);
@@ -221,17 +225,15 @@ where
221225
HeaderValue::from_static("identity,deflate,gzip"),
222226
);
223227

224-
req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
228+
req.map(|b| B::new(GrpcWebCall::request(b, encoding)))
225229
}
226230

227-
fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
231+
fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<B>
228232
where
229-
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
233+
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
230234
B::Error: Into<crate::BoxError> + fmt::Display,
231235
{
232-
let mut res = res
233-
.map(|b| GrpcWebCall::response(b, encoding))
234-
.map(Body::new);
236+
let mut res = res.map(|b| GrpcWebCall::response(b, encoding)).map(B::new);
235237

236238
res.headers_mut().insert(
237239
header::CONTENT_TYPE,
@@ -241,6 +243,45 @@ where
241243
res
242244
}
243245

246+
/// Alias for a type-erased error type.
247+
type BoxError = Box<dyn std::error::Error + Send + Sync>;
248+
249+
trait BoxedBody: http_body::Body<Data = bytes::Bytes> + Send {
250+
fn new<B>(body: B) -> Self
251+
where
252+
B: http_body::Body<Data = Bytes> + Send + 'static,
253+
B::Error: Into<BoxError>;
254+
255+
fn empty() -> Self;
256+
}
257+
258+
impl BoxedBody for tonic::body::Body {
259+
fn new<B>(body: B) -> Self
260+
where
261+
B: http_body::Body<Data = Bytes> + Send + 'static,
262+
B::Error: Into<BoxError>,
263+
{
264+
Self::new(body)
265+
}
266+
267+
fn empty() -> Self {
268+
Self::empty()
269+
}
270+
}
271+
#[cfg(feature = "axum")]
272+
impl BoxedBody for axum::body::Body {
273+
fn new<B>(body: B) -> Self
274+
where
275+
B: http_body::Body<Data = Bytes> + Send + 'static,
276+
B::Error: Into<BoxError>,
277+
{
278+
Self::new(body)
279+
}
280+
281+
fn empty() -> Self {
282+
Self::empty()
283+
}
284+
}
244285
#[cfg(test)]
245286
mod tests {
246287
use super::*;

0 commit comments

Comments
 (0)