diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 6cc193421..d0d4ef67c 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -68,6 +68,19 @@ fn main() { false, ); + // grpc + codegen( + &PathBuf::from(std::env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("grpc"), + &["proto/echo/echo.proto"], + &["proto"], + &PathBuf::from("src/generated"), + &PathBuf::from("src/generated/echo_fds.rs"), + true, + true, + ); println!("Codgen completed: {}ms", start.elapsed().as_millis()); } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index b2f6d0ccf..361ec8fcd 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -263,7 +263,6 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"] uds = ["dep:tokio-stream", "tokio-stream?/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["dep:tokio-stream", "dep:h2"] mock = ["dep:tokio-stream", "dep:tower", "dep:hyper-util"] -tower = ["dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls-ring"] @@ -273,7 +272,7 @@ types = ["dep:tonic-types"] h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] -full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "tls-client-auth", "types", "cancellation", "h2c"] +full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "json-codec", "compression", "tls", "tls-rustls", "tls-client-auth", "types", "cancellation", "h2c"] default = ["full"] [dependencies] diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index b82e54922..167c5e63a 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -5,39 +5,59 @@ edition = "2021" authors = ["gRPC Authors"] license = "MIT" +[package.metadata.cargo_check_external_types] +allowed_external_types = [ + "tonic::*", + "futures_core::stream::Stream", + "tokio::sync::oneshot::Sender", +] + +[features] +default = ["dns", "_runtime-tokio"] +dns = ["dep:hickory-resolver", "_runtime-tokio"] +# The following feature is used to ensure all modules use the runtime +# abstraction instead of using tokio directly. +# Using tower/buffer enables tokio's rt feature even though it's possible to +# create Buffers with a user provided executor. +_runtime-tokio = [ + "tokio/rt", + "tokio/net", + "tokio/time", + "dep:socket2", + "dep:tower", +] + [dependencies] bytes = "1.10.1" hickory-resolver = { version = "0.25.1", optional = true } http = "1.1.0" http-body = "1.0.1" hyper = { version = "1.6.0", features = ["client", "http2"] } -hyper-util = "0.1.14" parking_lot = "0.12.4" pin-project-lite = "0.2.16" rand = "0.9" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" -socket2 = "0.5.10" -tokio = { version = "1.37.0", features = ["sync", "rt", "net", "time", "macros"] } -tokio-stream = "0.1.17" -tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["codegen", "transport"] } -tower = "0.5.2" +socket2 = { version = "0.5.10", optional = true } +tokio = { version = "1.37.0", features = ["sync", "macros"] } +tokio-stream = { version = "0.1.17", default-features = false } +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = [ + "codegen", +] } +tower = { version = "0.5.2", features = [ + "limit", + "util", + "buffer", +], optional = true } tower-service = "0.3.3" url = "2.5.0" [dev-dependencies] async-stream = "0.3.6" -tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = ["server", "router"] } -hickory-server = "0.25.2" -prost = "0.14" - -[features] -default = ["dns"] -dns = ["dep:hickory-resolver"] - -[package.metadata.cargo_check_external_types] -allowed_external_types = [ - "tonic::*", - "futures_core::stream::Stream", - "tokio::sync::oneshot::Sender", -] +hickory-server = "0.25.2" +prost = "0.14.0" +tonic = { version = "0.14.0", path = "../tonic", default-features = false, features = [ + "server", + "router", +] } +tonic-prost = { version = "0.14.0", path = "../tonic-prost" } diff --git a/grpc/examples/inmemory.rs b/grpc/examples/inmemory.rs index 1ffc74b9d..88b17ee01 100644 --- a/grpc/examples/inmemory.rs +++ b/grpc/examples/inmemory.rs @@ -10,11 +10,8 @@ struct Handler {} #[derive(Debug)] struct MyReqMessage(String); -impl Message for MyReqMessage {} - #[derive(Debug)] struct MyResMessage(String); -impl Message for MyResMessage {} #[async_trait] impl Service for Handler { diff --git a/grpc/examples/multiaddr.rs b/grpc/examples/multiaddr.rs index 9fcc8f0ed..c631d33c5 100644 --- a/grpc/examples/multiaddr.rs +++ b/grpc/examples/multiaddr.rs @@ -12,11 +12,8 @@ struct Handler { #[derive(Debug)] struct MyReqMessage(String); -impl Message for MyReqMessage {} - #[derive(Debug)] struct MyResMessage(String); -impl Message for MyResMessage {} #[async_trait] impl Service for Handler { diff --git a/grpc/proto/echo/echo.proto b/grpc/proto/echo/echo.proto new file mode 100644 index 000000000..1ed1207f0 --- /dev/null +++ b/grpc/proto/echo/echo.proto @@ -0,0 +1,43 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +syntax = "proto3"; + +package grpc.examples.echo; + +// EchoRequest is the request for echo. +message EchoRequest { + string message = 1; +} + +// EchoResponse is the response for echo. +message EchoResponse { + string message = 1; +} + +// Echo is the echo service. +service Echo { + // UnaryEcho is unary echo. + rpc UnaryEcho(EchoRequest) returns (EchoResponse) {} + // ServerStreamingEcho is server side streaming. + rpc ServerStreamingEcho(EchoRequest) returns (stream EchoResponse) {} + // ClientStreamingEcho is client side streaming. + rpc ClientStreamingEcho(stream EchoRequest) returns (EchoResponse) {} + // BidirectionalStreamingEcho is bidi streaming. + rpc BidirectionalStreamingEcho(stream EchoRequest) returns (stream EchoResponse) {} +} diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index 6e41c6759..edbd55131 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -37,17 +37,16 @@ use std::{ }; use tokio::sync::{mpsc, oneshot, watch, Notify}; -use tokio::task::AbortHandle; use serde_json::json; use tonic::async_trait; use url::Url; // NOTE: http::Uri requires non-empty authority portion of URI -use crate::credentials::Credentials; +use crate::attributes::Attributes; use crate::rt; use crate::service::{Request, Response, Service}; -use crate::{attributes::Attributes, rt::tokio::TokioRuntime}; use crate::{client::ConnectivityState, rt::Runtime}; +use crate::{credentials::Credentials, rt::default_runtime}; use super::service_config::ServiceConfig; use super::transport::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; @@ -156,7 +155,7 @@ impl Channel { inner: Arc::new(PersistentChannel::new( target, credentials, - Arc::new(rt::tokio::TokioRuntime {}), + default_runtime(), options, )), } @@ -262,6 +261,7 @@ impl ActiveChannel { tx.clone(), picker.clone(), connectivity_state.clone(), + runtime.clone(), ); let resolver_helper = Box::new(tx.clone()); @@ -279,7 +279,7 @@ impl ActiveChannel { let resolver_opts = name_resolution::ResolverOptions { authority, work_scheduler, - runtime: Arc::new(TokioRuntime {}), + runtime: runtime.clone(), }; let resolver = rb.build(&target, resolver_opts); @@ -360,6 +360,7 @@ pub(crate) struct InternalChannelController { wqtx: WorkQueueTx, picker: Arc>>, connectivity_state: Arc>, + runtime: Arc, } impl InternalChannelController { @@ -369,8 +370,9 @@ impl InternalChannelController { wqtx: WorkQueueTx, picker: Arc>>, connectivity_state: Arc>, + runtime: Arc, ) -> Self { - let lb = Arc::new(GracefulSwitchBalancer::new(wqtx.clone())); + let lb = Arc::new(GracefulSwitchBalancer::new(wqtx.clone(), runtime.clone())); Self { lb, @@ -380,6 +382,7 @@ impl InternalChannelController { wqtx, picker, connectivity_state, + runtime, } } @@ -429,6 +432,7 @@ impl load_balancing::ChannelController for InternalChannelController { Box::new(move |k: SubchannelKey| { scp.unregister_subchannel(&k); }), + self.runtime.clone(), ); let _ = self.subchannel_pool.register_subchannel(&key, isc.clone()); self.new_esc_for_isc(isc) @@ -454,6 +458,7 @@ pub(super) struct GracefulSwitchBalancer { policy_builder: Mutex>>, work_scheduler: WorkQueueTx, pending: Mutex, + runtime: Arc, } impl WorkScheduler for GracefulSwitchBalancer { @@ -478,12 +483,13 @@ impl WorkScheduler for GracefulSwitchBalancer { } impl GracefulSwitchBalancer { - fn new(work_scheduler: WorkQueueTx) -> Self { + fn new(work_scheduler: WorkQueueTx, runtime: Arc) -> Self { Self { policy_builder: Mutex::default(), policy: Mutex::default(), // new(None::>), work_scheduler, pending: Mutex::default(), + runtime, } } @@ -501,6 +507,7 @@ impl GracefulSwitchBalancer { let builder = GLOBAL_LB_REGISTRY.get_policy(policy_name).unwrap(); let newpol = builder.build(LbPolicyOptions { work_scheduler: self.clone(), + runtime: self.runtime.clone(), }); *self.policy_builder.lock().unwrap() = Some(builder); *p = Some(newpol); diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index 0d4af6542..8086a842d 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -38,6 +38,7 @@ use crate::client::load_balancing::{ WeakSubchannel, WorkScheduler, }; use crate::client::name_resolution::{Address, ResolverUpdate}; +use crate::rt::Runtime; use super::{Subchannel, SubchannelState}; @@ -47,6 +48,7 @@ pub struct ChildManager { children: Vec>, update_sharder: Box>, pending_work: Arc>>, + runtime: Arc, } struct Child { @@ -81,12 +83,16 @@ pub trait ResolverUpdateSharder: Send { impl ChildManager { /// Creates a new ChildManager LB policy. shard_update is called whenever a /// resolver_update operation occurs. - pub fn new(update_sharder: Box>) -> Self { + pub fn new( + update_sharder: Box>, + runtime: Arc, + ) -> Self { Self { update_sharder, subchannel_child_map: Default::default(), children: Default::default(), pending_work: Default::default(), + runtime, } } @@ -197,6 +203,7 @@ impl LbPolicy for ChildManager }); let policy = builder.build(LbPolicyOptions { work_scheduler: work_scheduler.clone(), + runtime: self.runtime.clone(), }); let state = LbState::initial(); self.children.push(Child { diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index b91950c31..835355191 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -41,6 +41,7 @@ use tonic::{metadata::MetadataMap, Status}; use crate::{ client::channel::WorkQueueTx, + rt::Runtime, service::{Request, Response, Service}, }; @@ -64,6 +65,7 @@ pub struct LbPolicyOptions { /// A hook into the channel's work scheduler that allows the LbPolicy to /// request the ability to perform operations on the ChannelController. pub work_scheduler: Arc, + pub runtime: Arc, } /// Used to asynchronously request a call into the LbPolicy's work method if diff --git a/grpc/src/client/load_balancing/pick_first.rs b/grpc/src/client/load_balancing/pick_first.rs index ed7ae76f6..d88cd904f 100644 --- a/grpc/src/client/load_balancing/pick_first.rs +++ b/grpc/src/client/load_balancing/pick_first.rs @@ -4,7 +4,6 @@ use std::{ time::Duration, }; -use tokio::time::sleep; use tonic::metadata::MetadataMap; use crate::{ @@ -13,6 +12,7 @@ use crate::{ name_resolution::{Address, ResolverUpdate}, subchannel, ConnectivityState, }, + rt::Runtime, service::Request, }; @@ -31,6 +31,7 @@ impl LbPolicyBuilder for Builder { work_scheduler: options.work_scheduler, subchannel: None, next_addresses: Vec::default(), + runtime: options.runtime, }) } @@ -47,6 +48,7 @@ struct PickFirstPolicy { work_scheduler: Arc, subchannel: Option>, next_addresses: Vec
, + runtime: Arc, } impl LbPolicy for PickFirstPolicy { @@ -72,11 +74,12 @@ impl LbPolicy for PickFirstPolicy { self.next_addresses = addresses; let work_scheduler = self.work_scheduler.clone(); + let runtime = self.runtime.clone(); // TODO: Implement Drop that cancels this task. - tokio::task::spawn(async move { - sleep(Duration::from_millis(200)).await; + self.runtime.spawn(Box::pin(async move { + runtime.sleep(Duration::from_millis(200)).await; work_scheduler.schedule_work(); - }); + })); // TODO: return a picker that queues RPCs. Ok(()) } diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index 66c809e62..e896412ae 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -28,9 +28,8 @@ pub mod channel; pub(crate) mod load_balancing; pub(crate) mod name_resolution; pub mod service_config; -pub mod transport; - mod subchannel; +pub(crate) mod transport; pub use channel::Channel; pub use channel::ChannelOptions; diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index beda2ea32..135e8ccfa 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -290,6 +290,14 @@ impl rt::Runtime for FakeRuntime { fn sleep(&self, duration: std::time::Duration) -> Pin> { self.inner.sleep(duration) } + + fn tcp_stream( + &self, + target: std::net::SocketAddr, + opts: rt::TcpOptions, + ) -> Pin, String>> + Send>> { + self.inner.tcp_stream(target, opts) + } } #[tokio::test] diff --git a/grpc/src/client/subchannel.rs b/grpc/src/client/subchannel.rs index d9bef839b..c1db6fb20 100644 --- a/grpc/src/client/subchannel.rs +++ b/grpc/src/client/subchannel.rs @@ -2,14 +2,20 @@ use super::{ channel::{InternalChannelController, WorkQueueTx}, load_balancing::{self, ExternalSubchannel, Picker, Subchannel, SubchannelState}, name_resolution::Address, - transport::{self, ConnectedTransport, Transport, TransportRegistry}, + transport::{self, Transport, TransportRegistry}, ConnectivityState, }; use crate::{ - client::{channel::WorkQueueItem, subchannel}, + client::{ + channel::WorkQueueItem, + subchannel, + transport::{ConnectedTransport, TransportOptions}, + }, + rt::{Runtime, TaskHandle}, service::{Request, Response, Service}, }; use core::panic; +use std::time::{Duration, Instant}; use std::{ collections::BTreeMap, error::Error, @@ -17,14 +23,10 @@ use std::{ ops::Sub, sync::{Arc, Mutex, RwLock, Weak}, }; -use tokio::{ - sync::{mpsc, watch, Notify}, - task::{AbortHandle, JoinHandle}, - time::{Duration, Instant}, -}; +use tokio::sync::{mpsc, oneshot, watch, Notify}; use tonic::async_trait; -type SharedService = Arc; +type SharedService = Arc; pub trait Backoff: Send + Sync { fn backoff_until(&self) -> Instant; @@ -52,16 +54,16 @@ enum InternalSubchannelState { } struct InternalSubchannelConnectingState { - abort_handle: Option, + abort_handle: Option>, } struct InternalSubchannelReadyState { - abort_handle: Option, + abort_handle: Option>, svc: SharedService, } struct InternalSubchannelTransientFailureState { - abort_handle: Option, + task_handle: Option>, error: String, } @@ -163,7 +165,7 @@ impl Drop for InternalSubchannelState { } } Self::TransientFailure(st) => { - if let Some(ah) = &st.abort_handle { + if let Some(ah) = &st.task_handle { ah.abort(); } } @@ -178,13 +180,14 @@ pub(crate) struct InternalSubchannel { unregister_fn: Option>, state_machine_event_sender: mpsc::UnboundedSender, inner: Mutex, + runtime: Arc, } struct InnerSubchannel { state: InternalSubchannelState, watchers: Vec>, // TODO(easwars): Revisit the choice for this data structure. - backoff_task: Option>, - disconnect_task: Option>, + backoff_task: Option>, + disconnect_task: Option>, } #[async_trait] @@ -204,7 +207,7 @@ impl Service for InternalSubchannel { enum SubchannelStateMachineEvent { ConnectionRequested, - ConnectionSucceeded(SharedService), + ConnectionSucceeded(SharedService, oneshot::Receiver>), ConnectionTimedOut, ConnectionFailed(String), ConnectionTerminated, @@ -214,7 +217,7 @@ impl Debug for SubchannelStateMachineEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::ConnectionRequested => write!(f, "ConnectionRequested"), - Self::ConnectionSucceeded(_) => write!(f, "ConnectionSucceeded"), + Self::ConnectionSucceeded(_, _) => write!(f, "ConnectionSucceeded"), Self::ConnectionTimedOut => write!(f, "ConnectionTimedOut"), Self::ConnectionFailed(_) => write!(f, "ConnectionFailed"), Self::ConnectionTerminated => write!(f, "ConnectionTerminated"), @@ -229,6 +232,7 @@ impl InternalSubchannel { transport: Arc, backoff: Arc, unregister_fn: Box, + runtime: Arc, ) -> Arc { println!("creating new internal subchannel for: {:?}", &key); let (tx, mut rx) = mpsc::unbounded_channel::(); @@ -244,6 +248,7 @@ impl InternalSubchannel { backoff_task: None, disconnect_task: None, }), + runtime: runtime.clone(), }); // This long running task implements the subchannel state machine. When @@ -251,7 +256,7 @@ impl InternalSubchannel { // closed, and therefore this task exits because rx.recv() returns None // in that case. let arc_to_self = Arc::clone(&isc); - tokio::task::spawn(async move { + runtime.spawn(Box::pin(async move { println!("starting subchannel state machine for: {:?}", &key); while let Some(m) = rx.recv().await { println!("subchannel {:?} received event {:?}", &key, &m); @@ -259,8 +264,8 @@ impl InternalSubchannel { SubchannelStateMachineEvent::ConnectionRequested => { arc_to_self.move_to_connecting(); } - SubchannelStateMachineEvent::ConnectionSucceeded(svc) => { - arc_to_self.move_to_ready(svc); + SubchannelStateMachineEvent::ConnectionSucceeded(svc, rx) => { + arc_to_self.move_to_ready(svc, rx); } SubchannelStateMachineEvent::ConnectionTimedOut => { arc_to_self.move_to_transient_failure("connect timeout expired".into()); @@ -277,7 +282,7 @@ impl InternalSubchannel { } } println!("exiting work queue task in subchannel"); - }); + })); isc } @@ -345,15 +350,19 @@ impl InternalSubchannel { let transport = self.transport.clone(); let address = self.address().address; let state_machine_tx = self.state_machine_event_sender.clone(); - let connect_task = tokio::task::spawn(async move { + // TODO: All these options to be configured by users. + let transport_opts = TransportOptions::default(); + let runtime = self.runtime.clone(); + + let connect_task = self.runtime.spawn(Box::pin(async move { tokio::select! { - _ = tokio::time::sleep(min_connect_timeout) => { + _ = runtime.sleep(min_connect_timeout) => { let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTimedOut); } - result = transport.connect(address.to_string().clone()) => { + result = transport.connect(address.to_string().clone(), runtime, &transport_opts) => { match result { Ok(s) => { - let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionSucceeded(Arc::from(s))); + let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionSucceeded(Arc::from(s.service), s.disconnection_listener)); } Err(e) => { let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionFailed(e)); @@ -361,14 +370,14 @@ impl InternalSubchannel { } }, } - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::Connecting(InternalSubchannelConnectingState { - abort_handle: Some(connect_task.abort_handle()), + abort_handle: Some(connect_task), }); } - fn move_to_ready(&self, svc: SharedService) { + fn move_to_ready(&self, svc: SharedService, closed_rx: oneshot::Receiver>) { let svc2 = svc.clone(); { let mut inner = self.inner.lock().unwrap(); @@ -383,17 +392,19 @@ impl InternalSubchannel { }); let state_machine_tx = self.state_machine_event_sender.clone(); - let disconnect_task = tokio::task::spawn(async move { + let task_handle = self.runtime.spawn(Box::pin(async move { // TODO(easwars): Does it make sense for disconnected() to return an // error string containing information about why the connection // terminated? But what can we do with that error other than logging // it, which the transport can do as well? - svc.disconnected().await; + if let Err(e) = closed_rx.await { + eprintln!("Transport closed with error: {e}",) + }; let _ = state_machine_tx.send(SubchannelStateMachineEvent::ConnectionTerminated); - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::Ready(InternalSubchannelReadyState { - abort_handle: Some(disconnect_task.abort_handle()), + abort_handle: Some(task_handle), svc: svc2.clone(), }); } @@ -403,7 +414,7 @@ impl InternalSubchannel { let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::TransientFailure( InternalSubchannelTransientFailureState { - abort_handle: None, + task_handle: None, error: err.clone(), }, ); @@ -417,14 +428,17 @@ impl InternalSubchannel { let backoff_interval = self.backoff.backoff_until(); let state_machine_tx = self.state_machine_event_sender.clone(); - let backoff_task = tokio::task::spawn(async move { - tokio::time::sleep_until(backoff_interval).await; + let runtime = self.runtime.clone(); + let backoff_task = self.runtime.spawn(Box::pin(async move { + runtime + .sleep(backoff_interval.saturating_duration_since(Instant::now())) + .await; let _ = state_machine_tx.send(SubchannelStateMachineEvent::BackoffExpired); - }); + })); let mut inner = self.inner.lock().unwrap(); inner.state = InternalSubchannelState::TransientFailure(InternalSubchannelTransientFailureState { - abort_handle: Some(backoff_task.abort_handle()), + task_handle: Some(backoff_task), error: err.clone(), }); } diff --git a/grpc/src/client/transport/mod.rs b/grpc/src/client/transport/mod.rs index 4c5b021b8..411a2954b 100644 --- a/grpc/src/client/transport/mod.rs +++ b/grpc/src/client/transport/mod.rs @@ -1,16 +1,49 @@ -use crate::service::Service; +use crate::{rt::Runtime, service::Service}; +use std::time::Instant; +use std::{sync::Arc, time::Duration}; mod registry; +// Using tower/buffer enables tokio's rt feature even though it's possible to +// create Buffers with a user provided executor. +#[cfg(feature = "_runtime-tokio")] +mod tonic; + use ::tonic::async_trait; -pub use registry::{TransportRegistry, GLOBAL_TRANSPORT_REGISTRY}; +pub(crate) use registry::TransportRegistry; +pub(crate) use registry::GLOBAL_TRANSPORT_REGISTRY; +use tokio::sync::oneshot; -#[async_trait] -pub trait Transport: Send + Sync { - async fn connect(&self, address: String) -> Result, String>; +pub(crate) struct ConnectedTransport { + pub service: Box, + pub disconnection_listener: oneshot::Receiver>, +} + +// TODO: The following options are specific to HTTP/2. We should +// instead pass an `Attribute` like struct to the connect method instead which +// can hold config relevant to a particular transport. +#[derive(Default)] +pub(crate) struct TransportOptions { + pub(crate) init_stream_window_size: Option, + pub(crate) init_connection_window_size: Option, + pub(crate) http2_keep_alive_interval: Option, + pub(crate) http2_keep_alive_timeout: Option, + pub(crate) http2_keep_alive_while_idle: Option, + pub(crate) http2_max_header_list_size: Option, + pub(crate) http2_adaptive_window: Option, + pub(crate) concurrency_limit: Option, + pub(crate) rate_limit: Option<(u64, Duration)>, + pub(crate) tcp_keepalive: Option, + pub(crate) tcp_nodelay: bool, + pub(crate) connect_deadline: Option, } #[async_trait] -pub trait ConnectedTransport: Service { - async fn disconnected(&self); +pub(crate) trait Transport: Send + Sync { + async fn connect( + &self, + address: String, + runtime: Arc, + opts: &TransportOptions, + ) -> Result; } diff --git a/grpc/src/client/transport/registry.rs b/grpc/src/client/transport/registry.rs index e5f7f7fe0..0b4f614ef 100644 --- a/grpc/src/client/transport/registry.rs +++ b/grpc/src/client/transport/registry.rs @@ -1,20 +1,17 @@ -use std::{ - collections::HashMap, - sync::{Arc, LazyLock, Mutex}, -}; - use super::Transport; +use std::sync::{Arc, LazyLock, Mutex}; +use std::{collections::HashMap, fmt::Debug}; /// A registry to store and retrieve transports. Transports are indexed by /// the address type they are intended to handle. -#[derive(Clone)] -pub struct TransportRegistry { - m: Arc>>>, +#[derive(Default, Clone)] +pub(crate) struct TransportRegistry { + inner: Arc>>>, } -impl std::fmt::Debug for TransportRegistry { +impl Debug for TransportRegistry { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let m = self.m.lock().unwrap(); + let m = self.inner.lock().unwrap(); for key in m.keys() { write!(f, "k: {key:?}")? } @@ -24,21 +21,21 @@ impl std::fmt::Debug for TransportRegistry { impl TransportRegistry { /// Construct an empty name resolver registry. - pub fn new() -> Self { - Self { m: Arc::default() } + pub(crate) fn new() -> Self { + Self::default() } - /// Add a name resolver into the registry. - pub fn add_transport(&self, address_type: &str, transport: impl Transport + 'static) { - //let a: Arc = transport; - //let a: Arc> = transport; - self.m + + /// Add a transport into the registry. + pub(crate) fn add_transport(&self, address_type: &str, transport: impl Transport + 'static) { + self.inner .lock() .unwrap() .insert(address_type.to_string(), Arc::new(transport)); } + /// Retrieve a name resolver from the registry, or None if not found. - pub fn get_transport(&self, address_type: &str) -> Result, String> { - self.m + pub(crate) fn get_transport(&self, address_type: &str) -> Result, String> { + self.inner .lock() .unwrap() .get(address_type) @@ -49,12 +46,6 @@ impl TransportRegistry { } } -impl Default for TransportRegistry { - fn default() -> Self { - Self::new() - } -} - /// The registry used if a local registry is not provided to a channel or if it /// does not exist in the local registry. pub static GLOBAL_TRANSPORT_REGISTRY: LazyLock = diff --git a/grpc/src/client/transport/tonic/mod.rs b/grpc/src/client/transport/tonic/mod.rs new file mode 100644 index 000000000..7e53235ab --- /dev/null +++ b/grpc/src/client/transport/tonic/mod.rs @@ -0,0 +1,274 @@ +use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY; +use crate::client::transport::ConnectedTransport; +use crate::client::transport::Transport; +use crate::client::transport::TransportOptions; +use crate::codec::BytesCodec; +use crate::rt::hyper_wrapper::{HyperCompatExec, HyperCompatTimer, HyperStream}; +use crate::rt::Runtime; +use crate::rt::TaskHandle; +use crate::rt::TcpOptions; +use crate::service::Message; +use crate::service::Request as GrpcRequest; +use crate::service::Response as GrpcResponse; +use crate::{client::name_resolution::TCP_IP_NETWORK_TYPE, service::Service}; +use bytes::Bytes; +use http::uri::PathAndQuery; +use http::Request as HttpRequest; +use http::Response as HttpResponse; +use http::Uri; +use hyper::client::conn::http2::Builder; +use hyper::client::conn::http2::SendRequest; +use std::any::Any; +use std::task::{Context, Poll}; +use std::time::Instant; +use std::{error::Error, future::Future, net::SocketAddr, pin::Pin, str::FromStr, sync::Arc}; +use tokio::sync::oneshot; +use tokio_stream::Stream; +use tokio_stream::StreamExt; +use tonic::client::GrpcService; +use tonic::Request as TonicRequest; +use tonic::Response as TonicResponse; +use tonic::Streaming; +use tonic::{async_trait, body::Body, client::Grpc, Status}; +use tower::buffer::{future::ResponseFuture as BufferResponseFuture, Buffer}; +use tower::limit::{ConcurrencyLimitLayer, RateLimitLayer}; +use tower::{util::BoxService, ServiceBuilder}; +use tower_service::Service as TowerService; + +#[cfg(test)] +mod test; + +const DEFAULT_BUFFER_SIZE: usize = 1024; +pub(crate) type BoxError = Box; + +type BoxFuture<'a, T> = Pin + Send + 'a>>; +type BoxStream = Pin> + Send>>; + +pub(crate) fn reg() { + GLOBAL_TRANSPORT_REGISTRY.add_transport(TCP_IP_NETWORK_TYPE, TransportBuilder {}); +} + +struct TransportBuilder {} + +struct TonicTransport { + grpc: Grpc, + task_handle: Box, +} + +impl Drop for TonicTransport { + fn drop(&mut self) { + self.task_handle.abort(); + } +} + +#[async_trait] +impl Service for TonicTransport { + async fn call(&self, method: String, request: GrpcRequest) -> GrpcResponse { + let Ok(path) = PathAndQuery::from_maybe_shared(method) else { + let err = Status::internal("Failed to parse path"); + return create_error_response(err); + }; + let mut grpc = self.grpc.clone(); + if let Err(e) = grpc.ready().await { + let err = Status::unknown(format!("Service was not ready: {e}")); + return create_error_response(err); + }; + let request = convert_request(request); + let response = grpc.streaming(request, path, BytesCodec {}).await; + convert_response(response) + } +} + +/// Helper function to create an error response stream. +fn create_error_response(status: Status) -> GrpcResponse { + let stream = tokio_stream::once(Err(status)); + TonicResponse::new(Box::pin(stream)) +} + +fn convert_request(req: GrpcRequest) -> TonicRequest + Send>>> { + let (metadata, extensions, stream) = req.into_parts(); + + let bytes_stream = Box::pin(stream.filter_map(|msg| { + if let Ok(bytes) = (msg as Box).downcast::() { + Some(*bytes) + } else { + // If it fails, log the error and return None to filter it out. + eprintln!("A message could not be downcast to Bytes and was skipped."); + None + } + })); + + TonicRequest::from_parts(metadata, extensions, bytes_stream as _) +} + +fn convert_response(res: Result>, Status>) -> GrpcResponse { + let response = match res { + Ok(s) => s, + Err(e) => { + let stream = tokio_stream::once(Err(e)); + return TonicResponse::new(Box::pin(stream)); + } + }; + let (metadata, stream, extensions) = response.into_parts(); + let message_stream: BoxStream> = Box::pin(stream.map(|msg| { + msg.map(|b| { + let msg: Box = Box::new(b); + msg + }) + })); + TonicResponse::from_parts(metadata, message_stream, extensions) +} + +#[async_trait] +impl Transport for TransportBuilder { + async fn connect( + &self, + address: String, + runtime: Arc, + opts: &TransportOptions, + ) -> Result { + let runtime = runtime.clone(); + let mut settings = Builder::::new(HyperCompatExec { + inner: runtime.clone(), + }) + .timer(HyperCompatTimer { + inner: runtime.clone(), + }) + .initial_stream_window_size(opts.init_stream_window_size) + .initial_connection_window_size(opts.init_connection_window_size) + .keep_alive_interval(opts.http2_keep_alive_interval) + .clone(); + + if let Some(val) = opts.http2_keep_alive_timeout { + settings.keep_alive_timeout(val); + } + + if let Some(val) = opts.http2_keep_alive_while_idle { + settings.keep_alive_while_idle(val); + } + + if let Some(val) = opts.http2_adaptive_window { + settings.adaptive_window(val); + } + + if let Some(val) = opts.http2_max_header_list_size { + settings.max_header_list_size(val); + } + + let addr: SocketAddr = SocketAddr::from_str(&address).map_err(|err| err.to_string())?; + let tcp_stream_fut = runtime.tcp_stream( + addr, + TcpOptions { + enable_nodelay: opts.tcp_nodelay, + keepalive: opts.tcp_keepalive, + }, + ); + let tcp_stream = if let Some(deadline) = opts.connect_deadline { + let timeout = deadline.saturating_duration_since(Instant::now()); + tokio::select! { + _ = runtime.sleep(timeout) => { + return Err("timed out waiting for TCP stream to connect".to_string()) + } + tcp_stream = tcp_stream_fut => { tcp_stream? } + } + } else { + tcp_stream_fut.await? + }; + let tcp_stream = HyperStream::new(tcp_stream); + + let (sender, connection) = settings + .handshake(tcp_stream) + .await + .map_err(|err| err.to_string())?; + let (tx, rx) = oneshot::channel(); + + let task_handle = runtime.spawn(Box::pin(async move { + if let Err(err) = connection.await { + let _ = tx.send(Err(err.to_string())); + } else { + let _ = tx.send(Ok(())); + } + })); + let sender = SendRequestWrapper::from(sender); + + let service = ServiceBuilder::new() + .option_layer(opts.concurrency_limit.map(ConcurrencyLimitLayer::new)) + .option_layer(opts.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) + .map_err(Into::::into) + .service(sender); + + let service = BoxService::new(service); + let (service, worker) = Buffer::pair(service, DEFAULT_BUFFER_SIZE); + runtime.spawn(Box::pin(worker)); + let uri = + Uri::from_maybe_shared(format!("http://{}", &address)).map_err(|e| e.to_string())?; // TODO: err msg + let grpc = Grpc::with_origin(TonicService { inner: service }, uri); + + let service = TonicTransport { grpc, task_handle }; + Ok(ConnectedTransport { + service: Box::new(service), + disconnection_listener: rx, + }) + } +} + +struct SendRequestWrapper { + inner: SendRequest, +} + +impl From> for SendRequestWrapper { + fn from(inner: SendRequest) -> Self { + Self { inner } + } +} + +impl TowerService> for SendRequestWrapper { + type Response = HttpResponse; + type Error = BoxError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let fut = self.inner.send_request(req); + Box::pin(async move { fut.await.map_err(Into::into).map(|res| res.map(Body::new)) }) + } +} + +#[derive(Clone)] +struct TonicService { + inner: Buffer, BoxFuture<'static, Result, BoxError>>>, +} + +impl GrpcService for TonicService { + type ResponseBody = Body; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + tower::Service::poll_ready(&mut self.inner, cx) + } + + fn call(&mut self, request: http::Request) -> Self::Future { + ResponseFuture { + inner: tower::Service::call(&mut self.inner, request), + } + } +} + +/// A future that resolves to an HTTP response. +/// +/// This is returned by the `Service::call` on [`Channel`]. +pub struct ResponseFuture { + inner: BufferResponseFuture, BoxError>>>, +} + +impl Future for ResponseFuture { + type Output = Result, BoxError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs new file mode 100644 index 000000000..678280e34 --- /dev/null +++ b/grpc/src/client/transport/tonic/test.rs @@ -0,0 +1,165 @@ +use crate::client::name_resolution::TCP_IP_NETWORK_TYPE; +use crate::client::transport::registry::GLOBAL_TRANSPORT_REGISTRY; +use crate::echo_pb::echo_server::{Echo, EchoServer}; +use crate::echo_pb::{EchoRequest, EchoResponse}; +use crate::service::Message; +use crate::service::Request as GrpcRequest; +use crate::{client::transport::TransportOptions, rt::tokio::TokioRuntime}; +use bytes::Bytes; +use std::any::Any; +use std::{pin::Pin, sync::Arc, time::Duration}; +use tokio::net::TcpListener; +use tokio::sync::{mpsc, oneshot, Notify}; +use tokio::time::timeout; +use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt}; +use tonic::async_trait; +use tonic::{transport::Server, Request, Response, Status}; +use tonic_prost::prost::Message as ProstMessage; + +const DEFAULT_TEST_DURATION: Duration = Duration::from_secs(10); +const DEFAULT_TEST_SHORT_DURATION: Duration = Duration::from_millis(10); + +// Tests the tonic transport by creating a bi-di stream with a tonic server. +#[tokio::test] +pub async fn tonic_transport_rpc() { + super::reg(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); // get the assigned address + let shutdown_notify = Arc::new(Notify::new()); + let shutdown_notify_copy = shutdown_notify.clone(); + println!("EchoServer listening on: {addr}"); + let server_handle = tokio::spawn(async move { + let echo_server = EchoService {}; + let svc = EchoServer::new(echo_server); + let _ = Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_notify_copy.notified(), + ) + .await; + }); + + let builder = GLOBAL_TRANSPORT_REGISTRY + .get_transport(TCP_IP_NETWORK_TYPE) + .unwrap(); + let config = Arc::new(TransportOptions::default()); + let mut connected_transport = builder + .connect(addr.to_string(), Arc::new(TokioRuntime {}), &config) + .await + .unwrap(); + let conn = connected_transport.service; + + let (tx, rx) = mpsc::channel::>(1); + + // Convert the mpsc receiver into a Stream + let outbound: GrpcRequest = + Request::new(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))); + + let mut inbound = conn + .call( + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho".to_string(), + outbound, + ) + .await + .into_inner(); + + // Spawn a sender task + let client_handle = tokio::spawn(async move { + for i in 0..5 { + let message = format!("message {i}"); + let request = EchoRequest { + message: message.clone(), + }; + + let bytes = Bytes::from(request.encode_to_vec()); + + println!("Sent request: {request:?}"); + assert!(tx.send(Box::new(bytes)).await.is_ok(), "Receiver dropped"); + + // Wait for the reply + let resp = inbound + .next() + .await + .expect("server unexpectedly closed the stream!") + .expect("server returned error"); + + let bytes = (resp as Box).downcast::().unwrap(); + let echo_response = EchoResponse::decode(bytes).unwrap(); + println!("Got response: {echo_response:?}"); + assert_eq!(echo_response.message, message); + } + }); + + client_handle.await.unwrap(); + // The connection should break only after the server is stopped. + assert_eq!( + connected_transport.disconnection_listener.try_recv(), + Err(oneshot::error::TryRecvError::Empty), + ); + shutdown_notify.notify_waiters(); + let res = timeout( + DEFAULT_TEST_DURATION, + connected_transport.disconnection_listener, + ) + .await + .unwrap() + .unwrap(); + assert_eq!(res, Ok(())); + server_handle.await.unwrap(); +} + +#[derive(Debug)] +pub struct EchoService {} + +#[async_trait] +impl Echo for EchoService { + async fn unary_echo( + &self, + _: tonic::Request, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + + type ServerStreamingEchoStream = ReceiverStream>; + + async fn server_streaming_echo( + &self, + _: tonic::Request, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + + async fn client_streaming_echo( + &self, + _: tonic::Request>, + ) -> std::result::Result, tonic::Status> { + unimplemented!() + } + type BidirectionalStreamingEchoStream = + Pin> + Send + 'static>>; + + async fn bidirectional_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status> + { + let mut inbound = request.into_inner(); + + // Map each request to a corresponding EchoResponse + let outbound = async_stream::try_stream! { + while let Some(req) = inbound.next().await { + let req = req?; // Return Err(Status) if stream item is error + let reply = EchoResponse { + message: req.message.clone(), + }; + yield reply; + } + println!("Server closing stream"); + }; + + Ok(Response::new( + Box::pin(outbound) as Self::BidirectionalStreamingEchoStream + )) + } +} diff --git a/grpc/src/codec.rs b/grpc/src/codec.rs new file mode 100644 index 000000000..eb9cc03e7 --- /dev/null +++ b/grpc/src/codec.rs @@ -0,0 +1,53 @@ +use bytes::{Buf, BufMut, Bytes}; +use tonic::{ + codec::{Codec, Decoder, EncodeBuf, Encoder}, + Status, +}; + +/// An adapter for sending and receiving messages as bytes using tonic. +/// Coding/decoding is handled within gRPC. +/// TODO: Remove this when tonic allows access to bytes without requiring a +/// codec. +pub(crate) struct BytesCodec {} + +impl Codec for BytesCodec { + type Encode = Bytes; + type Decode = Bytes; + type Encoder = BytesEncoder; + type Decoder = BytesDecoder; + + fn encoder(&mut self) -> Self::Encoder { + BytesEncoder {} + } + + fn decoder(&mut self) -> Self::Decoder { + BytesDecoder {} + } +} + +pub struct BytesEncoder {} + +impl Encoder for BytesEncoder { + type Item = Bytes; + type Error = Status; + + fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { + dst.put_slice(&item); + Ok(()) + } +} + +#[derive(Debug)] +pub struct BytesDecoder {} + +impl Decoder for BytesDecoder { + type Item = Bytes; + type Error = Status; + + fn decode( + &mut self, + src: &mut tonic::codec::DecodeBuf<'_>, + ) -> Result, Self::Error> { + Ok(Some(src.copy_to_bytes(src.remaining()))) + } +} diff --git a/grpc/src/generated/echo_fds.rs b/grpc/src/generated/echo_fds.rs new file mode 100644 index 000000000..9833d2636 --- /dev/null +++ b/grpc/src/generated/echo_fds.rs @@ -0,0 +1,61 @@ +// This file is @generated by codegen. +// +// +// Copyright 2018 gRPC authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// +/// Byte encoded FILE_DESCRIPTOR_SET. +pub const FILE_DESCRIPTOR_SET: &[u8] = &[ + 10u8, 246u8, 3u8, 10u8, 15u8, 101u8, 99u8, 104u8, 111u8, 47u8, 101u8, 99u8, 104u8, + 111u8, 46u8, 112u8, 114u8, 111u8, 116u8, 111u8, 18u8, 18u8, 103u8, 114u8, 112u8, + 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, + 104u8, 111u8, 34u8, 39u8, 10u8, 11u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, + 117u8, 101u8, 115u8, 116u8, 18u8, 24u8, 10u8, 7u8, 109u8, 101u8, 115u8, 115u8, 97u8, + 103u8, 101u8, 24u8, 1u8, 32u8, 1u8, 40u8, 9u8, 82u8, 7u8, 109u8, 101u8, 115u8, 115u8, + 97u8, 103u8, 101u8, 34u8, 40u8, 10u8, 12u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, + 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 18u8, 24u8, 10u8, 7u8, 109u8, 101u8, 115u8, + 115u8, 97u8, 103u8, 101u8, 24u8, 1u8, 32u8, 1u8, 40u8, 9u8, 82u8, 7u8, 109u8, 101u8, + 115u8, 115u8, 97u8, 103u8, 101u8, 50u8, 243u8, 2u8, 10u8, 4u8, 69u8, 99u8, 104u8, + 111u8, 18u8, 78u8, 10u8, 9u8, 85u8, 110u8, 97u8, 114u8, 121u8, 69u8, 99u8, 104u8, + 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, + 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, + 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, 32u8, 46u8, 103u8, + 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, + 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 115u8, + 112u8, 111u8, 110u8, 115u8, 101u8, 18u8, 90u8, 10u8, 19u8, 83u8, 101u8, 114u8, 118u8, + 101u8, 114u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, 110u8, 103u8, 69u8, + 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, + 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, + 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, 32u8, 46u8, + 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, + 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, + 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 48u8, 1u8, 18u8, 90u8, 10u8, 19u8, 67u8, + 108u8, 105u8, 101u8, 110u8, 116u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, + 110u8, 103u8, 69u8, 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, + 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, + 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, + 115u8, 116u8, 26u8, 32u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, + 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, + 104u8, 111u8, 82u8, 101u8, 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 40u8, 1u8, 18u8, + 99u8, 10u8, 26u8, 66u8, 105u8, 100u8, 105u8, 114u8, 101u8, 99u8, 116u8, 105u8, 111u8, + 110u8, 97u8, 108u8, 83u8, 116u8, 114u8, 101u8, 97u8, 109u8, 105u8, 110u8, 103u8, + 69u8, 99u8, 104u8, 111u8, 18u8, 31u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, + 120u8, 97u8, 109u8, 112u8, 108u8, 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, + 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, 101u8, 113u8, 117u8, 101u8, 115u8, 116u8, 26u8, + 32u8, 46u8, 103u8, 114u8, 112u8, 99u8, 46u8, 101u8, 120u8, 97u8, 109u8, 112u8, 108u8, + 101u8, 115u8, 46u8, 101u8, 99u8, 104u8, 111u8, 46u8, 69u8, 99u8, 104u8, 111u8, 82u8, + 101u8, 115u8, 112u8, 111u8, 110u8, 115u8, 101u8, 40u8, 1u8, 48u8, 1u8, 98u8, 6u8, + 112u8, 114u8, 111u8, 116u8, 111u8, 51u8, +]; diff --git a/grpc/src/generated/grpc_examples_echo.rs b/grpc/src/generated/grpc_examples_echo.rs new file mode 100644 index 000000000..5545928b0 --- /dev/null +++ b/grpc/src/generated/grpc_examples_echo.rs @@ -0,0 +1,547 @@ +// This file is @generated by prost-build. +/// EchoRequest is the request for echo. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct EchoRequest { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} +/// EchoResponse is the response for echo. +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct EchoResponse { + #[prost(string, tag = "1")] + pub message: ::prost::alloc::string::String, +} +/// Generated client implementations. +pub mod echo_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + /// Echo is the echo service. + #[derive(Debug, Clone)] + pub struct EchoClient { + inner: tonic::client::Grpc, + } + impl EchoClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> EchoClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + std::marker::Send + std::marker::Sync, + { + EchoClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// UnaryEcho is unary echo. + pub async fn unary_echo( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/UnaryEcho", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("grpc.examples.echo.Echo", "UnaryEcho")); + self.inner.unary(req, path, codec).await + } + /// ServerStreamingEcho is server side streaming. + pub async fn server_streaming_echo( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/ServerStreamingEcho", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("grpc.examples.echo.Echo", "ServerStreamingEcho"), + ); + self.inner.server_streaming(req, path, codec).await + } + /// ClientStreamingEcho is client side streaming. + pub async fn client_streaming_echo( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result, tonic::Status> { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/ClientStreamingEcho", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new("grpc.examples.echo.Echo", "ClientStreamingEcho"), + ); + self.inner.client_streaming(req, path, codec).await + } + /// BidirectionalStreamingEcho is bidi streaming. + pub async fn bidirectional_streaming_echo( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::unknown( + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "grpc.examples.echo.Echo", + "BidirectionalStreamingEcho", + ), + ); + self.inner.streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod echo_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value, + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with EchoServer. + #[async_trait] + pub trait Echo: std::marker::Send + std::marker::Sync + 'static { + /// UnaryEcho is unary echo. + async fn unary_echo( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the ServerStreamingEcho method. + type ServerStreamingEchoStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + std::marker::Send + + 'static; + /// ServerStreamingEcho is server side streaming. + async fn server_streaming_echo( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + /// ClientStreamingEcho is client side streaming. + async fn client_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the BidirectionalStreamingEcho method. + type BidirectionalStreamingEchoStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + + std::marker::Send + + 'static; + /// BidirectionalStreamingEcho is bidi streaming. + async fn bidirectional_streaming_echo( + &self, + request: tonic::Request>, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + } + /// Echo is the echo service. + #[derive(Debug)] + pub struct EchoServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl EchoServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for EchoServer + where + T: Echo, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/grpc.examples.echo.Echo/UnaryEcho" => { + #[allow(non_camel_case_types)] + struct UnaryEchoSvc(pub Arc); + impl tonic::server::UnaryService + for UnaryEchoSvc { + type Response = super::EchoResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::unary_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = UnaryEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/ServerStreamingEcho" => { + #[allow(non_camel_case_types)] + struct ServerStreamingEchoSvc(pub Arc); + impl< + T: Echo, + > tonic::server::ServerStreamingService + for ServerStreamingEchoSvc { + type Response = super::EchoResponse; + type ResponseStream = T::ServerStreamingEchoStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::server_streaming_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ServerStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/ClientStreamingEcho" => { + #[allow(non_camel_case_types)] + struct ClientStreamingEchoSvc(pub Arc); + impl< + T: Echo, + > tonic::server::ClientStreamingService + for ClientStreamingEchoSvc { + type Response = super::EchoResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::client_streaming_echo(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = ClientStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.client_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/grpc.examples.echo.Echo/BidirectionalStreamingEcho" => { + #[allow(non_camel_case_types)] + struct BidirectionalStreamingEchoSvc(pub Arc); + impl tonic::server::StreamingService + for BidirectionalStreamingEchoSvc { + type Response = super::EchoResponse; + type ResponseStream = T::BidirectionalStreamingEchoStream; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::bidirectional_streaming_echo(&inner, request) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = BidirectionalStreamingEchoSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + let mut response = http::Response::new( + tonic::body::Body::default(), + ); + let headers = response.headers_mut(); + headers + .insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers + .insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }) + } + } + } + } + impl Clone for EchoServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "grpc.examples.echo.Echo"; + impl tonic::server::NamedService for EchoServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/grpc/src/inmemory/mod.rs b/grpc/src/inmemory/mod.rs index 48e97ceb5..b9dae99e0 100644 --- a/grpc/src/inmemory/mod.rs +++ b/grpc/src/inmemory/mod.rs @@ -1,11 +1,6 @@ -use std::{ - collections::HashMap, - ops::Add, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, LazyLock, - }, -}; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, LazyLock, Mutex}; +use std::{collections::HashMap, ops::Add}; use crate::{ client::{ @@ -13,20 +8,22 @@ use crate::{ self, global_registry, Address, ChannelController, Endpoint, Resolver, ResolverBuilder, ResolverOptions, ResolverUpdate, }, - transport::{self, ConnectedTransport, GLOBAL_TRANSPORT_REGISTRY}, + transport::{self, ConnectedTransport, TransportOptions, GLOBAL_TRANSPORT_REGISTRY}, }, + rt::Runtime, server, service::{Request, Response, Service}, }; -use tokio::sync::{mpsc, oneshot, Mutex, Notify}; +use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex, Notify}; use tonic::async_trait; pub struct Listener { id: String, s: Box>>, - r: Arc>>>, + r: Arc>>>, // List of notifiers to call when closed. - closed: Notify, + #[allow(clippy::type_complexity)] + closed_tx: Arc>>>>, } static ID: AtomicU32 = AtomicU32::new(0); @@ -37,8 +34,8 @@ impl Listener { let s = Arc::new(Self { id: format!("{}", ID.fetch_add(1, Ordering::Relaxed)), s: Box::new(tx), - r: Arc::new(Mutex::new(rx)), - closed: Notify::new(), + r: Arc::new(AsyncMutex::new(rx)), + closed_tx: Arc::new(Mutex::new(Vec::new())), }); LISTENERS.lock().unwrap().insert(s.id.clone(), s.clone()); s @@ -59,7 +56,10 @@ impl Listener { impl Drop for Listener { fn drop(&mut self) { - self.closed.notify_waiters(); + let txs = std::mem::take(&mut *self.closed_tx.lock().unwrap()); + for rx in txs { + let _ = rx.send(Ok(())); + } LISTENERS.lock().unwrap().remove(&self.id); } } @@ -75,25 +75,17 @@ impl Service for Arc { } } -#[async_trait] -impl ConnectedTransport for Arc { - async fn disconnected(&self) { - self.closed.notified().await; - } -} - #[async_trait] impl crate::server::Listener for Arc { async fn accept(&self) -> Option { let mut recv = self.r.lock().await; let r = recv.recv().await; - r.as_ref()?; - r.unwrap() + // Listener may be closed. + r? } } -static LISTENERS: LazyLock>>> = - LazyLock::new(std::sync::Mutex::default); +static LISTENERS: LazyLock>>> = LazyLock::new(Mutex::default); struct ClientTransport {} @@ -105,14 +97,24 @@ impl ClientTransport { #[async_trait] impl transport::Transport for ClientTransport { - async fn connect(&self, address: String) -> Result, String> { + async fn connect( + &self, + address: String, + _: Arc, + _: &TransportOptions, + ) -> Result { let lis = LISTENERS .lock() .unwrap() .get(&address) .ok_or(format!("Could not find listener for address {address}"))? .clone(); - Ok(Box::new(lis)) + let (tx, rx) = oneshot::channel(); + lis.closed_tx.lock().unwrap().push(tx); + Ok(ConnectedTransport { + service: Box::new(lis), + disconnection_listener: rx, + }) } } diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index f56fd2cab..512adbc8f 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -41,3 +41,11 @@ pub mod service; pub(crate) mod attributes; pub(crate) mod byte_str; +pub(crate) mod codec; +#[cfg(test)] +pub(crate) mod echo_pb { + include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/generated/grpc_examples_echo.rs" + )); +} diff --git a/grpc/src/rt/hyper_wrapper.rs b/grpc/src/rt/hyper_wrapper.rs new file mode 100644 index 000000000..6bdaad48f --- /dev/null +++ b/grpc/src/rt/hyper_wrapper.rs @@ -0,0 +1,158 @@ +use super::{Runtime, TcpStream}; +use hyper::rt::{Executor, Timer}; +use pin_project_lite::pin_project; +use std::task::{Context, Poll}; +use std::{future::Future, io, pin::Pin, sync::Arc, time::Instant}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// Adapts a runtime to a hyper compatible executor. +#[derive(Clone)] +pub(crate) struct HyperCompatExec { + pub(crate) inner: Arc, +} + +impl Executor for HyperCompatExec +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.spawn(Box::pin(async { + let _ = fut.await; + })); + } +} + +struct HyperCompatSleep { + inner: Pin>, +} + +impl Future for HyperCompatSleep { + type Output = (); + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.inner.as_mut().poll(cx) + } +} + +impl hyper::rt::Sleep for HyperCompatSleep {} + +/// Adapts a runtime to a hyper compatible timer. +pub(crate) struct HyperCompatTimer { + pub(crate) inner: Arc, +} + +impl Timer for HyperCompatTimer { + fn sleep(&self, duration: std::time::Duration) -> Pin> { + let sleep = self.inner.sleep(duration); + Box::pin(HyperCompatSleep { inner: sleep }) + } + + fn sleep_until(&self, deadline: Instant) -> Pin> { + let now = Instant::now(); + let duration = deadline.saturating_duration_since(now); + self.sleep(duration) + } +} + +// The following adapters are copied from hyper: +// https://github.com/hyperium/hyper/blob/v1.6.0/benches/support/tokiort.rs + +pin_project! { + /// A wrapper to make any `TcpStream` compatible with Hyper. It implements + /// Tokio's async IO traits. + pub(crate) struct HyperStream { + #[pin] + inner: Box, + } +} + +impl HyperStream { + /// Creates a new `HyperStream` from a type implementing `TcpStream`. + pub fn new(stream: Box) -> Self { + Self { inner: stream } + } +} + +impl AsyncRead for HyperStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // Delegate the poll_read call to the inner stream. + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for HyperStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl hyper::rt::Read for HyperStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: hyper::rt::ReadBufCursor<'_>, + ) -> Poll> { + let n = unsafe { + let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut()); + match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) { + Poll::Ready(Ok(())) => tbuf.filled().len(), + other => return other, + } + }; + + unsafe { + buf.advance(n); + } + Poll::Ready(Ok(())) + } +} + +impl hyper::rt::Write for HyperStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_shutdown(self.project().inner, cx) + } + + fn is_write_vectored(&self) -> bool { + AsyncWrite::is_write_vectored(&self.inner) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) + } +} diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index 78accb53f..9dff135e1 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -23,10 +23,13 @@ */ use ::tokio::io::{AsyncRead, AsyncWrite}; +use std::{future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; -use std::{future::Future, net::SocketAddr, pin::Pin, time::Duration}; +pub(crate) mod hyper_wrapper; +#[cfg(feature = "_runtime-tokio")] +pub(crate) mod tokio; -pub mod tokio; +type BoxFuture = Pin + Send>>; /// An abstraction over an asynchronous runtime. /// @@ -49,6 +52,14 @@ pub(super) trait Runtime: Send + Sync { /// Returns a future that completes after the specified duration. fn sleep(&self, duration: std::time::Duration) -> Pin>; + + /// Establishes a TCP connection to the given `target` address with the + /// specified `opts`. + fn tcp_stream( + &self, + target: SocketAddr, + opts: TcpOptions, + ) -> BoxFuture, String>>; } /// A future that resolves after a specified duration. @@ -77,7 +88,51 @@ pub(super) struct ResolverOptions { } #[derive(Default)] -pub struct TcpOptions { - pub enable_nodelay: bool, - pub keepalive: Option, +pub(crate) struct TcpOptions { + pub(crate) enable_nodelay: bool, + pub(crate) keepalive: Option, +} + +pub(crate) trait TcpStream: AsyncRead + AsyncWrite + Send + Unpin {} + +/// A fake runtime to satisfy the compiler when no runtime is enabled. This will +/// +/// # Panics +/// +/// Panics if any of its functions are called. +#[derive(Default)] +pub(crate) struct NoOpRuntime {} + +impl Runtime for NoOpRuntime { + fn spawn( + &self, + task: Pin + Send + 'static>>, + ) -> Box { + unimplemented!() + } + + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String> { + unimplemented!() + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + unimplemented!() + } + + fn tcp_stream( + &self, + target: SocketAddr, + opts: TcpOptions, + ) -> Pin, String>> + Send>> { + unimplemented!() + } +} + +pub(crate) fn default_runtime() -> Arc { + #[cfg(feature = "_runtime-tokio")] + { + return Arc::new(tokio::TokioRuntime {}); + } + #[allow(unreachable_code)] + Arc::new(NoOpRuntime::default()) } diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs index b0a66ae39..faa7ebbb7 100644 --- a/grpc/src/rt/tokio/mod.rs +++ b/grpc/src/rt/tokio/mod.rs @@ -31,6 +31,7 @@ use std::{ use tokio::{ io::{AsyncRead, AsyncWrite}, + net::TcpStream, task::JoinHandle, }; @@ -95,6 +96,28 @@ impl Runtime for TokioRuntime { fn sleep(&self, duration: Duration) -> Pin> { Box::pin(tokio::time::sleep(duration)) } + + fn tcp_stream( + &self, + target: SocketAddr, + opts: super::TcpOptions, + ) -> Pin, String>> + Send>> { + Box::pin(async move { + let stream = TcpStream::connect(target) + .await + .map_err(|err| err.to_string())?; + if let Some(duration) = opts.keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let mut ka = socket2::TcpKeepalive::new(); + ka = ka.with_time(duration); + sock_ref + .set_tcp_keepalive(&ka) + .map_err(|err| err.to_string())?; + } + let stream: Box = Box::new(TokioTcpStream { inner: stream }); + Ok(stream) + }) + } } impl TokioDefaultDnsResolver { @@ -106,6 +129,46 @@ impl TokioDefaultDnsResolver { } } +struct TokioTcpStream { + inner: TcpStream, +} + +impl AsyncRead for TokioTcpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for TokioTcpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl super::TcpStream for TokioTcpStream {} + #[cfg(test)] mod tests { use super::{DnsResolver, ResolverOptions, Runtime, TokioDefaultDnsResolver, TokioRuntime}; diff --git a/grpc/src/service.rs b/grpc/src/service.rs index b16f9c9b7..0f9a79901 100644 --- a/grpc/src/service.rs +++ b/grpc/src/service.rs @@ -29,7 +29,7 @@ use tonic::{async_trait, Request as TonicRequest, Response as TonicResponse, Sta pub type Request = TonicRequest> + Send + Sync>>>; pub type Response = - TonicResponse, Status>> + Send + Sync>>>; + TonicResponse, Status>> + Send>>>; #[async_trait] pub trait Service: Send + Sync { @@ -38,3 +38,5 @@ pub trait Service: Send + Sync { // TODO: define methods that will allow serialization/deserialization. pub trait Message: Any + Send + Sync {} + +impl Message for T where T: Any + Send + Sync {}