diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b8b3fdbbf1..0132b6de18 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -118,7 +118,7 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: mozilla-actions/sccache-action@v0.0.4 + - uses: mozilla-actions/sccache-action@v0.0.9 - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} @@ -180,7 +180,7 @@ jobs: SCCACHE_GHA_ENABLED: "on" steps: - uses: actions/checkout@v4 - - uses: mozilla-actions/sccache-action@v0.0.4 + - uses: mozilla-actions/sccache-action@v0.0.9 - uses: dtolnay/rust-toolchain@1.71.0 - run: cargo check --lib --all-features -p iroh-quinn-udp -p iroh-quinn-proto -p iroh-quinn @@ -191,7 +191,7 @@ jobs: SCCACHE_GHA_ENABLED: "on" steps: - uses: actions/checkout@v4 - - uses: mozilla-actions/sccache-action@v0.0.4 + - uses: mozilla-actions/sccache-action@v0.0.9 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 828c34dad7..ee7bf02ab0 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -19,7 +19,7 @@ use tracing::{debug_span, Instrument, Span}; use crate::{ mutex::Mutex, recv_stream::RecvStream, - runtime::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller}, + runtime::{AsyncTimer, Runtime, UdpSender}, send_stream::SendStream, udp_transmit, ConnectionEvent, Duration, Instant, VarInt, }; @@ -42,7 +42,7 @@ impl Connecting { conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, conn_events: mpsc::UnboundedReceiver, - socket: Arc, + sender: Pin>, runtime: Arc, ) -> Self { let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel(); @@ -54,7 +54,7 @@ impl Connecting { conn_events, on_handshake_data_send, on_connected_send, - socket, + sender, runtime.clone(), ); @@ -877,7 +877,7 @@ impl ConnectionRef { conn_events: mpsc::UnboundedReceiver, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, - socket: Arc, + sender: Pin>, runtime: Arc, ) -> Self { Self(Arc::new(ConnectionInner { @@ -897,8 +897,7 @@ impl ConnectionRef { stopped: FxHashMap::default(), error: None, ref_count: 0, - io_poller: socket.clone().create_io_poller(), - socket, + sender, runtime, send_buffer: Vec::new(), buffered_transmit: None, @@ -1017,8 +1016,7 @@ pub(crate) struct State { pub(crate) error: Option, /// Number of live handles that can be used to initiate or handle I/O; excludes the driver ref_count: usize, - socket: Arc, - io_poller: Pin>, + sender: Pin>, runtime: Arc, send_buffer: Vec, /// We buffer a transmit when the underlying I/O would block @@ -1032,7 +1030,7 @@ impl State { let now = self.runtime.now(); let mut transmits = 0; - let max_datagrams = self.socket.max_transmit_segments(); + let max_datagrams = self.sender.max_transmit_segments(); loop { // Retry the last transmit, or get a new one. @@ -1057,28 +1055,18 @@ impl State { } }; - if self.io_poller.as_mut().poll_writable(cx)?.is_pending() { - // Retry after a future wakeup - self.buffered_transmit = Some(t); - return Ok(false); - } - let len = t.size; - let retry = match self - .socket - .try_send(&udp_transmit(&t, &self.send_buffer[..len])) + match self + .sender + .as_mut() + .poll_send(&udp_transmit(&t, &self.send_buffer[..len]), cx) { - Ok(()) => false, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true, - Err(e) => return Err(e), - }; - if retry { - // We thought the socket was writable, but it wasn't. Retry so that either another - // `poll_writable` call determines that the socket is indeed not writable and - // registers us for a wakeup, or the send succeeds if this really was just a - // transient failure. - self.buffered_transmit = Some(t); - continue; + Poll::Pending => { + self.buffered_transmit = Some(t); + return Ok(false); + } + Poll::Ready(Err(e)) => return Err(e), + Poll::Ready(Ok(())) => {} } if transmits >= MAX_TRANSMIT_DATAGRAMS { @@ -1108,9 +1096,8 @@ impl State { ) -> Result<(), ConnectionError> { loop { match self.conn_events.poll_recv(cx) { - Poll::Ready(Some(ConnectionEvent::Rebind(socket))) => { - self.socket = socket; - self.io_poller = self.socket.clone().create_io_poller(); + Poll::Ready(Some(ConnectionEvent::Rebind(sender))) => { + self.sender = sender; self.inner.local_address_changed(); } Poll::Ready(Some(ConnectionEvent::Proto(event))) => { diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index f70be9ad4f..c2de97a1e2 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -15,7 +15,7 @@ use std::{ #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] use crate::runtime::default_runtime; use crate::{ - runtime::{AsyncUdpSocket, Runtime}, + runtime::{AsyncUdpSocket, Runtime, UdpSender}, udp_transmit, Instant, }; use bytes::{Bytes, BytesMut}; @@ -129,7 +129,7 @@ impl Endpoint { pub fn new_with_abstract_socket( config: EndpointConfig, server_config: Option, - socket: Arc, + socket: Box, runtime: Arc, ) -> io::Result { let addr = socket.local_addr()?; @@ -224,12 +224,12 @@ impl Endpoint { .inner .connect(self.runtime.now(), config, addr, server_name)?; - let socket = endpoint.socket.clone(); + let sender = endpoint.socket.create_sender(); endpoint.stats.outgoing_handshakes += 1; Ok(endpoint .recv_state .connections - .insert(ch, conn, socket, self.runtime.clone())) + .insert(ch, conn, sender, self.runtime.clone())) } /// Switch to a new UDP socket @@ -246,7 +246,7 @@ impl Endpoint { /// connections and connections to servers unreachable from the new address will be lost. /// /// On error, the old UDP socket is retained. - pub fn rebind_abstract(&self, socket: Arc) -> io::Result<()> { + pub fn rebind_abstract(&self, socket: Box) -> io::Result<()> { let addr = socket.local_addr()?; let mut inner = self.inner.state.lock().unwrap(); inner.prev_socket = Some(mem::replace(&mut inner.socket, socket)); @@ -255,7 +255,7 @@ impl Endpoint { // Update connection socket references for sender in inner.recv_state.connections.senders.values() { // Ignoring errors from dropped connections - let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone())); + let _ = sender.send(ConnectionEvent::Rebind(inner.socket.create_sender())); } Ok(()) @@ -420,16 +420,16 @@ impl EndpointInner { { Ok((handle, conn)) => { state.stats.accepted_handshakes += 1; - let socket = state.socket.clone(); + let sender = state.socket.create_sender(); let runtime = state.runtime.clone(); Ok(state .recv_state .connections - .insert(handle, conn, socket, runtime)) + .insert(handle, conn, sender, runtime)) } Err(error) => { if let Some(transmit) = error.response { - respond(transmit, &response_buffer, &*state.socket); + respond(transmit, &response_buffer, &mut state.sender); } Err(error.cause) } @@ -441,14 +441,14 @@ impl EndpointInner { state.stats.refused_handshakes += 1; let mut response_buffer = Vec::new(); let transmit = state.inner.refuse(incoming, &mut response_buffer); - respond(transmit, &response_buffer, &*state.socket); + respond(transmit, &response_buffer, &mut state.sender); } pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> { let mut state = self.state.lock().unwrap(); let mut response_buffer = Vec::new(); let transmit = state.inner.retry(incoming, &mut response_buffer)?; - respond(transmit, &response_buffer, &*state.socket); + respond(transmit, &response_buffer, &mut state.sender); Ok(()) } @@ -461,10 +461,11 @@ impl EndpointInner { #[derive(Debug)] pub(crate) struct State { - socket: Arc, + socket: Box, + sender: Pin>, /// During an active migration, abandoned_socket receives traffic /// until the first packet arrives on the new socket. - prev_socket: Option>, + prev_socket: Option>, inner: proto::Endpoint, recv_state: RecvState, driver: Option, @@ -487,18 +488,28 @@ impl State { fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result { let get_time = || self.runtime.now(); self.recv_state.recv_limiter.start_cycle(get_time); - if let Some(socket) = &self.prev_socket { + if let Some(socket) = &mut self.prev_socket { // We don't care about the `PollProgress` from old sockets. - let poll_res = - self.recv_state - .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now); + let poll_res = self.recv_state.poll_socket( + cx, + &mut self.inner, + &mut *socket, + &mut self.sender, + &*self.runtime, + now, + ); if poll_res.is_err() { self.prev_socket = None; } }; - let poll_res = - self.recv_state - .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now); + let poll_res = self.recv_state.poll_socket( + cx, + &mut self.inner, + &mut self.socket, + &mut self.sender, + &*self.runtime, + now, + ); self.recv_state.recv_limiter.finish_cycle(get_time); let poll_res = poll_res?; if poll_res.received_connection_packet { @@ -550,7 +561,11 @@ impl Drop for State { } } -fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) { +fn respond( + transmit: proto::Transmit, + response_buffer: &[u8], + sender: &mut Pin>, +) { // Send if there's kernel buffer space; otherwise, drop it // // As an endpoint-generated packet, we know this is an @@ -571,7 +586,9 @@ fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn Async // to transmit. This is morally equivalent to the packet getting // lost due to congestion further along the link, which // similarly relies on peer retries for recovery. - _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size])); + _ = sender + .as_mut() + .try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size])); } #[inline] @@ -598,7 +615,7 @@ impl ConnectionSet { &mut self, handle: ConnectionHandle, conn: proto::Connection, - socket: Arc, + sender: Pin>, runtime: Arc, ) -> Connecting { let (send, recv) = mpsc::unbounded_channel(); @@ -610,7 +627,7 @@ impl ConnectionSet { .unwrap(); } self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime) + Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime) } fn is_empty(&self) -> bool { @@ -669,13 +686,14 @@ pub(crate) struct EndpointRef(Arc); impl EndpointRef { pub(crate) fn new( - socket: Arc, + socket: Box, inner: proto::Endpoint, ipv6: bool, runtime: Arc, ) -> Self { let (sender, events) = mpsc::unbounded_channel(); let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner); + let sender = socket.create_sender(); Self(Arc::new(EndpointInner { shared: Shared { incoming: Notify::new(), @@ -683,6 +701,7 @@ impl EndpointRef { }, state: Mutex::new(State { socket, + sender, prev_socket: None, inner, ipv6, @@ -764,7 +783,8 @@ impl RecvState { &mut self, cx: &mut Context, endpoint: &mut proto::Endpoint, - socket: &dyn AsyncUdpSocket, + socket: &mut Box, + sender: &mut Pin>, runtime: &dyn Runtime, now: Instant, ) -> Result { @@ -804,7 +824,7 @@ impl RecvState { } else { let transmit = endpoint.refuse(incoming, &mut response_buffer); - respond(transmit, &response_buffer, socket); + respond(transmit, &response_buffer, sender); } } Some(DatagramEvent::ConnectionEvent(handle, event)) => { @@ -818,7 +838,7 @@ impl RecvState { .send(ConnectionEvent::Proto(event)); } Some(DatagramEvent::Response(transmit)) => { - respond(transmit, &response_buffer, socket); + respond(transmit, &response_buffer, sender); } None => {} } diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index d971574e67..b9d1c89181 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -41,7 +41,7 @@ #![warn(unreachable_pub)] #![warn(clippy::use_self)] -use std::sync::Arc; +use std::pin::Pin; macro_rules! ready { ($e:expr $(,)?) => { @@ -92,7 +92,10 @@ pub use crate::runtime::AsyncStdRuntime; pub use crate::runtime::SmolRuntime; #[cfg(feature = "runtime-tokio")] pub use crate::runtime::TokioRuntime; -pub use crate::runtime::{default_runtime, AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller}; +pub use crate::runtime::{ + default_runtime, AsyncTimer, AsyncUdpSocket, Runtime, UdpSender, UdpSenderHelper, + UdpSenderHelperSocket, +}; pub use crate::send_stream::{SendStream, StoppedError, WriteError}; #[cfg(test)] @@ -105,7 +108,7 @@ enum ConnectionEvent { reason: bytes::Bytes, }, Proto(proto::ConnectionEvent), - Rebind(Arc), + Rebind(Pin>), } fn udp_transmit<'a>(t: &proto::Transmit, buffer: &'a [u8]) -> udp::Transmit<'a> { diff --git a/quinn/src/runtime.rs b/quinn/src/runtime.rs index 9471c1834e..71241b351e 100644 --- a/quinn/src/runtime.rs +++ b/quinn/src/runtime.rs @@ -20,7 +20,7 @@ pub trait Runtime: Send + Sync + Debug + 'static { fn spawn(&self, future: Pin + Send>>); /// Convert `t` into the socket type used by this runtime #[cfg(not(wasm_browser))] - fn wrap_udp_socket(&self, t: std::net::UdpSocket) -> io::Result>; + fn wrap_udp_socket(&self, t: std::net::UdpSocket) -> io::Result>; /// Look up the current time /// /// Allows simulating the flow of time for testing. @@ -48,18 +48,11 @@ pub trait AsyncUdpSocket: Send + Sync + Debug + 'static { /// [`Waker`]. /// /// [`Waker`]: std::task::Waker - fn create_io_poller(self: Arc) -> Pin>; - - /// Send UDP datagrams from `transmits`, or return `WouldBlock` and clear the underlying - /// socket's readiness, or return an I/O error - /// - /// If this returns [`io::ErrorKind::WouldBlock`], [`UdpPoller::poll_writable`] must be called - /// to register the calling task to be woken when a send should be attempted again. - fn try_send(&self, transmit: &Transmit) -> io::Result<()>; + fn create_sender(&self) -> Pin>; /// Receive UDP datagrams, or register to be woken if receiving may succeed in the future fn poll_recv( - &self, + &mut self, cx: &mut Context, bufs: &mut [IoSliceMut<'_>], meta: &mut [RecvMeta], @@ -68,11 +61,6 @@ pub trait AsyncUdpSocket: Send + Sync + Debug + 'static { /// Look up the local IP address and port used by this socket fn local_addr(&self) -> io::Result; - /// Maximum number of datagrams that a [`Transmit`] may encode - fn max_transmit_segments(&self) -> usize { - 1 - } - /// Maximum number of datagrams that might be described by a single [`RecvMeta`] fn max_receive_segments(&self) -> usize { 1 @@ -91,71 +79,119 @@ pub trait AsyncUdpSocket: Send + Sync + Debug + 'static { /// /// Any number of `UdpPoller`s may exist for a single [`AsyncUdpSocket`]. Each `UdpPoller` is /// responsible for notifying at most one task when that socket becomes writable. -pub trait UdpPoller: Send + Sync + Debug + 'static { +pub trait UdpSender: Send + Sync + Debug + 'static { /// Check whether the associated socket is likely to be writable /// /// Must be called after [`AsyncUdpSocket::try_send`] returns [`io::ErrorKind::WouldBlock`] to /// register the task associated with `cx` to be woken when a send should be attempted /// again. Unlike in [`Future::poll`], a [`UdpPoller`] may be reused indefinitely no matter how /// many times `poll_writable` returns [`Poll::Ready`]. - fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll>; + /// + /// // TODO(matheus23): Fix weird documentation merge + /// + /// Send UDP datagrams from `transmits`, or return `WouldBlock` and clear the underlying + /// socket's readiness, or return an I/O error + /// + /// If this returns [`io::ErrorKind::WouldBlock`], [`UdpPoller::poll_writable`] must be called + /// to register the calling task to be woken when a send should be attempted again. + fn poll_send( + self: Pin<&mut Self>, + transmit: &Transmit, + cx: &mut Context, + ) -> Poll>; + + /// Maximum number of datagrams that a [`Transmit`] may encode + fn max_transmit_segments(&self) -> usize { + 1 + } + + /// TODO(matheus23): Docs + /// Last ditch/best effort of sending a transmit. + /// Used by the endpoint for resets / close frames when dropped, etc. + fn try_send(self: Pin<&mut Self>, transmit: &Transmit) -> io::Result<()>; } pin_project_lite::pin_project! { - /// Helper adapting a function `MakeFut` that constructs a single-use future `Fut` into a - /// [`UdpPoller`] that may be reused indefinitely - struct UdpPollHelper { + pub struct UdpSenderHelper { + socket: Socket, make_fut: MakeFut, #[pin] fut: Option, } } -impl UdpPollHelper { - /// Construct a [`UdpPoller`] that calls `make_fut` to get the future to poll, storing it until - /// it yields [`Poll::Ready`], then creating a new one on the next - /// [`poll_writable`](UdpPoller::poll_writable) - #[cfg(any( - feature = "runtime-async-std", - feature = "runtime-smol", - feature = "runtime-tokio" - ))] - fn new(make_fut: MakeFut) -> Self { +impl Debug for UdpSenderHelper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("UdpSender") + } +} + +impl UdpSenderHelper { + pub fn new(inner: Socket, make_fut: MakeFut) -> Self { Self { + socket: inner, make_fut, fut: None, } } } -impl UdpPoller for UdpPollHelper +impl super::UdpSender for UdpSenderHelper where - MakeFut: Fn() -> Fut + Send + Sync + 'static, + Socket: UdpSenderHelperSocket, + MakeFut: Fn(&Socket) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { - fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_send( + self: Pin<&mut Self>, + transmit: &udp::Transmit, + cx: &mut Context, + ) -> Poll> { let mut this = self.project(); - if this.fut.is_none() { - this.fut.set(Some((this.make_fut)())); - } - // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely - // obtain an `&mut Fut` after storing it in `self.fut` when `self` is already behind `Pin`, - // and if we didn't store it then we wouldn't be able to keep it alive between - // `poll_writable` calls. - let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx); - if result.is_ready() { + loop { + if this.fut.is_none() { + this.fut.set(Some((this.make_fut)(&this.socket))); + } + // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely + // obtain an `&mut Fut` after storing it in `self.fut` when `self` is already behind `Pin`, + // and if we didn't store it then we wouldn't be able to keep it alive between + // `poll_writable` calls. + let result = ready!(this.fut.as_mut().as_pin_mut().unwrap().poll(cx)); + // Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for // a new `Future` to be created on the next call. this.fut.set(None); + + // If .writable() fails, propagate the error + result?; + + let result = this.socket.try_send(transmit); + + match result { + // We thought the socket was writable, but it wasn't, then retry so that either another + // `writable().await` call determines that the socket is indeed not writable and + // registers us for a wakeup, or the send succeeds if this really was just a + // transient failure. + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + // In all other cases, either propagate the error or we're Ok + _ => return Poll::Ready(result), + } } - result } -} -impl Debug for UdpPollHelper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("UdpPollHelper").finish_non_exhaustive() + fn max_transmit_segments(&self) -> usize { + self.socket.max_transmit_segments() } + + fn try_send(self: Pin<&mut Self>, transmit: &udp::Transmit) -> io::Result<()> { + self.socket.try_send(transmit) + } +} + +pub trait UdpSenderHelperSocket: Send + Sync + 'static { + fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()>; + + fn max_transmit_segments(&self) -> usize; } /// Automatically select an appropriate runtime from those enabled at compile time diff --git a/quinn/src/runtime/async_io.rs b/quinn/src/runtime/async_io.rs index 34df24d76f..5a7a2f5fe0 100644 --- a/quinn/src/runtime/async_io.rs +++ b/quinn/src/runtime/async_io.rs @@ -9,7 +9,7 @@ use std::{ use async_io::{Async, Timer}; -use super::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPollHelper}; +use super::{AsyncTimer, AsyncUdpSocket, Runtime, UdpSenderHelper}; #[cfg(feature = "smol")] // Due to MSRV, we must specify `self::` where there's crate/module ambiguity @@ -35,8 +35,8 @@ mod smol { fn wrap_udp_socket( &self, sock: std::net::UdpSocket, - ) -> io::Result> { - Ok(Arc::new(UdpSocket::new(sock)?)) + ) -> io::Result> { + Ok(Box::new(UdpSocket::new(sock)?)) } } } @@ -65,8 +65,8 @@ mod async_std { fn wrap_udp_socket( &self, sock: std::net::UdpSocket, - ) -> io::Result> { - Ok(Arc::new(UdpSocket::new(sock)?)) + ) -> io::Result> { + Ok(Box::new(UdpSocket::new(sock)?)) } } } @@ -81,35 +81,41 @@ impl AsyncTimer for Timer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct UdpSocket { - io: Async, - inner: udp::UdpSocketState, + io: Arc>, + inner: Arc, } impl UdpSocket { fn new(sock: std::net::UdpSocket) -> io::Result { Ok(Self { - inner: udp::UdpSocketState::new((&sock).into())?, - io: Async::new_nonblocking(sock)?, + inner: Arc::new(udp::UdpSocketState::new((&sock).into())?), + io: Arc::new(Async::new_nonblocking(sock)?), }) } } -impl AsyncUdpSocket for UdpSocket { - fn create_io_poller(self: Arc) -> Pin> { - Box::pin(UdpPollHelper::new(move || { - let socket = self.clone(); - async move { socket.io.writable().await } - })) +impl super::UdpSenderHelperSocket for UdpSocket { + fn max_transmit_segments(&self) -> usize { + self.inner.max_gso_segments() } fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> { self.inner.send((&self.io).into(), transmit) } +} + +impl AsyncUdpSocket for UdpSocket { + fn create_sender(&self) -> Pin> { + Box::pin(UdpSenderHelper::new(self.clone(), |socket: &UdpSocket| { + let socket = socket.clone(); + async move { socket.io.writable().await } + })) + } fn poll_recv( - &self, + &mut self, cx: &mut Context, bufs: &mut [io::IoSliceMut<'_>], meta: &mut [udp::RecvMeta], @@ -123,17 +129,13 @@ impl AsyncUdpSocket for UdpSocket { } fn local_addr(&self) -> io::Result { - self.io.as_ref().local_addr() + self.io.as_ref().as_ref().local_addr() } fn may_fragment(&self) -> bool { self.inner.may_fragment() } - fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() - } - fn max_receive_segments(&self) -> usize { self.inner.gro_segments() } diff --git a/quinn/src/runtime/tokio.rs b/quinn/src/runtime/tokio.rs index ad321a240c..2d66dfd5b7 100644 --- a/quinn/src/runtime/tokio.rs +++ b/quinn/src/runtime/tokio.rs @@ -1,4 +1,5 @@ use std::{ + fmt::Debug, future::Future, io, pin::Pin, @@ -12,7 +13,7 @@ use tokio::{ time::{sleep_until, Sleep}, }; -use super::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPollHelper}; +use super::{AsyncTimer, AsyncUdpSocket, Runtime, UdpSenderHelper, UdpSenderHelperSocket}; /// A Quinn runtime for Tokio #[derive(Debug)] @@ -27,10 +28,10 @@ impl Runtime for TokioRuntime { tokio::spawn(future); } - fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result> { - Ok(Arc::new(UdpSocket { - inner: udp::UdpSocketState::new((&sock).into())?, - io: tokio::net::UdpSocket::from_std(sock)?, + fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result> { + Ok(Box::new(UdpSocket { + inner: Arc::new(udp::UdpSocketState::new((&sock).into())?), + io: Arc::new(tokio::net::UdpSocket::from_std(sock)?), })) } @@ -48,18 +49,15 @@ impl AsyncTimer for Sleep { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct UdpSocket { - io: tokio::net::UdpSocket, - inner: udp::UdpSocketState, + io: Arc, + inner: Arc, } -impl AsyncUdpSocket for UdpSocket { - fn create_io_poller(self: Arc) -> Pin> { - Box::pin(UdpPollHelper::new(move || { - let socket = self.clone(); - async move { socket.io.writable().await } - })) +impl UdpSenderHelperSocket for UdpSocket { + fn max_transmit_segments(&self) -> usize { + self.inner.max_gso_segments() } fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> { @@ -67,9 +65,18 @@ impl AsyncUdpSocket for UdpSocket { self.inner.send((&self.io).into(), transmit) }) } +} + +impl AsyncUdpSocket for UdpSocket { + fn create_sender(&self) -> Pin> { + Box::pin(UdpSenderHelper::new(self.clone(), |socket: &UdpSocket| { + let socket = socket.clone(); + async move { socket.io.writable().await } + })) + } fn poll_recv( - &self, + &mut self, cx: &mut Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [udp::RecvMeta], @@ -92,10 +99,6 @@ impl AsyncUdpSocket for UdpSocket { self.inner.may_fragment() } - fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() - } - fn max_receive_segments(&self) -> usize { self.inner.gro_segments() }