Skip to content

Make the tonic-web layer axum compatible #2157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tonic-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tonic = { version = "0.13.0", path = "../tonic", default-features = false }
tower-service = "0.3"
tower-layer = "0.3"
tracing = "0.1"
axum = { version = "0.8.1", optional = true }

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
Expand Down
75 changes: 58 additions & 17 deletions tonic-web/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ use std::task::{ready, Context, Poll};
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
use pin_project::pin_project;
use tonic::metadata::GRPC_CONTENT_TYPE;
use tonic::{body::Body, server::NamedService};
use tonic::server::NamedService;
use tower_service::Service;
use tracing::{debug, trace};

use crate::call::content_types::is_grpc_web;
use crate::call::{Encoding, GrpcWebCall};

use bytes::Bytes;

/// Service implementing the grpc-web protocol.
#[derive(Debug, Clone)]
pub struct GrpcWebService<S> {
Expand Down Expand Up @@ -45,9 +47,9 @@ impl<S> GrpcWebService<S> {

impl<S, B> Service<Request<B>> for GrpcWebService<S>
where
S: Service<Request<Body>, Response = Response<Body>>,
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
S: Service<Request<B>, Response = Response<B>>,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + std::error::Error + fmt::Display + Send + Sync,
{
type Response = S::Response;
type Error = S::Error;
Expand Down Expand Up @@ -100,7 +102,7 @@ where
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
ResponseFuture {
case: Case::Other {
future: self.inner.call(req.map(Body::new)),
future: self.inner.call(req.map(B::new)),
},
}
}
Expand Down Expand Up @@ -152,11 +154,13 @@ impl<F> Case<F> {
}
}

impl<F, E> Future for ResponseFuture<F>
impl<F, E, A> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<Body>, E>>,
F: Future<Output = Result<Response<A>, E>>,
A: BoxedBody + 'static,
<A as http_body::Body>::Error: std::error::Error + Send + Sync + 'static,
{
type Output = Result<Response<Body>, E>;
type Output = Result<Response<A>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
Expand All @@ -169,7 +173,7 @@ where
}
CaseProj::Other { future } => future.poll(cx),
CaseProj::ImmediateResponse { res } => {
let res = Response::from_parts(res.take().unwrap(), Body::empty());
let res = Response::from_parts(res.take().unwrap(), A::empty());
Poll::Ready(Ok(res))
}
}
Expand Down Expand Up @@ -203,9 +207,9 @@ impl<'a> RequestKind<'a> {
// Mutating request headers to conform to a gRPC request is not really
// necessary for us at this point. We could remove most of these except
// maybe for inserting `header::TE`, which tonic should check?
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<B>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
req.headers_mut().remove(header::CONTENT_LENGTH);
Expand All @@ -221,17 +225,15 @@ where
HeaderValue::from_static("identity,deflate,gzip"),
);

req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
req.map(|b| B::new(GrpcWebCall::request(b, encoding)))
}

fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<B>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B: http_body::Body<Data = bytes::Bytes> + BoxedBody + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
let mut res = res
.map(|b| GrpcWebCall::response(b, encoding))
.map(Body::new);
let mut res = res.map(|b| GrpcWebCall::response(b, encoding)).map(B::new);

res.headers_mut().insert(
header::CONTENT_TYPE,
Expand All @@ -241,6 +243,45 @@ where
res
}

/// Alias for a type-erased error type.
type BoxError = Box<dyn std::error::Error + Send + Sync>;

trait BoxedBody: http_body::Body<Data = bytes::Bytes> + Send {
fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>;

fn empty() -> Self;
}

impl BoxedBody for tonic::body::Body {
fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
Self::new(body)
}

fn empty() -> Self {
Self::empty()
}
}
#[cfg(feature = "axum")]
impl BoxedBody for axum::body::Body {
fn new<B>(body: B) -> Self
where
B: http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
Self::new(body)
}

fn empty() -> Self {
Self::empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down