diff --git a/protocols/mdns/src/behaviour.rs b/protocols/mdns/src/behaviour.rs index 68e28cf3d63..35e5474f105 100644 --- a/protocols/mdns/src/behaviour.rs +++ b/protocols/mdns/src/behaviour.rs @@ -31,17 +31,17 @@ use std::{ convert::Infallible, fmt, future::Future, - io, + io, mem, net::IpAddr, pin::Pin, - sync::{Arc, RwLock}, - task::{Context, Poll}, + task::{Context, Poll, Waker}, time::Instant, }; use futures::{channel::mpsc, Stream, StreamExt}; use if_watch::IfEvent; -use libp2p_core::{transport::PortUse, Endpoint, Multiaddr}; +use iface::ListenAddressUpdate; +use libp2p_core::{multiaddr::Protocol, transport::PortUse, Endpoint, Multiaddr}; use libp2p_identity::PeerId; use libp2p_swarm::{ behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour, @@ -64,18 +64,11 @@ pub trait Provider: 'static { /// The IfWatcher type. type Watcher: Stream> + fmt::Debug + Unpin; - type TaskHandle: Abort; - /// Create a new instance of the `IfWatcher` type. fn new_watcher() -> Result; #[track_caller] - fn spawn(task: impl Future + Send + 'static) -> Self::TaskHandle; -} - -#[allow(unreachable_pub)] // Not re-exported. -pub trait Abort { - fn abort(self); + fn spawn(task: impl Future + Send + 'static); } /// The type of a [`Behaviour`] using the `async-io` implementation. @@ -83,11 +76,10 @@ pub trait Abort { pub mod async_io { use std::future::Future; - use async_std::task::JoinHandle; use if_watch::smol::IfWatcher; use super::Provider; - use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort}; + use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer}; #[doc(hidden)] pub enum AsyncIo {} @@ -96,20 +88,13 @@ pub mod async_io { type Socket = AsyncUdpSocket; type Timer = AsyncTimer; type Watcher = IfWatcher; - type TaskHandle = JoinHandle<()>; fn new_watcher() -> Result { IfWatcher::new() } - fn spawn(task: impl Future + Send + 'static) -> JoinHandle<()> { - async_std::task::spawn(task) - } - } - - impl Abort for JoinHandle<()> { - fn abort(self) { - async_std::task::spawn(self.cancel()); + fn spawn(task: impl Future + Send + 'static) { + async_std::task::spawn(task); } } @@ -122,10 +107,9 @@ pub mod tokio { use std::future::Future; use if_watch::tokio::IfWatcher; - use tokio::task::JoinHandle; use super::Provider; - use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort}; + use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer}; #[doc(hidden)] pub enum Tokio {} @@ -134,20 +118,13 @@ pub mod tokio { type Socket = TokioUdpSocket; type Timer = TokioTimer; type Watcher = IfWatcher; - type TaskHandle = JoinHandle<()>; fn new_watcher() -> Result { IfWatcher::new() } - fn spawn(task: impl Future + Send + 'static) -> Self::TaskHandle { - tokio::spawn(task) - } - } - - impl Abort for JoinHandle<()> { - fn abort(self) { - JoinHandle::abort(&self) + fn spawn(task: impl Future + Send + 'static) { + tokio::spawn(task); } } @@ -167,8 +144,8 @@ where /// Iface watcher. if_watch: P::Watcher, - /// Handles to tasks running the mDNS queries. - if_tasks: HashMap, + /// Channel for sending address updates to interface tasks. + if_tasks: HashMap>, query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>, query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, @@ -185,16 +162,17 @@ where closest_expiration: Option, /// The current set of listen addresses. - /// - /// This is shared across all interface tasks using an [`RwLock`]. - /// The [`Behaviour`] updates this upon new [`FromSwarm`] - /// events where as [`InterfaceState`]s read from it to answer inbound mDNS queries. - listen_addresses: Arc>, + listen_addresses: ListenAddresses, local_peer_id: PeerId, /// Pending behaviour events to be emitted. pending_events: VecDeque>, + + /// Pending address updates to send to interfaces. + pending_address_updates: Vec, + + waker: Waker, } impl

Behaviour

@@ -216,6 +194,8 @@ where listen_addresses: Default::default(), local_peer_id, pending_events: Default::default(), + pending_address_updates: Default::default(), + waker: Waker::noop().clone(), }) } @@ -241,6 +221,30 @@ where } self.closest_expiration = Some(P::Timer::at(now)); } + + /// Try to send an address update to the interface task that matches the address' IP. + /// + /// Returns the address if the sending failed due to a full channel. + fn try_send_address_update( + &mut self, + cx: &mut Context<'_>, + update: ListenAddressUpdate, + ) -> Option { + let ip = update.ip_addr()?; + let tx = self.if_tasks.get_mut(&ip)?; + match tx.poll_ready(cx) { + Poll::Ready(Ok(())) => { + tx.start_send(update).expect("Channel is ready."); + None + } + Poll::Ready(Err(e)) if e.is_disconnected() => { + tracing::error!("`InterfaceState` for ip {ip} dropped"); + self.if_tasks.remove(&ip); + None + } + _ => Some(update), + } + } } impl

NetworkBehaviour for Behaviour

@@ -301,10 +305,14 @@ where } fn on_swarm_event(&mut self, event: FromSwarm) { - self.listen_addresses - .write() - .unwrap_or_else(|e| e.into_inner()) - .on_swarm_event(&event); + if !self.listen_addresses.on_swarm_event(&event) { + return; + } + if let Some(update) = ListenAddressUpdate::from_swarm(event).and_then(|update| { + self.try_send_address_update(&mut Context::from_waker(&self.waker.clone()), update) + }) { + self.pending_address_updates.push(update); + } } #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))] @@ -313,6 +321,13 @@ where cx: &mut Context<'_>, ) -> Poll>> { loop { + // Send address updates to interface tasks. + for update in mem::take(&mut self.pending_address_updates) { + if let Some(update) = self.try_send_address_update(cx, update) { + self.pending_address_updates.push(update); + } + } + // Check for pending events and emit them. if let Some(event) = self.pending_events.pop_front() { return Poll::Ready(event); @@ -322,25 +337,34 @@ where while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) { match event { Ok(IfEvent::Up(inet)) => { - let addr = inet.addr(); - if addr.is_loopback() { + let ip_addr = inet.addr(); + if ip_addr.is_loopback() { continue; } - if addr.is_ipv4() && self.config.enable_ipv6 - || addr.is_ipv6() && !self.config.enable_ipv6 + if ip_addr.is_ipv4() && self.config.enable_ipv6 + || ip_addr.is_ipv6() && !self.config.enable_ipv6 { continue; } - if let Entry::Vacant(e) = self.if_tasks.entry(addr) { + if let Entry::Vacant(e) = self.if_tasks.entry(ip_addr) { + let (addr_tx, addr_rx) = mpsc::channel(10); // Chosen arbitrarily. + let listen_addresses = self + .listen_addresses + .iter() + .filter(|multiaddr| multiaddr_matches_ip(multiaddr, &ip_addr)) + .cloned() + .collect(); match InterfaceState::::new( - addr, + ip_addr, self.config.clone(), self.local_peer_id, - self.listen_addresses.clone(), + listen_addresses, + addr_rx, self.query_response_sender.clone(), ) { Ok(iface_state) => { - e.insert(P::spawn(iface_state)); + P::spawn(iface_state); + e.insert(addr_tx); } Err(err) => { tracing::error!("failed to create `InterfaceState`: {}", err) @@ -349,10 +373,8 @@ where } } Ok(IfEvent::Down(inet)) => { - if let Some(handle) = self.if_tasks.remove(&inet.addr()) { + if self.if_tasks.remove(&inet.addr()).is_some() { tracing::info!(instance=%inet.addr(), "dropping instance"); - - handle.abort(); } } Err(err) => tracing::error!("if watch returned an error: {}", err), @@ -417,11 +439,20 @@ where self.closest_expiration = Some(timer); } + self.waker = cx.waker().clone(); return Poll::Pending; } } } +fn multiaddr_matches_ip(addr: &Multiaddr, ip: &IpAddr) -> bool { + match addr.iter().next() { + Some(Protocol::Ip4(ipv4)) => &IpAddr::V4(ipv4) == ip, + Some(Protocol::Ip6(ipv6)) => &IpAddr::V6(ipv6) == ip, + _ => false, + } +} + /// Event that can be produced by the `Mdns` behaviour. #[derive(Debug, Clone)] pub enum Event { diff --git a/protocols/mdns/src/behaviour/iface.rs b/protocols/mdns/src/behaviour/iface.rs index 873bb8a307b..15ee016fcf3 100644 --- a/protocols/mdns/src/behaviour/iface.rs +++ b/protocols/mdns/src/behaviour/iface.rs @@ -27,15 +27,14 @@ use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, pin::Pin, - sync::{Arc, RwLock}, task::{Context, Poll}, time::{Duration, Instant}, }; use futures::{channel::mpsc, SinkExt, StreamExt}; -use libp2p_core::Multiaddr; +use libp2p_core::{multiaddr::Protocol, Multiaddr}; use libp2p_identity::PeerId; -use libp2p_swarm::ListenAddresses; +use libp2p_swarm::{ExpiredListenAddr, FromSwarm, NewListenAddr}; use socket2::{Domain, Socket, Type}; use self::{ @@ -71,6 +70,38 @@ impl ProbeState { } } +/// Event to inform the [`InterfaceState`] of a change in listening addresses. +#[derive(Debug, Clone)] +pub(crate) enum ListenAddressUpdate { + New(Multiaddr), + Expired(Multiaddr), +} + +impl ListenAddressUpdate { + pub(crate) fn from_swarm(event: FromSwarm) -> Option { + match event { + FromSwarm::NewListenAddr(NewListenAddr { addr, .. }) => { + Some(ListenAddressUpdate::New(addr.clone())) + } + FromSwarm::ExpiredListenAddr(ExpiredListenAddr { addr, .. }) => { + Some(ListenAddressUpdate::Expired(addr.clone())) + } + _ => None, + } + } + + pub(crate) fn ip_addr(&self) -> Option { + let addr = match self { + ListenAddressUpdate::New(a) | ListenAddressUpdate::Expired(a) => a, + }; + match addr.iter().next()? { + Protocol::Ip4(a) => Some(IpAddr::V4(a)), + Protocol::Ip6(a) => Some(IpAddr::V6(a)), + _ => None, + } + } +} + /// An mDNS instance for a networking interface. To discover all peers when having multiple /// interfaces an [`InterfaceState`] is required for each interface. #[derive(Debug)] @@ -81,8 +112,10 @@ pub(crate) struct InterfaceState { recv_socket: U, /// Send socket. send_socket: U, - - listen_addresses: Arc>, + /// Current listening addresses. + listen_addresses: Vec, + /// Receiver for listening-address updates from the swarm. + listen_addresses_rx: mpsc::Receiver, query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, @@ -119,7 +152,8 @@ where addr: IpAddr, config: Config, local_peer_id: PeerId, - listen_addresses: Arc>, + listen_addresses: Vec, + listen_addresses_rx: mpsc::Receiver, query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>, ) -> io::Result { tracing::info!(address=%addr, "creating instance on iface address"); @@ -175,6 +209,7 @@ where recv_socket, send_socket, listen_addresses, + listen_addresses_rx, query_response_sender, recv_buffer: [0; 4096], send_buffer: Default::default(), @@ -210,7 +245,21 @@ where let this = self.get_mut(); loop { - // 1st priority: Low latency: Create packet ASAP after timeout. + // 1st priority: Poll for a change in listen addresses. + match this.listen_addresses_rx.poll_next_unpin(cx) { + Poll::Ready(Some(ListenAddressUpdate::New(addr))) => { + this.listen_addresses.push(addr); + continue; + } + Poll::Ready(Some(ListenAddressUpdate::Expired(addr))) => { + this.listen_addresses.retain(|a| a != &addr); + continue; + } + Poll::Ready(None) => return Poll::Ready(()), + Poll::Pending => {} + } + + // 2nd priority: Low latency: Create packet ASAP after timeout. if this.timeout.poll_next_unpin(cx).is_ready() { tracing::trace!(address=%this.addr, "sending query on iface"); this.send_buffer.push_back(build_query()); @@ -229,7 +278,7 @@ where this.reset_timer(); } - // 2nd priority: Keep local buffers small: Send packets to remote. + // 3d priority: Keep local buffers small: Send packets to remote. if let Some(packet) = this.send_buffer.pop_front() { match this.send_socket.poll_write(cx, &packet, this.mdns_socket()) { Poll::Ready(Ok(_)) => { @@ -246,7 +295,7 @@ where } } - // 3rd priority: Keep local buffers small: Return discovered addresses. + // 4th priority: Keep local buffers small: Return discovered addresses. if this.query_response_sender.poll_ready_unpin(cx).is_ready() { if let Some(discovered) = this.discovered.pop_front() { match this.query_response_sender.try_send(discovered) { @@ -263,7 +312,7 @@ where } } - // 4th priority: Remote work: Answer incoming requests. + // 5th priority: Remote work: Answer incoming requests. match this .recv_socket .poll_read(cx, &mut this.recv_buffer) @@ -279,10 +328,7 @@ where this.send_buffer.extend(build_query_response( query.query_id(), this.local_peer_id, - this.listen_addresses - .read() - .unwrap_or_else(|e| e.into_inner()) - .iter(), + this.listen_addresses.iter(), this.ttl, )); continue; diff --git a/protocols/mdns/tests/use-async-std.rs b/protocols/mdns/tests/use-async-std.rs index df08b39af07..4eaf3dac7e2 100644 --- a/protocols/mdns/tests/use-async-std.rs +++ b/protocols/mdns/tests/use-async-std.rs @@ -159,18 +159,18 @@ async fn create_swarm(config: Config) -> Swarm { Swarm::new_ephemeral(|key| Behaviour::new(config, key.public().to_peer_id()).unwrap()); // Manually listen on all interfaces because mDNS only works for non-loopback addresses. - let expected_listener_id = swarm + let expected_listener_id_ip4 = swarm .listen_on("/ip4/0.0.0.0/tcp/0".parse().unwrap()) .unwrap(); + let expected_listener_id_ip6 = swarm.listen_on("/ip6/::/tcp/0".parse().unwrap()).unwrap(); - swarm - .wait(|e| match e { - SwarmEvent::NewListenAddr { listener_id, .. } => { - (listener_id == expected_listener_id).then_some(()) - } - _ => None, - }) - .await; + let mut listen_both = false; + while !listen_both { + if let SwarmEvent::NewListenAddr { listener_id, .. } = swarm.next_swarm_event().await { + listen_both |= listener_id == expected_listener_id_ip4; + listen_both |= listener_id == expected_listener_id_ip6; + } + } swarm } diff --git a/protocols/mdns/tests/use-tokio.rs b/protocols/mdns/tests/use-tokio.rs index 0ec90a52b90..2bc6071395d 100644 --- a/protocols/mdns/tests/use-tokio.rs +++ b/protocols/mdns/tests/use-tokio.rs @@ -114,18 +114,19 @@ async fn create_swarm(config: Config) -> Swarm { Swarm::new_ephemeral(|key| Behaviour::new(config, key.public().to_peer_id()).unwrap()); // Manually listen on all interfaces because mDNS only works for non-loopback addresses. - let expected_listener_id = swarm + let expected_listener_id_ip4 = swarm .listen_on("/ip4/0.0.0.0/tcp/0".parse().unwrap()) .unwrap(); + let expected_listener_id_ip6 = swarm.listen_on("/ip6/::/tcp/0".parse().unwrap()).unwrap(); - swarm - .wait(|e| match e { - SwarmEvent::NewListenAddr { listener_id, .. } => { - (listener_id == expected_listener_id).then_some(()) - } - _ => None, - }) - .await; + let mut listen_both = false; + + while !listen_both { + if let SwarmEvent::NewListenAddr { listener_id, .. } = swarm.next_swarm_event().await { + listen_both |= listener_id == expected_listener_id_ip4; + listen_both |= listener_id == expected_listener_id_ip6; + } + } swarm }