Skip to content

feat(quinn): Refactor polling & sending to take &mut self #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: iroh-0.11.x
Choose a base branch
from
Draft
6 changes: 3 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:

steps:
- uses: actions/checkout@v4
- uses: mozilla-actions/[email protected].4
- uses: mozilla-actions/[email protected].9
- uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ matrix.rust }}
Expand Down Expand Up @@ -180,7 +180,7 @@ jobs:
SCCACHE_GHA_ENABLED: "on"
steps:
- uses: actions/checkout@v4
- uses: mozilla-actions/[email protected].4
- uses: mozilla-actions/[email protected].9
- uses: dtolnay/[email protected]
- run: cargo check --lib --all-features -p iroh-quinn-udp -p iroh-quinn-proto -p iroh-quinn

Expand All @@ -191,7 +191,7 @@ jobs:
SCCACHE_GHA_ENABLED: "on"
steps:
- uses: actions/checkout@v4
- uses: mozilla-actions/[email protected].4
- uses: mozilla-actions/[email protected].9
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
Expand Down
51 changes: 19 additions & 32 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tracing::{debug_span, Instrument, Span};
use crate::{
mutex::Mutex,
recv_stream::RecvStream,
runtime::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller},
runtime::{AsyncTimer, Runtime, UdpSender},
send_stream::SendStream,
udp_transmit, ConnectionEvent, Duration, Instant, VarInt,
};
Expand All @@ -42,7 +42,7 @@ impl Connecting {
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Self {
let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
Expand All @@ -54,7 +54,7 @@ impl Connecting {
conn_events,
on_handshake_data_send,
on_connected_send,
socket,
sender,
runtime.clone(),
);

Expand Down Expand Up @@ -877,7 +877,7 @@ impl ConnectionRef {
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Self {
Self(Arc::new(ConnectionInner {
Expand All @@ -897,8 +897,7 @@ impl ConnectionRef {
stopped: FxHashMap::default(),
error: None,
ref_count: 0,
io_poller: socket.clone().create_io_poller(),
socket,
sender,
runtime,
send_buffer: Vec::new(),
buffered_transmit: None,
Expand Down Expand Up @@ -1017,8 +1016,7 @@ pub(crate) struct State {
pub(crate) error: Option<ConnectionError>,
/// Number of live handles that can be used to initiate or handle I/O; excludes the driver
ref_count: usize,
socket: Arc<dyn AsyncUdpSocket>,
io_poller: Pin<Box<dyn UdpPoller>>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
send_buffer: Vec<u8>,
/// We buffer a transmit when the underlying I/O would block
Expand All @@ -1032,7 +1030,7 @@ impl State {
let now = self.runtime.now();
let mut transmits = 0;

let max_datagrams = self.socket.max_transmit_segments();
let max_datagrams = self.sender.max_transmit_segments();

loop {
// Retry the last transmit, or get a new one.
Expand All @@ -1057,28 +1055,18 @@ impl State {
}
};

if self.io_poller.as_mut().poll_writable(cx)?.is_pending() {
// Retry after a future wakeup
self.buffered_transmit = Some(t);
return Ok(false);
}

let len = t.size;
let retry = match self
.socket
.try_send(&udp_transmit(&t, &self.send_buffer[..len]))
match self
.sender
.as_mut()
.poll_send(&udp_transmit(&t, &self.send_buffer[..len]), cx)
{
Ok(()) => false,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => true,
Err(e) => return Err(e),
};
if retry {
// We thought the socket was writable, but it wasn't. Retry so that either another
// `poll_writable` call determines that the socket is indeed not writable and
// registers us for a wakeup, or the send succeeds if this really was just a
// transient failure.
self.buffered_transmit = Some(t);
continue;
Poll::Pending => {
self.buffered_transmit = Some(t);
return Ok(false);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Ready(Ok(())) => {}
}

if transmits >= MAX_TRANSMIT_DATAGRAMS {
Expand Down Expand Up @@ -1108,9 +1096,8 @@ impl State {
) -> Result<(), ConnectionError> {
loop {
match self.conn_events.poll_recv(cx) {
Poll::Ready(Some(ConnectionEvent::Rebind(socket))) => {
self.socket = socket;
self.io_poller = self.socket.clone().create_io_poller();
Poll::Ready(Some(ConnectionEvent::Rebind(sender))) => {
self.sender = sender;
self.inner.local_address_changed();
}
Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
Expand Down
76 changes: 48 additions & 28 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{
#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
use crate::runtime::default_runtime;
use crate::{
runtime::{AsyncUdpSocket, Runtime},
runtime::{AsyncUdpSocket, Runtime, UdpSender},
udp_transmit, Instant,
};
use bytes::{Bytes, BytesMut};
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Endpoint {
pub fn new_with_abstract_socket(
config: EndpointConfig,
server_config: Option<ServerConfig>,
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
) -> io::Result<Self> {
let addr = socket.local_addr()?;
Expand Down Expand Up @@ -224,12 +224,12 @@ impl Endpoint {
.inner
.connect(self.runtime.now(), config, addr, server_name)?;

let socket = endpoint.socket.clone();
let sender = endpoint.socket.create_sender();
endpoint.stats.outgoing_handshakes += 1;
Ok(endpoint
.recv_state
.connections
.insert(ch, conn, socket, self.runtime.clone()))
.insert(ch, conn, sender, self.runtime.clone()))
}

/// Switch to a new UDP socket
Expand All @@ -246,7 +246,7 @@ impl Endpoint {
/// connections and connections to servers unreachable from the new address will be lost.
///
/// On error, the old UDP socket is retained.
pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
pub fn rebind_abstract(&self, socket: Box<dyn AsyncUdpSocket>) -> io::Result<()> {
let addr = socket.local_addr()?;
let mut inner = self.inner.state.lock().unwrap();
inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
Expand All @@ -255,7 +255,7 @@ impl Endpoint {
// Update connection socket references
for sender in inner.recv_state.connections.senders.values() {
// Ignoring errors from dropped connections
let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
let _ = sender.send(ConnectionEvent::Rebind(inner.socket.create_sender()));
}

Ok(())
Expand Down Expand Up @@ -420,16 +420,16 @@ impl EndpointInner {
{
Ok((handle, conn)) => {
state.stats.accepted_handshakes += 1;
let socket = state.socket.clone();
let sender = state.socket.create_sender();
let runtime = state.runtime.clone();
Ok(state
.recv_state
.connections
.insert(handle, conn, socket, runtime))
.insert(handle, conn, sender, runtime))
}
Err(error) => {
if let Some(transmit) = error.response {
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
}
Err(error.cause)
}
Expand All @@ -441,14 +441,14 @@ impl EndpointInner {
state.stats.refused_handshakes += 1;
let mut response_buffer = Vec::new();
let transmit = state.inner.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
}

pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
let mut state = self.state.lock().unwrap();
let mut response_buffer = Vec::new();
let transmit = state.inner.retry(incoming, &mut response_buffer)?;
respond(transmit, &response_buffer, &*state.socket);
respond(transmit, &response_buffer, &mut state.sender);
Ok(())
}

Expand All @@ -461,10 +461,11 @@ impl EndpointInner {

#[derive(Debug)]
pub(crate) struct State {
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
/// During an active migration, abandoned_socket receives traffic
/// until the first packet arrives on the new socket.
prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
prev_socket: Option<Box<dyn AsyncUdpSocket>>,
inner: proto::Endpoint,
recv_state: RecvState,
driver: Option<Waker>,
Expand All @@ -487,18 +488,28 @@ impl State {
fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
let get_time = || self.runtime.now();
self.recv_state.recv_limiter.start_cycle(get_time);
if let Some(socket) = &self.prev_socket {
if let Some(socket) = &mut self.prev_socket {
// We don't care about the `PollProgress` from old sockets.
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
let poll_res = self.recv_state.poll_socket(
cx,
&mut self.inner,
&mut *socket,
&mut self.sender,
&*self.runtime,
now,
);
if poll_res.is_err() {
self.prev_socket = None;
}
};
let poll_res =
self.recv_state
.poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
let poll_res = self.recv_state.poll_socket(
cx,
&mut self.inner,
&mut self.socket,
&mut self.sender,
&*self.runtime,
now,
);
self.recv_state.recv_limiter.finish_cycle(get_time);
let poll_res = poll_res?;
if poll_res.received_connection_packet {
Expand Down Expand Up @@ -550,7 +561,11 @@ impl Drop for State {
}
}

fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
fn respond(
transmit: proto::Transmit,
response_buffer: &[u8],
sender: &mut Pin<Box<dyn UdpSender>>,
) {
// Send if there's kernel buffer space; otherwise, drop it
//
// As an endpoint-generated packet, we know this is an
Expand All @@ -571,7 +586,9 @@ fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn Async
// to transmit. This is morally equivalent to the packet getting
// lost due to congestion further along the link, which
// similarly relies on peer retries for recovery.
_ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
_ = sender
.as_mut()
.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
}

#[inline]
Expand All @@ -598,7 +615,7 @@ impl ConnectionSet {
&mut self,
handle: ConnectionHandle,
conn: proto::Connection,
socket: Arc<dyn AsyncUdpSocket>,
sender: Pin<Box<dyn UdpSender>>,
runtime: Arc<dyn Runtime>,
) -> Connecting {
let (send, recv) = mpsc::unbounded_channel();
Expand All @@ -610,7 +627,7 @@ impl ConnectionSet {
.unwrap();
}
self.senders.insert(handle, send);
Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
Connecting::new(handle, conn, self.sender.clone(), recv, sender, runtime)
}

fn is_empty(&self) -> bool {
Expand Down Expand Up @@ -669,20 +686,22 @@ pub(crate) struct EndpointRef(Arc<EndpointInner>);

impl EndpointRef {
pub(crate) fn new(
socket: Arc<dyn AsyncUdpSocket>,
socket: Box<dyn AsyncUdpSocket>,
inner: proto::Endpoint,
ipv6: bool,
runtime: Arc<dyn Runtime>,
) -> Self {
let (sender, events) = mpsc::unbounded_channel();
let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
let sender = socket.create_sender();
Self(Arc::new(EndpointInner {
shared: Shared {
incoming: Notify::new(),
idle: Notify::new(),
},
state: Mutex::new(State {
socket,
sender,
prev_socket: None,
inner,
ipv6,
Expand Down Expand Up @@ -764,7 +783,8 @@ impl RecvState {
&mut self,
cx: &mut Context,
endpoint: &mut proto::Endpoint,
socket: &dyn AsyncUdpSocket,
socket: &mut Box<dyn AsyncUdpSocket>,
sender: &mut Pin<Box<dyn UdpSender>>,
runtime: &dyn Runtime,
now: Instant,
) -> Result<PollProgress, io::Error> {
Expand Down Expand Up @@ -804,7 +824,7 @@ impl RecvState {
} else {
let transmit =
endpoint.refuse(incoming, &mut response_buffer);
respond(transmit, &response_buffer, socket);
respond(transmit, &response_buffer, sender);
}
}
Some(DatagramEvent::ConnectionEvent(handle, event)) => {
Expand All @@ -818,7 +838,7 @@ impl RecvState {
.send(ConnectionEvent::Proto(event));
}
Some(DatagramEvent::Response(transmit)) => {
respond(transmit, &response_buffer, socket);
respond(transmit, &response_buffer, sender);
}
None => {}
}
Expand Down
Loading
Loading