diff --git a/README.md b/README.md index b2f238cf51..0e04b6bab5 100644 --- a/README.md +++ b/README.md @@ -97,18 +97,16 @@ let router = Router::builder(endpoint) struct Echo; impl ProtocolHandler for Echo { - fn accept(&self, connection: Connection) -> BoxedFuture> { - Box::pin(async move { - let (mut send, mut recv) = connection.accept_bi().await?; + async fn accept(&self, connection: Connection) -> Result<()> { + let (mut send, mut recv) = connection.accept_bi().await?; - // Echo any bytes received back directly. - let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; + // Echo any bytes received back directly. + let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; - send.finish()?; - connection.closed().await; + send.finish()?; + connection.closed().await; - Ok(()) - }) + Ok(()) } } ``` diff --git a/iroh/examples/echo.rs b/iroh/examples/echo.rs index 3e863005f7..e02ee825c3 100644 --- a/iroh/examples/echo.rs +++ b/iroh/examples/echo.rs @@ -13,7 +13,6 @@ use iroh::{ watcher::Watcher as _, Endpoint, NodeAddr, }; -use n0_future::boxed::BoxFuture; /// Each protocol is identified by its ALPN string. /// @@ -84,31 +83,28 @@ impl ProtocolHandler for Echo { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - fn accept(&self, connection: Connection) -> BoxFuture> { - // We have to return a boxed future from the handler. - Box::pin(async move { - // We can get the remote's node id from the connection. - let node_id = connection.remote_node_id()?; - println!("accepted connection from {node_id}"); - - // Our protocol is a simple request-response protocol, so we expect the - // connecting peer to open a single bi-directional stream. - let (mut send, mut recv) = connection.accept_bi().await?; - - // Echo any bytes received back directly. - // This will keep copying until the sender signals the end of data on the stream. - let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; - println!("Copied over {bytes_sent} byte(s)"); - - // By calling `finish` on the send stream we signal that we will not send anything - // further, which makes the receive stream on the other end terminate. - send.finish()?; - - // Wait until the remote closes the connection, which it does once it - // received the response. - connection.closed().await; - - Ok(()) - }) + async fn accept(&self, connection: Connection) -> Result<()> { + // We can get the remote's node id from the connection. + let node_id = connection.remote_node_id()?; + println!("accepted connection from {node_id}"); + + // Our protocol is a simple request-response protocol, so we expect the + // connecting peer to open a single bi-directional stream. + let (mut send, mut recv) = connection.accept_bi().await?; + + // Echo any bytes received back directly. + // This will keep copying until the sender signals the end of data on the stream. + let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; + println!("Copied over {bytes_sent} byte(s)"); + + // By calling `finish` on the send stream we signal that we will not send anything + // further, which makes the receive stream on the other end terminate. + send.finish()?; + + // Wait until the remote closes the connection, which it does once it + // received the response. + connection.closed().await; + + Ok(()) } } diff --git a/iroh/examples/search.rs b/iroh/examples/search.rs index f2120ea4e8..e007718dff 100644 --- a/iroh/examples/search.rs +++ b/iroh/examples/search.rs @@ -38,7 +38,6 @@ use iroh::{ protocol::{ProtocolHandler, Router}, Endpoint, NodeId, }; -use n0_future::boxed::BoxFuture; use tokio::sync::Mutex; use tracing_subscriber::{prelude::*, EnvFilter}; @@ -127,40 +126,36 @@ impl ProtocolHandler for BlobSearch { /// /// The returned future runs on a newly spawned tokio task, so it can run as long as /// the connection lasts. - fn accept(&self, connection: Connection) -> BoxFuture> { - let this = self.clone(); - // We have to return a boxed future from the handler. - Box::pin(async move { - // We can get the remote's node id from the connection. - let node_id = connection.remote_node_id()?; - println!("accepted connection from {node_id}"); - - // Our protocol is a simple request-response protocol, so we expect the - // connecting peer to open a single bi-directional stream. - let (mut send, mut recv) = connection.accept_bi().await?; - - // We read the query from the receive stream, while enforcing a max query length. - let query_bytes = recv.read_to_end(64).await?; - - // Now, we can perform the actual query on our local database. - let query = String::from_utf8(query_bytes)?; - let num_matches = this.query_local(&query).await; - - // We want to return a list of hashes. We do the simplest thing possible, and just send - // one hash after the other. Because the hashes have a fixed size of 32 bytes, this is - // very easy to parse on the other end. - send.write_all(&num_matches.to_le_bytes()).await?; - - // By calling `finish` on the send stream we signal that we will not send anything - // further, which makes the receive stream on the other end terminate. - send.finish()?; - - // Wait until the remote closes the connection, which it does once it - // received the response. - connection.closed().await; - - Ok(()) - }) + async fn accept(&self, connection: Connection) -> Result<()> { + // We can get the remote's node id from the connection. + let node_id = connection.remote_node_id()?; + println!("accepted connection from {node_id}"); + + // Our protocol is a simple request-response protocol, so we expect the + // connecting peer to open a single bi-directional stream. + let (mut send, mut recv) = connection.accept_bi().await?; + + // We read the query from the receive stream, while enforcing a max query length. + let query_bytes = recv.read_to_end(64).await?; + + // Now, we can perform the actual query on our local database. + let query = String::from_utf8(query_bytes)?; + let num_matches = self.query_local(&query).await; + + // We want to return a list of hashes. We do the simplest thing possible, and just send + // one hash after the other. Because the hashes have a fixed size of 32 bytes, this is + // very easy to parse on the other end. + send.write_all(&num_matches.to_le_bytes()).await?; + + // By calling `finish` on the send stream we signal that we will not send anything + // further, which makes the receive stream on the other end terminate. + send.finish()?; + + // Wait until the remote closes the connection, which it does once it + // received the response. + connection.closed().await; + + Ok(()) } } diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index ba948f1fad..8df217d86e 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -1991,7 +1991,7 @@ impl Connection { /// [`Connecting::handshake_data()`] succeeds. See that method's documentations for /// details on the returned value. /// - /// [`Connection::handshake_data()`]: crate::Connecting::handshake_data + /// [`Connection::handshake_data()`]: crate::endpoint::Connecting::handshake_data #[inline] pub fn handshake_data(&self) -> Option> { self.inner.handshake_data() diff --git a/iroh/src/protocol.rs b/iroh/src/protocol.rs index 4732ad004d..56abe7e29c 100644 --- a/iroh/src/protocol.rs +++ b/iroh/src/protocol.rs @@ -4,7 +4,6 @@ //! //! ```no_run //! # use anyhow::Result; -//! # use futures_lite::future::Boxed as BoxedFuture; //! # use iroh::{endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeAddr}; //! # //! # async fn test_compile() -> Result<()> { @@ -21,27 +20,24 @@ //! struct Echo; //! //! impl ProtocolHandler for Echo { -//! fn accept(&self, connection: Connection) -> BoxedFuture> { -//! Box::pin(async move { -//! let (mut send, mut recv) = connection.accept_bi().await?; +//! async fn accept(&self, connection: Connection) -> Result<()> { +//! let (mut send, mut recv) = connection.accept_bi().await?; //! -//! // Echo any bytes received back directly. -//! let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; +//! // Echo any bytes received back directly. +//! let bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; //! -//! send.finish()?; -//! connection.closed().await; +//! send.finish()?; +//! connection.closed().await; //! -//! Ok(()) -//! }) +//! Ok(()) //! } //! } //! ``` -use std::{collections::BTreeMap, sync::Arc}; +use std::{collections::BTreeMap, future::Future, pin::Pin, sync::Arc}; use anyhow::Result; use iroh_base::NodeId; use n0_future::{ - boxed::BoxFuture, join_all, task::{self, AbortOnDropHandle, JoinSet}, }; @@ -109,75 +105,135 @@ pub struct RouterBuilder { /// Implement this trait on a struct that should handle incoming connections. /// The protocol handler must then be registered on the node for an ALPN protocol with /// [`crate::protocol::RouterBuilder::accept`]. +/// +/// See the [module documentation](crate::protocol) for an example. pub trait ProtocolHandler: Send + Sync + std::fmt::Debug + 'static { /// Optional interception point to handle the `Connecting` state. /// + /// Can be implemented as `async fn on_connecting(&self, connecting: Connecting) -> Result`. + /// /// This enables accepting 0-RTT data from clients, among other things. - fn on_connecting(&self, connecting: Connecting) -> BoxFuture> { - Box::pin(async move { + fn on_connecting( + &self, + connecting: Connecting, + ) -> impl Future> + Send { + async move { let conn = connecting.await?; Ok(conn) - }) + } } /// Handle an incoming connection. /// - /// This runs on a freshly spawned tokio task so the returned future can be long-running. + /// Can be implemented as `async fn accept(&self, connection: Connection) -> Result`. + /// + /// The returned future runs on a freshly spawned tokio task so it can be long-running. /// /// When [`Router::shutdown`] is called, no further connections will be accepted, and /// the futures returned by [`Self::accept`] will be aborted after the future returned /// from [`ProtocolHandler::shutdown`] completes. - fn accept(&self, connection: Connection) -> BoxFuture>; + fn accept(&self, connection: Connection) -> impl Future> + Send; /// Called when the router shuts down. /// + /// Can be implemented as `async fn shutdown(&self)`. + /// /// This is called from [`Router::shutdown`]. The returned future is awaited before /// the router closes the endpoint. - fn shutdown(&self) -> BoxFuture<()> { - Box::pin(async move {}) + fn shutdown(&self) -> impl Future + Send { + async move {} } } impl ProtocolHandler for Arc { - fn on_connecting(&self, conn: Connecting) -> BoxFuture> { - self.as_ref().on_connecting(conn) + async fn on_connecting(&self, conn: Connecting) -> Result { + self.as_ref().on_connecting(conn).await } - fn accept(&self, conn: Connection) -> BoxFuture> { - self.as_ref().accept(conn) + async fn accept(&self, conn: Connection) -> Result<()> { + self.as_ref().accept(conn).await } - fn shutdown(&self) -> BoxFuture<()> { - self.as_ref().shutdown() + async fn shutdown(&self) { + self.as_ref().shutdown().await } } impl ProtocolHandler for Box { - fn on_connecting(&self, conn: Connecting) -> BoxFuture> { - self.as_ref().on_connecting(conn) + async fn on_connecting(&self, conn: Connecting) -> Result { + self.as_ref().on_connecting(conn).await + } + + async fn accept(&self, conn: Connection) -> Result<()> { + self.as_ref().accept(conn).await + } + + async fn shutdown(&self) { + self.as_ref().shutdown().await + } +} + +/// A dyn-compatible version of [`ProtocolHandler`] that returns boxed futures. +/// +/// We are not using [`n0_future::boxed::BoxFuture] because we don't need a `'static` bound +/// on these futures. +pub(crate) trait DynProtocolHandler: Send + Sync + std::fmt::Debug + 'static { + /// See [`ProtocolHandler::on_connecting`]. + fn on_connecting( + &self, + connecting: Connecting, + ) -> Pin> + Send + '_>> { + Box::pin(async move { + let conn = connecting.await?; + Ok(conn) + }) } - fn accept(&self, conn: Connection) -> BoxFuture> { - self.as_ref().accept(conn) + /// See [`ProtocolHandler::accept`]. + fn accept( + &self, + connection: Connection, + ) -> Pin> + Send + '_>>; + + /// See [`ProtocolHandler::shutdown`]. + fn shutdown(&self) -> Pin + Send + '_>> { + Box::pin(async move {}) + } +} + +impl DynProtocolHandler for P { + fn accept( + &self, + connection: Connection, + ) -> Pin> + Send + '_>> { + Box::pin(::accept(self, connection)) } - fn shutdown(&self) -> BoxFuture<()> { - self.as_ref().shutdown() + fn on_connecting( + &self, + connecting: Connecting, + ) -> Pin> + Send + '_>> { + Box::pin(::on_connecting(self, connecting)) + } + + fn shutdown(&self) -> Pin + Send + '_>> { + Box::pin(::shutdown(self)) } } /// A typed map of protocol handlers, mapping them from ALPNs. #[derive(Debug, Default)] -pub(crate) struct ProtocolMap(BTreeMap, Box>); +pub(crate) struct ProtocolMap(BTreeMap, Box>); impl ProtocolMap { /// Returns the registered protocol handler for an ALPN as a [`Arc`]. - pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn ProtocolHandler> { + pub(crate) fn get(&self, alpn: &[u8]) -> Option<&dyn DynProtocolHandler> { self.0.get(alpn).map(|p| &**p) } /// Inserts a protocol handler. - pub(crate) fn insert(&mut self, alpn: Vec, handler: Box) { + pub(crate) fn insert(&mut self, alpn: Vec, handler: impl ProtocolHandler) { + let handler = Box::new(handler); self.0.insert(alpn, handler); } @@ -248,8 +304,7 @@ impl RouterBuilder { /// Configures the router to accept the [`ProtocolHandler`] when receiving a connection /// with this `alpn`. - pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: T) -> Self { - let handler = Box::new(handler); + pub fn accept(mut self, alpn: impl AsRef<[u8]>, handler: impl ProtocolHandler) -> Self { self.protocols.insert(alpn.as_ref().to_vec(), handler); self } @@ -416,25 +471,22 @@ impl AccessLimit

{ } impl ProtocolHandler for AccessLimit

{ - fn on_connecting(&self, conn: Connecting) -> BoxFuture> { + fn on_connecting(&self, conn: Connecting) -> impl Future> + Send { self.proto.on_connecting(conn) } - fn accept(&self, conn: Connection) -> BoxFuture> { - let this = self.clone(); - Box::pin(async move { - let remote = conn.remote_node_id()?; - let is_allowed = (this.limiter)(remote); - if !is_allowed { - conn.close(0u32.into(), b"not allowed"); - anyhow::bail!("not allowed"); - } - this.proto.accept(conn).await?; - Ok(()) - }) + async fn accept(&self, conn: Connection) -> Result<()> { + let remote = conn.remote_node_id()?; + let is_allowed = (self.limiter)(remote); + if !is_allowed { + conn.close(0u32.into(), b"not allowed"); + anyhow::bail!("not allowed"); + } + self.proto.accept(conn).await?; + Ok(()) } - fn shutdown(&self) -> BoxFuture<()> { + fn shutdown(&self) -> impl Future + Send { self.proto.shutdown() } } @@ -472,19 +524,17 @@ mod tests { const ECHO_ALPN: &[u8] = b"/iroh/echo/1"; impl ProtocolHandler for Echo { - fn accept(&self, connection: Connection) -> BoxFuture> { + async fn accept(&self, connection: Connection) -> Result<()> { println!("accepting echo"); - Box::pin(async move { - let (mut send, mut recv) = connection.accept_bi().await?; + let (mut send, mut recv) = connection.accept_bi().await?; - // Echo any bytes received back directly. - let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; + // Echo any bytes received back directly. + let _bytes_sent = tokio::io::copy(&mut recv, &mut send).await?; - send.finish()?; - connection.closed().await; + send.finish()?; + connection.closed().await; - Ok(()) - }) + Ok(()) } } #[tokio::test] @@ -521,23 +571,17 @@ mod tests { const TEST_ALPN: &[u8] = b"/iroh/test/1"; impl ProtocolHandler for TestProtocol { - fn accept(&self, connection: Connection) -> BoxFuture> { - let this = self.clone(); - Box::pin(async move { - this.connections.lock().expect("poisoned").push(connection); - Ok(()) - }) + async fn accept(&self, connection: Connection) -> Result<()> { + self.connections.lock().expect("poisoned").push(connection); + Ok(()) } - fn shutdown(&self) -> BoxFuture<()> { - let this = self.clone(); - Box::pin(async move { - tokio::time::sleep(Duration::from_millis(100)).await; - let mut connections = this.connections.lock().expect("poisoned"); - for conn in connections.drain(..) { - conn.close(42u32.into(), b"shutdown"); - } - }) + async fn shutdown(&self) { + tokio::time::sleep(Duration::from_millis(100)).await; + let mut connections = self.connections.lock().expect("poisoned"); + for conn in connections.drain(..) { + conn.close(42u32.into(), b"shutdown"); + } } }