diff --git a/protocols/stream/src/behaviour.rs b/protocols/stream/src/behaviour.rs index e72af8fbfce..977e5a6fd3f 100644 --- a/protocols/stream/src/behaviour.rs +++ b/protocols/stream/src/behaviour.rs @@ -32,7 +32,7 @@ impl Default for Behaviour { impl Behaviour { pub fn new() -> Self { - let (dial_sender, dial_receiver) = mpsc::channel(0); + let (dial_sender, dial_receiver) = mpsc::channel(32); Self { shared: Arc::new(Mutex::new(Shared::new(dial_sender))), diff --git a/protocols/stream/src/control.rs b/protocols/stream/src/control.rs index 2149c6bca48..1c0ad4ecfd9 100644 --- a/protocols/stream/src/control.rs +++ b/protocols/stream/src/control.rs @@ -8,7 +8,7 @@ use std::{ use futures::{ channel::{mpsc, oneshot}, - SinkExt as _, StreamExt as _, + StreamExt as _, }; use libp2p_identity::PeerId; use libp2p_swarm::{Stream, StreamProtocol}; @@ -48,14 +48,15 @@ impl Control { ) -> Result { tracing::debug!(%peer, "Requesting new stream"); - let mut new_stream_sender = Shared::lock(&self.shared).sender(peer); - let (sender, receiver) = oneshot::channel(); - new_stream_sender - .send(NewStream { protocol, sender }) - .await - .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e))?; + Shared::send_new_stream( + &self.shared, + peer, + NewStream { protocol, sender }, + ) + .await + .map_err(OpenStreamError::Io)?; let stream = receiver .await diff --git a/protocols/stream/src/shared.rs b/protocols/stream/src/shared.rs index bee04b39fb1..2fb214dafb1 100644 --- a/protocols/stream/src/shared.rs +++ b/protocols/stream/src/shared.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, Mutex, MutexGuard}, }; -use futures::channel::mpsc; +use futures::{channel::mpsc, SinkExt as _}; use libp2p_identity::PeerId; use libp2p_swarm::{ConnectionId, Stream, StreamProtocol}; use rand::seq::IteratorRandom as _; @@ -122,7 +122,7 @@ impl Shared { } } - pub(crate) fn sender(&mut self, peer: PeerId) -> mpsc::Sender { + fn prepare_sender(&mut self, peer: PeerId) -> SenderAction { let maybe_sender = self .connections .iter() @@ -134,7 +134,9 @@ impl Shared { Some(sender) => { tracing::debug!("Returning sender to existing connection"); - sender.clone() + SenderAction::Connected { + sender: sender.clone(), + } } None => { tracing::debug!(%peer, "Not connected to peer, initiating dial"); @@ -144,9 +146,44 @@ impl Shared { .entry(peer) .or_insert_with(|| mpsc::channel(0)); - let _ = self.dial_sender.try_send(peer); + SenderAction::Dial { + pending_sender: sender.clone(), + dial_sender: self.dial_sender.clone(), + peer_to_dial: peer, + } + } + } + } + + pub(crate) async fn send_new_stream( + shared: &Arc>, + peer: PeerId, + new_stream: NewStream, + ) -> io::Result<()> { + let action = { + let mut shared = Shared::lock(shared); + shared.prepare_sender(peer) + }; - sender.clone() + match action { + SenderAction::Connected { mut sender } => sender + .send(new_stream) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e)), + SenderAction::Dial { + mut pending_sender, + mut dial_sender, + peer_to_dial, + } => { + dial_sender + .send(peer_to_dial) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e.clone()))?; + + pending_sender + .send(new_stream) + .await + .map_err(|e| io::Error::new(io::ErrorKind::ConnectionReset, e)) } } } @@ -171,3 +208,14 @@ impl Shared { receiver } } + +enum SenderAction { + Connected { + sender: mpsc::Sender, + }, + Dial { + pending_sender: mpsc::Sender, + dial_sender: mpsc::Sender, + peer_to_dial: PeerId, + }, +} diff --git a/protocols/stream/tests/lib.rs b/protocols/stream/tests/lib.rs index a00a00a8f62..374372a68bf 100644 --- a/protocols/stream/tests/lib.rs +++ b/protocols/stream/tests/lib.rs @@ -78,3 +78,47 @@ async fn dial_errors_are_propagated() { assert_eq!(e.kind(), io::ErrorKind::NotConnected); assert_eq!("Dial error: no addresses for peer.", e.to_string()); } + +#[tokio::test] +async fn backpressure_on_many_concurrent_dials() { + let _ = tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env() + .unwrap(), + ) + .with_test_writer() + .try_init(); + + let swarm1 = Swarm::new_ephemeral_tokio(|_| stream::Behaviour::new()); + let control = swarm1.behaviour().new_control(); + + tokio::spawn(swarm1.loop_on_next()); + + // Spawn many concurrent dial attempts that will all fail + // Before the fix: some would silently drop and hang forever + // After the fix: all should fail with proper errors (backpressure propagated) + let mut handles = vec![]; + + for _ in 0..50 { + let mut control_clone = control.clone(); + let handle = tokio::spawn(async move { + let result = control_clone.open_stream(PeerId::random(), PROTOCOL).await; + // All should fail, none should hang + assert!(result.is_err()); + }); + handles.push(handle); + } + + // All tasks should complete (not hang indefinitely) + for handle in handles { + tokio::time::timeout( + std::time::Duration::from_secs(5), + handle, + ) + .await + .expect("Task should not hang - backpressure should work") + .expect("Task should complete successfully"); + } +}