diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 944033102c6..2ff88bc066a 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -689,6 +689,7 @@ mod tests { ) -> Result<(), ()> { Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) {} fn handle_reply_channel_range( &self, _their_node_id: PublicKey, _msg: ReplyChannelRange, ) -> Result<(), LightningError> { diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 659ec65f6cf..1323fab435f 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1578,6 +1578,8 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Handle an incoming `channel_reestablish` message from the given peer. fn handle_channel_reestablish(&self, their_node_id: PublicKey, msg: &ChannelReestablish); @@ -1656,7 +1658,11 @@ pub trait RoutingMessageHandler : MessageSendEventsProvider { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>; + /// Indicates a connection to the peer failed/an existing connection was lost. + fn peer_disconnected(&self, their_node_id: PublicKey); /// Handles the reply of a query we initiated to learn about channels /// for a given range of blocks. We can expect to receive one or more /// replies to a single query. @@ -1707,6 +1713,8 @@ pub trait OnionMessageHandler { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>; /// Indicates a connection to the peer failed/an existing connection was lost. Allows handlers to diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 80b92cec1bd..8df168fee12 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Gets the node feature flags which this handler itself supports. All available handlers are @@ -119,6 +121,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler { Option<(msgs::ChannelAnnouncement, Option, Option)> { None } fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option { None } fn peer_connected(&self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) { } fn handle_reply_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), LightningError> { Ok(()) } fn handle_reply_short_channel_ids_end(&self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) } fn handle_query_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), LightningError> { Ok(()) } @@ -1714,14 +1717,20 @@ impl Self { + Self { + features, + conn_tracker: test_utils::ConnectionTracker::new(), + } + } } impl wire::CustomMessageReader for TestCustomMessageHandler { @@ -2872,10 +2893,13 @@ mod tests { fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() } + fn peer_disconnected(&self, their_node_id: PublicKey) { + self.conn_tracker.peer_disconnected(their_node_id); + } - fn peer_disconnected(&self, _their_node_id: PublicKey) {} - - fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) } + fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { + self.conn_tracker.peer_connected(their_node_id) + } fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } @@ -2898,7 +2922,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::with_id(i.to_string()), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2921,7 +2945,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2941,7 +2965,7 @@ mod tests { chan_handler: test_utils::TestChannelMessageHandler::new(network), logger: test_utils::TestLogger::new(), routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler { features }, + custom_handler: TestCustomMessageHandler::new(features), node_signer: test_utils::TestNodeSigner::new(node_secret), } ); @@ -2965,19 +2989,16 @@ mod tests { peers } - fn establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor) { + fn try_establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor, Result, Result) { + let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + static FD_COUNTER: AtomicUsize = AtomicUsize::new(0); let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16; let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); let mut fd_a = FileDescriptor::new(fd); - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; - - let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap(); - let features_a = peer_a.init_features(id_b); - let features_b = peer_b.init_features(id_a); let mut fd_b = FileDescriptor::new(fd); - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); @@ -2989,11 +3010,30 @@ mod tests { peer_b.process_events(); let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_a.read_event(&mut fd_a, &b_data).unwrap(), false); + let a_refused = peer_a.read_event(&mut fd_a, &b_data); peer_a.process_events(); let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - assert_eq!(peer_b.read_event(&mut fd_b, &a_data).unwrap(), false); + let b_refused = peer_b.read_event(&mut fd_b, &a_data); + + (fd_a, fd_b, a_refused, b_refused) + } + + + fn establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor) { + let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + + let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); + let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap(); + + let features_a = peer_a.init_features(id_b); + let features_b = peer_b.init_features(id_a); + + let (fd_a, fd_b, a_refused, b_refused) = try_establish_connection(peer_a, peer_b); + + assert_eq!(a_refused.unwrap(), false); + assert_eq!(b_refused.unwrap(), false); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().counterparty_node_id, id_b); assert_eq!(peer_a.peer_by_node_id(&id_b).unwrap().socket_address, Some(addr_b)); @@ -3246,6 +3286,50 @@ mod tests { assert_eq!(peers[0].peers.read().unwrap().len(), 0); } + fn do_test_peer_connected_error_disconnects(handler: usize) { + // Test that if a message handler fails a connection in `peer_connected` we reliably + // produce `peer_disconnected` events for all other message handlers (that saw a + // corresponding `peer_connected`). + let cfgs = create_peermgr_cfgs(2); + let peers = create_network(2, &cfgs); + + match handler & !1 { + 0 => { + peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + 2 => { + peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + 4 => { + peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + } + _ => panic!(), + } + let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]); + if handler & 1 == 0 { + assert!(a_refused.is_err()); + assert!(peers[0].list_peers().is_empty()); + } else { + assert!(b_refused.is_err()); + assert!(peers[1].list_peers().is_empty()); + } + // At least one message handler should have seen the connection. + assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) || + peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) || + peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire)); + // And both message handlers doing tracking should see the disconnection + assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + } + + #[test] + fn test_peer_connected_error_disconnects() { + for i in 0..6 { + do_test_peer_connected_error_disconnects(i); + } + } + #[test] fn test_do_attempt_write_data() { // Create 2 peers with custom TestRoutingMessageHandlers and connect them. diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index c552005d9ca..6c1bc096437 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -701,6 +701,8 @@ where Ok(()) } + fn peer_disconnected(&self, _their_node_id: PublicKey) {} + fn handle_reply_channel_range( &self, _their_node_id: PublicKey, _msg: ReplyChannelRange, ) -> Result<(), LightningError> { diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 48bdfe2324a..2e89c51bd51 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -889,10 +889,45 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster { } } +pub struct ConnectionTracker { + pub had_peers: AtomicBool, + pub connected_peers: Mutex>, + pub fail_connections: AtomicBool, +} + +impl ConnectionTracker { + pub fn new() -> Self { + Self { + had_peers: AtomicBool::new(false), + connected_peers: Mutex::new(Vec::new()), + fail_connections: AtomicBool::new(false), + } + } + + pub fn peer_connected(&self, their_node_id: PublicKey) -> Result<(), ()> { + self.had_peers.store(true, Ordering::Release); + let mut connected_peers = self.connected_peers.lock().unwrap(); + assert!(!connected_peers.contains(&their_node_id)); + if self.fail_connections.load(Ordering::Acquire) { + Err(()) + } else { + connected_peers.push(their_node_id); + Ok(()) + } + } + + pub fn peer_disconnected(&self, their_node_id: PublicKey) { + assert!(self.had_peers.load(Ordering::Acquire)); + let mut connected_peers = self.connected_peers.lock().unwrap(); + assert!(connected_peers.contains(&their_node_id)); + connected_peers.retain(|id| *id != their_node_id); + } +} + pub struct TestChannelMessageHandler { pub pending_events: Mutex>, expected_recv_msgs: Mutex>>>, - connected_peers: Mutex>, + pub conn_tracker: ConnectionTracker, chain_hash: ChainHash, } @@ -907,7 +942,7 @@ impl TestChannelMessageHandler { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), expected_recv_msgs: Mutex::new(None), - connected_peers: Mutex::new(new_hash_set()), + conn_tracker: ConnectionTracker::new(), chain_hash, } } @@ -1019,15 +1054,14 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { self.received_msg(wire::Message::ChannelReestablish(msg.clone())); } fn peer_disconnected(&self, their_node_id: PublicKey) { - assert!(self.connected_peers.lock().unwrap().remove(&their_node_id)); + self.conn_tracker.peer_disconnected(their_node_id) } fn peer_connected( &self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool, ) -> Result<(), ()> { - assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone())); // Don't bother with `received_msg` for Init as its auto-generated and we don't want to // bother re-generating the expected Init message in all tests. - Ok(()) + self.conn_tracker.peer_connected(their_node_id) } fn handle_error(&self, _their_node_id: PublicKey, msg: &msgs::ErrorMessage) { self.received_msg(wire::Message::Error(msg.clone())); @@ -1157,6 +1191,7 @@ pub struct TestRoutingMessageHandler { pub pending_events: Mutex>, pub request_full_sync: AtomicBool, pub announcement_available_for_sync: AtomicBool, + pub conn_tracker: ConnectionTracker, } impl TestRoutingMessageHandler { @@ -1168,6 +1203,7 @@ impl TestRoutingMessageHandler { pending_events, request_full_sync: AtomicBool::new(false), announcement_available_for_sync: AtomicBool::new(false), + conn_tracker: ConnectionTracker::new(), } } } @@ -1242,7 +1278,12 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { timestamp_range: u32::max_value(), }, }); - Ok(()) + + self.conn_tracker.peer_connected(their_node_id) + } + + fn peer_disconnected(&self, their_node_id: PublicKey) { + self.conn_tracker.peer_disconnected(their_node_id); } fn handle_reply_channel_range(