diff --git a/lightning-liquidity/src/lsps2/service.rs b/lightning-liquidity/src/lsps2/service.rs index abab51366ff..6d69925fa2d 100644 --- a/lightning-liquidity/src/lsps2/service.rs +++ b/lightning-liquidity/src/lsps2/service.rs @@ -12,6 +12,7 @@ use alloc::string::{String, ToString}; use alloc::vec::Vec; +use core::cmp::Ordering as CmpOrdering; use core::ops::Deref; use core::sync::atomic::{AtomicUsize, Ordering}; @@ -645,13 +646,21 @@ where match self.remove_pending_request(&mut peer_state_lock, &request_id) { Some(LSPS2Request::GetInfo(_)) => { - let response = LSPS2Response::GetInfo(LSPS2GetInfoResponse { - opening_fee_params_menu: opening_fee_params_menu + let mut opening_fee_params_menu: Vec = + opening_fee_params_menu .into_iter() .map(|param| { param.into_opening_fee_params(&self.config.promise_secret) }) - .collect(), + .collect(); + opening_fee_params_menu.sort_by(|a, b| { + match a.min_fee_msat.cmp(&b.min_fee_msat) { + CmpOrdering::Equal => a.proportional.cmp(&b.proportional), + other => other, + } + }); + let response = LSPS2Response::GetInfo(LSPS2GetInfoResponse { + opening_fee_params_menu, }); (Ok(()), Some(response)) }, diff --git a/lightning-liquidity/tests/lsps2_integration_tests.rs b/lightning-liquidity/tests/lsps2_integration_tests.rs index ef88d6220a4..a2721cab1de 100644 --- a/lightning-liquidity/tests/lsps2_integration_tests.rs +++ b/lightning-liquidity/tests/lsps2_integration_tests.rs @@ -746,3 +746,75 @@ fn invalid_token_flow() { panic!("Expected LSPS2ClientEvent::GetInfoFailed event"); } } + +#[test] +fn opening_fee_params_menu_is_sorted_by_spec() { + let (service_node_id, client_node_id, service_node, client_node, _secret) = + setup_test_lsps2("opening_fee_params_menu_is_sorted_by_spec"); + + let client_handler = client_node.liquidity_manager.lsps2_client_handler().unwrap(); + let service_handler = service_node.liquidity_manager.lsps2_service_handler().unwrap(); + + let _ = client_handler.request_opening_params(service_node_id, None); + let get_info_request = get_lsps_message!(client_node, service_node_id); + service_node.liquidity_manager.handle_custom_message(get_info_request, client_node_id).unwrap(); + + let get_info_event = service_node.liquidity_manager.next_event().unwrap(); + let request_id = match get_info_event { + LiquidityEvent::LSPS2Service(LSPS2ServiceEvent::GetInfo { request_id, .. }) => request_id, + _ => panic!("Unexpected event"), + }; + + let raw_params_generator = |min_fee_msat: u64, proportional: u32| LSPS2RawOpeningFeeParams { + min_fee_msat, + proportional, + valid_until: LSPSDateTime::from_str("2035-05-20T08:30:45Z").unwrap(), + min_lifetime: 144, + max_client_to_self_delay: 128, + min_payment_size_msat: 1, + max_payment_size_msat: 100_000_000, + }; + + let raw_params = vec![ + raw_params_generator(200, 20), // Will be sorted to position 2 + raw_params_generator(100, 10), // Will be sorted to position 0 (lowest min_fee, lowest proportional) + raw_params_generator(300, 30), // Will be sorted to position 4 (highest min_fee, highest proportional) + raw_params_generator(100, 20), // Will be sorted to position 1 (same min_fee as 0, higher proportional) + raw_params_generator(200, 30), // Will be sorted to position 3 (higher min_fee than 2, higher proportional) + ]; + + service_handler + .opening_fee_params_generated(&client_node_id, request_id.clone(), raw_params) + .unwrap(); + + let get_info_response = get_lsps_message!(service_node, client_node_id); + client_node + .liquidity_manager + .handle_custom_message(get_info_response, service_node_id) + .unwrap(); + + let event = client_node.liquidity_manager.next_event().unwrap(); + if let LiquidityEvent::LSPS2Client(LSPS2ClientEvent::OpeningParametersReady { + opening_fee_params_menu, + .. + }) = event + { + // The LSP, when ordering the opening_fee_params_menu array, MUST order by the following rules: + // The 0th item MAY have any parameters. + // Each succeeding item MUST, compared to the previous item, obey any one of the following: + // Have a larger min_fee_msat, and equal proportional. + // Have a larger proportional, and equal min_fee_msat. + // Have a larger min_fee_msat, AND larger proportional. + for (cur, next) in + opening_fee_params_menu.iter().zip(opening_fee_params_menu.iter().skip(1)) + { + let valid = (next.min_fee_msat > cur.min_fee_msat + && next.proportional == cur.proportional) + || (next.proportional > cur.proportional && next.min_fee_msat == cur.min_fee_msat) + || (next.min_fee_msat > cur.min_fee_msat && next.proportional > cur.proportional); + assert!(valid, "Params not sorted as per spec"); + } + } else { + panic!("Unexpected event"); + } +} diff --git a/lightning/src/events/mod.rs b/lightning/src/events/mod.rs index c16305bcca0..c9dc26f9e9c 100644 --- a/lightning/src/events/mod.rs +++ b/lightning/src/events/mod.rs @@ -1572,6 +1572,55 @@ pub enum Event { /// onion messages. peer_node_id: PublicKey, }, + /// Indicates that a funding transaction constructed via interactive transaction construction for a + /// channel is ready to be signed by the client. This event will only be triggered + /// if at least one input was contributed by the holder and needs to be signed. + /// + /// The transaction contains all inputs provided by both parties along with the channel's funding + /// output and a change output if applicable. + /// + /// No part of the transaction should be changed before signing as the content of the transaction + /// has already been negotiated with the counterparty. + /// + /// Each signature MUST use the SIGHASH_ALL flag to avoid invalidation of the initial commitment and + /// hence possible loss of funds. + /// + /// After signing, call [`ChannelManager::funding_transaction_signed`] with the (partially) signed + /// funding transaction. + /// + /// Generated in [`ChannelManager`] message handling. + /// + /// [`ChannelManager`]: crate::ln::channelmanager::ChannelManager + /// [`ChannelManager::funding_transaction_signed`]: crate::ln::channelmanager::ChannelManager::funding_transaction_signed + FundingTransactionReadyForSigning { + /// The channel_id of the channel which you'll need to pass back into + /// [`ChannelManager::funding_transaction_signed`]. + /// + /// [`ChannelManager::funding_transaction_signed`]: crate::ln::channelmanager::ChannelManager::funding_transaction_signed + channel_id: ChannelId, + /// The counterparty's node_id, which you'll need to pass back into + /// [`ChannelManager::funding_transaction_signed`]. + /// + /// [`ChannelManager::funding_transaction_signed`]: crate::ln::channelmanager::ChannelManager::funding_transaction_signed + counterparty_node_id: PublicKey, + // TODO(dual_funding): Enable links when methods are implemented + /// The `user_channel_id` value passed in to `ChannelManager::create_dual_funded_channel` for outbound + /// channels, or to [`ChannelManager::accept_inbound_channel`] or `ChannelManager::accept_inbound_channel_with_contribution` + /// for inbound channels if [`UserConfig::manually_accept_inbound_channels`] config flag is set to true. + /// Otherwise `user_channel_id` will be randomized for an inbound channel. + /// This may be zero for objects serialized with LDK versions prior to 0.0.113. + /// + /// [`ChannelManager::accept_inbound_channel`]: crate::ln::channelmanager::ChannelManager::accept_inbound_channel + /// [`UserConfig::manually_accept_inbound_channels`]: crate::util::config::UserConfig::manually_accept_inbound_channels + // [`ChannelManager::create_dual_funded_channel`]: crate::ln::channelmanager::ChannelManager::create_dual_funded_channel + // [`ChannelManager::accept_inbound_channel_with_contribution`]: crate::ln::channelmanager::ChannelManager::accept_inbound_channel_with_contribution + user_channel_id: u128, + /// The unsigned transaction to be signed and passed back to + /// [`ChannelManager::funding_transaction_signed`]. + /// + /// [`ChannelManager::funding_transaction_signed`]: crate::ln::channelmanager::ChannelManager::funding_transaction_signed + unsigned_transaction: Transaction, + }, } impl Writeable for Event { @@ -1996,6 +2045,13 @@ impl Writeable for Event { (8, former_temporary_channel_id, required), }); }, + &Event::FundingTransactionReadyForSigning { .. } => { + 45u8.write(writer)?; + // We never write out FundingTransactionReadyForSigning events as, upon disconnection, peers + // drop any V2-established/spliced channels which have not yet exchanged the initial `commitment_signed`. + // We only exhange the initial `commitment_signed` after the client calls + // `ChannelManager::funding_transaction_signed` and ALWAYS before we send a `tx_signatures` + }, // Note that, going forward, all new events must only write data inside of // `write_tlv_fields`. Versions 0.0.101+ will ignore odd-numbered events that write // data via `write_tlv_fields`. @@ -2560,6 +2616,10 @@ impl MaybeReadable for Event { former_temporary_channel_id: former_temporary_channel_id.0.unwrap(), })) }, + 45u8 => { + // Value 45 is used for `Event::FundingTransactionReadyForSigning`. + Ok(None) + }, // Versions prior to 0.0.100 did not ignore odd types, instead returning InvalidValue. // Version 0.0.100 failed to properly ignore odd types, possibly resulting in corrupt // reads. diff --git a/lightning/src/ln/channel.rs b/lightning/src/ln/channel.rs index 3861a7052f1..401075332bd 100644 --- a/lightning/src/ln/channel.rs +++ b/lightning/src/ln/channel.rs @@ -14,7 +14,7 @@ use bitcoin::constants::ChainHash; use bitcoin::script::{Builder, Script, ScriptBuf, WScriptHash}; use bitcoin::sighash::EcdsaSighashType; use bitcoin::transaction::{Transaction, TxIn, TxOut}; -use bitcoin::Weight; +use bitcoin::{Weight, Witness}; use bitcoin::hash_types::{BlockHash, Txid}; use bitcoin::hashes::sha256::Hash as Sha256; @@ -57,8 +57,7 @@ use crate::ln::channelmanager::{ use crate::ln::interactivetxs::{ calculate_change_output_value, get_output_weight, AbortReason, HandleTxCompleteResult, InteractiveTxConstructor, InteractiveTxConstructorArgs, InteractiveTxMessageSend, - InteractiveTxMessageSendResult, InteractiveTxSigningSession, OutputOwned, SharedOwnedOutput, - TX_COMMON_FIELDS_WEIGHT, + InteractiveTxMessageSendResult, InteractiveTxSigningSession, TX_COMMON_FIELDS_WEIGHT, }; use crate::ln::msgs; use crate::ln::msgs::{ClosingSigned, ClosingSignedFeeRange, DecodeError, OnionErrorPacket}; @@ -2628,24 +2627,12 @@ where // Note: For the error case when the inputs are insufficient, it will be handled after // the `calculate_change_output_value` call below let mut funding_outputs = Vec::new(); - let mut expected_remote_shared_funding_output = None; let shared_funding_output = TxOut { value: Amount::from_sat(self.funding.get_value_satoshis()), script_pubkey: self.funding.get_funding_redeemscript().to_p2wsh(), }; - if self.funding.is_outbound() { - funding_outputs.push( - OutputOwned::Shared(SharedOwnedOutput::new( - shared_funding_output, self.dual_funding_context.our_funding_satoshis, - )) - ); - } else { - let TxOut { value, script_pubkey } = shared_funding_output; - expected_remote_shared_funding_output = Some((script_pubkey, value.to_sat())); - } - // Optionally add change output let change_script = if let Some(script) = change_destination_opt { script @@ -2655,7 +2642,8 @@ where }; let change_value_opt = calculate_change_output_value( self.funding.is_outbound(), self.dual_funding_context.our_funding_satoshis, - &funding_inputs, &funding_outputs, + &funding_inputs, None, + &shared_funding_output.script_pubkey, &funding_outputs, self.dual_funding_context.funding_feerate_sat_per_1000_weight, change_script.minimal_non_dust().to_sat(), )?; @@ -2670,7 +2658,7 @@ where // Check dust limit again if change_value_decreased_with_fee > self.context.holder_dust_limit_satoshis { change_output.value = Amount::from_sat(change_value_decreased_with_fee); - funding_outputs.push(OutputOwned::Single(change_output)); + funding_outputs.push(change_output); } } @@ -2683,8 +2671,9 @@ where is_initiator: self.funding.is_outbound(), funding_tx_locktime: self.dual_funding_context.funding_tx_locktime, inputs_to_contribute: funding_inputs, + shared_funding_input: None, + shared_funding_output: (shared_funding_output, self.dual_funding_context.our_funding_satoshis), outputs_to_contribute: funding_outputs, - expected_remote_shared_funding_output, }; let mut tx_constructor = InteractiveTxConstructor::new(constructor_args)?; let msg = tx_constructor.take_initiator_first_message(); @@ -2811,7 +2800,7 @@ where }, }; - let funding_ready_for_sig_event = if signing_session.local_inputs_count() == 0 { + let funding_ready_for_sig_event_opt = if signing_session.local_inputs_count() == 0 { debug_assert_eq!(our_funding_satoshis, 0); if signing_session.provide_holder_witnesses(self.context.channel_id, Vec::new()).is_err() { debug_assert!( @@ -2825,28 +2814,12 @@ where } None } else { - // TODO(dual_funding): Send event for signing if we've contributed funds. - // Inform the user that SIGHASH_ALL must be used for all signatures when contributing - // inputs/signatures. - // Also warn the user that we don't do anything to prevent the counterparty from - // providing non-standard witnesses which will prevent the funding transaction from - // confirming. This warning must appear in doc comments wherever the user is contributing - // funds, whether they are initiator or acceptor. - // - // The following warning can be used when the APIs allowing contributing inputs become available: - //
- // WARNING: LDK makes no attempt to prevent the counterparty from using non-standard inputs which - // will prevent the funding transaction from being relayed on the bitcoin network and hence being - // confirmed. - //
- debug_assert!( - false, - "We don't support users providing inputs but somehow we had more than zero inputs", - ); - return Err(ChannelError::Close(( - "V2 channel rejected due to sender error".into(), - ClosureReason::HolderForceClosed { broadcasted_latest_txn: Some(false) } - ))); + Some(Event::FundingTransactionReadyForSigning { + channel_id: self.context.channel_id, + counterparty_node_id: self.context.counterparty_node_id, + user_channel_id: self.context.user_id, + unsigned_transaction: signing_session.unsigned_tx().build_unsigned_tx(), + }) }; let mut channel_state = ChannelState::FundingNegotiated(FundingNegotiatedFlags::new()); @@ -2857,7 +2830,7 @@ where self.interactive_tx_constructor.take(); self.interactive_tx_signing_session = Some(signing_session); - Ok((commitment_signed, funding_ready_for_sig_event)) + Ok((commitment_signed, funding_ready_for_sig_event_opt)) } } @@ -5630,7 +5603,7 @@ pub(super) struct DualFundingChannelContext { /// /// Note that this field may be emptied once the interactive negotiation has been started. #[allow(dead_code)] // TODO(dual_funding): Remove once contribution to V2 channels is enabled. - pub our_funding_inputs: Vec<(TxIn, TransactionU16LenLimited)>, + pub our_funding_inputs: Vec<(TxIn, TransactionU16LenLimited, Weight)>, } // Holder designates channel data owned for the benefit of the user client. @@ -6417,27 +6390,46 @@ where Ok(channel_monitor) } - #[rustfmt::skip] - pub fn commitment_signed(&mut self, msg: &msgs::CommitmentSigned, logger: &L) -> Result, ChannelError> - where L::Target: Logger + pub fn commitment_signed( + &mut self, msg: &msgs::CommitmentSigned, logger: &L, + ) -> Result, ChannelError> + where + L::Target: Logger, { self.commitment_signed_check_state()?; + if !self.pending_funding.is_empty() { + return Err(ChannelError::close( + "Got a single commitment_signed message when expecting a batch".to_owned(), + )); + } + let updates = self .context .validate_commitment_signed(&self.funding, &self.holder_commitment_point, msg, logger) - .map(|LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs, nondust_htlc_sources }| - vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { - commitment_tx, htlc_outputs, claimed_htlcs: vec![], nondust_htlc_sources, - }] + .map( + |LatestHolderCommitmentTXInfo { + commitment_tx, + htlc_outputs, + nondust_htlc_sources, + }| { + vec![ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { + commitment_tx, + htlc_outputs, + claimed_htlcs: vec![], + nondust_htlc_sources, + }] + }, )?; self.commitment_signed_update_monitor(updates, logger) } - #[rustfmt::skip] - pub fn commitment_signed_batch(&mut self, batch: Vec, logger: &L) -> Result, ChannelError> - where L::Target: Logger + pub fn commitment_signed_batch( + &mut self, batch: Vec, logger: &L, + ) -> Result, ChannelError> + where + L::Target: Logger, { self.commitment_signed_check_state()?; @@ -6446,15 +6438,22 @@ where let funding_txid = match msg.funding_txid { Some(funding_txid) => funding_txid, None => { - return Err(ChannelError::close("Peer sent batched commitment_signed without a funding_txid".to_string())); + return Err(ChannelError::close( + "Peer sent batched commitment_signed without a funding_txid".to_string(), + )); }, }; match messages.entry(funding_txid) { - btree_map::Entry::Vacant(entry) => { entry.insert(msg); }, + btree_map::Entry::Vacant(entry) => { + entry.insert(msg); + }, btree_map::Entry::Occupied(_) => { - return Err(ChannelError::close(format!("Peer sent batched commitment_signed with duplicate funding_txid {}", funding_txid))); - } + return Err(ChannelError::close(format!( + "Peer sent batched commitment_signed with duplicate funding_txid {}", + funding_txid + ))); + }, } } @@ -6464,36 +6463,56 @@ where .chain(self.pending_funding.iter()) .map(|funding| { let funding_txid = funding.get_funding_txo().unwrap().txid; - let msg = messages - .get(&funding_txid) - .ok_or_else(|| ChannelError::close(format!("Peer did not send a commitment_signed for pending splice transaction: {}", funding_txid)))?; + let msg = messages.get(&funding_txid).ok_or_else(|| { + ChannelError::close(format!( + "Peer did not send a commitment_signed for pending splice transaction: {}", + funding_txid + )) + })?; self.context .validate_commitment_signed(funding, &self.holder_commitment_point, msg, logger) - .map(|LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs, nondust_htlc_sources }| - ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { - commitment_tx, htlc_outputs, claimed_htlcs: vec![], nondust_htlc_sources, - } + .map( + |LatestHolderCommitmentTXInfo { + commitment_tx, + htlc_outputs, + nondust_htlc_sources, + }| ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { + commitment_tx, + htlc_outputs, + claimed_htlcs: vec![], + nondust_htlc_sources, + }, ) - } - ) + }) .collect::, ChannelError>>()?; self.commitment_signed_update_monitor(updates, logger) } - #[rustfmt::skip] fn commitment_signed_check_state(&self) -> Result<(), ChannelError> { if self.context.channel_state.is_quiescent() { - return Err(ChannelError::WarnAndDisconnect("Got commitment_signed message while quiescent".to_owned())); + return Err(ChannelError::WarnAndDisconnect( + "Got commitment_signed message while quiescent".to_owned(), + )); } if !matches!(self.context.channel_state, ChannelState::ChannelReady(_)) { - return Err(ChannelError::close("Got commitment signed message when channel was not in an operational state".to_owned())); + return Err(ChannelError::close( + "Got commitment signed message when channel was not in an operational state" + .to_owned(), + )); } if self.context.channel_state.is_peer_disconnected() { - return Err(ChannelError::close("Peer sent commitment_signed when we needed a channel_reestablish".to_owned())); + return Err(ChannelError::close( + "Peer sent commitment_signed when we needed a channel_reestablish".to_owned(), + )); } - if self.context.channel_state.is_both_sides_shutdown() && self.context.last_sent_closing_fee.is_some() { - return Err(ChannelError::close("Peer sent commitment_signed after we'd started exchanging closing_signeds".to_owned())); + if self.context.channel_state.is_both_sides_shutdown() + && self.context.last_sent_closing_fee.is_some() + { + return Err(ChannelError::close( + "Peer sent commitment_signed after we'd started exchanging closing_signeds" + .to_owned(), + )); } Ok(()) @@ -7069,6 +7088,45 @@ where } } + fn verify_interactive_tx_signatures(&mut self, _witnesses: &Vec) { + if let Some(ref mut _signing_session) = self.interactive_tx_signing_session { + // Check that sighash_all was used: + // TODO(dual_funding): Check sig for sighash + } + } + + pub fn funding_transaction_signed( + &mut self, witnesses: Vec, logger: &L, + ) -> Result, APIError> + where + L::Target: Logger, + { + self.verify_interactive_tx_signatures(&witnesses); + if let Some(ref mut signing_session) = self.interactive_tx_signing_session { + let logger = WithChannelContext::from(logger, &self.context, None); + if let Some(holder_tx_signatures) = signing_session + .provide_holder_witnesses(self.context.channel_id, witnesses) + .map_err(|err| APIError::APIMisuseError { err })? + { + if self.is_awaiting_initial_mon_persist() { + log_debug!(logger, "Not sending tx_signatures: a monitor update is in progress. Setting monitor_pending_tx_signatures."); + self.context.monitor_pending_tx_signatures = Some(holder_tx_signatures); + return Ok(None); + } + return Ok(Some(holder_tx_signatures)); + } else { + return Ok(None); + } + } else { + return Err(APIError::APIMisuseError { + err: format!( + "Channel with id {} not expecting funding signatures", + self.context.channel_id + ), + }); + } + } + #[rustfmt::skip] pub fn tx_signatures(&mut self, msg: &msgs::TxSignatures, logger: &L) -> Result<(Option, Option), ChannelError> where L::Target: Logger @@ -7353,6 +7411,18 @@ where assert!(self.context.channel_state.is_monitor_update_in_progress()); self.context.channel_state.clear_monitor_update_in_progress(); + // For channels established with V2 establishment we won't send a `tx_signatures` when we're in + // MonitorUpdateInProgress (and we assume the user will never directly broadcast the funding + // transaction and waits for us to do it). + let tx_signatures = self.context.monitor_pending_tx_signatures.take(); + if tx_signatures.is_some() { + if self.context.channel_state.is_their_tx_signatures_sent() { + self.context.channel_state = ChannelState::AwaitingChannelReady(AwaitingChannelReadyFlags::new()); + } else { + self.context.channel_state.set_our_tx_signatures_ready(); + } + } + // If we're past (or at) the AwaitingChannelReady stage on an outbound (or V2-established) channel, // try to (re-)broadcast the funding transaction as we may have declined to broadcast it when we // first received the funding_signed. @@ -7392,17 +7462,6 @@ where mem::swap(&mut finalized_claimed_htlcs, &mut self.context.monitor_pending_finalized_fulfills); let mut pending_update_adds = Vec::new(); mem::swap(&mut pending_update_adds, &mut self.context.monitor_pending_update_adds); - // For channels established with V2 establishment we won't send a `tx_signatures` when we're in - // MonitorUpdateInProgress (and we assume the user will never directly broadcast the funding - // transaction and waits for us to do it). - let tx_signatures = self.context.monitor_pending_tx_signatures.take(); - if tx_signatures.is_some() { - if self.context.channel_state.is_their_tx_signatures_sent() { - self.context.channel_state = ChannelState::AwaitingChannelReady(AwaitingChannelReadyFlags::new()); - } else { - self.context.channel_state.set_our_tx_signatures_ready(); - } - } if self.context.channel_state.is_peer_disconnected() { self.context.monitor_pending_revoke_and_ack = false; @@ -8691,7 +8750,7 @@ where #[rustfmt::skip] pub fn is_awaiting_initial_mon_persist(&self) -> bool { if !self.is_awaiting_monitor_update() { return false; } - if matches!( + if self.context.channel_state.is_interactive_signing() || matches!( self.context.channel_state, ChannelState::AwaitingChannelReady(flags) if flags.clone().clear(AwaitingChannelReadyFlags::THEIR_CHANNEL_READY | FundedStateFlags::PEER_DISCONNECTED | FundedStateFlags::MONITOR_UPDATE_IN_PROGRESS | AwaitingChannelReadyFlags::WAITING_FOR_BATCH).is_empty() ) { @@ -10837,7 +10896,7 @@ where pub fn new_outbound( fee_estimator: &LowerBoundedFeeEstimator, entropy_source: &ES, signer_provider: &SP, counterparty_node_id: PublicKey, their_features: &InitFeatures, funding_satoshis: u64, - funding_inputs: Vec<(TxIn, TransactionU16LenLimited)>, user_id: u128, config: &UserConfig, + funding_inputs: Vec<(TxIn, TransactionU16LenLimited, Weight)>, user_id: u128, config: &UserConfig, current_chain_height: u32, outbound_scid_alias: u64, funding_confirmation_target: ConfirmationTarget, logger: L, ) -> Result @@ -10979,23 +11038,19 @@ where /// Creates a new dual-funded channel from a remote side's request for one. /// Assumes chain_hash has already been checked and corresponds with what we expect! - /// TODO(dual_funding): Allow contributions, pass intended amount and inputs #[allow(dead_code)] // TODO(dual_funding): Remove once V2 channels is enabled. #[rustfmt::skip] pub fn new_inbound( fee_estimator: &LowerBoundedFeeEstimator, entropy_source: &ES, signer_provider: &SP, holder_node_id: PublicKey, counterparty_node_id: PublicKey, our_supported_features: &ChannelTypeFeatures, - their_features: &InitFeatures, msg: &msgs::OpenChannelV2, - user_id: u128, config: &UserConfig, current_chain_height: u32, logger: &L, + their_features: &InitFeatures, msg: &msgs::OpenChannelV2, user_id: u128, config: &UserConfig, + current_chain_height: u32, logger: &L, our_funding_satoshis: u64, + our_funding_inputs: Vec<(TxIn, TransactionU16LenLimited, Weight)>, ) -> Result where ES::Target: EntropySource, F::Target: FeeEstimator, L::Target: Logger, { - // TODO(dual_funding): Take these as input once supported - let our_funding_satoshis = 0u64; - let our_funding_inputs = Vec::new(); - let channel_value_satoshis = our_funding_satoshis.saturating_add(msg.common_fields.funding_satoshis); let counterparty_selected_channel_reserve_satoshis = get_v2_channel_reserve_satoshis( channel_value_satoshis, msg.common_fields.dust_limit_satoshis); @@ -11045,12 +11100,41 @@ where context.channel_id = channel_id; let dual_funding_context = DualFundingChannelContext { - our_funding_satoshis: our_funding_satoshis, + our_funding_satoshis, their_funding_satoshis: Some(msg.common_fields.funding_satoshis), funding_tx_locktime: LockTime::from_consensus(msg.locktime), funding_feerate_sat_per_1000_weight: msg.funding_feerate_sat_per_1000_weight, our_funding_inputs: our_funding_inputs.clone(), }; + let shared_funding_output = TxOut { + value: Amount::from_sat(funding.get_value_satoshis()), + script_pubkey: funding.get_funding_redeemscript().to_p2wsh(), + }; + + // Optionally add change output + let change_script = signer_provider.get_destination_script(context.channel_keys_id) + .map_err(|_| ChannelError::close("Error getting change destination script".to_string()))?; + let change_value_opt = calculate_change_output_value( + funding.is_outbound(), dual_funding_context.our_funding_satoshis, + &our_funding_inputs, None, &shared_funding_output.script_pubkey, &vec![], + dual_funding_context.funding_feerate_sat_per_1000_weight, + change_script.minimal_non_dust().to_sat(), + ).map_err(|_| ChannelError::close("Error calculating change output value".to_string()))?; + let mut our_funding_outputs = vec![]; + if let Some(change_value) = change_value_opt { + let mut change_output = TxOut { + value: Amount::from_sat(change_value), + script_pubkey: change_script, + }; + let change_output_weight = get_output_weight(&change_output.script_pubkey).to_wu(); + let change_output_fee = fee_for_weight(dual_funding_context.funding_feerate_sat_per_1000_weight, change_output_weight); + let change_value_decreased_with_fee = change_value.saturating_sub(change_output_fee); + // Check dust limit again + if change_value_decreased_with_fee > context.holder_dust_limit_satoshis { + change_output.value = Amount::from_sat(change_value_decreased_with_fee); + our_funding_outputs.push(change_output); + } + } let interactive_tx_constructor = Some(InteractiveTxConstructor::new( InteractiveTxConstructorArgs { @@ -11062,8 +11146,9 @@ where funding_tx_locktime: dual_funding_context.funding_tx_locktime, is_initiator: false, inputs_to_contribute: our_funding_inputs, - outputs_to_contribute: Vec::new(), - expected_remote_shared_funding_output: Some((funding.get_funding_redeemscript().to_p2wsh(), funding.get_value_satoshis())), + shared_funding_input: None, + shared_funding_output: (shared_funding_output, our_funding_satoshis), + outputs_to_contribute: our_funding_outputs, } ).map_err(|_| ChannelError::Close(( "V2 channel rejected due to sender error".into(), diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 773fda9d769..49a64879f47 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -30,9 +30,7 @@ use bitcoin::hashes::{Hash, HashEngine, HmacEngine}; use bitcoin::secp256k1::Secp256k1; use bitcoin::secp256k1::{PublicKey, SecretKey}; -use bitcoin::{secp256k1, Sequence}; -#[cfg(splicing)] -use bitcoin::{TxIn, Weight}; +use bitcoin::{secp256k1, Sequence, TxIn, Weight}; use crate::blinded_path::message::MessageForwardNode; use crate::blinded_path::message::{AsyncPaymentsContext, OffersContext}; @@ -122,8 +120,8 @@ use crate::util::errors::APIError; use crate::util::logger::{Level, Logger, WithContext}; use crate::util::scid_utils::fake_scid; use crate::util::ser::{ - BigSize, FixedLengthReader, LengthReadable, MaybeReadable, Readable, ReadableArgs, VecWriter, - Writeable, Writer, + BigSize, FixedLengthReader, LengthReadable, MaybeReadable, Readable, ReadableArgs, + TransactionU16LenLimited, VecWriter, Writeable, Writer, }; use crate::util::string::UntrustedString; use crate::util::wakers::{Future, Notifier}; @@ -5863,6 +5861,76 @@ where result } + /// Handles a signed funding transaction generated by interactive transaction construction and + /// provided by the client. + /// + /// Do NOT broadcast the funding transaction yourself. When we have safely received our + /// counterparty's signature(s) the funding transaction will automatically be broadcast via the + /// [`BroadcasterInterface`] provided when this `ChannelManager` was constructed. + /// + /// SIGHASH_ALL MUST be used for all signatures when providing signatures. + /// + ///
+ /// WARNING: LDK makes no attempt to prevent the counterparty from using non-standard inputs which + /// will prevent the funding transaction from being relayed on the bitcoin network and hence being + /// confirmed. + ///
+ pub fn funding_transaction_signed( + &self, channel_id: &ChannelId, counterparty_node_id: &PublicKey, transaction: Transaction, + ) -> Result<(), APIError> { + let witnesses: Vec<_> = transaction + .input + .into_iter() + .filter_map(|input| if input.witness.is_empty() { None } else { Some(input.witness) }) + .collect(); + + let per_peer_state = self.per_peer_state.read().unwrap(); + let peer_state_mutex = per_peer_state.get(counterparty_node_id).ok_or_else(|| { + APIError::ChannelUnavailable { + err: format!( + "Can't find a peer matching the passed counterparty node_id {}", + counterparty_node_id + ), + } + })?; + + let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let peer_state = &mut *peer_state_lock; + + match peer_state.channel_by_id.get_mut(channel_id) { + Some(channel) => match channel.as_funded_mut() { + Some(chan) => { + if let Some(tx_signatures) = + chan.funding_transaction_signed(witnesses, &self.logger)? + { + peer_state.pending_msg_events.push(MessageSendEvent::SendTxSignatures { + node_id: *counterparty_node_id, + msg: tx_signatures, + }); + } + }, + None => { + return Err(APIError::APIMisuseError { + err: format!( + "Channel with id {} not expecting funding signatures", + channel_id + ), + }) + }, + }, + None => { + return Err(APIError::ChannelUnavailable { + err: format!( + "Channel with id {} not found for the passed counterparty node_id {}", + channel_id, counterparty_node_id + ), + }) + }, + } + + Ok(()) + } + /// Atomically applies partial updates to the [`ChannelConfig`] of the given channels. /// /// Once the updates are applied, each eligible channel (advertised with a known short channel @@ -8220,6 +8288,8 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ false, user_channel_id, config_overrides, + 0, + vec![], ) } @@ -8251,14 +8321,82 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ true, user_channel_id, config_overrides, + 0, + vec![], + ) + } + + /// Accepts a request to open a dual-funded channel with a contribution provided by us after an + /// [`Event::OpenChannelRequest`]. + /// + /// The [`Event::OpenChannelRequest::channel_negotiation_type`] field will indicate the open channel + /// request is for a dual-funded channel when the variant is `InboundChannelFunds::DualFunded`. + /// + /// The `temporary_channel_id` parameter indicates which inbound channel should be accepted, + /// and the `counterparty_node_id` parameter is the id of the peer which has requested to open + /// the channel. + /// + /// The `user_channel_id` parameter will be provided back in + /// [`Event::ChannelClosed::user_channel_id`] to allow tracking of which events correspond + /// with which `accept_inbound_channel_*` call. + /// + /// The `funding_inputs` parameter provides the `txin`s along with their previous transactions, and + /// a corresponding witness weight for each input that will be used to contribute towards our + /// portion of the channel value. Our contribution will be calculated as the total value of these + /// inputs minus the fees we need to cover for the interactive funding transaction. The witness + /// weights must correspond to the witnesses you will provide through [`ChannelManager::funding_transaction_signed`] + /// after receiving [`Event::FundingTransactionReadyForSigning`]. + /// + /// Note that this method will return an error and reject the channel if it requires support for + /// zero confirmations. + // TODO(dual_funding): Discussion on complications with 0conf dual-funded channels where "locking" + // of UTXOs used for funding would be required and other issues. + // See https://diyhpl.us/~bryan/irc/bitcoin/bitcoin-dev/linuxfoundation-pipermail/lightning-dev/2023-May/003922.txt + /// + /// [`Event::OpenChannelRequest`]: events::Event::OpenChannelRequest + /// [`Event::OpenChannelRequest::channel_negotiation_type`]: events::Event::OpenChannelRequest::channel_negotiation_type + /// [`Event::ChannelClosed::user_channel_id`]: events::Event::ChannelClosed::user_channel_id + /// [`Event::FundingTransactionReadyForSigning`]: events::Event::FundingTransactionReadyForSigning + /// [`ChannelManager::funding_transaction_signed`]: ChannelManager::funding_transaction_signed + pub fn accept_inbound_channel_with_contribution( + &self, temporary_channel_id: &ChannelId, counterparty_node_id: &PublicKey, + user_channel_id: u128, config_overrides: Option, + our_funding_satoshis: u64, funding_inputs: Vec<(TxIn, Transaction, Weight)>, + ) -> Result<(), APIError> { + let funding_inputs = Self::length_limit_holder_input_prev_txs(funding_inputs)?; + self.do_accept_inbound_channel( + temporary_channel_id, + counterparty_node_id, + false, + user_channel_id, + config_overrides, + our_funding_satoshis, + funding_inputs, ) } + fn length_limit_holder_input_prev_txs( + funding_inputs: Vec<(TxIn, Transaction, Weight)>, + ) -> Result, APIError> { + funding_inputs + .into_iter() + .map(|(txin, tx, witness_weight)| match TransactionU16LenLimited::new(tx) { + Ok(tx) => Ok((txin, tx, witness_weight)), + Err(err) => Err(err), + }) + .collect::, ()>>() + .map_err(|_| APIError::APIMisuseError { + err: "One or more transactions had a serialized length exceeding 65535 bytes" + .into(), + }) + } + /// TODO(dual_funding): Allow contributions, pass intended amount and inputs #[rustfmt::skip] fn do_accept_inbound_channel( &self, temporary_channel_id: &ChannelId, counterparty_node_id: &PublicKey, accept_0conf: bool, - user_channel_id: u128, config_overrides: Option + user_channel_id: u128, config_overrides: Option, our_funding_satoshis: u64, + funding_inputs: Vec<(TxIn, TransactionU16LenLimited, Weight)> ) -> Result<(), APIError> { let mut config = self.default_configuration.clone(); @@ -8317,7 +8455,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ &self.channel_type_features(), &peer_state.latest_features, &open_channel_msg, user_channel_id, &config, best_block_height, - &self.logger, + &self.logger, our_funding_satoshis, funding_inputs, ).map_err(|_| MsgHandleErrInternal::from_chan_no_close( ChannelError::Close( ( @@ -8604,7 +8742,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ &self.fee_estimator, &self.entropy_source, &self.signer_provider, self.get_our_node_id(), *counterparty_node_id, &self.channel_type_features(), &peer_state.latest_features, msg, user_channel_id, - &self.default_configuration, best_block_height, &self.logger, + &self.default_configuration, best_block_height, &self.logger, 0, vec![], ).map_err(|e| MsgHandleErrInternal::from_chan_no_close(e, msg.common_fields.temporary_channel_id))?; let message_send_event = MessageSendEvent::SendAcceptChannelV2 { node_id: *counterparty_node_id, diff --git a/lightning/src/ln/dual_funding_tests.rs b/lightning/src/ln/dual_funding_tests.rs index ed770d06e6d..3682fb39d6a 100644 --- a/lightning/src/ln/dual_funding_tests.rs +++ b/lightning/src/ln/dual_funding_tests.rs @@ -10,28 +10,35 @@ //! Tests that test the creation of dual-funded channels in ChannelManager. use { - crate::chain::chaininterface::{ConfirmationTarget, LowerBoundedFeeEstimator}, - crate::events::Event, - crate::ln::chan_utils::{ - make_funding_redeemscript, ChannelPublicKeys, ChannelTransactionParameters, - CounterpartyChannelTransactionParameters, + crate::{ + chain::chaininterface::{ConfirmationTarget, LowerBoundedFeeEstimator}, + events::{Event, InboundChannelFunds}, + ln::{ + chan_utils::{ + make_funding_redeemscript, ChannelPublicKeys, ChannelTransactionParameters, + CounterpartyChannelTransactionParameters, + }, + channel::PendingV2Channel, + channel_keys::{DelayedPaymentBasepoint, HtlcBasepoint, RevocationBasepoint}, + functional_test_utils::*, + msgs::{ + BaseMessageHandler, ChannelMessageHandler, CommitmentSigned, MessageSendEvent, + TxAddInput, TxAddOutput, TxComplete, TxSignatures, + }, + types::ChannelId, + }, + prelude::*, + util::{ser::TransactionU16LenLimited, test_utils}, }, - crate::ln::channel::PendingV2Channel, - crate::ln::channel_keys::{DelayedPaymentBasepoint, HtlcBasepoint, RevocationBasepoint}, - crate::ln::functional_test_utils::*, - crate::ln::msgs::{BaseMessageHandler, ChannelMessageHandler, MessageSendEvent}, - crate::ln::msgs::{CommitmentSigned, TxAddInput, TxAddOutput, TxComplete, TxSignatures}, - crate::ln::types::ChannelId, - crate::prelude::*, - crate::util::ser::TransactionU16LenLimited, - crate::util::test_utils, bitcoin::Witness, }; // Dual-funding: V2 Channel Establishment Tests struct V2ChannelEstablishmentTestSession { - funding_input_sats: u64, + initiator_funding_satoshis: u64, initiator_input_value_satoshis: u64, + acceptor_funding_satoshis: u64, + acceptor_input_value_satoshis: u64, } // TODO(dual_funding): Use real node and API for creating V2 channels as initiator when available, @@ -41,28 +48,37 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let mut node_1_user_config = test_default_channel_config(); node_1_user_config.enable_dual_funded_channels = true; + node_1_user_config.manually_accept_inbound_channels = true; let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, Some(node_1_user_config)]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); let logger_a = test_utils::TestLogger::with_id("node a".to_owned()); - // Create a funding input for the new channel along with its previous transaction. + // Create initiator funding input for the new channel along with its previous transaction. let initiator_funding_inputs: Vec<_> = create_dual_funding_utxos_with_prev_txs( &nodes[0], &[session.initiator_input_value_satoshis], ) .into_iter() - .map(|(txin, tx, _)| (txin, TransactionU16LenLimited::new(tx).unwrap())) + .map(|(txin, tx, weight)| (txin, TransactionU16LenLimited::new(tx).unwrap(), weight)) .collect(); + // Create acceptor funding input for the new channel along with its previous transaction. + let acceptor_funding_inputs: Vec<_> = if session.acceptor_input_value_satoshis == 0 { + vec![] + } else { + create_dual_funding_utxos_with_prev_txs(&nodes[1], &[session.acceptor_input_value_satoshis]) + }; + let acceptor_funding_inputs_count = acceptor_funding_inputs.len(); + // Alice creates a dual-funded channel as initiator. - let funding_satoshis = session.funding_input_sats; + let initiator_funding_satoshis = session.initiator_funding_satoshis; let mut channel = PendingV2Channel::new_outbound( &LowerBoundedFeeEstimator(node_cfgs[0].fee_estimator), &nodes[0].node.entropy_source, &nodes[0].node.signer_provider, nodes[1].node.get_our_node_id(), &nodes[1].node.init_features(), - funding_satoshis, + initiator_funding_satoshis, initiator_funding_inputs.clone(), 42, /* user_channel_id */ nodes[0].node.get_current_default_configuration(), @@ -76,11 +92,35 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) nodes[1].node.handle_open_channel_v2(nodes[0].node.get_our_node_id(), &open_channel_v2_msg); - let accept_channel_v2_msg = get_event_msg!( - nodes[1], - MessageSendEvent::SendAcceptChannelV2, - nodes[0].node.get_our_node_id() - ); + let events = nodes[1].node.get_and_clear_pending_events(); + let accept_channel_v2_msg = match &events[0] { + Event::OpenChannelRequest { + temporary_channel_id, + counterparty_node_id, + channel_negotiation_type, + .. + } => { + assert!(matches!(channel_negotiation_type, &InboundChannelFunds::DualFunded)); + nodes[1] + .node + .accept_inbound_channel_with_contribution( + temporary_channel_id, + counterparty_node_id, + u128::MAX - 2, + None, + session.acceptor_funding_satoshis, + acceptor_funding_inputs.clone(), + ) + .unwrap(); + get_event_msg!( + nodes[1], + MessageSendEvent::SendAcceptChannelV2, + nodes[0].node.get_our_node_id() + ) + }, + _ => panic!("Unexpected event"), + }; + let channel_id = ChannelId::v2_from_revocation_basepoints( &RevocationBasepoint::from(accept_channel_v2_msg.common_fields.revocation_basepoint), &RevocationBasepoint::from(open_channel_v2_msg.common_fields.revocation_basepoint), @@ -89,24 +129,36 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) let tx_add_input_msg = TxAddInput { channel_id, serial_id: 2, // Even serial_id from initiator. - prevtx: initiator_funding_inputs[0].1.clone(), + prevtx: Some(initiator_funding_inputs[0].1.clone()), prevtx_out: 0, sequence: initiator_funding_inputs[0].0.sequence.0, shared_input_txid: None, }; - let input_value = - tx_add_input_msg.prevtx.as_transaction().output[tx_add_input_msg.prevtx_out as usize].value; + let input_value = tx_add_input_msg.prevtx.as_ref().unwrap().as_transaction().output + [tx_add_input_msg.prevtx_out as usize] + .value; assert_eq!(input_value.to_sat(), session.initiator_input_value_satoshis); nodes[1].node.handle_tx_add_input(nodes[0].node.get_our_node_id(), &tx_add_input_msg); - let _tx_complete_msg = - get_event_msg!(nodes[1], MessageSendEvent::SendTxComplete, nodes[0].node.get_our_node_id()); + if acceptor_funding_inputs_count > 0 { + let _tx_add_input_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxAddInput, + nodes[0].node.get_our_node_id() + ); + } else { + let _tx_complete_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxComplete, + nodes[0].node.get_our_node_id() + ); + } let tx_add_output_msg = TxAddOutput { channel_id, serial_id: 4, - sats: funding_satoshis, + sats: initiator_funding_satoshis.saturating_add(session.acceptor_funding_satoshis), script: make_funding_redeemscript( &open_channel_v2_msg.common_fields.funding_pubkey, &accept_channel_v2_msg.common_fields.funding_pubkey, @@ -115,16 +167,44 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) }; nodes[1].node.handle_tx_add_output(nodes[0].node.get_our_node_id(), &tx_add_output_msg); - let _tx_complete_msg = - get_event_msg!(nodes[1], MessageSendEvent::SendTxComplete, nodes[0].node.get_our_node_id()); + let acceptor_change_value_satoshis = + session.initiator_input_value_satoshis.saturating_sub(session.initiator_funding_satoshis); + if acceptor_funding_inputs_count > 0 + && acceptor_change_value_satoshis > accept_channel_v2_msg.common_fields.dust_limit_satoshis + { + println!("Change: {acceptor_change_value_satoshis} satoshis"); + let _tx_add_output_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxAddOutput, + nodes[0].node.get_our_node_id() + ); + } else { + let _tx_complete_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxComplete, + nodes[0].node.get_our_node_id() + ); + } let tx_complete_msg = TxComplete { channel_id }; nodes[1].node.handle_tx_complete(nodes[0].node.get_our_node_id(), &tx_complete_msg); let msg_events = nodes[1].node.get_and_clear_pending_msg_events(); - assert_eq!(msg_events.len(), 1); - let _msg_commitment_signed_from_1 = match msg_events[0] { - MessageSendEvent::UpdateHTLCs { ref node_id, channel_id: _, ref updates } => { + let update_htlcs_msg_event = if acceptor_funding_inputs_count > 0 { + assert_eq!(msg_events.len(), 2); + match msg_events[0] { + MessageSendEvent::SendTxComplete { ref node_id, .. } => { + assert_eq!(*node_id, nodes[0].node.get_our_node_id()); + }, + _ => panic!("Unexpected event"), + }; + &msg_events[1] + } else { + assert_eq!(msg_events.len(), 1); + &msg_events[0] + }; + let _msg_commitment_signed_from_1 = match update_htlcs_msg_event { + MessageSendEvent::UpdateHTLCs { node_id, channel_id: _, updates } => { assert_eq!(*node_id, nodes[0].node.get_our_node_id()); updates.commitment_signed.clone() }, @@ -171,7 +251,8 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) funding_outpoint, splice_parent_funding_txid: None, channel_type_features, - channel_value_satoshis: funding_satoshis, + channel_value_satoshis: initiator_funding_satoshis + .saturating_add(session.acceptor_funding_satoshis), }; let msg_commitment_signed_from_0 = CommitmentSigned { @@ -201,33 +282,85 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) // The funding transaction should not have been broadcast before persisting initial monitor has // been completed. assert_eq!(nodes[1].tx_broadcaster.txn_broadcast().len(), 0); - assert_eq!(nodes[1].node.get_and_clear_pending_events().len(), 0); + + if acceptor_funding_inputs_count > 0 { + let events = nodes[1].node.get_and_clear_pending_events(); + match &events[0] { + Event::FundingTransactionReadyForSigning { + counterparty_node_id, + unsigned_transaction, + .. + } => { + assert_eq!(counterparty_node_id, &nodes[0].node.get_our_node_id()); + let mut transaction = unsigned_transaction.clone(); + for input in transaction.input.iter_mut() { + if input.previous_output.txid + == acceptor_funding_inputs[0].0.previous_output.txid + { + let mut witness = Witness::new(); + witness.push([0x0]); + input.witness = witness; + } + } + nodes[1] + .node + .funding_transaction_signed(&channel_id, counterparty_node_id, transaction) + .unwrap(); + }, + _ => panic!("Unexpected event"), + }; + } else { + assert_eq!(nodes[1].node.get_and_clear_pending_events().len(), 0); + } // Complete the persistence of the monitor. let events = nodes[1].node.get_and_clear_pending_events(); assert!(events.is_empty()); nodes[1].chain_monitor.complete_sole_pending_chan_update(&channel_id); - let tx_signatures_msg = get_event_msg!( - nodes[1], - MessageSendEvent::SendTxSignatures, - nodes[0].node.get_our_node_id() - ); + if session.acceptor_input_value_satoshis < session.initiator_input_value_satoshis { + let tx_signatures_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxSignatures, + nodes[0].node.get_our_node_id() + ); - assert_eq!(tx_signatures_msg.channel_id, channel_id); - - let mut witness = Witness::new(); - witness.push([0x0]); - // Receive tx_signatures from channel initiator. - nodes[1].node.handle_tx_signatures( - nodes[0].node.get_our_node_id(), - &TxSignatures { - channel_id, - tx_hash: funding_outpoint.unwrap().txid, - witnesses: vec![witness], - shared_input_signature: None, - }, - ); + assert_eq!(tx_signatures_msg.channel_id, channel_id); + + let mut witness = Witness::new(); + witness.push([0x0]); + // Receive tx_signatures from channel initiator. + nodes[1].node.handle_tx_signatures( + nodes[0].node.get_our_node_id(), + &TxSignatures { + channel_id, + tx_hash: funding_outpoint.unwrap().txid, + witnesses: vec![witness], + shared_input_signature: None, + }, + ); + } else { + let mut witness = Witness::new(); + witness.push([0x0]); + // Receive tx_signatures from channel initiator. + nodes[1].node.handle_tx_signatures( + nodes[0].node.get_our_node_id(), + &TxSignatures { + channel_id, + tx_hash: funding_outpoint.unwrap().txid, + witnesses: vec![witness], + shared_input_signature: None, + }, + ); + + let tx_signatures_msg = get_event_msg!( + nodes[1], + MessageSendEvent::SendTxSignatures, + nodes[0].node.get_our_node_id() + ); + + assert_eq!(tx_signatures_msg.channel_id, channel_id); + } let events = nodes[1].node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); @@ -244,8 +377,35 @@ fn do_test_v2_channel_establishment(session: V2ChannelEstablishmentTestSession) #[test] fn test_v2_channel_establishment() { + // Initiator contributes inputs, acceptor does not. + do_test_v2_channel_establishment(V2ChannelEstablishmentTestSession { + initiator_funding_satoshis: 100_00, + initiator_input_value_satoshis: 150_000, + acceptor_funding_satoshis: 0, + acceptor_input_value_satoshis: 0, + }); + // Initiator contributes more input value than acceptor. + do_test_v2_channel_establishment(V2ChannelEstablishmentTestSession { + initiator_funding_satoshis: 100_00, + initiator_input_value_satoshis: 150_000, + acceptor_funding_satoshis: 50_00, + acceptor_input_value_satoshis: 100_000, + }); + // Initiator contributes less input value than acceptor. + do_test_v2_channel_establishment(V2ChannelEstablishmentTestSession { + initiator_funding_satoshis: 100_00, + initiator_input_value_satoshis: 150_000, + acceptor_funding_satoshis: 125_00, + acceptor_input_value_satoshis: 200_000, + }); + // Initiator contributes the same input value as acceptor. + // nodes[0] node_id: 88ce8f35acfc... + // nodes[1] node_id: 236cdaa42692... + // Since nodes[1] has a node_id in earlier lexicographical order, it should send tx_signatures first. do_test_v2_channel_establishment(V2ChannelEstablishmentTestSession { - funding_input_sats: 100_00, + initiator_funding_satoshis: 100_00, initiator_input_value_satoshis: 150_000, + acceptor_funding_satoshis: 125_00, + acceptor_input_value_satoshis: 150_000, }); } diff --git a/lightning/src/ln/interactivetxs.rs b/lightning/src/ln/interactivetxs.rs index ee991a0ae8c..b91b6aca3d7 100644 --- a/lightning/src/ln/interactivetxs.rs +++ b/lightning/src/ln/interactivetxs.rs @@ -106,14 +106,20 @@ pub(crate) enum AbortReason { InsufficientFees, OutputsValueExceedsInputsValue, InvalidTx, + /// No funding (shared) input found. + MissingFundingInput, /// No funding (shared) output found. MissingFundingOutput, /// More than one funding (shared) output found. DuplicateFundingOutput, + /// More than one funding (shared) input found. + DuplicateFundingInput, /// The intended local part of the funding output is higher than the actual shared funding output, /// if funding output is provided by the peer this is an interop error, /// if provided by the same node than internal input consistency error. InvalidLowFundingOutputValue, + /// The intended local part of the funding input is higher than the actual shared funding input. + InvalidLowFundingInputValue, /// Internal error InternalError(&'static str), } @@ -158,13 +164,18 @@ impl Display for AbortReason { f.write_str("Total value of outputs exceeds total value of inputs") }, AbortReason::InvalidTx => f.write_str("The transaction is invalid"), + AbortReason::MissingFundingInput => f.write_str("No shared funding input found"), AbortReason::MissingFundingOutput => f.write_str("No shared funding output found"), AbortReason::DuplicateFundingOutput => { f.write_str("More than one funding output found") }, + AbortReason::DuplicateFundingInput => f.write_str("More than one funding input found"), AbortReason::InvalidLowFundingOutputValue => f.write_str( "Local part of funding output value is greater than the funding output value", ), + AbortReason::InvalidLowFundingInputValue => f.write_str( + "Local part of shared input value is greater than the shared input value", + ), AbortReason::InternalError(text) => { f.write_fmt(format_args!("Internal error: {}", text)) }, @@ -396,9 +407,14 @@ impl InteractiveTxSigningSession { /// unsigned transaction. pub fn provide_holder_witnesses( &mut self, channel_id: ChannelId, witnesses: Vec, - ) -> Result<(), ()> { - if self.local_inputs_count() != witnesses.len() { - return Err(()); + ) -> Result, String> { + let local_inputs_count = self.local_inputs_count(); + if local_inputs_count != witnesses.len() { + return Err(format!( + "Provided witness count of {} does not match required count for {} inputs", + witnesses.len(), + local_inputs_count + )); } self.unsigned_tx.add_local_witnesses(witnesses.clone()); @@ -409,7 +425,11 @@ impl InteractiveTxSigningSession { shared_input_signature: None, }); - Ok(()) + if self.holder_sends_tx_signatures_first && self.has_received_commitment_signed { + Ok(self.holder_tx_signatures.clone()) + } else { + Ok(None) + } } pub fn remote_inputs_count(&self) -> usize { @@ -466,19 +486,27 @@ struct NegotiationContext { received_tx_add_input_count: u16, received_tx_add_output_count: u16, inputs: HashMap, - /// The output script intended to be the new funding output script. - /// The script pubkey is used to determine which output is the funding output. - /// When an output with the same script pubkey is added by any of the nodes, it will be - /// treated as the shared output. - /// The value is the holder's intended contribution to the shared funding output. - /// The rest is the counterparty's contribution. - /// When the funding output is added (recognized by its output script pubkey), it will be marked - /// as shared, and split between the peers according to the local value. - /// If the local value is found to be larger than the actual funding output, an error is generated. - expected_shared_funding_output: (ScriptBuf, u64), - /// The actual new funding output, set only after the output has actually been added. - /// NOTE: this output is also included in `outputs`. - actual_new_funding_output: Option, + /// Optional intended/expected funding input, used during splicing. + /// The funding input is shared, it is usually co-owned by both peers. + /// - For the initiator: + /// The intended previous funding input. This will be added alongside to the + /// provided inputs. + /// The values are the output value and the the holder's part of the shared input. + /// - For the acceptor: + /// The expected previous funding input. It should be added by the initiator node. + /// The values are the output value and the the holder's part of the shared input. + shared_funding_input: Option<(OutPoint, u64, u64)>, + /// The intended/extended funding output, potentially co-owned by both peers (shared). + /// - For the initiator: + /// The output intended to be the new funding output. This will be added alonside to the + /// provided outputs. + /// The value is the holder's intended contribution to the shared funding output + /// (must be less or equal then the amount of the output). + /// - For the acceptor: + /// The output expected as new funding output. It should be added by the initiator node. + /// The value is the holder's intended contribution to the shared funding output + /// (must be less or equal then the amount of the output). + shared_funding_output: (TxOut, u64), prevtx_outpoints: HashSet, /// The outputs added so far. outputs: HashMap, @@ -500,6 +528,12 @@ pub(crate) fn estimate_input_weight(prev_output: &TxOut) -> Weight { }) } +pub(crate) fn get_input_weight(witness_weight: Weight) -> Weight { + Weight::from_wu( + (BASE_INPUT_WEIGHT + EMPTY_SCRIPT_SIG_WEIGHT).saturating_add(witness_weight.to_wu()), + ) +} + pub(crate) fn get_output_weight(script_pubkey: &ScriptBuf) -> Weight { Weight::from_wu( (8 /* value */ + script_pubkey.consensus_encode(&mut sink()).unwrap() as u64) @@ -515,8 +549,8 @@ fn is_serial_id_valid_for_counterparty(holder_is_initiator: bool, serial_id: Ser impl NegotiationContext { fn new( holder_node_id: PublicKey, counterparty_node_id: PublicKey, holder_is_initiator: bool, - expected_shared_funding_output: (ScriptBuf, u64), tx_locktime: AbsoluteLockTime, - feerate_sat_per_kw: u32, + shared_funding_input: Option<(OutPoint, u64, u64)>, shared_funding_output: (TxOut, u64), + tx_locktime: AbsoluteLockTime, feerate_sat_per_kw: u32, ) -> Self { NegotiationContext { holder_node_id, @@ -525,8 +559,8 @@ impl NegotiationContext { received_tx_add_input_count: 0, received_tx_add_output_count: 0, inputs: new_hash_map(), - expected_shared_funding_output, - actual_new_funding_output: None, + shared_funding_input, + shared_funding_output, prevtx_outpoints: new_hash_set(), outputs: new_hash_map(), tx_locktime, @@ -534,23 +568,6 @@ impl NegotiationContext { } } - fn set_actual_new_funding_output( - &mut self, tx_out: TxOut, - ) -> Result { - if self.actual_new_funding_output.is_some() { - return Err(AbortReason::DuplicateFundingOutput); - } - let value = tx_out.value.to_sat(); - let local_owned = self.expected_shared_funding_output.1; - // Sanity check - if local_owned > value { - return Err(AbortReason::InvalidLowFundingOutputValue); - } - let shared_output = SharedOwnedOutput::new(tx_out, local_owned); - self.actual_new_funding_output = Some(shared_output.clone()); - Ok(shared_output) - } - fn is_serial_id_valid_for_counterparty(&self, serial_id: &SerialId) -> bool { is_serial_id_valid_for_counterparty(self.holder_is_initiator, *serial_id) } @@ -619,36 +636,79 @@ impl NegotiationContext { return Err(AbortReason::IncorrectInputSequenceValue); } - let transaction = msg.prevtx.as_transaction(); - let txid = transaction.compute_txid(); - - if let Some(tx_out) = transaction.output.get(msg.prevtx_out as usize) { - if !tx_out.script_pubkey.is_witness_program() { - // The receiving node: - // - MUST fail the negotiation if: - // - the `scriptPubKey` is not a witness program - return Err(AbortReason::PrevTxOutInvalid); + // Extract info from msg, check if shared + let (input, prev_outpoint) = if let Some(shared_txid) = &msg.shared_input_txid { + // This is a shared input + if self.holder_is_initiator { + return Err(AbortReason::DuplicateFundingInput); } - - if !self.prevtx_outpoints.insert(OutPoint { txid, vout: msg.prevtx_out }) { - // The receiving node: - // - MUST fail the negotiation if: - // - the `prevtx` and `prevtx_vout` are identical to a previously added - // (and not removed) input's - return Err(AbortReason::PrevTxOutInvalid); + if let Some(shared_funding_input) = &self.shared_funding_input { + // There can only be one shared output. + if self.inputs.values().any(|input| matches!(input.input, InputOwned::Shared(_))) { + return Err(AbortReason::DuplicateFundingInput); + } + // Check if receied shared input matches the expected + if shared_funding_input.0.txid != *shared_txid { + // Shared input TXID differs from expected + return Err(AbortReason::MissingFundingInput); + } else { + let previous_output = OutPoint { txid: *shared_txid, vout: msg.prevtx_out }; + let txin = TxIn { + previous_output, + sequence: Sequence(msg.sequence), + ..Default::default() + }; + let prev_output = TxOut { + value: Amount::from_sat(shared_funding_input.1), + script_pubkey: txin.script_sig.to_p2wsh(), + }; + let local_owned_sats = shared_funding_input.2; + let shared_input = SharedOwnedInput::new(txin, prev_output, local_owned_sats); + (InputOwned::Shared(shared_input), previous_output) + } + } else { + // Unexpected shared input received + return Err(AbortReason::MissingFundingInput); } } else { - // The receiving node: - // - MUST fail the negotiation if: - // - `prevtx_vout` is greater or equal to the number of outputs on `prevtx` - return Err(AbortReason::PrevTxOutInvalid); - } + // Non-shared input + if let Some(prevtx) = &msg.prevtx { + let transaction = prevtx.as_transaction(); + let txid = transaction.compute_txid(); + + if let Some(tx_out) = transaction.output.get(msg.prevtx_out as usize) { + if !tx_out.script_pubkey.is_witness_program() { + // The receiving node: + // - MUST fail the negotiation if: + // - the `scriptPubKey` is not a witness program + return Err(AbortReason::PrevTxOutInvalid); + } - let prev_out = if let Some(prev_out) = transaction.output.get(msg.prevtx_out as usize) { - prev_out.clone() - } else { - return Err(AbortReason::PrevTxOutInvalid); + let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out }; + let txin = TxIn { + previous_output: prev_outpoint, + sequence: Sequence(msg.sequence), + ..Default::default() + }; + ( + InputOwned::Single(SingleOwnedInput { + input: txin, + prev_tx: prevtx.clone(), + prev_output: tx_out.clone(), + }), + prev_outpoint, + ) + } else { + // The receiving node: + // - MUST fail the negotiation if: + // - `prevtx_vout` is greater or equal to the number of outputs on `prevtx` + return Err(AbortReason::PrevTxOutInvalid); + } + } else { + return Err(AbortReason::MissingFundingInput); + } }; + match self.inputs.entry(msg.serial_id) { hash_map::Entry::Occupied(_) => { // The receiving node: @@ -657,17 +717,20 @@ impl NegotiationContext { Err(AbortReason::DuplicateSerialId) }, hash_map::Entry::Vacant(entry) => { - let prev_outpoint = OutPoint { txid, vout: msg.prevtx_out }; - entry.insert(InteractiveTxInput::Remote(LocalOrRemoteInput { + entry.insert(InteractiveTxInput { serial_id: msg.serial_id, - input: TxIn { - previous_output: prev_outpoint, - sequence: Sequence(msg.sequence), - ..Default::default() - }, - prev_output: prev_out, - })); + added_by: AddingRole::Remote, + input, + }); + if !self.prevtx_outpoints.insert(prev_outpoint) { + // The receiving node: + // - MUST fail the negotiation if: + // - the `prevtx` and `prevtx_vout` are identical to a previously added + // (and not removed) input's + return Err(AbortReason::PrevTxOutInvalid); + } self.prevtx_outpoints.insert(prev_outpoint); + Ok(()) }, } @@ -745,22 +808,21 @@ impl NegotiationContext { } let txout = TxOut { value: Amount::from_sat(msg.sats), script_pubkey: msg.script.clone() }; - let is_shared = msg.script == self.expected_shared_funding_output.0; - let output = if is_shared { - // this is a shared funding output - let shared_output = self.set_actual_new_funding_output(txout)?; - InteractiveTxOutput { - serial_id: msg.serial_id, - added_by: AddingRole::Remote, - output: OutputOwned::Shared(shared_output), + let output = if txout == self.shared_funding_output.0 { + // This is a shared output + if self.holder_is_initiator { + return Err(AbortReason::DuplicateFundingOutput); } - } else { - InteractiveTxOutput { - serial_id: msg.serial_id, - added_by: AddingRole::Remote, - output: OutputOwned::Single(txout), + // There can only be one shared output. + if self.outputs.values().any(|output| matches!(output.output, OutputOwned::Shared(_))) { + return Err(AbortReason::DuplicateFundingOutput); } + OutputOwned::Shared(SharedOwnedOutput::new(txout, self.shared_funding_output.1)) + } else { + OutputOwned::Single(txout) }; + let output = + InteractiveTxOutput { serial_id: msg.serial_id, added_by: AddingRole::Remote, output }; match self.outputs.entry(msg.serial_id) { hash_map::Entry::Occupied(_) => { // The receiving node: @@ -791,45 +853,76 @@ impl NegotiationContext { } fn sent_tx_add_input(&mut self, msg: &msgs::TxAddInput) -> Result<(), AbortReason> { - let tx = msg.prevtx.as_transaction(); - let txin = TxIn { - previous_output: OutPoint { txid: tx.compute_txid(), vout: msg.prevtx_out }, - sequence: Sequence(msg.sequence), - ..Default::default() + let vout = msg.prevtx_out as usize; + let (prev_outpoint, input) = if let Some(shared_input_txid) = msg.shared_input_txid { + // This is the shared input + let prev_outpoint = OutPoint { txid: shared_input_txid, vout: msg.prevtx_out }; + let txin = TxIn { + previous_output: prev_outpoint, + sequence: Sequence(msg.sequence), + ..Default::default() + }; + if let Some(shared_funding_input) = &self.shared_funding_input { + let value = shared_funding_input.1; + let local_owned = shared_funding_input.2; + // Sanity check + if local_owned > value { + return Err(AbortReason::InvalidLowFundingInputValue); + } + let prev_output = TxOut { + value: Amount::from_sat(value), + script_pubkey: txin.script_sig.to_p2wsh(), + }; + ( + prev_outpoint, + InputOwned::Shared(SharedOwnedInput::new(txin, prev_output, local_owned)), + ) + } else { + return Err(AbortReason::MissingFundingInput); + } + } else { + // Non-shared input + if let Some(prevtx) = &msg.prevtx { + let prev_txid = prevtx.as_transaction().compute_txid(); + let prev_outpoint = OutPoint { txid: prev_txid, vout: msg.prevtx_out }; + let prev_output = prevtx + .as_transaction() + .output + .get(vout) + .ok_or(AbortReason::PrevTxOutInvalid)? + .clone(); + let txin = TxIn { + previous_output: prev_outpoint, + sequence: Sequence(msg.sequence), + ..Default::default() + }; + let single_input = + SingleOwnedInput { input: txin, prev_tx: prevtx.clone(), prev_output }; + (prev_outpoint, InputOwned::Single(single_input)) + } else { + return Err(AbortReason::PrevTxOutInvalid); + } }; - if !self.prevtx_outpoints.insert(txin.previous_output) { + if !self.prevtx_outpoints.insert(prev_outpoint) { // We have added an input that already exists return Err(AbortReason::PrevTxOutInvalid); } - let vout = txin.previous_output.vout as usize; - let prev_output = tx.output.get(vout).ok_or(AbortReason::PrevTxOutInvalid)?.clone(); - let input = InteractiveTxInput::Local(LocalOrRemoteInput { - serial_id: msg.serial_id, - input: txin, - prev_output, - }); + let input = + InteractiveTxInput { serial_id: msg.serial_id, added_by: AddingRole::Local, input }; self.inputs.insert(msg.serial_id, input); Ok(()) } fn sent_tx_add_output(&mut self, msg: &msgs::TxAddOutput) -> Result<(), AbortReason> { let txout = TxOut { value: Amount::from_sat(msg.sats), script_pubkey: msg.script.clone() }; - let is_shared = msg.script == self.expected_shared_funding_output.0; - let output = if is_shared { - // this is a shared funding output - let shared_output = self.set_actual_new_funding_output(txout)?; - InteractiveTxOutput { - serial_id: msg.serial_id, - added_by: AddingRole::Local, - output: OutputOwned::Shared(shared_output), - } + let output = if txout == self.shared_funding_output.0 { + // this is the shared output + OutputOwned::Shared(SharedOwnedOutput::new(txout, self.shared_funding_output.1)) } else { - InteractiveTxOutput { - serial_id: msg.serial_id, - added_by: AddingRole::Local, - output: OutputOwned::Single(txout), - } + OutputOwned::Single(txout) }; + let output = + InteractiveTxOutput { serial_id: msg.serial_id, added_by: AddingRole::Local, output }; self.outputs.insert(msg.serial_id, output); Ok(()) } @@ -886,15 +979,25 @@ impl NegotiationContext { return Err(AbortReason::ExceededNumberOfInputsOrOutputs); } - if self.actual_new_funding_output.is_none() { - return Err(AbortReason::MissingFundingOutput); - } - // - the peer's paid feerate does not meet or exceed the agreed feerate (based on the minimum fee). self.check_counterparty_fees(remote_inputs_value.saturating_sub(remote_outputs_value))?; + let shared_funding_output = self.shared_funding_output.clone(); + let opt_shared_funding_input = self.shared_funding_input.clone(); let constructed_tx = ConstructedTransaction::new(self); - + if let Some(shared_funding_input) = &opt_shared_funding_input { + if !constructed_tx + .inputs + .iter() + .any(|input| input.txin().previous_output == shared_funding_input.0) + { + return Err(AbortReason::MissingFundingInput); + } + } + if !constructed_tx.outputs.iter().any(|output| *output.tx_out() == shared_funding_output.0) + { + return Err(AbortReason::MissingFundingOutput); + } if constructed_tx.weight().to_wu() > MAX_STANDARD_TX_WEIGHT as u64 { return Err(AbortReason::TransactionTooLarge); } @@ -1107,13 +1210,14 @@ impl StateMachine { fn new( holder_node_id: PublicKey, counterparty_node_id: PublicKey, feerate_sat_per_kw: u32, is_initiator: bool, tx_locktime: AbsoluteLockTime, - expected_shared_funding_output: (ScriptBuf, u64), + shared_funding_input: Option<(OutPoint, u64, u64)>, shared_funding_output: (TxOut, u64), ) -> Self { let context = NegotiationContext::new( holder_node_id, counterparty_node_id, is_initiator, - expected_shared_funding_output, + shared_funding_input, + shared_funding_output, tx_locktime, feerate_sat_per_kw, ); @@ -1188,32 +1292,136 @@ impl_writeable_tlv_based_enum!(AddingRole, /// Represents an input -- local or remote (both have the same fields) #[derive(Clone, Debug, Eq, PartialEq)] -pub struct LocalOrRemoteInput { - serial_id: SerialId, +struct SingleOwnedInput { input: TxIn, + prev_tx: TransactionU16LenLimited, prev_output: TxOut, } -impl_writeable_tlv_based!(LocalOrRemoteInput, { - (1, serial_id, required), - (3, input, required), +impl_writeable_tlv_based!(SingleOwnedInput, { + (1, input, required), + (3, prev_tx, required), (5, prev_output, required), }); #[derive(Clone, Debug, Eq, PartialEq)] -pub(crate) enum InteractiveTxInput { - Local(LocalOrRemoteInput), - Remote(LocalOrRemoteInput), - // TODO(splicing) SharedInput should be added +struct SharedOwnedInput { + input: TxIn, + prev_output: TxOut, + local_owned: u64, } -impl_writeable_tlv_based_enum!(InteractiveTxInput, - {1, Local} => (), - {3, Remote} => (), +impl_writeable_tlv_based!(SharedOwnedInput, { + (1, input, required), + (3, prev_output, required), + (5, local_owned, required), +}); + +impl SharedOwnedInput { + pub fn new(input: TxIn, prev_output: TxOut, local_owned: u64) -> Self { + debug_assert!( + local_owned <= prev_output.value.to_sat(), + "SharedOwnedInput: Inconsistent local_owned value {}, larger than prev out value {}", + local_owned, + prev_output.value.to_sat(), + ); + Self { input, prev_output, local_owned } + } + + fn remote_owned(&self) -> u64 { + self.prev_output.value.to_sat().saturating_sub(self.local_owned) + } +} + +/// A transaction input, differentiated by ownership: +/// - exclusive by the adder, or +/// - shared +#[derive(Clone, Debug, Eq, PartialEq)] +enum InputOwned { + /// Belongs to a single party -- controlled exclusively and fully belonging to a single party + /// Includes the input and the previous output + Single(SingleOwnedInput), + // Input with shared control and value split between the two ends (or fully at one side) + Shared(SharedOwnedInput), +} + +impl_writeable_tlv_based_enum!(InputOwned, + {1, Single} => (), + {3, Shared} => (), ); +impl InputOwned { + pub fn tx_in(&self) -> &TxIn { + match &self { + InputOwned::Single(single) => &single.input, + InputOwned::Shared(shared) => &shared.input, + } + } + + pub fn tx_in_mut(&mut self) -> &mut TxIn { + match self { + InputOwned::Single(ref mut single) => &mut single.input, + InputOwned::Shared(shared) => &mut shared.input, + } + } + + pub fn into_tx_in(self) -> TxIn { + match self { + InputOwned::Single(single) => single.input, + InputOwned::Shared(shared) => shared.input, + } + } + + pub fn prev_output(&self) -> &TxOut { + match self { + InputOwned::Single(single) => &single.prev_output, + InputOwned::Shared(shared) => &shared.prev_output, + } + } + + fn is_shared(&self) -> bool { + match self { + InputOwned::Single(_) => false, + InputOwned::Shared(_) => true, + } + } + + fn local_value(&self, local_role: AddingRole) -> u64 { + match self { + InputOwned::Single(single) => match local_role { + AddingRole::Local => single.prev_output.value.to_sat(), + AddingRole::Remote => 0, + }, + InputOwned::Shared(shared) => shared.local_owned, + } + } + + fn remote_value(&self, local_role: AddingRole) -> u64 { + match self { + InputOwned::Single(single) => match local_role { + AddingRole::Local => 0, + AddingRole::Remote => single.prev_output.value.to_sat(), + }, + InputOwned::Shared(shared) => shared.remote_owned(), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct InteractiveTxInput { + serial_id: SerialId, + added_by: AddingRole, + input: InputOwned, +} + +impl_writeable_tlv_based!(InteractiveTxInput, { + (1, serial_id, required), + (3, added_by, required), + (5, input, required), +}); + #[derive(Clone, Debug, Eq, PartialEq)] -pub(super) struct SharedOwnedOutput { +struct SharedOwnedOutput { tx_out: TxOut, local_owned: u64, } @@ -1224,14 +1432,14 @@ impl_writeable_tlv_based!(SharedOwnedOutput, { }); impl SharedOwnedOutput { - pub fn new(tx_out: TxOut, local_owned: u64) -> SharedOwnedOutput { + pub fn new(tx_out: TxOut, local_owned: u64) -> Self { debug_assert!( local_owned <= tx_out.value.to_sat(), "SharedOwnedOutput: Inconsistent local_owned value {}, larger than output value {}", local_owned, - tx_out.value + tx_out.value.to_sat(), ); - SharedOwnedOutput { tx_out, local_owned } + Self { tx_out, local_owned } } fn remote_owned(&self) -> u64 { @@ -1239,11 +1447,11 @@ impl SharedOwnedOutput { } } -/// Represents an output, with information about -/// its control -- exclusive by the adder or shared --, and -/// its ownership -- value fully owned by the adder or jointly +/// A transaction output, differentiated by ownership: +/// - exclusive by the adder, or +/// - shared #[derive(Clone, Debug, Eq, PartialEq)] -pub(super) enum OutputOwned { +enum OutputOwned { /// Belongs to a single party -- controlled exclusively and fully belonging to a single party Single(TxOut), /// Output with shared control and value split between the two ends (or fully at one side) @@ -1343,38 +1551,23 @@ impl InteractiveTxOutput { impl InteractiveTxInput { pub fn serial_id(&self) -> SerialId { - match self { - InteractiveTxInput::Local(input) => input.serial_id, - InteractiveTxInput::Remote(input) => input.serial_id, - } + self.serial_id } pub fn txin(&self) -> &TxIn { - match self { - InteractiveTxInput::Local(input) => &input.input, - InteractiveTxInput::Remote(input) => &input.input, - } + self.input.tx_in() } pub fn txin_mut(&mut self) -> &mut TxIn { - match self { - InteractiveTxInput::Local(input) => &mut input.input, - InteractiveTxInput::Remote(input) => &mut input.input, - } + self.input.tx_in_mut() } pub fn into_txin(self) -> TxIn { - match self { - InteractiveTxInput::Local(input) => input.input, - InteractiveTxInput::Remote(input) => input.input, - } + self.input.into_tx_in() } pub fn prev_output(&self) -> &TxOut { - match self { - InteractiveTxInput::Local(input) => &input.prev_output, - InteractiveTxInput::Remote(input) => &input.prev_output, - } + self.input.prev_output() } pub fn value(&self) -> u64 { @@ -1382,17 +1575,11 @@ impl InteractiveTxInput { } pub fn local_value(&self) -> u64 { - match self { - InteractiveTxInput::Local(input) => input.prev_output.value.to_sat(), - InteractiveTxInput::Remote(_input) => 0, - } + self.input.local_value(self.added_by) } pub fn remote_value(&self) -> u64 { - match self { - InteractiveTxInput::Local(_input) => 0, - InteractiveTxInput::Remote(input) => input.prev_output.value.to_sat(), - } + self.input.remote_value(self.added_by) } } @@ -1400,7 +1587,7 @@ pub(super) struct InteractiveTxConstructor { state_machine: StateMachine, initiator_first_message: Option, channel_id: ChannelId, - inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)>, + inputs_to_contribute: Vec<(SerialId, InputOwned)>, outputs_to_contribute: Vec<(SerialId, OutputOwned)>, } @@ -1526,22 +1713,15 @@ where pub feerate_sat_per_kw: u32, pub is_initiator: bool, pub funding_tx_locktime: AbsoluteLockTime, - pub inputs_to_contribute: Vec<(TxIn, TransactionU16LenLimited)>, - pub outputs_to_contribute: Vec, - pub expected_remote_shared_funding_output: Option<(ScriptBuf, u64)>, + pub inputs_to_contribute: Vec<(TxIn, TransactionU16LenLimited, Weight)>, + pub shared_funding_input: Option<(OutPoint, u64, u64)>, + pub shared_funding_output: (TxOut, u64), + pub outputs_to_contribute: Vec, } impl InteractiveTxConstructor { /// Instantiates a new `InteractiveTxConstructor`. /// - /// `expected_remote_shared_funding_output`: In the case when the local node doesn't - /// add a shared output, but it expects a shared output to be added by the remote node, - /// it has to specify the script pubkey, used to determine the shared output, - /// and its (local) contribution from the shared output: - /// 0 when the whole value belongs to the remote node, or - /// positive if owned also by local. - /// Note: The local value cannot be larger than the actual shared output. - /// /// If the holder is the initiator, they need to send the first message which is a `TxAddInput` /// message. pub fn new(args: InteractiveTxConstructorArgs) -> Result @@ -1557,81 +1737,99 @@ impl InteractiveTxConstructor { is_initiator, funding_tx_locktime, inputs_to_contribute, + shared_funding_input, + shared_funding_output, outputs_to_contribute, - expected_remote_shared_funding_output, } = args; - // Sanity check: There can be at most one shared output, local-added or remote-added - let mut expected_shared_funding_output: Option<(ScriptBuf, u64)> = None; - for output in &outputs_to_contribute { - let new_output = match output { - OutputOwned::Single(_tx_out) => None, - OutputOwned::Shared(output) => { - // Sanity check - if output.local_owned > output.tx_out.value.to_sat() { - return Err(AbortReason::InvalidLowFundingOutputValue); - } - Some((output.tx_out.script_pubkey.clone(), output.local_owned)) - }, - }; - if new_output.is_some() { - if expected_shared_funding_output.is_some() - || expected_remote_shared_funding_output.is_some() - { - // more than one local-added shared output or - // one local-added and one remote-expected shared output - return Err(AbortReason::DuplicateFundingOutput); - } - expected_shared_funding_output = new_output; + + let state_machine = StateMachine::new( + holder_node_id, + counterparty_node_id, + feerate_sat_per_kw, + is_initiator, + funding_tx_locktime, + shared_funding_input.clone(), + shared_funding_output.clone(), + ); + + // Check for the existence of prevouts' + for (txin, tx, _) in inputs_to_contribute.iter() { + let vout = txin.previous_output.vout as usize; + if tx.as_transaction().output.get(vout).is_none() { + return Err(AbortReason::PrevTxOutInvalid); } } - if let Some(expected_remote_shared_funding_output) = expected_remote_shared_funding_output { - expected_shared_funding_output = Some(expected_remote_shared_funding_output); - } - if let Some(expected_shared_funding_output) = expected_shared_funding_output { - let state_machine = StateMachine::new( - holder_node_id, - counterparty_node_id, - feerate_sat_per_kw, - is_initiator, - funding_tx_locktime, - expected_shared_funding_output, - ); - let mut inputs_to_contribute: Vec<(SerialId, TxIn, TransactionU16LenLimited)> = - inputs_to_contribute - .into_iter() - .map(|(input, tx)| { - let serial_id = generate_holder_serial_id(entropy_source, is_initiator); - (serial_id, input, tx) - }) - .collect(); - // We'll sort by the randomly generated serial IDs, effectively shuffling the order of the inputs - // as the user passed them to us to avoid leaking any potential categorization of transactions - // before we pass any of the inputs to the counterparty. - inputs_to_contribute.sort_unstable_by_key(|(serial_id, _, _)| *serial_id); - let mut outputs_to_contribute: Vec<_> = outputs_to_contribute - .into_iter() - .map(|output| { - let serial_id = generate_holder_serial_id(entropy_source, is_initiator); - (serial_id, output) - }) - .collect(); - // In the same manner and for the same rationale as the inputs above, we'll shuffle the outputs. - outputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id); - let mut constructor = Self { - state_machine, - initiator_first_message: None, - channel_id, - inputs_to_contribute, - outputs_to_contribute, - }; - // We'll store the first message for the initiator. + let mut inputs_to_contribute: Vec<(SerialId, InputOwned)> = inputs_to_contribute + .into_iter() + .map(|(txin, tx, _)| { + let serial_id = generate_holder_serial_id(entropy_source, is_initiator); + let vout = txin.previous_output.vout as usize; + let prev_output = tx.as_transaction().output.get(vout).unwrap().clone(); // checked above + let input = + InputOwned::Single(SingleOwnedInput { input: txin, prev_tx: tx, prev_output }); + (serial_id, input) + }) + .collect(); + if let Some(shared_funding_input) = &shared_funding_input { if is_initiator { - constructor.initiator_first_message = Some(constructor.maybe_send_message()?); + // Add shared funding input + let serial_id = generate_holder_serial_id(entropy_source, is_initiator); + let value = shared_funding_input.1; + let local_owned = shared_funding_input.2; + // Sanity check + if local_owned > value { + return Err(AbortReason::InvalidLowFundingInputValue); + } + let txin = TxIn { + previous_output: shared_funding_input.0, + sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, + ..Default::default() + }; + let prev_out = TxOut { + value: Amount::from_sat(value), + script_pubkey: txin.script_sig.to_p2wsh(), + }; + let input = SharedOwnedInput::new(txin, prev_out, local_owned); + inputs_to_contribute.push((serial_id, InputOwned::Shared(input))); } - Ok(constructor) - } else { - Err(AbortReason::MissingFundingOutput) } + // We'll sort by the randomly generated serial IDs, effectively shuffling the order of the inputs + // as the user passed them to us to avoid leaking any potential categorization of transactions + // before we pass any of the inputs to the counterparty. + inputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id); + + let mut outputs_to_contribute: Vec<_> = outputs_to_contribute + .into_iter() + .map(|output| { + let serial_id = generate_holder_serial_id(entropy_source, is_initiator); + let output = OutputOwned::Single(output); + (serial_id, output) + }) + .collect(); + if is_initiator { + // Add shared funding output + let serial_id = generate_holder_serial_id(entropy_source, is_initiator); + let output = OutputOwned::Shared(SharedOwnedOutput::new( + shared_funding_output.0, + shared_funding_output.1, + )); + outputs_to_contribute.push((serial_id, output)); + } + // In the same manner and for the same rationale as the inputs above, we'll shuffle the outputs. + outputs_to_contribute.sort_unstable_by_key(|(serial_id, _)| *serial_id); + + let mut constructor = Self { + state_machine, + initiator_first_message: None, + channel_id, + inputs_to_contribute, + outputs_to_contribute, + }; + // We'll store the first message for the initiator. + if is_initiator { + constructor.initiator_first_message = Some(constructor.maybe_send_message()?); + } + Ok(constructor) } pub fn take_initiator_first_message(&mut self) -> Option { @@ -1641,14 +1839,24 @@ impl InteractiveTxConstructor { fn maybe_send_message(&mut self) -> Result { // We first attempt to send inputs we want to add, then outputs. Once we are done sending // them both, then we always send tx_complete. - if let Some((serial_id, input, prevtx)) = self.inputs_to_contribute.pop() { - let msg = msgs::TxAddInput { - channel_id: self.channel_id, - serial_id, - prevtx, - prevtx_out: input.previous_output.vout, - sequence: input.sequence.to_consensus_u32(), - shared_input_txid: None, + if let Some((serial_id, input)) = self.inputs_to_contribute.pop() { + let msg = match input { + InputOwned::Single(single) => msgs::TxAddInput { + channel_id: self.channel_id, + serial_id, + prevtx: Some(single.prev_tx), + prevtx_out: single.input.previous_output.vout, + sequence: single.input.sequence.to_consensus_u32(), + shared_input_txid: None, + }, + InputOwned::Shared(shared) => msgs::TxAddInput { + channel_id: self.channel_id, + serial_id, + prevtx: None, + prevtx_out: shared.input.previous_output.vout, + sequence: shared.input.sequence.to_consensus_u32(), + shared_input_txid: Some(shared.input.previous_output.txid), + }, }; do_state_transition!(self, sent_tx_add_input, &msg)?; Ok(InteractiveTxMessageSend::TxAddInput(msg)) @@ -1740,17 +1948,17 @@ impl InteractiveTxConstructor { /// `Ok(None)` /// - Inputs are not sufficent to cover contribution and fees: /// `Err(AbortReason::InsufficientFees)` -#[allow(dead_code)] // TODO(dual_funding): Remove once begin_interactive_funding_tx_construction() is used pub(super) fn calculate_change_output_value( is_initiator: bool, our_contribution: u64, - funding_inputs: &Vec<(TxIn, TransactionU16LenLimited)>, funding_outputs: &Vec, + funding_inputs: &Vec<(TxIn, TransactionU16LenLimited, Weight)>, shared_input: Option, + shared_output_funding_script: &ScriptBuf, funding_outputs: &Vec, funding_feerate_sat_per_1000_weight: u32, change_output_dust_limit: u64, ) -> Result, AbortReason> { // Process inputs and their prev txs: // calculate value sum and weight sum of inputs, also perform checks let mut total_input_satoshis = 0u64; let mut our_funding_inputs_weight = 0u64; - for (txin, tx) in funding_inputs.iter() { + for (txin, tx, witness_weight) in funding_inputs.iter() { let txid = tx.as_transaction().compute_txid(); if txin.previous_output.txid != txid { return Err(AbortReason::PrevTxOutInvalid); @@ -1758,19 +1966,30 @@ pub(super) fn calculate_change_output_value( if let Some(output) = tx.as_transaction().output.get(txin.previous_output.vout as usize) { total_input_satoshis = total_input_satoshis.saturating_add(output.value.to_sat()); our_funding_inputs_weight = - our_funding_inputs_weight.saturating_add(estimate_input_weight(output).to_wu()); + our_funding_inputs_weight.saturating_add(get_input_weight(*witness_weight).to_wu()); } else { return Err(AbortReason::PrevTxOutInvalid); } } + // If there is a shared input, account for it, + // and for the initiator also consider the fee + if let Some(shared_input) = shared_input { + total_input_satoshis = total_input_satoshis.saturating_add(shared_input); + if is_initiator { + our_funding_inputs_weight = + our_funding_inputs_weight.saturating_add(P2WSH_INPUT_WEIGHT_LOWER_BOUND); + } + } let our_funding_outputs_weight = funding_outputs.iter().fold(0u64, |weight, out| { - weight.saturating_add(get_output_weight(&out.tx_out().script_pubkey).to_wu()) + weight.saturating_add(get_output_weight(&out.script_pubkey).to_wu()) }); let mut weight = our_funding_outputs_weight.saturating_add(our_funding_inputs_weight); - // If we are the initiator, we must pay for weight of all common fields in the funding transaction. + // If we are the initiator, we must pay for the weight of the funding output and + // all common fields in the funding transaction. if is_initiator { + weight = weight.saturating_add(get_output_weight(shared_output_funding_script).to_wu()); weight = weight.saturating_add(TX_COMMON_FIELDS_WEIGHT); } @@ -1808,20 +2027,21 @@ mod tests { use crate::util::ser::TransactionU16LenLimited; use bitcoin::absolute::LockTime as AbsoluteLockTime; use bitcoin::amount::Amount; + use bitcoin::ecdsa::Signature; use bitcoin::hashes::Hash; + use bitcoin::hex::FromHex as _; use bitcoin::key::UntweakedPublicKey; - use bitcoin::opcodes; use bitcoin::script::Builder; use bitcoin::secp256k1::{Keypair, PublicKey, Secp256k1, SecretKey}; use bitcoin::transaction::Version; + use bitcoin::{opcodes, Weight}; use bitcoin::{ OutPoint, PubkeyHash, ScriptBuf, Sequence, Transaction, TxIn, TxOut, WPubkeyHash, Witness, }; use core::ops::Deref; use super::{ - get_output_weight, AddingRole, OutputOwned, SharedOwnedOutput, - P2TR_INPUT_WEIGHT_LOWER_BOUND, P2WPKH_INPUT_WEIGHT_LOWER_BOUND, + get_output_weight, P2TR_INPUT_WEIGHT_LOWER_BOUND, P2WPKH_INPUT_WEIGHT_LOWER_BOUND, P2WSH_INPUT_WEIGHT_LOWER_BOUND, TX_COMMON_FIELDS_WEIGHT, }; @@ -1869,15 +2089,17 @@ mod tests { struct TestSession { description: &'static str, - inputs_a: Vec<(TxIn, TransactionU16LenLimited)>, - outputs_a: Vec, - inputs_b: Vec<(TxIn, TransactionU16LenLimited)>, - outputs_b: Vec, + inputs_a: Vec<(TxIn, TransactionU16LenLimited, Weight)>, + a_shared_input: Option<(OutPoint, u64, u64)>, + /// The funding output, with the value contributed + shared_output_a: (TxOut, u64), + outputs_a: Vec, + inputs_b: Vec<(TxIn, TransactionU16LenLimited, Weight)>, + b_shared_input: Option<(OutPoint, u64, u64)>, + /// The funding output, with the value contributed + shared_output_b: (TxOut, u64), + outputs_b: Vec, expect_error: Option<(AbortReason, ErrorCulprit)>, - /// A node adds no shared output, but expects the peer to add one, with the specific script pubkey, and local contribution - a_expected_remote_shared_output: Option<(ScriptBuf, u64)>, - /// B node adds no shared output, but expects the peer to add one, with the specific script pubkey, and local contribution - b_expected_remote_shared_output: Option<(ScriptBuf, u64)>, } fn do_test_interactive_tx_constructor(session: TestSession) { @@ -1909,57 +2131,6 @@ mod tests { &SecretKey::from_slice(&[43; 32]).unwrap(), ); - // funding output sanity check - let shared_outputs_by_a: Vec<_> = - session.outputs_a.iter().filter(|o| o.is_shared()).collect(); - if shared_outputs_by_a.len() > 1 { - println!("Test warning: Expected at most one shared output. NodeA"); - } - let shared_output_by_a = if !shared_outputs_by_a.is_empty() { - Some(shared_outputs_by_a[0].value()) - } else { - None - }; - let shared_outputs_by_b: Vec<_> = - session.outputs_b.iter().filter(|o| o.is_shared()).collect(); - if shared_outputs_by_b.len() > 1 { - println!("Test warning: Expected at most one shared output. NodeB"); - } - let shared_output_by_b = if !shared_outputs_by_b.is_empty() { - Some(shared_outputs_by_b[0].value()) - } else { - None - }; - if session.a_expected_remote_shared_output.is_some() - || session.b_expected_remote_shared_output.is_some() - { - let expected_by_a = if let Some(a_expected_remote_shared_output) = - &session.a_expected_remote_shared_output - { - a_expected_remote_shared_output.1 - } else if !shared_outputs_by_a.is_empty() { - shared_outputs_by_a[0].local_value(AddingRole::Local) - } else { - 0 - }; - let expected_by_b = if let Some(b_expected_remote_shared_output) = - &session.b_expected_remote_shared_output - { - b_expected_remote_shared_output.1 - } else if !shared_outputs_by_b.is_empty() { - shared_outputs_by_b[0].local_value(AddingRole::Local) - } else { - 0 - }; - - let expected_sum = expected_by_a + expected_by_b; - let actual_shared_output = - shared_output_by_a.unwrap_or(shared_output_by_b.unwrap_or(0)); - if expected_sum != actual_shared_output { - println!("Test warning: Sum of expected shared output values does not match actual shared output value, {} {} {} {} {} {}", expected_sum, actual_shared_output, expected_by_a, expected_by_b, shared_output_by_a.unwrap_or(0), shared_output_by_b.unwrap_or(0)); - } - } - let mut constructor_a = match InteractiveTxConstructor::new(InteractiveTxConstructorArgs { entropy_source, channel_id, @@ -1969,8 +2140,9 @@ mod tests { is_initiator: true, funding_tx_locktime, inputs_to_contribute: session.inputs_a, - outputs_to_contribute: session.outputs_a.to_vec(), - expected_remote_shared_funding_output: session.a_expected_remote_shared_output, + shared_funding_input: session.a_shared_input, + shared_funding_output: (session.shared_output_a.0, session.shared_output_a.1), + outputs_to_contribute: session.outputs_a, }) { Ok(r) => r, Err(abort_reason) => { @@ -1992,8 +2164,9 @@ mod tests { is_initiator: false, funding_tx_locktime, inputs_to_contribute: session.inputs_b, - outputs_to_contribute: session.outputs_b.to_vec(), - expected_remote_shared_funding_output: session.b_expected_remote_shared_output, + shared_funding_input: session.b_shared_input, + shared_funding_output: (session.shared_output_b.0, session.shared_output_b.1), + outputs_to_contribute: session.outputs_b, }) { Ok(r) => r, Err(abort_reason) => { @@ -2144,24 +2317,42 @@ mod tests { } } - fn generate_inputs(outputs: &[TestOutput]) -> Vec<(TxIn, TransactionU16LenLimited)> { + fn generate_inputs(outputs: &[TestOutput]) -> Vec<(TxIn, TransactionU16LenLimited, Weight)> { let tx = generate_tx(outputs); let txid = tx.compute_txid(); tx.output .iter() .enumerate() .map(|(idx, _)| { - let input = TxIn { + let txin = TxIn { previous_output: OutPoint { txid, vout: idx as u32 }, script_sig: Default::default(), sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, - witness: Default::default(), + witness: Witness::p2wpkh( + &Signature::sighash_all( + bitcoin::secp256k1::ecdsa::Signature::from_der(&>::from_hex("3044022008f4f37e2d8f74e18c1b8fde2374d5f28402fb8ab7fd1cc5b786aa40851a70cb022032b1374d1a0f125eae4f69d1bc0b7f896c964cfdba329f38a952426cf427484c").unwrap()[..]).unwrap() + ) + .into(), + &PublicKey::from_slice(&[2; 33]).unwrap(), + ), }; - (input, TransactionU16LenLimited::new(tx.clone()).unwrap()) + let witness_weight = Weight::from_wu_usize(txin.witness.size()); + (txin, TransactionU16LenLimited::new(tx.clone()).unwrap(), witness_weight) }) .collect() } + fn generate_shared_input( + prev_funding_tx: &Transaction, vout: u32, local_owned: u64, + ) -> (OutPoint, u64, u64) { + let txid = prev_funding_tx.compute_txid(); + let value = prev_funding_tx.output.get(vout as usize).unwrap().value.to_sat(); + if local_owned > value { + println!("Warning: local owned > value for shared input, {} {}", local_owned, value); + } + (OutPoint { txid, vout }, value, local_owned) + } + fn generate_p2wsh_script_pubkey() -> ScriptBuf { Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_p2wsh() } @@ -2174,45 +2365,31 @@ mod tests { Builder::new().push_int(33).into_script().to_p2wsh() } - fn generate_output_nonfunding_one(output: &TestOutput) -> OutputOwned { - OutputOwned::Single(generate_txout(output)) + fn generate_output_nonfunding_one(output: &TestOutput) -> TxOut { + generate_txout(output) } - fn generate_outputs(outputs: &[TestOutput]) -> Vec { + fn generate_outputs(outputs: &[TestOutput]) -> Vec { outputs.iter().map(generate_output_nonfunding_one).collect() } - /// Generate a single output that is the funding output - fn generate_output(output: &TestOutput) -> Vec { - let txout = generate_txout(output); - let value = txout.value.to_sat(); - vec![OutputOwned::Shared(SharedOwnedOutput::new(txout, value))] - } - - /// Generate a single P2WSH output that is the funding output - fn generate_funding_output(value: u64) -> Vec { - generate_output(&TestOutput::P2WSH(value)) - } - - /// Generate a single P2WSH output with shared contribution that is the funding output - fn generate_shared_funding_output_one(value: u64, local_value: u64) -> OutputOwned { - OutputOwned::Shared(SharedOwnedOutput { - tx_out: generate_txout(&TestOutput::P2WSH(value)), - local_owned: local_value, - }) - } - - /// Generate a single P2WSH output with shared contribution that is the funding output - fn generate_shared_funding_output(value: u64, local_value: u64) -> Vec { - vec![generate_shared_funding_output_one(value, local_value)] + /// Generate a single P2WSH output that is the funding output, with local contributions + fn generate_funding_txout(value: u64, local_value: u64) -> (TxOut, u64) { + if local_value > value { + println!("Warning: Invalid local value, {} {}", value, local_value); + } + (generate_txout(&TestOutput::P2WSH(value)), local_value) } - fn generate_fixed_number_of_inputs(count: u16) -> Vec<(TxIn, TransactionU16LenLimited)> { + fn generate_fixed_number_of_inputs( + count: u16, + ) -> Vec<(TxIn, TransactionU16LenLimited, Weight)> { // Generate transactions with a total `count` number of outputs such that no transaction has a // serialized length greater than u16::MAX. let max_outputs_per_prevtx = 1_500; let mut remaining = count; - let mut inputs: Vec<(TxIn, TransactionU16LenLimited)> = Vec::with_capacity(count as usize); + let mut inputs: Vec<(TxIn, TransactionU16LenLimited, Weight)> = + Vec::with_capacity(count as usize); while remaining > 0 { let tx_output_count = remaining.min(max_outputs_per_prevtx); @@ -2225,7 +2402,7 @@ mod tests { ); let txid = tx.compute_txid(); - let mut temp: Vec<(TxIn, TransactionU16LenLimited)> = tx + let mut temp: Vec<(TxIn, TransactionU16LenLimited, Weight)> = tx .output .iter() .enumerate() @@ -2234,9 +2411,16 @@ mod tests { previous_output: OutPoint { txid, vout: idx as u32 }, script_sig: Default::default(), sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, - witness: Default::default(), + witness: Witness::p2wpkh( + &Signature::sighash_all( + bitcoin::secp256k1::ecdsa::Signature::from_der(&>::from_hex("3044022008f4f37e2d8f74e18c1b8fde2374d5f28402fb8ab7fd1cc5b786aa40851a70cb022032b1374d1a0f125eae4f69d1bc0b7f896c964cfdba329f38a952426cf427484c").unwrap()[..]).unwrap() + ) + .into(), + &PublicKey::from_slice(&[2; 33]).unwrap(), + ), }; - (input, TransactionU16LenLimited::new(tx.clone()).unwrap()) + let witness_weight = Weight::from_wu_usize(input.witness.size()); + (input, TransactionU16LenLimited::new(tx.clone()).unwrap(), witness_weight) }) .collect(); @@ -2246,7 +2430,7 @@ mod tests { inputs } - fn generate_fixed_number_of_outputs(count: u16) -> Vec { + fn generate_fixed_number_of_outputs(count: u16) -> Vec { // Set a constant value for each TxOut generate_outputs(&vec![TestOutput::P2WPKH(1_000_000); count as usize]) } @@ -2255,111 +2439,122 @@ mod tests { Builder::new().push_opcode(opcodes::OP_TRUE).into_script().to_p2sh() } - fn generate_non_witness_output(value: u64) -> OutputOwned { - OutputOwned::Single(TxOut { - value: Amount::from_sat(value), - script_pubkey: generate_p2sh_script_pubkey(), - }) + fn generate_non_witness_output(value: u64) -> TxOut { + TxOut { value: Amount::from_sat(value), script_pubkey: generate_p2sh_script_pubkey() } } #[test] fn test_interactive_tx_constructor() { - do_test_interactive_tx_constructor(TestSession { - description: "No contributions", - inputs_a: vec![], - outputs_a: vec![], - inputs_b: vec![], - outputs_b: vec![], - expect_error: Some((AbortReason::MissingFundingOutput, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: None, - }); + // A transaction that can be used as a previous funding transaction + let prev_funding_tx_1 = Transaction { + input: Vec::new(), + output: vec![TxOut { + value: Amount::from_sat(60_000), + script_pubkey: ScriptBuf::new(), + }], + lock_time: AbsoluteLockTime::ZERO, + version: Version::TWO, + }; + do_test_interactive_tx_constructor(TestSession { description: "Single contribution, no initiator inputs", inputs_a: vec![], - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), - inputs_b: vec![], - outputs_b: vec![], - expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), - }); - do_test_interactive_tx_constructor(TestSession { - description: "Single contribution, no initiator outputs", - inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], - expect_error: Some((AbortReason::MissingFundingOutput, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: None, + expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)), }); + do_test_interactive_tx_constructor(TestSession { description: "Single contribution, no fees", inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); let p2wpkh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WPKH_INPUT_WEIGHT_LOWER_BOUND); let outputs_fee = fee_for_weight( TEST_FEERATE_SATS_PER_KW, - get_output_weight(&generate_p2wpkh_script_pubkey()).to_wu(), + get_output_weight(&generate_p2wsh_script_pubkey()).to_wu(), ); let tx_common_fields_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, TX_COMMON_FIELDS_WEIGHT); let amount_adjusted_with_p2wpkh_fee = - 1_000_000 - p2wpkh_fee - outputs_fee - tx_common_fields_fee; + 1_000_000 - p2wpkh_fee - outputs_fee - tx_common_fields_fee + 1; do_test_interactive_tx_constructor(TestSession { description: "Single contribution, with P2WPKH input, insufficient fees", inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH( - amount_adjusted_with_p2wpkh_fee + 1, /* makes fees insuffcient for initiator */ - )), + a_shared_input: None, + // makes fees insuffcient for initiator + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2wpkh_fee + 1, + amount_adjusted_with_p2wpkh_fee + 1, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2wpkh_fee + 1, 0), outputs_b: vec![], expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Single contribution with P2WPKH input, sufficient fees", inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH(amount_adjusted_with_p2wpkh_fee)), + a_shared_input: None, + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2wpkh_fee, + amount_adjusted_with_p2wpkh_fee, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2wpkh_fee, 0), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); let p2wsh_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2WSH_INPUT_WEIGHT_LOWER_BOUND); let amount_adjusted_with_p2wsh_fee = - 1_000_000 - p2wsh_fee - outputs_fee - tx_common_fields_fee; + 1_000_000 - p2wsh_fee - outputs_fee - tx_common_fields_fee + 1; do_test_interactive_tx_constructor(TestSession { description: "Single contribution, with P2WSH input, insufficient fees", inputs_a: generate_inputs(&[TestOutput::P2WSH(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH( - amount_adjusted_with_p2wsh_fee + 1, /* makes fees insuffcient for initiator */ - )), + a_shared_input: None, + // makes fees insuffcient for initiator + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2wsh_fee + 1, + amount_adjusted_with_p2wsh_fee + 1, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2wsh_fee + 1, 0), outputs_b: vec![], expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Single contribution with P2WSH input, sufficient fees", inputs_a: generate_inputs(&[TestOutput::P2WSH(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH(amount_adjusted_with_p2wsh_fee)), + a_shared_input: None, + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2wsh_fee, + amount_adjusted_with_p2wsh_fee, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2wsh_fee, 0), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); let p2tr_fee = fee_for_weight(TEST_FEERATE_SATS_PER_KW, P2TR_INPUT_WEIGHT_LOWER_BOUND); let amount_adjusted_with_p2tr_fee = @@ -2367,61 +2562,73 @@ mod tests { do_test_interactive_tx_constructor(TestSession { description: "Single contribution, with P2TR input, insufficient fees", inputs_a: generate_inputs(&[TestOutput::P2TR(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH( - amount_adjusted_with_p2tr_fee + 1, /* makes fees insuffcient for initiator */ - )), + a_shared_input: None, + // makes fees insuffcient for initiator + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2tr_fee + 1, + amount_adjusted_with_p2tr_fee + 1, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2tr_fee + 1, 0), outputs_b: vec![], expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Single contribution with P2TR input, sufficient fees", inputs_a: generate_inputs(&[TestOutput::P2TR(1_000_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH(amount_adjusted_with_p2tr_fee)), + a_shared_input: None, + shared_output_a: generate_funding_txout( + amount_adjusted_with_p2tr_fee, + amount_adjusted_with_p2tr_fee, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(amount_adjusted_with_p2tr_fee, 0), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Initiator contributes sufficient fees, but non-initiator does not", inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(100_000, 0), outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(100_000)]), - outputs_b: generate_output(&TestOutput::P2WPKH(100_000)), + b_shared_input: None, + shared_output_b: generate_funding_txout(100_000, 100_000), + outputs_b: vec![], expect_error: Some((AbortReason::InsufficientFees, ErrorCulprit::NodeB)), - a_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), - b_expected_remote_shared_output: None, }); do_test_interactive_tx_constructor(TestSession { description: "Multi-input-output contributions from both sides", inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000); 2]), - outputs_a: vec![ - generate_shared_funding_output_one(1_000_000, 200_000), - generate_output_nonfunding_one(&TestOutput::P2WPKH(200_000)), - ], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 200_000), + outputs_a: vec![generate_output_nonfunding_one(&TestOutput::P2WPKH(200_000))], inputs_b: generate_inputs(&[ TestOutput::P2WPKH(1_000_000), TestOutput::P2WPKH(500_000), ]), + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 800_000), outputs_b: vec![generate_output_nonfunding_one(&TestOutput::P2WPKH(400_000))], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 800_000)), }); do_test_interactive_tx_constructor(TestSession { description: "Prevout from initiator is not a witness program", inputs_a: generate_inputs(&[TestOutput::P2PKH(1_000_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); let tx = @@ -2430,30 +2637,44 @@ mod tests { previous_output: OutPoint { txid: tx.as_transaction().compute_txid(), vout: 0 }, ..Default::default() }; + let invalid_sequence_input_witness_weight = + Weight::from_wu_usize(invalid_sequence_input.witness.size()); do_test_interactive_tx_constructor(TestSession { description: "Invalid input sequence from initiator", - inputs_a: vec![(invalid_sequence_input, tx.clone())], - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), + inputs_a: vec![( + invalid_sequence_input, + tx.clone(), + invalid_sequence_input_witness_weight, + )], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::IncorrectInputSequenceValue, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); let duplicate_input = TxIn { previous_output: OutPoint { txid: tx.as_transaction().compute_txid(), vout: 0 }, sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, ..Default::default() }; + let duplicate_input_witness_weight = Weight::from_wu_usize(duplicate_input.witness.size()); do_test_interactive_tx_constructor(TestSession { description: "Duplicate prevout from initiator", - inputs_a: vec![(duplicate_input.clone(), tx.clone()), (duplicate_input, tx.clone())], - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), + inputs_a: vec![ + (duplicate_input.clone(), tx.clone(), duplicate_input_witness_weight), + (duplicate_input, tx.clone(), duplicate_input_witness_weight), + ], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeB)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); // Non-initiator uses same prevout as initiator. let duplicate_input = TxIn { @@ -2461,108 +2682,130 @@ mod tests { sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, ..Default::default() }; + let duplicate_input_witness_weight = Weight::from_wu_usize(duplicate_input.witness.size()); do_test_interactive_tx_constructor(TestSession { description: "Non-initiator uses same prevout as initiator", - inputs_a: vec![(duplicate_input.clone(), tx.clone())], - outputs_a: generate_shared_funding_output(1_000_000, 905_000), - inputs_b: vec![(duplicate_input.clone(), tx.clone())], + inputs_a: vec![(duplicate_input.clone(), tx.clone(), duplicate_input_witness_weight)], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 905_000), + outputs_a: vec![], + inputs_b: vec![(duplicate_input.clone(), tx.clone(), duplicate_input_witness_weight)], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 95_000), outputs_b: vec![], expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 95_000)), }); let duplicate_input = TxIn { previous_output: OutPoint { txid: tx.as_transaction().compute_txid(), vout: 0 }, sequence: Sequence::ENABLE_RBF_NO_LOCKTIME, ..Default::default() }; + let duplicate_input_witness_weight = Weight::from_wu_usize(duplicate_input.witness.size()); do_test_interactive_tx_constructor(TestSession { description: "Non-initiator uses same prevout as initiator", - inputs_a: vec![(duplicate_input.clone(), tx.clone())], - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), - inputs_b: vec![(duplicate_input.clone(), tx.clone())], + inputs_a: vec![(duplicate_input.clone(), tx.clone(), duplicate_input_witness_weight)], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: vec![], + inputs_b: vec![(duplicate_input.clone(), tx.clone(), duplicate_input_witness_weight)], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::PrevTxOutInvalid, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_p2wpkh_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Initiator sends too many TxAddInputs", inputs_a: generate_fixed_number_of_inputs(MAX_RECEIVED_TX_ADD_INPUT_COUNT + 1), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::ReceivedTooManyTxAddInputs, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor_with_entropy_source( TestSession { // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this. description: "Attempt to queue up two inputs with duplicate serial ids", inputs_a: generate_fixed_number_of_inputs(2), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }, &DuplicateEntropySource, ); do_test_interactive_tx_constructor(TestSession { description: "Initiator sends too many TxAddOutputs", inputs_a: vec![], - outputs_a: generate_fixed_number_of_outputs(MAX_RECEIVED_TX_ADD_OUTPUT_COUNT + 1), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: generate_fixed_number_of_outputs(MAX_RECEIVED_TX_ADD_OUTPUT_COUNT), inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::ReceivedTooManyTxAddOutputs, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); + let dust_amount = generate_p2wsh_script_pubkey().minimal_non_dust().to_sat() - 1; do_test_interactive_tx_constructor(TestSession { description: "Initiator sends an output below dust value", inputs_a: vec![], - outputs_a: generate_funding_output( - generate_p2wsh_script_pubkey().minimal_non_dust().to_sat() - 1, - ), + a_shared_input: None, + shared_output_a: generate_funding_txout(dust_amount, dust_amount), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(dust_amount, 0), outputs_b: vec![], expect_error: Some((AbortReason::BelowDustLimit, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Initiator sends an output above maximum sats allowed", inputs_a: vec![], - outputs_a: generate_output(&TestOutput::P2WPKH(TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1)), + a_shared_input: None, + shared_output_a: generate_funding_txout( + TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1, + TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1, + ), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(TOTAL_BITCOIN_SUPPLY_SATOSHIS + 1, 0), outputs_b: vec![], expect_error: Some((AbortReason::ExceededMaximumSatsAllowed, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Initiator sends an output without a witness program", inputs_a: vec![], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![generate_non_witness_output(1_000_000)], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::InvalidOutputScript, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor_with_entropy_source( TestSession { // We use a deliberately bad entropy source, `DuplicateEntropySource` to simulate this. description: "Attempt to queue up two outputs with duplicate serial ids", inputs_a: vec![], + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: generate_fixed_number_of_outputs(2), inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::DuplicateSerialId, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }, &DuplicateEntropySource, ); @@ -2570,99 +2813,101 @@ mod tests { do_test_interactive_tx_constructor(TestSession { description: "Peer contributed more output value than inputs", inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000)]), - outputs_a: generate_output(&TestOutput::P2WPKH(1_000_000)), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Peer contributed more than allowed number of inputs", inputs_a: generate_fixed_number_of_inputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), outputs_a: vec![], inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some(( AbortReason::ExceededNumberOfInputsOrOutputs, ErrorCulprit::Indeterminate, )), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); do_test_interactive_tx_constructor(TestSession { description: "Peer contributed more than allowed number of outputs", inputs_a: generate_inputs(&[TestOutput::P2WPKH(TOTAL_BITCOIN_SUPPLY_SATOSHIS)]), - outputs_a: generate_fixed_number_of_outputs(MAX_INPUTS_OUTPUTS_COUNT as u16 + 1), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 1_000_000), + outputs_a: generate_fixed_number_of_outputs(MAX_INPUTS_OUTPUTS_COUNT as u16), inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 0), outputs_b: vec![], expect_error: Some(( AbortReason::ExceededNumberOfInputsOrOutputs, ErrorCulprit::Indeterminate, )), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), - }); - - // Adding multiple outputs to the funding output pubkey is an error - do_test_interactive_tx_constructor(TestSession { - description: "Adding two outputs to the funding output pubkey", - inputs_a: generate_inputs(&[TestOutput::P2WPKH(1_000_000)]), - outputs_a: generate_funding_output(100_000), - inputs_b: generate_inputs(&[TestOutput::P2WPKH(1_001_000)]), - outputs_b: generate_funding_output(100_000), - expect_error: Some((AbortReason::DuplicateFundingOutput, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: None, }); // We add the funding output, but we contribute a little do_test_interactive_tx_constructor(TestSession { description: "Funding output by us, small contribution", inputs_a: generate_inputs(&[TestOutput::P2WPKH(12_000)]), - outputs_a: generate_shared_funding_output(1_000_000, 10_000), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 10_000), + outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(992_000)]), + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 990_000), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 990_000)), }); // They add the funding output, and we contribute a little do_test_interactive_tx_constructor(TestSession { description: "Funding output by them, small contribution", inputs_a: generate_inputs(&[TestOutput::P2WPKH(12_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 10_000), outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(992_000)]), - outputs_b: generate_shared_funding_output(1_000_000, 990_000), + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 990_000), + outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 10_000)), - b_expected_remote_shared_output: None, }); // We add the funding output, and we contribute most do_test_interactive_tx_constructor(TestSession { description: "Funding output by us, large contribution", inputs_a: generate_inputs(&[TestOutput::P2WPKH(992_000)]), - outputs_a: generate_shared_funding_output(1_000_000, 990_000), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 990_000), + outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(12_000)]), + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 10_000), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 10_000)), }); // They add the funding output, but we contribute most do_test_interactive_tx_constructor(TestSession { description: "Funding output by them, large contribution", inputs_a: generate_inputs(&[TestOutput::P2WPKH(992_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(1_000_000, 990_000), outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(12_000)]), - outputs_b: generate_shared_funding_output(1_000_000, 10_000), + b_shared_input: None, + shared_output_b: generate_funding_txout(1_000_000, 10_000), + outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 990_000)), - b_expected_remote_shared_output: None, }); // During a splice-out, with peer providing more output value than input value @@ -2671,12 +2916,14 @@ mod tests { do_test_interactive_tx_constructor(TestSession { description: "Splice out with sufficient initiator balance", inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000), TestOutput::P2WPKH(50_000)]), - outputs_a: generate_funding_output(120_000), + a_shared_input: None, + shared_output_a: generate_funding_txout(120_000, 120_000), + outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(50_000)]), + b_shared_input: None, + shared_output_b: generate_funding_txout(120_000, 0), outputs_b: vec![], expect_error: None, - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); // During a splice-out, with peer providing more output value than input value @@ -2685,37 +2932,70 @@ mod tests { do_test_interactive_tx_constructor(TestSession { description: "Splice out with insufficient initiator balance", inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000), TestOutput::P2WPKH(15_000)]), - outputs_a: generate_funding_output(120_000), + a_shared_input: None, + shared_output_a: generate_funding_txout(120_000, 120_000), + outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(85_000)]), + b_shared_input: None, + shared_output_b: generate_funding_txout(120_000, 0), outputs_b: vec![], expect_error: Some((AbortReason::OutputsValueExceedsInputsValue, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 0)), }); - // The actual funding output value is lower than the intended local contribution by the same node + // The intended&expected shared output value differ do_test_interactive_tx_constructor(TestSession { description: "Splice in, invalid intended local contribution", inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000), TestOutput::P2WPKH(15_000)]), - outputs_a: generate_shared_funding_output(100_000, 120_000), // local value is higher than the output value + a_shared_input: None, + shared_output_a: generate_funding_txout(100_000, 100_000), + outputs_a: vec![], inputs_b: generate_inputs(&[TestOutput::P2WPKH(85_000)]), + b_shared_input: None, + shared_output_b: generate_funding_txout(120_000, 0), // value different outputs_b: vec![], - expect_error: Some((AbortReason::InvalidLowFundingOutputValue, ErrorCulprit::NodeA)), - a_expected_remote_shared_output: None, - b_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 20_000)), + expect_error: Some((AbortReason::MissingFundingOutput, ErrorCulprit::NodeA)), }); - // The actual funding output value is lower than the intended local contribution of the other node + // Provide and expect a shared input do_test_interactive_tx_constructor(TestSession { - description: "Splice in, invalid intended local contribution", - inputs_a: generate_inputs(&[TestOutput::P2WPKH(100_000), TestOutput::P2WPKH(15_000)]), + description: "Provide and expect a shared input", + inputs_a: generate_inputs(&[TestOutput::P2WPKH(50_000)]), + a_shared_input: Some(generate_shared_input(&prev_funding_tx_1, 0, 60_000)), + shared_output_a: generate_funding_txout(108_000, 108_000), outputs_a: vec![], - inputs_b: generate_inputs(&[TestOutput::P2WPKH(85_000)]), - outputs_b: generate_funding_output(100_000), - // The error is caused by NodeA, it occurs when nodeA prepares the message to be sent to NodeB, that's why here it shows up as NodeB - expect_error: Some((AbortReason::InvalidLowFundingOutputValue, ErrorCulprit::NodeB)), - a_expected_remote_shared_output: Some((generate_funding_script_pubkey(), 120_000)), // this is higher than the actual output value - b_expected_remote_shared_output: None, + inputs_b: vec![], + b_shared_input: Some(generate_shared_input(&prev_funding_tx_1, 0, 0)), + shared_output_b: generate_funding_txout(108_000, 0), + outputs_b: vec![], + expect_error: None, + }); + + // Expect a shared input, but it's missing + do_test_interactive_tx_constructor(TestSession { + description: "Expect a shared input, but it's missing", + inputs_a: generate_inputs(&[TestOutput::P2WPKH(110_000)]), + a_shared_input: None, + shared_output_a: generate_funding_txout(108_000, 108_000), + outputs_a: vec![], + inputs_b: vec![], + b_shared_input: Some(generate_shared_input(&prev_funding_tx_1, 0, 0)), + shared_output_b: generate_funding_txout(108_000, 0), + outputs_b: vec![], + expect_error: Some((AbortReason::MissingFundingInput, ErrorCulprit::NodeA)), + }); + + // Provide a shared input, but it's not expected + do_test_interactive_tx_constructor(TestSession { + description: "Provide a shared input, but it's not expected", + inputs_a: generate_inputs(&[TestOutput::P2WPKH(50_000)]), + a_shared_input: Some(generate_shared_input(&prev_funding_tx_1, 0, 60_000)), + shared_output_a: generate_funding_txout(108_000, 108_000), + outputs_a: vec![], + inputs_b: vec![], + b_shared_input: None, + shared_output_b: generate_funding_txout(108_000, 0), + outputs_b: vec![], + expect_error: Some((AbortReason::MissingFundingInput, ErrorCulprit::NodeA)), }); } @@ -2748,27 +3028,37 @@ mod tests { previous_output: OutPoint { txid, vout: 0 }, script_sig: ScriptBuf::new(), sequence: Sequence::ZERO, - witness: Witness::new(), + witness: Witness::p2wpkh( + &Signature::sighash_all( + bitcoin::secp256k1::ecdsa::Signature::from_der(&>::from_hex("3044022008f4f37e2d8f74e18c1b8fde2374d5f28402fb8ab7fd1cc5b786aa40851a70cb022032b1374d1a0f125eae4f69d1bc0b7f896c964cfdba329f38a952426cf427484c").unwrap()[..]).unwrap() + ) + .into(), + &PublicKey::from_slice(&[2; 33]).unwrap(), + ), }; - (txin, TransactionU16LenLimited::new(tx).unwrap()) + let witness_weight = Weight::from_wu_usize(txin.witness.size()); + (txin, TransactionU16LenLimited::new(tx).unwrap(), witness_weight) }) - .collect::>(); + .collect::>(); let our_contributed = 110_000; let txout = TxOut { value: Amount::from_sat(128_000), script_pubkey: ScriptBuf::new() }; - let value = txout.value.to_sat(); - let outputs = vec![OutputOwned::Shared(SharedOwnedOutput::new(txout, value))]; + let _value = txout.value.to_sat(); + // let outputs = OutputOwned::Shared(SharedOwnedOutput::new(txout, value))]; + let outputs = vec![]; let funding_feerate_sat_per_1000_weight = 3000; let total_inputs: u64 = input_prevouts.iter().map(|o| o.value.to_sat()).sum(); let gross_change = total_inputs - our_contributed; - let fees = 1746; - let common_fees = 126; + let fees = 1626; // 1734 - 108; + let common_fees = 234; // 126 + 108; { // There is leftover for change let res = calculate_change_output_value( true, our_contributed, &inputs, + None, + &ScriptBuf::new(), &outputs, funding_feerate_sat_per_1000_weight, 300, @@ -2781,6 +3071,8 @@ mod tests { false, our_contributed, &inputs, + None, + &ScriptBuf::new(), &outputs, funding_feerate_sat_per_1000_weight, 300, @@ -2789,9 +3081,17 @@ mod tests { } { // Larger fee, smaller change - let res = - calculate_change_output_value(true, our_contributed, &inputs, &outputs, 9000, 300); - assert_eq!(res.unwrap().unwrap(), 14384); + let res = calculate_change_output_value( + true, + our_contributed, + &inputs, + None, + &ScriptBuf::new(), + &outputs, + 9000, + 300, + ); + assert_eq!(res.unwrap().unwrap(), 14420); } { // Insufficient inputs, no leftover @@ -2799,6 +3099,8 @@ mod tests { false, 130_000, &inputs, + None, + &ScriptBuf::new(), &outputs, funding_feerate_sat_per_1000_weight, 300, @@ -2811,6 +3113,8 @@ mod tests { false, 128_100, &inputs, + None, + &ScriptBuf::new(), &outputs, funding_feerate_sat_per_1000_weight, 300, @@ -2823,11 +3127,13 @@ mod tests { false, 128_100, &inputs, + None, + &ScriptBuf::new(), &outputs, funding_feerate_sat_per_1000_weight, 100, ); - assert_eq!(res.unwrap().unwrap(), 154); + assert_eq!(res.unwrap().unwrap(), 274); } } } diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 2b18901591e..762240ef4d3 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -29,7 +29,7 @@ use bitcoin::hash_types::Txid; use bitcoin::script::ScriptBuf; use bitcoin::secp256k1::ecdsa::Signature; use bitcoin::secp256k1::PublicKey; -use bitcoin::{secp256k1, Witness}; +use bitcoin::{secp256k1, Transaction, Witness}; use crate::blinded_path::payment::{ BlindedPaymentTlvs, ForwardTlvs, ReceiveTlvs, UnauthenticatedReceiveTlvs, @@ -525,9 +525,9 @@ pub struct TxAddInput { /// A randomly chosen unique identifier for this input, which is even for initiators and odd for /// non-initiators. pub serial_id: SerialId, - /// Serialized transaction that contains the output this input spends to verify that it is non - /// malleable. - pub prevtx: TransactionU16LenLimited, + /// Serialized transaction that contains the output this input spends to verify that it is + /// non-malleable. Omitted for shared input. + pub prevtx: Option, /// The index of the output being spent pub prevtx_out: u32, /// The sequence number of this input @@ -2664,15 +2664,60 @@ impl_writeable_msg!(SpliceLocked, { splice_txid, }, {}); -impl_writeable_msg!(TxAddInput, { - channel_id, - serial_id, - prevtx, - prevtx_out, - sequence, -}, { - (0, shared_input_txid, option), // `funding_txid` -}); +impl Writeable for TxAddInput { + fn write(&self, w: &mut W) -> Result<(), io::Error> { + self.channel_id.write(w)?; + self.serial_id.write(w)?; + if let Some(prevtx) = self.prevtx.as_ref() { + debug_assert!(self.shared_input_txid.is_none()); + prevtx.write(w)?; + } else { + debug_assert!(self.shared_input_txid.is_some()); + 0u16.write(w)?; + } + self.prevtx_out.write(w)?; + self.sequence.write(w)?; + + if let Some(shared_input_txid) = self.shared_input_txid.as_ref() { + encode_tlv_stream!(w, { (0, shared_input_txid, required) }); + } else { + encode_tlv_stream!(w, {}); + } + + Ok(()) + } +} + +impl LengthReadable for TxAddInput { + fn read_from_fixed_length_buffer(r: &mut R) -> Result { + let channel_id = Readable::read(r)?; + let serial_id = Readable::read(r)?; + let prevtx_len = ::read(r)?; + let mut prevtx = None; + if prevtx_len > 0 { + let mut tx_reader = FixedLengthReader::new(r, prevtx_len as u64); + let tx: Transaction = Readable::read(&mut tx_reader)?; + if tx_reader.bytes_remain() { + return Err(DecodeError::BadLengthDescriptor); + } + prevtx = + Some(TransactionU16LenLimited::new(tx).map_err(|_| DecodeError::InvalidValue)?); + } + let prevtx_out = Readable::read(r)?; + let sequence = Readable::read(r)?; + + let mut shared_input_txid = None; + if prevtx_len > 0 { + decode_tlv_stream!(r, {}); + } else { + decode_tlv_stream!(r, { + (0, shared_input_txid, required), + }); + } + + Ok(Self { channel_id, serial_id, prevtx, prevtx_out, sequence, shared_input_txid }) + } +} impl_writeable_msg!(TxAddOutput, { channel_id, @@ -5206,7 +5251,7 @@ mod tests { let tx_add_input = msgs::TxAddInput { channel_id: ChannelId::from_bytes([2; 32]), serial_id: 4886718345, - prevtx: TransactionU16LenLimited::new(Transaction { + prevtx: Some(TransactionU16LenLimited::new(Transaction { version: Version::TWO, lock_time: LockTime::ZERO, input: vec![TxIn { @@ -5227,13 +5272,31 @@ mod tests { script_pubkey: Address::from_str("bc1qxmk834g5marzm227dgqvynd23y2nvt2ztwcw2z").unwrap().assume_checked().script_pubkey(), }, ], - }).unwrap(), + }).unwrap()), + prevtx_out: 305419896, + sequence: 305419896, + shared_input_txid: None, + }; + let encoded_value = tx_add_input.encode(); + let target_value = "0202020202020202020202020202020202020202020202020202020202020202000000012345678900de02000000000101779ced6c148293f86b60cb222108553d22c89207326bb7b6b897e23e64ab5b300200000000fdffffff0236dbc1000000000016001417d29e4dd454bac3b1cde50d1926da80cfc5287b9cbd03000000000016001436ec78d514df462da95e6a00c24daa8915362d420247304402206af85b7dd67450ad12c979302fac49dfacbc6a8620f49c5da2b5721cf9565ca502207002b32fed9ce1bf095f57aeb10c36928ac60b12e723d97d2964a54640ceefa701210301ab7dc16488303549bfcdd80f6ae5ee4c20bf97ab5410bbd6b1bfa85dcd6944000000001234567812345678"; + assert_eq!(encoded_value.as_hex().to_string(), target_value); + } + + #[test] + fn encoding_tx_add_input_shared() { + let tx_add_input = msgs::TxAddInput { + channel_id: ChannelId::from_bytes([2; 32]), + serial_id: 4886718345, + prevtx: None, prevtx_out: 305419896, sequence: 305419896, - shared_input_txid: Some(Txid::from_str("c2d4449afa8d26140898dd54d3390b057ba2a5afcf03ba29d7dc0d8b9ffe966e").unwrap()), + shared_input_txid: Some( + Txid::from_str("c2d4449afa8d26140898dd54d3390b057ba2a5afcf03ba29d7dc0d8b9ffe966e") + .unwrap(), + ), }; let encoded_value = tx_add_input.encode(); - let target_value = "0202020202020202020202020202020202020202020202020202020202020202000000012345678900de02000000000101779ced6c148293f86b60cb222108553d22c89207326bb7b6b897e23e64ab5b300200000000fdffffff0236dbc1000000000016001417d29e4dd454bac3b1cde50d1926da80cfc5287b9cbd03000000000016001436ec78d514df462da95e6a00c24daa8915362d420247304402206af85b7dd67450ad12c979302fac49dfacbc6a8620f49c5da2b5721cf9565ca502207002b32fed9ce1bf095f57aeb10c36928ac60b12e723d97d2964a54640ceefa701210301ab7dc16488303549bfcdd80f6ae5ee4c20bf97ab5410bbd6b1bfa85dcd694400000000123456781234567800206e96fe9f8b0ddcd729ba03cfafa5a27b050b39d354dd980814268dfa9a44d4c2"; + let target_value = "020202020202020202020202020202020202020202020202020202020202020200000001234567890000123456781234567800206e96fe9f8b0ddcd729ba03cfafa5a27b050b39d354dd980814268dfa9a44d4c2"; assert_eq!(encoded_value.as_hex().to_string(), target_value); } diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index a989d172687..73c831884f9 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -1,5 +1,3 @@ -#![cfg_attr(rustfmt, rustfmt_skip)] - // This file is Copyright its original authors, visible in version control // history. // @@ -18,50 +16,61 @@ //! messages they should handle, and encoding/sending response messages. use bitcoin::constants::ChainHash; -use bitcoin::secp256k1::{self, Secp256k1, SecretKey, PublicKey}; +use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; use crate::blinded_path::message::{AsyncPaymentsContext, DNSResolverContext, OffersContext}; -use crate::sign::{NodeSigner, Recipient}; -use crate::ln::types::ChannelId; -use crate::types::features::{InitFeatures, NodeFeatures}; use crate::ln::msgs; -use crate::ln::msgs::{BaseMessageHandler, ChannelMessageHandler, Init, LightningError, SocketAddress, MessageSendEvent, OnionMessageHandler, RoutingMessageHandler}; -use crate::util::ser::{VecWriter, Writeable, Writer}; -use crate::ln::peer_channel_encryptor::{PeerChannelEncryptor, NextNoiseStep, MessageBuf, MSG_BUF_ALLOC_SIZE}; +use crate::ln::msgs::{ + BaseMessageHandler, ChannelMessageHandler, Init, LightningError, MessageSendEvent, + OnionMessageHandler, RoutingMessageHandler, SocketAddress, +}; +use crate::ln::peer_channel_encryptor::{ + MessageBuf, NextNoiseStep, PeerChannelEncryptor, MSG_BUF_ALLOC_SIZE, +}; +use crate::ln::types::ChannelId; use crate::ln::wire; use crate::ln::wire::{Encode, Type}; -use crate::onion_message::async_payments::{AsyncPaymentsMessageHandler, HeldHtlcAvailable, ReleaseHeldHtlc}; -use crate::onion_message::dns_resolution::{DNSResolverMessageHandler, DNSResolverMessage, DNSSECProof, DNSSECQuery}; -use crate::onion_message::messenger::{CustomOnionMessageHandler, Responder, ResponseInstruction, MessageSendInstructions}; +use crate::onion_message::async_payments::{ + AsyncPaymentsMessageHandler, HeldHtlcAvailable, ReleaseHeldHtlc, +}; +use crate::onion_message::dns_resolution::{ + DNSResolverMessage, DNSResolverMessageHandler, DNSSECProof, DNSSECQuery, +}; +use crate::onion_message::messenger::{ + CustomOnionMessageHandler, MessageSendInstructions, Responder, ResponseInstruction, +}; use crate::onion_message::offers::{OffersMessage, OffersMessageHandler}; use crate::onion_message::packet::OnionMessageContents; -use crate::routing::gossip::{NodeId, NodeAlias}; +use crate::routing::gossip::{NodeAlias, NodeId}; +use crate::sign::{NodeSigner, Recipient}; +use crate::types::features::{InitFeatures, NodeFeatures}; use crate::util::atomic_counter::AtomicCounter; use crate::util::logger::{Level, Logger, WithContext}; +use crate::util::ser::{VecWriter, Writeable, Writer}; use crate::util::string::PrintableString; #[allow(unused_imports)] use crate::prelude::*; use crate::io; -use crate::sync::{Mutex, MutexGuard, FairRwLock}; -use core::sync::atomic::{AtomicBool, AtomicU32, AtomicI32, Ordering}; -use core::{cmp, hash, fmt, mem}; -use core::ops::Deref; +use crate::sync::{FairRwLock, Mutex, MutexGuard}; use core::convert::Infallible; +use core::ops::Deref; +use core::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, Ordering}; +use core::{cmp, fmt, hash, mem}; #[cfg(not(c_bindings))] use { crate::chain::chainmonitor::ChainMonitor, crate::ln::channelmanager::{SimpleArcChannelManager, SimpleRefChannelManager}, crate::onion_message::messenger::{SimpleArcOnionMessenger, SimpleRefOnionMessenger}, crate::routing::gossip::{NetworkGraph, P2PGossipSync}, - crate::sign::{KeysManager, InMemorySigner}, + crate::sign::{InMemorySigner, KeysManager}, crate::sync::Arc, }; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::hashes::sha256::HashEngine as Sha256Engine; -use bitcoin::hashes::{HashEngine, Hash}; +use bitcoin::hashes::{Hash, HashEngine}; /// A handler provided to [`PeerManager`] for reading and handling custom messages. /// @@ -75,7 +84,9 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// Handles the given message sent from `sender_node_id`, possibly producing messages for /// [`CustomMessageHandler::get_and_clear_pending_msg`] to return and thus for [`PeerManager`] /// to send. - fn handle_custom_message(&self, msg: Self::CustomMessage, sender_node_id: PublicKey) -> Result<(), LightningError>; + fn handle_custom_message( + &self, msg: Self::CustomMessage, sender_node_id: PublicKey, + ) -> Result<(), LightningError>; /// Returns the list of pending messages that were generated by the handler, clearing the list /// in the process. Each message is paired with the node id of the intended recipient. If no @@ -92,7 +103,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// message handlers may still wish to communicate with this peer. /// /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. - fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; + fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) + -> Result<(), ()>; /// Gets the node feature flags which this handler itself supports. All available handlers are /// queried similarly and their feature flags are OR'd together to form the [`NodeFeatures`] @@ -111,39 +123,89 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// A dummy struct which implements `RoutingMessageHandler` without storing any routing information /// or doing any processing. You can provide one of these as the route_handler in a MessageHandler. -pub struct IgnoringMessageHandler{} +pub struct IgnoringMessageHandler {} impl BaseMessageHandler for IgnoringMessageHandler { fn peer_disconnected(&self, _their_node_id: PublicKey) {} - fn peer_connected(&self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) } - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } + fn peer_connected( + &self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool, + ) -> Result<(), ()> { + Ok(()) + } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures { InitFeatures::empty() } - - fn get_and_clear_pending_msg_events(&self) -> Vec { Vec::new() } + fn get_and_clear_pending_msg_events(&self) -> Vec { + Vec::new() + } } impl RoutingMessageHandler for IgnoringMessageHandler { - fn handle_node_announcement(&self, _their_node_id: Option, _msg: &msgs::NodeAnnouncement) -> Result { Ok(false) } - fn handle_channel_announcement(&self, _their_node_id: Option, _msg: &msgs::ChannelAnnouncement) -> Result { Ok(false) } - fn handle_channel_update(&self, _their_node_id: Option, _msg: &msgs::ChannelUpdate) -> Result { Ok(false) } - fn get_next_channel_announcement(&self, _starting_point: u64) -> - Option<(msgs::ChannelAnnouncement, Option, Option)> { None } - fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option { None } - fn handle_reply_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), LightningError> { Ok(()) } - fn handle_reply_short_channel_ids_end(&self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) } - fn handle_query_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), LightningError> { Ok(()) } - fn handle_query_short_channel_ids(&self, _their_node_id: PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), LightningError> { Ok(()) } - fn processing_queue_high(&self) -> bool { false } + fn handle_node_announcement( + &self, _their_node_id: Option, _msg: &msgs::NodeAnnouncement, + ) -> Result { + Ok(false) + } + fn handle_channel_announcement( + &self, _their_node_id: Option, _msg: &msgs::ChannelAnnouncement, + ) -> Result { + Ok(false) + } + fn handle_channel_update( + &self, _their_node_id: Option, _msg: &msgs::ChannelUpdate, + ) -> Result { + Ok(false) + } + fn get_next_channel_announcement( + &self, _starting_point: u64, + ) -> Option<(msgs::ChannelAnnouncement, Option, Option)> + { + None + } + fn get_next_node_announcement( + &self, _starting_point: Option<&NodeId>, + ) -> Option { + None + } + fn handle_reply_channel_range( + &self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_reply_short_channel_ids_end( + &self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_query_channel_range( + &self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange, + ) -> Result<(), LightningError> { + Ok(()) + } + fn handle_query_short_channel_ids( + &self, _their_node_id: PublicKey, _msg: msgs::QueryShortChannelIds, + ) -> Result<(), LightningError> { + Ok(()) + } + fn processing_queue_high(&self) -> bool { + false + } } impl OnionMessageHandler for IgnoringMessageHandler { fn handle_onion_message(&self, _their_node_id: PublicKey, _msg: &msgs::OnionMessage) {} - fn next_onion_message_for_peer(&self, _peer_node_id: PublicKey) -> Option { None } + fn next_onion_message_for_peer(&self, _peer_node_id: PublicKey) -> Option { + None + } fn timer_tick_occurred(&self) {} } impl OffersMessageHandler for IgnoringMessageHandler { - fn handle_message(&self, _message: OffersMessage, _context: Option, _responder: Option) -> Option<(OffersMessage, ResponseInstruction)> { + fn handle_message( + &self, _message: OffersMessage, _context: Option, + _responder: Option, + ) -> Option<(OffersMessage, ResponseInstruction)> { None } } @@ -166,11 +228,18 @@ impl DNSResolverMessageHandler for IgnoringMessageHandler { } impl CustomOnionMessageHandler for IgnoringMessageHandler { type CustomMessage = Infallible; - fn handle_custom_message(&self, _message: Infallible, _context: Option>, _responder: Option) -> Option<(Infallible, ResponseInstruction)> { + fn handle_custom_message( + &self, _message: Infallible, _context: Option>, _responder: Option, + ) -> Option<(Infallible, ResponseInstruction)> { // Since we always return `None` in the read the handle method should never be called. unreachable!(); } - fn read_custom_message(&self, _msg_type: u64, _buffer: &mut R) -> Result, msgs::DecodeError> where Self: Sized { + fn read_custom_message( + &self, _msg_type: u64, _buffer: &mut R, + ) -> Result, msgs::DecodeError> + where + Self: Sized, + { Ok(None) } fn release_pending_custom_messages(&self) -> Vec<(Infallible, MessageSendInstructions)> { @@ -179,16 +248,24 @@ impl CustomOnionMessageHandler for IgnoringMessageHandler { } impl OnionMessageContents for Infallible { - fn tlv_type(&self) -> u64 { unreachable!(); } + fn tlv_type(&self) -> u64 { + unreachable!(); + } #[cfg(c_bindings)] - fn msg_type(&self) -> String { unreachable!(); } + fn msg_type(&self) -> String { + unreachable!(); + } #[cfg(not(c_bindings))] - fn msg_type(&self) -> &'static str { unreachable!(); } + fn msg_type(&self) -> &'static str { + unreachable!(); + } } impl Deref for IgnoringMessageHandler { type Target = IgnoringMessageHandler; - fn deref(&self) -> &Self { self } + fn deref(&self) -> &Self { + self + } } // Implement Type for Infallible, note that it cannot be constructed, and thus you can never call a @@ -206,24 +283,36 @@ impl Writeable for Infallible { impl wire::CustomMessageReader for IgnoringMessageHandler { type CustomMessage = Infallible; - fn read(&self, _message_type: u16, _buffer: &mut R) -> Result, msgs::DecodeError> { + fn read( + &self, _message_type: u16, _buffer: &mut R, + ) -> Result, msgs::DecodeError> { Ok(None) } } impl CustomMessageHandler for IgnoringMessageHandler { - fn handle_custom_message(&self, _msg: Infallible, _sender_node_id: PublicKey) -> Result<(), LightningError> { + fn handle_custom_message( + &self, _msg: Infallible, _sender_node_id: PublicKey, + ) -> Result<(), LightningError> { // Since we always return `None` in the read the handle method should never be called. unreachable!(); } - fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() } + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { + Vec::new() + } fn peer_disconnected(&self, _their_node_id: PublicKey) {} - fn peer_connected(&self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) } + fn peer_connected( + &self, _their_node_id: PublicKey, _msg: &Init, _inbound: bool, + ) -> Result<(), ()> { + Ok(()) + } - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures { InitFeatures::empty() @@ -233,7 +322,7 @@ impl CustomMessageHandler for IgnoringMessageHandler { /// A dummy struct which implements `ChannelMessageHandler` without having any channels. /// You can provide one of these as the route_handler in a MessageHandler. pub struct ErroringMessageHandler { - message_queue: Mutex> + message_queue: Mutex>, } impl ErroringMessageHandler { /// Constructs a new ErroringMessageHandler @@ -243,7 +332,10 @@ impl ErroringMessageHandler { fn push_error(&self, node_id: PublicKey, channel_id: ChannelId) { self.message_queue.lock().unwrap().push(MessageSendEvent::HandleError { action: msgs::ErrorAction::SendErrorMessage { - msg: msgs::ErrorMessage { channel_id, data: "We do not support channel messages, sorry.".to_owned() }, + msg: msgs::ErrorMessage { + channel_id, + data: "We do not support channel messages, sorry.".to_owned(), + }, }, node_id, }); @@ -251,8 +343,14 @@ impl ErroringMessageHandler { } impl BaseMessageHandler for ErroringMessageHandler { fn peer_disconnected(&self, _their_node_id: PublicKey) {} - fn peer_connected(&self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) } - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } + fn peer_connected( + &self, _their_node_id: PublicKey, _init: &msgs::Init, _inbound: bool, + ) -> Result<(), ()> { + Ok(()) + } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures { // Set a number of features which various nodes may require to talk to us. It's totally // reasonable to indicate we "support" all kinds of channel features...we just reject all @@ -284,10 +382,18 @@ impl ChannelMessageHandler for ErroringMessageHandler { // Any messages which are related to a specific channel generate an error message to let the // peer know we don't care about channels. fn handle_open_channel(&self, their_node_id: PublicKey, msg: &msgs::OpenChannel) { - ErroringMessageHandler::push_error(self, their_node_id, msg.common_fields.temporary_channel_id); + ErroringMessageHandler::push_error( + self, + their_node_id, + msg.common_fields.temporary_channel_id, + ); } fn handle_accept_channel(&self, their_node_id: PublicKey, msg: &msgs::AcceptChannel) { - ErroringMessageHandler::push_error(self, their_node_id, msg.common_fields.temporary_channel_id); + ErroringMessageHandler::push_error( + self, + their_node_id, + msg.common_fields.temporary_channel_id, + ); } fn handle_funding_created(&self, their_node_id: PublicKey, msg: &msgs::FundingCreated) { ErroringMessageHandler::push_error(self, their_node_id, msg.temporary_channel_id); @@ -328,7 +434,9 @@ impl ChannelMessageHandler for ErroringMessageHandler { fn handle_update_fail_htlc(&self, their_node_id: PublicKey, msg: &msgs::UpdateFailHTLC) { ErroringMessageHandler::push_error(self, their_node_id, msg.channel_id); } - fn handle_update_fail_malformed_htlc(&self, their_node_id: PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { + fn handle_update_fail_malformed_htlc( + &self, their_node_id: PublicKey, msg: &msgs::UpdateFailMalformedHTLC, + ) { ErroringMessageHandler::push_error(self, their_node_id, msg.channel_id); } fn handle_commitment_signed(&self, their_node_id: PublicKey, msg: &msgs::CommitmentSigned) { @@ -345,7 +453,9 @@ impl ChannelMessageHandler for ErroringMessageHandler { fn handle_update_fee(&self, their_node_id: PublicKey, msg: &msgs::UpdateFee) { ErroringMessageHandler::push_error(self, their_node_id, msg.channel_id); } - fn handle_announcement_signatures(&self, their_node_id: PublicKey, msg: &msgs::AnnouncementSignatures) { + fn handle_announcement_signatures( + &self, their_node_id: PublicKey, msg: &msgs::AnnouncementSignatures, + ) { ErroringMessageHandler::push_error(self, their_node_id, msg.channel_id); } fn handle_channel_reestablish(&self, their_node_id: PublicKey, msg: &msgs::ChannelReestablish) { @@ -355,7 +465,10 @@ impl ChannelMessageHandler for ErroringMessageHandler { fn handle_channel_update(&self, _their_node_id: PublicKey, _msg: &msgs::ChannelUpdate) {} fn handle_peer_storage(&self, _their_node_id: PublicKey, _msg: msgs::PeerStorage) {} - fn handle_peer_storage_retrieval(&self, _their_node_id: PublicKey, _msg: msgs::PeerStorageRetrieval) {} + fn handle_peer_storage_retrieval( + &self, _their_node_id: PublicKey, _msg: msgs::PeerStorageRetrieval, + ) { + } fn handle_error(&self, _their_node_id: PublicKey, _msg: &msgs::ErrorMessage) {} @@ -367,11 +480,19 @@ impl ChannelMessageHandler for ErroringMessageHandler { } fn handle_open_channel_v2(&self, their_node_id: PublicKey, msg: &msgs::OpenChannelV2) { - ErroringMessageHandler::push_error(self, their_node_id, msg.common_fields.temporary_channel_id); + ErroringMessageHandler::push_error( + self, + their_node_id, + msg.common_fields.temporary_channel_id, + ); } fn handle_accept_channel_v2(&self, their_node_id: PublicKey, msg: &msgs::AcceptChannelV2) { - ErroringMessageHandler::push_error(self, their_node_id, msg.common_fields.temporary_channel_id); + ErroringMessageHandler::push_error( + self, + their_node_id, + msg.common_fields.temporary_channel_id, + ); } fn handle_tx_add_input(&self, their_node_id: PublicKey, msg: &msgs::TxAddInput) { @@ -415,11 +536,14 @@ impl ChannelMessageHandler for ErroringMessageHandler { impl Deref for ErroringMessageHandler { type Target = ErroringMessageHandler; - fn deref(&self) -> &Self { self } + fn deref(&self) -> &Self { + self + } } /// Provides references to trait impls which handle different types of messages. -pub struct MessageHandler where +pub struct MessageHandler +where CM::Target: ChannelMessageHandler, RM::Target: RoutingMessageHandler, OM::Target: OnionMessageHandler, @@ -449,7 +573,7 @@ pub struct MessageHandler Result<(), fmt::Error> { formatter.write_str("Peer Sent Invalid Data") @@ -529,7 +653,7 @@ impl fmt::Display for PeerHandleError { } /// Internal struct for keeping track of the gossip syncing progress with a given peer -enum InitSyncTracker{ +enum InitSyncTracker { /// Only sync ad-hoc gossip as it comes in, do not send historical gossip. /// Upon receipt of a GossipTimestampFilter message, this is the default initial state if the /// contained timestamp is less than 6 hours old. @@ -572,7 +696,8 @@ const FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO: usize = 2; const OUTBOUND_BUFFER_LIMIT_READ_PAUSE: usize = 12; /// When the outbound buffer has this many messages, we'll simply skip relaying gossip messages to /// the peer. -const OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP: usize = OUTBOUND_BUFFER_LIMIT_READ_PAUSE * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO; +const OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP: usize = + OUTBOUND_BUFFER_LIMIT_READ_PAUSE * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO; /// If we've sent a ping, and are still awaiting a response, we may need to churn our way through /// the socket receive buffer before receiving the ping. @@ -660,11 +785,14 @@ impl Peer { /// point and we shouldn't send it yet to avoid sending duplicate updates. If we've already /// sent the old versions, we should send the update, and so return true here. fn should_forward_channel_announcement(&self, channel_id: u64) -> bool { - if !self.handshake_complete() { return false; } - if self.their_features.as_ref().unwrap().supports_gossip_queries() && - !self.sent_gossip_timestamp_filter { - return false; - } + if !self.handshake_complete() { + return false; + } + if self.their_features.as_ref().unwrap().supports_gossip_queries() + && !self.sent_gossip_timestamp_filter + { + return false; + } match self.sync_status { InitSyncTracker::NoSyncRequested => true, InitSyncTracker::ChannelsSyncing(i) => i < channel_id, @@ -674,15 +802,20 @@ impl Peer { /// Similar to the above, but for node announcements indexed by node_id. fn should_forward_node_announcement(&self, node_id: NodeId) -> bool { - if !self.handshake_complete() { return false; } - if self.their_features.as_ref().unwrap().supports_gossip_queries() && - !self.sent_gossip_timestamp_filter { - return false; - } + if !self.handshake_complete() { + return false; + } + if self.their_features.as_ref().unwrap().supports_gossip_queries() + && !self.sent_gossip_timestamp_filter + { + return false; + } match self.sync_status { InitSyncTracker::NoSyncRequested => true, InitSyncTracker::ChannelsSyncing(_) => false, - InitSyncTracker::NodesSyncing(sync_node_id) => sync_node_id.as_slice() < node_id.as_slice(), + InitSyncTracker::NodesSyncing(sync_node_id) => { + sync_node_id.as_slice() < node_id.as_slice() + }, } } @@ -692,14 +825,15 @@ impl Peer { if !gossip_processing_backlogged { self.received_channel_announce_since_backlogged = false; } - self.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE && - (!gossip_processing_backlogged || !self.received_channel_announce_since_backlogged) + self.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE + && (!gossip_processing_backlogged || !self.received_channel_announce_since_backlogged) } /// Determines if we should push additional gossip background sync (aka "backfill") onto a peer's /// outbound buffer. This is checked every time the peer's buffer may have been drained. fn should_buffer_gossip_backfill(&self) -> bool { - self.pending_outbound_buffer.is_empty() && self.gossip_broadcast_buffer.is_empty() + self.pending_outbound_buffer.is_empty() + && self.gossip_broadcast_buffer.is_empty() && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK && self.handshake_complete() } @@ -707,14 +841,16 @@ impl Peer { /// Determines if we should push an onion message onto a peer's outbound buffer. This is checked /// every time the peer's buffer may have been drained. fn should_buffer_onion_message(&self) -> bool { - self.pending_outbound_buffer.is_empty() && self.handshake_complete() + self.pending_outbound_buffer.is_empty() + && self.handshake_complete() && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK } /// Determines if we should push additional gossip broadcast messages onto a peer's outbound /// buffer. This is checked every time the peer's buffer may have been drained. fn should_buffer_gossip_broadcast(&self) -> bool { - self.pending_outbound_buffer.is_empty() && self.handshake_complete() + self.pending_outbound_buffer.is_empty() + && self.handshake_complete() && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK } @@ -723,8 +859,9 @@ impl Peer { let total_outbound_buffered = self.gossip_broadcast_buffer.len() + self.pending_outbound_buffer.len(); - total_outbound_buffered > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP || - self.msgs_sent_since_pong > BUFFER_DRAIN_MSGS_PER_TICK * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO + total_outbound_buffered > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP + || self.msgs_sent_since_pong + > BUFFER_DRAIN_MSGS_PER_TICK * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO } fn set_their_node_id(&mut self, node_id: PublicKey) { @@ -760,6 +897,7 @@ pub type SimpleArcPeerManager = PeerManager< /// /// This is not exported to bindings users as type aliases aren't supported in most languages. #[cfg(not(c_bindings))] +#[rustfmt::skip] pub type SimpleRefPeerManager< 'a, 'b, 'c, 'd, 'e, 'f, 'logger, 'h, 'i, 'j, 'graph, 'k, 'mr, SD, M, T, F, C, L > = PeerManager< @@ -773,7 +911,6 @@ pub type SimpleRefPeerManager< &'j ChainMonitor<&'a M, C, &'b T, &'c F, &'logger L, &'c KeysManager, &'c KeysManager>, >; - /// A generic trait which is implemented for all [`PeerManager`]s. This makes bounding functions or /// structs on any [`PeerManager`] much simpler as only this trait is needed as a bound, rather /// than the full set of bounds on [`PeerManager`] itself. @@ -784,25 +921,45 @@ pub type SimpleRefPeerManager< pub trait APeerManager { type Descriptor: SocketDescriptor; type CMT: ChannelMessageHandler + ?Sized; - type CM: Deref; + type CM: Deref; type RMT: RoutingMessageHandler + ?Sized; - type RM: Deref; + type RM: Deref; type OMT: OnionMessageHandler + ?Sized; - type OM: Deref; + type OM: Deref; type LT: Logger + ?Sized; - type L: Deref; + type L: Deref; type CMHT: CustomMessageHandler + ?Sized; - type CMH: Deref; + type CMH: Deref; type NST: NodeSigner + ?Sized; - type NS: Deref; + type NS: Deref; type SMT: BaseMessageHandler + ?Sized; - type SM: Deref; + type SM: Deref; /// Gets a reference to the underlying [`PeerManager`]. - fn as_ref(&self) -> &PeerManager; + fn as_ref( + &self, + ) -> &PeerManager< + Self::Descriptor, + Self::CM, + Self::RM, + Self::OM, + Self::L, + Self::CMH, + Self::NS, + Self::SM, + >; } -impl -APeerManager for PeerManager where +impl< + Descriptor: SocketDescriptor, + CM: Deref, + RM: Deref, + OM: Deref, + L: Deref, + CMH: Deref, + NS: Deref, + SM: Deref, + > APeerManager for PeerManager +where CM::Target: ChannelMessageHandler, RM::Target: RoutingMessageHandler, OM::Target: OnionMessageHandler, @@ -826,7 +983,9 @@ APeerManager for PeerManager where type NS = NS; type SMT = ::Target; type SM = SM; - fn as_ref(&self) -> &PeerManager { self } + fn as_ref(&self) -> &PeerManager { + self + } } /// A PeerManager manages a set of peers, described by their [`SocketDescriptor`] and marshalls @@ -848,14 +1007,23 @@ APeerManager for PeerManager where /// you're using lightning-net-tokio. /// /// [`read_event`]: PeerManager::read_event -pub struct PeerManager where - CM::Target: ChannelMessageHandler, - RM::Target: RoutingMessageHandler, - OM::Target: OnionMessageHandler, - L::Target: Logger, - CMH::Target: CustomMessageHandler, - NS::Target: NodeSigner, - SM::Target: BaseMessageHandler, +pub struct PeerManager< + Descriptor: SocketDescriptor, + CM: Deref, + RM: Deref, + OM: Deref, + L: Deref, + CMH: Deref, + NS: Deref, + SM: Deref, +> where + CM::Target: ChannelMessageHandler, + RM::Target: RoutingMessageHandler, + OM::Target: OnionMessageHandler, + L::Target: Logger, + CMH::Target: CustomMessageHandler, + NS::Target: NodeSigner, + SM::Target: BaseMessageHandler, { message_handler: MessageHandler, /// Connection state for each connected peer - we have an outer read-write lock which is taken @@ -899,7 +1067,7 @@ pub struct PeerManager + secp_ctx: Secp256k1, } enum LogicalMessage { @@ -929,15 +1097,17 @@ macro_rules! encode_msg { let mut buffer = VecWriter(Vec::with_capacity(MSG_BUF_ALLOC_SIZE)); wire::write($msg, &mut buffer).unwrap(); buffer.0 - }} + }}; } -impl PeerManager where - CM::Target: ChannelMessageHandler, - OM::Target: OnionMessageHandler, - L::Target: Logger, - NS::Target: NodeSigner, - SM::Target: BaseMessageHandler, +impl + PeerManager +where + CM::Target: ChannelMessageHandler, + OM::Target: OnionMessageHandler, + L::Target: Logger, + NS::Target: NodeSigner, + SM::Target: BaseMessageHandler, { /// Constructs a new `PeerManager` with the given `ChannelMessageHandler` and /// `OnionMessageHandler`. No routing message handler is used and network graph messages are @@ -952,21 +1122,42 @@ impl Self { - Self::new(MessageHandler { - chan_handler: channel_message_handler, - route_handler: IgnoringMessageHandler{}, - onion_message_handler, - custom_message_handler: IgnoringMessageHandler{}, - send_only_message_handler, - }, current_time, ephemeral_random_data, logger, node_signer) + pub fn new_channel_only( + channel_message_handler: CM, onion_message_handler: OM, current_time: u32, + ephemeral_random_data: &[u8; 32], logger: L, node_signer: NS, + send_only_message_handler: SM, + ) -> Self { + Self::new( + MessageHandler { + chan_handler: channel_message_handler, + route_handler: IgnoringMessageHandler {}, + onion_message_handler, + custom_message_handler: IgnoringMessageHandler {}, + send_only_message_handler, + }, + current_time, + ephemeral_random_data, + logger, + node_signer, + ) } } -impl PeerManager where - RM::Target: RoutingMessageHandler, - L::Target: Logger, - NS::Target: NodeSigner { +impl + PeerManager< + Descriptor, + ErroringMessageHandler, + RM, + IgnoringMessageHandler, + L, + IgnoringMessageHandler, + NS, + IgnoringMessageHandler, + > where + RM::Target: RoutingMessageHandler, + L::Target: Logger, + NS::Target: NodeSigner, +{ /// Constructs a new `PeerManager` with the given `RoutingMessageHandler`. No channel message /// handler or onion message handler is used and onion and channel messages will be ignored (or /// generate error messages). Note that some other lightning implementations time-out connections @@ -981,14 +1172,23 @@ impl PeerManager Self { - Self::new(MessageHandler { - chan_handler: ErroringMessageHandler::new(), - route_handler: routing_message_handler, - onion_message_handler: IgnoringMessageHandler{}, - custom_message_handler: IgnoringMessageHandler{}, - send_only_message_handler: IgnoringMessageHandler{}, - }, current_time, ephemeral_random_data, logger, node_signer) + pub fn new_routing_only( + routing_message_handler: RM, current_time: u32, ephemeral_random_data: &[u8; 32], + logger: L, node_signer: NS, + ) -> Self { + Self::new( + MessageHandler { + chan_handler: ErroringMessageHandler::new(), + route_handler: routing_message_handler, + onion_message_handler: IgnoringMessageHandler {}, + custom_message_handler: IgnoringMessageHandler {}, + send_only_message_handler: IgnoringMessageHandler {}, + }, + current_time, + ephemeral_random_data, + logger, + node_signer, + ) } } @@ -999,7 +1199,11 @@ impl PeerManager(&'a Option<(PublicKey, NodeId)>); impl core::fmt::Display for OptionalFromDebugger<'_> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { - if let Some((node_id, _)) = self.0 { write!(f, " from {}", log_pubkey!(node_id)) } else { Ok(()) } + if let Some((node_id, _)) = self.0 { + write!(f, " from {}", log_pubkey!(node_id)) + } else { + Ok(()) + } } } @@ -1007,40 +1211,53 @@ impl core::fmt::Display for OptionalFromDebugger<'_> { /// /// fn filter_addresses(ip_address: Option) -> Option { - match ip_address{ + match ip_address { // For IPv4 range 10.0.0.0 - 10.255.255.255 (10/8) - Some(SocketAddress::TcpIpV4{addr: [10, _, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [10, _, _, _], port: _ }) => None, // For IPv4 range 0.0.0.0 - 0.255.255.255 (0/8) - Some(SocketAddress::TcpIpV4{addr: [0, _, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [0, _, _, _], port: _ }) => None, // For IPv4 range 100.64.0.0 - 100.127.255.255 (100.64/10) - Some(SocketAddress::TcpIpV4{addr: [100, 64..=127, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [100, 64..=127, _, _], port: _ }) => None, // For IPv4 range 127.0.0.0 - 127.255.255.255 (127/8) - Some(SocketAddress::TcpIpV4{addr: [127, _, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [127, _, _, _], port: _ }) => None, // For IPv4 range 169.254.0.0 - 169.254.255.255 (169.254/16) - Some(SocketAddress::TcpIpV4{addr: [169, 254, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [169, 254, _, _], port: _ }) => None, // For IPv4 range 172.16.0.0 - 172.31.255.255 (172.16/12) - Some(SocketAddress::TcpIpV4{addr: [172, 16..=31, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [172, 16..=31, _, _], port: _ }) => None, // For IPv4 range 192.168.0.0 - 192.168.255.255 (192.168/16) - Some(SocketAddress::TcpIpV4{addr: [192, 168, _, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [192, 168, _, _], port: _ }) => None, // For IPv4 range 192.88.99.0 - 192.88.99.255 (192.88.99/24) - Some(SocketAddress::TcpIpV4{addr: [192, 88, 99, _], port: _}) => None, + Some(SocketAddress::TcpIpV4 { addr: [192, 88, 99, _], port: _ }) => None, // For IPv6 range 2000:0000:0000:0000:0000:0000:0000:0000 - 3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff (2000::/3) - Some(SocketAddress::TcpIpV6{addr: [0x20..=0x3F, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], port: _}) => ip_address, + Some(SocketAddress::TcpIpV6 { + addr: [0x20..=0x3F, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], + port: _, + }) => ip_address, // For remaining addresses - Some(SocketAddress::TcpIpV6{addr: _, port: _}) => None, + Some(SocketAddress::TcpIpV6 { addr: _, port: _ }) => None, Some(..) => ip_address, None => None, } } -impl PeerManager where - CM::Target: ChannelMessageHandler, - RM::Target: RoutingMessageHandler, - OM::Target: OnionMessageHandler, - L::Target: Logger, - CMH::Target: CustomMessageHandler, - NS::Target: NodeSigner, - SM::Target: BaseMessageHandler, +impl< + Descriptor: SocketDescriptor, + CM: Deref, + RM: Deref, + OM: Deref, + L: Deref, + CMH: Deref, + NS: Deref, + SM: Deref, + > PeerManager +where + CM::Target: ChannelMessageHandler, + RM::Target: RoutingMessageHandler, + OM::Target: OnionMessageHandler, + L::Target: Logger, + CMH::Target: CustomMessageHandler, + NS::Target: NodeSigner, + SM::Target: BaseMessageHandler, { /// Constructs a new `PeerManager` with the given message handlers. /// @@ -1051,7 +1268,10 @@ impl, current_time: u32, ephemeral_random_data: &[u8; 32], logger: L, node_signer: NS) -> Self { + pub fn new( + message_handler: MessageHandler, current_time: u32, + ephemeral_random_data: &[u8; 32], logger: L, node_signer: NS, + ) -> Self { let mut ephemeral_key_midstate = Sha256::engine(); ephemeral_key_midstate.input(ephemeral_random_data); @@ -1079,7 +1299,7 @@ impl Vec { let peers = self.peers.read().unwrap(); - peers.values().filter_map(|peer_mutex| { + let filter_fn = |peer_mutex: &Mutex| { let p = peer_mutex.lock().unwrap(); if !p.handshake_complete() { return None; @@ -1095,7 +1315,8 @@ impl InitFeatures { @@ -1158,8 +1380,12 @@ impl) -> Result, PeerHandleError> { - let mut peer_encryptor = PeerChannelEncryptor::new_outbound(their_node_id.clone(), self.get_ephemeral_key()); + pub fn new_outbound_connection( + &self, their_node_id: PublicKey, descriptor: Descriptor, + remote_network_address: Option, + ) -> Result, PeerHandleError> { + let mut peer_encryptor = + PeerChannelEncryptor::new_outbound(their_node_id.clone(), self.get_ephemeral_key()); let res = peer_encryptor.get_act_one(&self.secp_ctx).to_vec(); let pending_read_buffer = [0; 50].to_vec(); // Noise act two is 50 bytes @@ -1198,7 +1424,7 @@ impl) -> Result<(), PeerHandleError> { + pub fn new_inbound_connection( + &self, descriptor: Descriptor, remote_network_address: Option, + ) -> Result<(), PeerHandleError> { let peer_encryptor = PeerChannelEncryptor::new_inbound(&self.node_signer); let pending_read_buffer = [0; 50].to_vec(); // Noise act one is 50 bytes @@ -1256,7 +1484,7 @@ impl { - if let Some(msg) = self.message_handler.route_handler.get_next_node_announcement(None) { + let handler = &self.message_handler.route_handler; + if let Some(msg) = handler.get_next_node_announcement(None) { self.enqueue_message(peer, &msg); peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); } else { @@ -1317,7 +1554,8 @@ impl unreachable!(), InitSyncTracker::NodesSyncing(sync_node_id) => { - if let Some(msg) = self.message_handler.route_handler.get_next_node_announcement(Some(&sync_node_id)) { + let handler = &self.message_handler.route_handler; + if let Some(msg) = handler.get_next_node_announcement(Some(&sync_node_id)) { self.enqueue_message(peer, &msg); peer.sync_status = InitSyncTracker::NodesSyncing(msg.contents.node_id); } else { @@ -1339,7 +1577,7 @@ impl buff, }; @@ -1376,20 +1614,22 @@ impl Result<(), PeerHandleError> { + pub fn write_buffer_space_avail( + &self, descriptor: &mut Descriptor, + ) -> Result<(), PeerHandleError> { let peers = self.peers.read().unwrap(); match peers.get(descriptor) { None => { // This is most likely a simple race condition where the user found that the socket // was writeable, then we told the user to `disconnect_socket()`, then they called // this method. Return an error to make sure we get disconnected. - return Err(PeerHandleError { }); + return Err(PeerHandleError {}); }, Some(peer_mutex) => { let mut peer = peer_mutex.lock().unwrap(); peer.awaiting_write_event = false; self.do_attempt_write_data(descriptor, &mut peer, false); - } + }, }; Ok(()) } @@ -1411,23 +1651,28 @@ impl Result { + pub fn read_event( + &self, peer_descriptor: &mut Descriptor, data: &[u8], + ) -> Result { match self.do_read_event(peer_descriptor, data) { Ok(res) => Ok(res), Err(e) => { self.disconnect_event_internal(peer_descriptor, "of a protocol error"); Err(e) - } + }, } } /// Append a message to a peer's pending outbound/write buffer fn enqueue_message(&self, peer: &mut Peer, message: &M) { - let logger = WithContext::from(&self.logger, peer.their_node_id.map(|p| p.0), None, None); + let their_node_id = peer.their_node_id.map(|p| p.0); + let logger = WithContext::from(&self.logger, their_node_id, None, None); + // `unwrap` SAFETY: `their_node_id` is guaranteed to be `Some` after the handshake + let node_id = peer.their_node_id.unwrap().0; if is_gossip_msg(message.type_id()) { - log_gossip!(logger, "Enqueueing message {:?} to {}", message, log_pubkey!(peer.their_node_id.unwrap().0)); + log_gossip!(logger, "Enqueueing message {:?} to {}", message, log_pubkey!(node_id)); } else { - log_trace!(logger, "Enqueueing message {:?} to {}", message, log_pubkey!(peer.their_node_id.unwrap().0)) + log_trace!(logger, "Enqueueing message {:?} to {}", message, log_pubkey!(node_id)); } peer.msgs_sent_since_pong += 1; peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(message)); @@ -1439,249 +1684,310 @@ impl Result { + fn do_read_event( + &self, peer_descriptor: &mut Descriptor, data: &[u8], + ) -> Result { let mut pause_read = false; let peers = self.peers.read().unwrap(); let mut msgs_to_forward = Vec::new(); let mut peer_node_id = None; - match peers.get(peer_descriptor) { - None => { - // This is most likely a simple race condition where the user read some bytes - // from the socket, then we told the user to `disconnect_socket()`, then they - // called this method. Return an error to make sure we get disconnected. - return Err(PeerHandleError { }); - }, - Some(peer_mutex) => { - let mut read_pos = 0; - while read_pos < data.len() { - macro_rules! try_potential_handleerror { - ($peer: expr, $thing: expr) => {{ - let res = $thing; - let logger = WithContext::from(&self.logger, peer_node_id.map(|(id, _)| id), None, None); - match res { - Ok(x) => x, - Err(e) => { - match e.action { - msgs::ErrorAction::DisconnectPeer { .. } => { - // We may have an `ErrorMessage` to send to the peer, - // but writing to the socket while reading can lead to - // re-entrant code and possibly unexpected behavior. The - // message send is optimistic anyway, and in this case - // we immediately disconnect the peer. - log_debug!(logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer_node_id), e.err); - return Err(PeerHandleError { }); - }, - msgs::ErrorAction::DisconnectPeerWithWarning { .. } => { - // We have a `WarningMessage` to send to the peer, but - // writing to the socket while reading can lead to - // re-entrant code and possibly unexpected behavior. The - // message send is optimistic anyway, and in this case - // we immediately disconnect the peer. - log_debug!(logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer_node_id), e.err); - return Err(PeerHandleError { }); - }, - msgs::ErrorAction::IgnoreAndLog(level) => { - log_given_level!(logger, level, "Error handling {}message{}; ignoring: {}", - if level == Level::Gossip { "gossip " } else { "" }, - OptionalFromDebugger(&peer_node_id), e.err); - continue - }, - msgs::ErrorAction::IgnoreDuplicateGossip => continue, // Don't even bother logging these - msgs::ErrorAction::IgnoreError => { - log_debug!(logger, "Error handling message{}; ignoring: {}", OptionalFromDebugger(&peer_node_id), e.err); - continue; - }, - msgs::ErrorAction::SendErrorMessage { msg } => { - log_debug!(logger, "Error handling message{}; sending error message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - self.enqueue_message($peer, &msg); - continue; - }, - msgs::ErrorAction::SendWarningMessage { msg, log_level } => { - log_given_level!(logger, log_level, "Error handling message{}; sending warning message with: {}", OptionalFromDebugger(&peer_node_id), e.err); - self.enqueue_message($peer, &msg); - continue; - }, - } + + if let Some(peer_mutex) = peers.get(peer_descriptor) { + let mut read_pos = 0; + while read_pos < data.len() { + macro_rules! try_potential_handleerror { + ($peer: expr, $thing: expr) => {{ + let res = $thing; + let logger = WithContext::from(&self.logger, peer_node_id.map(|(id, _)| id), None, None); + match res { + Ok(x) => x, + Err(e) => { + match e.action { + msgs::ErrorAction::DisconnectPeer { .. } => { + // We may have an `ErrorMessage` to send to the peer, + // but writing to the socket while reading can lead to + // re-entrant code and possibly unexpected behavior. The + // message send is optimistic anyway, and in this case + // we immediately disconnect the peer. + log_debug!(logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer_node_id), e.err); + return Err(PeerHandleError { }); + }, + msgs::ErrorAction::DisconnectPeerWithWarning { .. } => { + // We have a `WarningMessage` to send to the peer, but + // writing to the socket while reading can lead to + // re-entrant code and possibly unexpected behavior. The + // message send is optimistic anyway, and in this case + // we immediately disconnect the peer. + log_debug!(logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer_node_id), e.err); + return Err(PeerHandleError { }); + }, + msgs::ErrorAction::IgnoreAndLog(level) => { + log_given_level!(logger, level, "Error handling {}message{}; ignoring: {}", + if level == Level::Gossip { "gossip " } else { "" }, + OptionalFromDebugger(&peer_node_id), e.err); + continue + }, + msgs::ErrorAction::IgnoreDuplicateGossip => continue, // Don't even bother logging these + msgs::ErrorAction::IgnoreError => { + log_debug!(logger, "Error handling message{}; ignoring: {}", OptionalFromDebugger(&peer_node_id), e.err); + continue; + }, + msgs::ErrorAction::SendErrorMessage { msg } => { + log_debug!(logger, "Error handling message{}; sending error message with: {}", OptionalFromDebugger(&peer_node_id), e.err); + self.enqueue_message($peer, &msg); + continue; + }, + msgs::ErrorAction::SendWarningMessage { msg, log_level } => { + log_given_level!(logger, log_level, "Error handling message{}; sending warning message with: {}", OptionalFromDebugger(&peer_node_id), e.err); + self.enqueue_message($peer, &msg); + continue; + }, } } - }} - } + } + }} + } - let mut peer_lock = peer_mutex.lock().unwrap(); - let peer = &mut *peer_lock; - let mut msg_to_handle = None; - if peer_node_id.is_none() { - peer_node_id.clone_from(&peer.their_node_id); - } + let mut peer_lock = peer_mutex.lock().unwrap(); + let peer = &mut *peer_lock; + let mut msg_to_handle = None; + if peer_node_id.is_none() { + peer_node_id.clone_from(&peer.their_node_id); + } - assert!(peer.pending_read_buffer.len() > 0); - assert!(peer.pending_read_buffer.len() > peer.pending_read_buffer_pos); + assert!(peer.pending_read_buffer.len() > 0); + assert!(peer.pending_read_buffer.len() > peer.pending_read_buffer_pos); - { - let data_to_copy = cmp::min(peer.pending_read_buffer.len() - peer.pending_read_buffer_pos, data.len() - read_pos); - peer.pending_read_buffer[peer.pending_read_buffer_pos..peer.pending_read_buffer_pos + data_to_copy].copy_from_slice(&data[read_pos..read_pos + data_to_copy]); - read_pos += data_to_copy; - peer.pending_read_buffer_pos += data_to_copy; - } + { + let data_to_copy = cmp::min( + peer.pending_read_buffer.len() - peer.pending_read_buffer_pos, + data.len() - read_pos, + ); + peer.pending_read_buffer + [peer.pending_read_buffer_pos..peer.pending_read_buffer_pos + data_to_copy] + .copy_from_slice(&data[read_pos..read_pos + data_to_copy]); + read_pos += data_to_copy; + peer.pending_read_buffer_pos += data_to_copy; + } - if peer.pending_read_buffer_pos == peer.pending_read_buffer.len() { - peer.pending_read_buffer_pos = 0; - - macro_rules! insert_node_id { - () => { - let logger = WithContext::from(&self.logger, peer.their_node_id.map(|p| p.0), None, None); - match self.node_id_to_descriptor.lock().unwrap().entry(peer.their_node_id.unwrap().0) { - hash_map::Entry::Occupied(e) => { - log_trace!(logger, "Got second connection with {}, closing", log_pubkey!(peer.their_node_id.unwrap().0)); - peer.their_node_id = None; // Unset so that we don't generate a peer_disconnected event - // Check that the peers map is consistent with the - // node_id_to_descriptor map, as this has been broken - // before. - debug_assert!(peers.get(e.get()).is_some()); - return Err(PeerHandleError { }) - }, - hash_map::Entry::Vacant(entry) => { - log_debug!(logger, "Finished noise handshake for connection with {}", log_pubkey!(peer.their_node_id.unwrap().0)); - entry.insert(peer_descriptor.clone()) - }, - }; - } + if peer.pending_read_buffer_pos == peer.pending_read_buffer.len() { + peer.pending_read_buffer_pos = 0; + + macro_rules! insert_node_id { + () => { + let their_node_id = peer.their_node_id.map(|p| p.0); + let logger = WithContext::from(&self.logger, their_node_id, None, None); + match self.node_id_to_descriptor.lock().unwrap().entry(peer.their_node_id.unwrap().0) { + hash_map::Entry::Occupied(e) => { + log_trace!(logger, "Got second connection with {}, closing", log_pubkey!(peer.their_node_id.unwrap().0)); + // Unset `their_node_id` so that we don't generate a peer_disconnected event + // Check that the peers map is consistent with the + // node_id_to_descriptor map, as this has been broken + // before. + peer.their_node_id = None; + debug_assert!(peers.get(e.get()).is_some()); + return Err(PeerHandleError { }) + }, + hash_map::Entry::Vacant(entry) => { + log_debug!(logger, "Finished noise handshake for connection with {}", log_pubkey!(peer.their_node_id.unwrap().0)); + entry.insert(peer_descriptor.clone()) + }, + }; } + } - let next_step = peer.channel_encryptor.get_noise_step(); - match next_step { - NextNoiseStep::ActOne => { - let act_two = try_potential_handleerror!(peer, peer.channel_encryptor - .process_act_one_with_keys(&peer.pending_read_buffer[..], - &self.node_signer, self.get_ephemeral_key(), &self.secp_ctx)).to_vec(); - peer.pending_outbound_buffer.push_back(act_two); - peer.pending_read_buffer = [0; 66].to_vec(); // act three is 66 bytes long - }, - NextNoiseStep::ActTwo => { - let (act_three, their_node_id) = try_potential_handleerror!(peer, - peer.channel_encryptor.process_act_two(&peer.pending_read_buffer[..], - &self.node_signer)); - peer.pending_outbound_buffer.push_back(act_three.to_vec()); - peer.pending_read_buffer = [0; 18].to_vec(); // Message length header is 18 bytes + let next_step = peer.channel_encryptor.get_noise_step(); + match next_step { + NextNoiseStep::ActOne => { + let res = peer.channel_encryptor.process_act_one_with_keys( + &peer.pending_read_buffer[..], + &self.node_signer, + self.get_ephemeral_key(), + &self.secp_ctx, + ); + let act_two = try_potential_handleerror!(peer, res).to_vec(); + peer.pending_outbound_buffer.push_back(act_two); + peer.pending_read_buffer = [0; 66].to_vec(); // act three is 66 bytes long + }, + NextNoiseStep::ActTwo => { + let res = peer + .channel_encryptor + .process_act_two(&peer.pending_read_buffer[..], &self.node_signer); + let (act_three, their_node_id) = try_potential_handleerror!(peer, res); + peer.pending_outbound_buffer.push_back(act_three.to_vec()); + peer.pending_read_buffer = [0; 18].to_vec(); // Message length header is 18 bytes + peer.pending_read_is_header = true; + + peer.set_their_node_id(their_node_id); + insert_node_id!(); + let features = self.init_features(their_node_id); + let networks = self.message_handler.chan_handler.get_chain_hashes(); + let resp = msgs::Init { + features, + networks, + remote_network_address: filter_addresses( + peer.their_socket_address.clone(), + ), + }; + self.enqueue_message(peer, &resp); + }, + NextNoiseStep::ActThree => { + let res = peer + .channel_encryptor + .process_act_three(&peer.pending_read_buffer[..]); + let their_node_id = try_potential_handleerror!(peer, res); + peer.pending_read_buffer = [0; 18].to_vec(); // Message length header is 18 bytes + peer.pending_read_is_header = true; + peer.set_their_node_id(their_node_id); + insert_node_id!(); + let features = self.init_features(their_node_id); + let networks = self.message_handler.chan_handler.get_chain_hashes(); + let resp = msgs::Init { + features, + networks, + remote_network_address: filter_addresses( + peer.their_socket_address.clone(), + ), + }; + self.enqueue_message(peer, &resp); + }, + NextNoiseStep::NoiseComplete => { + if peer.pending_read_is_header { + let res = peer + .channel_encryptor + .decrypt_length_header(&peer.pending_read_buffer[..]); + let msg_len = try_potential_handleerror!(peer, res); + if peer.pending_read_buffer.capacity() > 8192 { + peer.pending_read_buffer = Vec::new(); + } + peer.pending_read_buffer.resize(msg_len as usize + 16, 0); + if msg_len < 2 { + // Need at least the message type tag + return Err(PeerHandleError {}); + } + peer.pending_read_is_header = false; + } else { + debug_assert!(peer.pending_read_buffer.len() >= 2 + 16); + let res = peer + .channel_encryptor + .decrypt_message(&mut peer.pending_read_buffer[..]); + try_potential_handleerror!(peer, res); + + let message_result = wire::read( + &mut &peer.pending_read_buffer + [..peer.pending_read_buffer.len() - 16], + &*self.message_handler.custom_message_handler, + ); + + // Reset read buffer + if peer.pending_read_buffer.capacity() > 8192 { + peer.pending_read_buffer = Vec::new(); + } + peer.pending_read_buffer.resize(18, 0); peer.pending_read_is_header = true; - peer.set_their_node_id(their_node_id); - insert_node_id!(); - let features = self.init_features(their_node_id); - let networks = self.message_handler.chan_handler.get_chain_hashes(); - let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) }; - self.enqueue_message(peer, &resp); - }, - NextNoiseStep::ActThree => { - let their_node_id = try_potential_handleerror!(peer, - peer.channel_encryptor.process_act_three(&peer.pending_read_buffer[..])); - peer.pending_read_buffer = [0; 18].to_vec(); // Message length header is 18 bytes - peer.pending_read_is_header = true; - peer.set_their_node_id(their_node_id); - insert_node_id!(); - let features = self.init_features(their_node_id); - let networks = self.message_handler.chan_handler.get_chain_hashes(); - let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) }; - self.enqueue_message(peer, &resp); - }, - NextNoiseStep::NoiseComplete => { - if peer.pending_read_is_header { - let msg_len = try_potential_handleerror!(peer, - peer.channel_encryptor.decrypt_length_header(&peer.pending_read_buffer[..])); - if peer.pending_read_buffer.capacity() > 8192 { peer.pending_read_buffer = Vec::new(); } - peer.pending_read_buffer.resize(msg_len as usize + 16, 0); - if msg_len < 2 { // Need at least the message type tag - return Err(PeerHandleError { }); - } - peer.pending_read_is_header = false; - } else { - debug_assert!(peer.pending_read_buffer.len() >= 2 + 16); - try_potential_handleerror!(peer, - peer.channel_encryptor.decrypt_message(&mut peer.pending_read_buffer[..])); - - let message_result = wire::read( - &mut &peer.pending_read_buffer[..peer.pending_read_buffer.len() - 16], - &*self.message_handler.custom_message_handler - ); - - // Reset read buffer - if peer.pending_read_buffer.capacity() > 8192 { peer.pending_read_buffer = Vec::new(); } - peer.pending_read_buffer.resize(18, 0); - peer.pending_read_is_header = true; - - let logger = WithContext::from(&self.logger, peer.their_node_id.map(|p| p.0), None, None); - let message = match message_result { - Ok(x) => x, - Err(e) => { - match e { - // Note that to avoid re-entrancy we never call - // `do_attempt_write_data` from here, causing - // the messages enqueued here to not actually - // be sent before the peer is disconnected. - (msgs::DecodeError::UnknownRequiredFeature, Some(ty)) if is_gossip_msg(ty) => { - log_gossip!(logger, "Got a channel/node announcement with an unknown required feature flag, you may want to update!"); - continue; - } - (msgs::DecodeError::UnsupportedCompression, _) => { - log_gossip!(logger, "We don't support zlib-compressed message fields, sending a warning and ignoring message"); - self.enqueue_message(peer, &msgs::WarningMessage { channel_id: ChannelId::new_zero(), data: "Unsupported message compression: zlib".to_owned() }); - continue; - } - (_, Some(ty)) if is_gossip_msg(ty) => { - log_gossip!(logger, "Got an invalid value while deserializing a gossip message"); - self.enqueue_message(peer, &msgs::WarningMessage { - channel_id: ChannelId::new_zero(), - data: format!("Unreadable/bogus gossip message of type {}", ty), - }); - continue; - } - (msgs::DecodeError::UnknownRequiredFeature, _) => { - log_debug!(logger, "Received a message with an unknown required feature flag or TLV, you may want to update!"); - return Err(PeerHandleError { }); - } - (msgs::DecodeError::UnknownVersion, _) => return Err(PeerHandleError { }), - (msgs::DecodeError::InvalidValue, _) => { - log_debug!(logger, "Got an invalid value while deserializing message"); - return Err(PeerHandleError { }); - } - (msgs::DecodeError::ShortRead, _) => { - log_debug!(logger, "Deserialization failed due to shortness of message"); - return Err(PeerHandleError { }); - } - (msgs::DecodeError::BadLengthDescriptor, _) => return Err(PeerHandleError { }), - (msgs::DecodeError::Io(_), _) => return Err(PeerHandleError { }), - (msgs::DecodeError::DangerousValue, _) => return Err(PeerHandleError { }), - } + let their_node_id = peer.their_node_id.map(|p| p.0); + let logger = + WithContext::from(&self.logger, their_node_id, None, None); + let message = match message_result { + Ok(x) => x, + Err(e) => { + match e { + // Note that to avoid re-entrancy we never call + // `do_attempt_write_data` from here, causing + // the messages enqueued here to not actually + // be sent before the peer is disconnected. + ( + msgs::DecodeError::UnknownRequiredFeature, + Some(ty), + ) if is_gossip_msg(ty) => { + log_gossip!(logger, "Got a channel/node announcement with an unknown required feature flag, you may want to update!"); + continue; + }, + (msgs::DecodeError::UnsupportedCompression, _) => { + log_gossip!(logger, "We don't support zlib-compressed message fields, sending a warning and ignoring message"); + let channel_id = ChannelId::new_zero(); + let data = "Unsupported message compression: zlib" + .to_owned(); + let msg = msgs::WarningMessage { channel_id, data }; + self.enqueue_message(peer, &msg); + continue; + }, + (_, Some(ty)) if is_gossip_msg(ty) => { + log_gossip!(logger, "Got an invalid value while deserializing a gossip message"); + let channel_id = ChannelId::new_zero(); + let data = format!( + "Unreadable/bogus gossip message of type {}", + ty + ); + let msg = msgs::WarningMessage { channel_id, data }; + self.enqueue_message(peer, &msg); + continue; + }, + (msgs::DecodeError::UnknownRequiredFeature, _) => { + log_debug!(logger, "Received a message with an unknown required feature flag or TLV, you may want to update!"); + return Err(PeerHandleError {}); + }, + (msgs::DecodeError::UnknownVersion, _) => { + return Err(PeerHandleError {}) + }, + (msgs::DecodeError::InvalidValue, _) => { + log_debug!(logger, "Got an invalid value while deserializing message"); + return Err(PeerHandleError {}); + }, + (msgs::DecodeError::ShortRead, _) => { + log_debug!(logger, "Deserialization failed due to shortness of message"); + return Err(PeerHandleError {}); + }, + (msgs::DecodeError::BadLengthDescriptor, _) => { + return Err(PeerHandleError {}) + }, + (msgs::DecodeError::Io(_), _) => { + return Err(PeerHandleError {}) + }, + (msgs::DecodeError::DangerousValue, _) => { + return Err(PeerHandleError {}) + }, } - }; + }, + }; - msg_to_handle = Some(message); - } + msg_to_handle = Some(message); } - } + }, } - pause_read = !self.peer_should_read(peer); - - if let Some(message) = msg_to_handle { - match self.handle_message(&peer_mutex, peer_lock, message) { - Err(handling_error) => match handling_error { - MessageHandlingError::PeerHandleError(e) => { return Err(e) }, - MessageHandlingError::LightningError(e) => { - try_potential_handleerror!(&mut peer_mutex.lock().unwrap(), Err(e)); - }, - }, - Ok(Some(msg)) => { - msgs_to_forward.push(msg); + } + pause_read = !self.peer_should_read(peer); + + if let Some(message) = msg_to_handle { + match self.handle_message(&peer_mutex, peer_lock, message) { + Err(handling_error) => match handling_error { + MessageHandlingError::PeerHandleError(e) => return Err(e), + MessageHandlingError::LightningError(e) => { + try_potential_handleerror!(&mut peer_mutex.lock().unwrap(), Err(e)); }, - Ok(None) => {}, - } + }, + Ok(Some(msg)) => { + msgs_to_forward.push(msg); + }, + Ok(None) => {}, } } } + } else { + // This is most likely a simple race condition where the user read some bytes + // from the socket, then we told the user to `disconnect_socket()`, then they + // called this method. Return an error to make sure we get disconnected. + return Err(PeerHandleError {}); } for msg in msgs_to_forward.drain(..) { - self.forward_broadcast_msg(&*peers, &msg, peer_node_id.as_ref().map(|(pk, _)| pk), false); + self.forward_broadcast_msg( + &*peers, + &msg, + peer_node_id.as_ref().map(|(pk, _)| pk), + false, + ); } Ok(pause_read) @@ -1691,25 +1997,41 @@ impl, - peer_lock: MutexGuard, - message: wire::Message<<::Target as wire::CustomMessageReader>::CustomMessage> - ) -> Result::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> { - let their_node_id = peer_lock.their_node_id.expect("We know the peer's public key by the time we receive messages").0; + &self, peer_mutex: &Mutex, peer_lock: MutexGuard, + message: wire::Message< + <::Target as wire::CustomMessageReader>::CustomMessage, + >, + ) -> Result< + Option::Target as wire::CustomMessageReader>::CustomMessage>>, + MessageHandlingError, + > { + let their_node_id = peer_lock + .their_node_id + .expect("We know the peer's public key by the time we receive messages") + .0; let logger = WithContext::from(&self.logger, Some(their_node_id), None, None); - let unprocessed_message = self.do_handle_message_holding_peer_lock(peer_lock, message, their_node_id, &logger)?; + let unprocessed_message = + self.do_handle_message_holding_peer_lock(peer_lock, message, their_node_id, &logger)?; self.message_handler.chan_handler.message_received(); match unprocessed_message { - Some(LogicalMessage::FromWire(message)) => { - self.do_handle_message_without_peer_lock(peer_mutex, message, their_node_id, &logger) - }, + Some(LogicalMessage::FromWire(message)) => self.do_handle_message_without_peer_lock( + peer_mutex, + message, + their_node_id, + &logger, + ), Some(LogicalMessage::CommitmentSignedBatch(channel_id, batch)) => { - log_trace!(logger, "Received commitment_signed batch {:?} from {}", batch, log_pubkey!(their_node_id)); - self.message_handler.chan_handler.handle_commitment_signed_batch(their_node_id, channel_id, batch); + log_trace!( + logger, + "Received commitment_signed batch {:?} from {}", + batch, + log_pubkey!(their_node_id) + ); + let chan_handler = &self.message_handler.chan_handler; + chan_handler.handle_commitment_signed_batch(their_node_id, channel_id, batch); return Ok(None); }, None => Ok(None), @@ -1721,20 +2043,25 @@ impl( - &self, - mut peer_lock: MutexGuard, - message: wire::Message<<::Target as wire::CustomMessageReader>::CustomMessage>, - their_node_id: PublicKey, - logger: &WithContext<'a, L> - ) -> Result::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> - { + &self, mut peer_lock: MutexGuard, + message: wire::Message< + <::Target as wire::CustomMessageReader>::CustomMessage, + >, + their_node_id: PublicKey, logger: &WithContext<'a, L>, + ) -> Result< + Option< + LogicalMessage<<::Target as wire::CustomMessageReader>::CustomMessage>, + >, + MessageHandlingError, + > { peer_lock.received_message_since_timer_tick = true; // Need an Init as first message if let wire::Message::Init(msg) = message { // Check if we have any compatible chains if the `networks` field is specified. if let Some(networks) = &msg.networks { - if let Some(our_chains) = self.message_handler.chan_handler.get_chain_hashes() { + let chan_handler = &self.message_handler.chan_handler; + if let Some(our_chains) = chan_handler.get_chain_hashes() { let mut have_compatible_chains = false; 'our_chains: for our_chain in our_chains.iter() { for their_chain in networks { @@ -1746,56 +2073,90 @@ impl COMMITMENT_SIGNED_BATCH_LIMIT { - let error = format!("Peer {} sent start_batch for channel {} exceeding the limit", log_pubkey!(their_node_id), &msg.channel_id); + let error = format!( + "Peer {} sent start_batch for channel {} exceeding the limit", + log_pubkey!(their_node_id), + &msg.channel_id + ); log_debug!(logger, "{}", error); return Err(LightningError { err: error.clone(), action: msgs::ErrorAction::DisconnectPeerWithWarning { - msg: msgs::WarningMessage { - channel_id: msg.channel_id, - data: error, - }, + msg: msgs::WarningMessage { channel_id: msg.channel_id, data: error }, }, - }.into()); + } + .into()); } let messages = match msg.message_type { @@ -1860,25 +2227,23 @@ impl { - let error = format!("Peer {} sent start_batch for channel {} without a known message type", log_pubkey!(their_node_id), &msg.channel_id); + let error = format!( + "Peer {} sent start_batch for channel {} without a known message type", + log_pubkey!(their_node_id), + &msg.channel_id + ); log_debug!(logger, "{}", error); return Err(LightningError { err: error.clone(), action: msgs::ErrorAction::DisconnectPeerWithWarning { - msg: msgs::WarningMessage { - channel_id: msg.channel_id, - data: error, - }, + msg: msgs::WarningMessage { channel_id: msg.channel_id, data: error }, }, - }.into()); + } + .into()); }, }; - let message_batch = MessageBatch { - channel_id: msg.channel_id, - batch_size, - messages, - }; + let message_batch = MessageBatch { channel_id: msg.channel_id, batch_size, messages }; peer_lock.message_batch = Some(message_batch); return Ok(None); @@ -1886,7 +2251,8 @@ impl { - log_debug!(logger, "Peer {} sent an unexpected message for a commitment_signed batch", log_pubkey!(their_node_id)); + log_debug!( + logger, + "Peer {} sent an unexpected message for a commitment_signed batch", + log_pubkey!(their_node_id) + ); }, } - return Err(PeerHandleError { }.into()); + return Err(PeerHandleError {}.into()); } if let wire::Message::GossipTimestampFilter(_msg) = message { // When supporting gossip messages, start initial gossip sync only after we receive // a GossipTimestampFilter - if peer_lock.their_features.as_ref().unwrap().supports_gossip_queries() && - !peer_lock.sent_gossip_timestamp_filter { + if peer_lock.their_features.as_ref().unwrap().supports_gossip_queries() + && !peer_lock.sent_gossip_timestamp_filter + { peer_lock.sent_gossip_timestamp_filter = true; #[allow(unused_mut)] @@ -1939,7 +2309,10 @@ impl 1970").as_secs() - 6 * 3600; + let full_sync_threshold = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time must be > 1970") + .as_secs() - 6 * 3600; if (_msg.first_timestamp as u64) > full_sync_threshold { should_do_full_sync = false; } @@ -1964,17 +2337,29 @@ impl( - &self, - peer_mutex: &Mutex, - message: wire::Message<<::Target as wire::CustomMessageReader>::CustomMessage>, - their_node_id: PublicKey, - logger: &WithContext<'a, L> - ) -> Result::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> - { + &self, peer_mutex: &Mutex, + message: wire::Message< + <::Target as wire::CustomMessageReader>::CustomMessage, + >, + their_node_id: PublicKey, logger: &WithContext<'a, L>, + ) -> Result< + Option::Target as wire::CustomMessageReader>::CustomMessage>>, + MessageHandlingError, + > { if is_gossip_msg(message.type_id()) { - log_gossip!(logger, "Received message {:?} from {}", message, log_pubkey!(their_node_id)); + log_gossip!( + logger, + "Received message {:?} from {}", + message, + log_pubkey!(their_node_id) + ); } else { - log_trace!(logger, "Received message {:?} from {}", message, log_pubkey!(their_node_id)); + log_trace!( + logger, + "Received message {:?} from {}", + message, + log_pubkey!(their_node_id) + ); } let mut should_forward = None; @@ -1988,14 +2373,24 @@ impl { - log_debug!(logger, "Got Err message from {}: {}", log_pubkey!(their_node_id), PrintableString(&msg.data)); + log_debug!( + logger, + "Got Err message from {}: {}", + log_pubkey!(their_node_id), + PrintableString(&msg.data) + ); self.message_handler.chan_handler.handle_error(their_node_id, &msg); if msg.channel_id.is_zero() { - return Err(PeerHandleError { }.into()); + return Err(PeerHandleError {}.into()); } }, wire::Message::Warning(msg) => { - log_debug!(logger, "Got warning message from {}: {}", log_pubkey!(their_node_id), PrintableString(&msg.data)); + log_debug!( + logger, + "Got warning message from {}: {}", + log_pubkey!(their_node_id), + PrintableString(&msg.data) + ); }, wire::Message::Ping(msg) => { @@ -2046,21 +2441,21 @@ impl { self.message_handler.chan_handler.handle_stfu(their_node_id, &msg); - } + }, #[cfg(splicing)] // Splicing messages: wire::Message::SpliceInit(msg) => { self.message_handler.chan_handler.handle_splice_init(their_node_id, &msg); - } + }, #[cfg(splicing)] wire::Message::SpliceAck(msg) => { self.message_handler.chan_handler.handle_splice_ack(their_node_id, &msg); - } + }, #[cfg(splicing)] wire::Message::SpliceLocked(msg) => { self.message_handler.chan_handler.handle_splice_locked(their_node_id, &msg); - } + }, // Interactive transaction construction messages: wire::Message::TxAddInput(msg) => { @@ -2089,7 +2484,7 @@ impl { self.message_handler.chan_handler.handle_tx_abort(their_node_id, &msg); - } + }, wire::Message::Shutdown(msg) => { self.message_handler.chan_handler.handle_shutdown(their_node_id, &msg); @@ -2109,7 +2504,8 @@ impl { - self.message_handler.chan_handler.handle_update_fail_malformed_htlc(their_node_id, &msg); + let chan_handler = &self.message_handler.chan_handler; + chan_handler.handle_update_fail_malformed_htlc(their_node_id, &msg); }, wire::Message::CommitmentSigned(msg) => { @@ -2127,58 +2523,77 @@ impl { - self.message_handler.chan_handler.handle_announcement_signatures(their_node_id, &msg); + let chan_handler = &self.message_handler.chan_handler; + chan_handler.handle_announcement_signatures(their_node_id, &msg); }, wire::Message::ChannelAnnouncement(msg) => { - if self.message_handler.route_handler.handle_channel_announcement(Some(their_node_id), &msg) - .map_err(|e| -> MessageHandlingError { e.into() })? { + let route_handler = &self.message_handler.route_handler; + if route_handler + .handle_channel_announcement(Some(their_node_id), &msg) + .map_err(|e| -> MessageHandlingError { e.into() })? + { should_forward = Some(wire::Message::ChannelAnnouncement(msg)); } self.update_gossip_backlogged(); }, wire::Message::NodeAnnouncement(msg) => { - if self.message_handler.route_handler.handle_node_announcement(Some(their_node_id), &msg) - .map_err(|e| -> MessageHandlingError { e.into() })? { + let route_handler = &self.message_handler.route_handler; + if route_handler + .handle_node_announcement(Some(their_node_id), &msg) + .map_err(|e| -> MessageHandlingError { e.into() })? + { should_forward = Some(wire::Message::NodeAnnouncement(msg)); } self.update_gossip_backlogged(); }, wire::Message::ChannelUpdate(msg) => { - self.message_handler.chan_handler.handle_channel_update(their_node_id, &msg); - if self.message_handler.route_handler.handle_channel_update(Some(their_node_id), &msg) - .map_err(|e| -> MessageHandlingError { e.into() })? { + let route_handler = &self.message_handler.route_handler; + if route_handler + .handle_channel_update(Some(their_node_id), &msg) + .map_err(|e| -> MessageHandlingError { e.into() })? + { should_forward = Some(wire::Message::ChannelUpdate(msg)); } self.update_gossip_backlogged(); }, wire::Message::QueryShortChannelIds(msg) => { - self.message_handler.route_handler.handle_query_short_channel_ids(their_node_id, msg)?; + let route_handler = &self.message_handler.route_handler; + route_handler.handle_query_short_channel_ids(their_node_id, msg)?; }, wire::Message::ReplyShortChannelIdsEnd(msg) => { - self.message_handler.route_handler.handle_reply_short_channel_ids_end(their_node_id, msg)?; + let route_handler = &self.message_handler.route_handler; + route_handler.handle_reply_short_channel_ids_end(their_node_id, msg)?; }, wire::Message::QueryChannelRange(msg) => { - self.message_handler.route_handler.handle_query_channel_range(their_node_id, msg)?; + let route_handler = &self.message_handler.route_handler; + route_handler.handle_query_channel_range(their_node_id, msg)?; }, wire::Message::ReplyChannelRange(msg) => { - self.message_handler.route_handler.handle_reply_channel_range(their_node_id, msg)?; + let route_handler = &self.message_handler.route_handler; + route_handler.handle_reply_channel_range(their_node_id, msg)?; }, // Onion message: wire::Message::OnionMessage(msg) => { - self.message_handler.onion_message_handler.handle_onion_message(their_node_id, &msg); + let onion_message_handler = &self.message_handler.onion_message_handler; + onion_message_handler.handle_onion_message(their_node_id, &msg); }, // Unknown messages: wire::Message::Unknown(type_id) if message.is_even() => { - log_debug!(logger, "Received unknown even message of type {}, disconnecting peer!", type_id); - return Err(PeerHandleError { }.into()); + log_debug!( + logger, + "Received unknown even message of type {}, disconnecting peer!", + type_id + ); + return Err(PeerHandleError {}.into()); }, wire::Message::Unknown(type_id) => { log_trace!(logger, "Received unknown odd message of type {}, ignoring", type_id); }, wire::Message::Custom(custom) => { - self.message_handler.custom_message_handler.handle_custom_message(custom, their_node_id)?; + let custom_message_handler = &self.message_handler.custom_message_handler; + custom_message_handler.handle_custom_message(custom, their_node_id)?; }, }; Ok(should_forward) @@ -2202,43 +2617,67 @@ impl { - log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced node: {:?}", except_node, msg); + log_gossip!( + self.logger, + "Sending message to all peers except {:?} or the announced node: {:?}", + except_node, + msg + ); let encoded_msg = encode_msg!(msg); for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if !peer.handshake_complete() || - !peer.should_forward_node_announcement(msg.contents.node_id) { - continue + if !peer.handshake_complete() + || !peer.should_forward_node_announcement(msg.contents.node_id) + { + continue; } debug_assert!(peer.their_node_id.is_some()); debug_assert!(peer.channel_encryptor.is_ready_for_encryption()); - let logger = WithContext::from(&self.logger, peer.their_node_id.map(|p| p.0), None, None); + let their_node_id = peer.their_node_id.map(|p| p.0); + let logger = WithContext::from(&self.logger, their_node_id, None, None); if peer.buffer_full_drop_gossip_broadcast() && !allow_large_buffer { - log_gossip!(logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id); + log_gossip!( + logger, + "Skipping broadcast message to {:?} as its outbound buffer is full", + peer.their_node_id + ); continue; } if let Some((_, their_node_id)) = peer.their_node_id { @@ -2246,36 +2685,59 @@ impl { - log_gossip!(self.logger, "Sending message to all peers except {:?}: {:?}", except_node, msg); + log_gossip!( + self.logger, + "Sending message to all peers except {:?}: {:?}", + except_node, + msg + ); let encoded_msg = encode_msg!(msg); for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if !peer.handshake_complete() || - !peer.should_forward_channel_announcement(msg.contents.short_channel_id) { - continue + if !peer.handshake_complete() + || !peer.should_forward_channel_announcement(msg.contents.short_channel_id) + { + continue; } debug_assert!(peer.their_node_id.is_some()); debug_assert!(peer.channel_encryptor.is_ready_for_encryption()); - let logger = WithContext::from(&self.logger, peer.their_node_id.map(|p| p.0), None, None); + let their_node_id = peer.their_node_id.map(|p| p.0); + let logger = WithContext::from(&self.logger, their_node_id, None, None); if peer.buffer_full_drop_gossip_broadcast() && !allow_large_buffer { - log_gossip!(logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id); + log_gossip!( + logger, + "Skipping broadcast message to {:?} as its outbound buffer is full", + peer.their_node_id + ); continue; } - if except_node.is_some() && peer.their_node_id.as_ref().map(|(pk, _)| pk) == except_node { + if except_node.is_some() + && peer.their_node_id.as_ref().map(|(pk, _)| pk) == except_node + { continue; } - self.enqueue_encoded_gossip_broadcast(&mut *peer, MessageBuf::from_encoded(&encoded_msg)); + self.enqueue_encoded_gossip_broadcast( + &mut *peer, + MessageBuf::from_encoded(&encoded_msg), + ); } }, - _ => debug_assert!(false, "We shouldn't attempt to forward anything but gossip messages"), + _ => { + debug_assert!(false, "We shouldn't attempt to forward anything but gossip messages") + }, } } @@ -2308,7 +2770,8 @@ impl { - { - if peers_to_disconnect.get($node_id).is_some() { - // If we've "disconnected" this peer, do not send to it. - None - } else { - let descriptor_opt = self.node_id_to_descriptor.lock().unwrap().get($node_id).cloned(); - match descriptor_opt { - Some(descriptor) => match peers.get(&descriptor) { - Some(peer_mutex) => { - let peer_lock = peer_mutex.lock().unwrap(); - if !peer_lock.handshake_complete() { - None - } else { - Some(peer_lock) - } - }, - None => { - debug_assert!(false, "Inconsistent peers set state!"); + ($node_id: expr) => {{ + if peers_to_disconnect.get($node_id).is_some() { + // If we've "disconnected" this peer, do not send to it. + None + } else { + let descriptor_opt = + self.node_id_to_descriptor.lock().unwrap().get($node_id).cloned(); + match descriptor_opt { + Some(descriptor) => match peers.get(&descriptor) { + Some(peer_mutex) => { + let peer_lock = peer_mutex.lock().unwrap(); + if !peer_lock.handshake_complete() { None + } else { + Some(peer_lock) } }, None => { + debug_assert!(false, "Inconsistent peers set state!"); None }, - } + }, + None => None, } } - } + }}; } + let route_handler = &self.message_handler.route_handler; + let chan_handler = &self.message_handler.chan_handler; + let onion_message_handler = &self.message_handler.onion_message_handler; + let custom_message_handler = &self.message_handler.custom_message_handler; + let send_only_message_handler = &self.message_handler.send_only_message_handler; + // Handles a `MessageSendEvent`, using `from_chan_handler` to decide if we should // robustly gossip broadcast events even if a peer's message buffer is full. let mut handle_event = |event, from_chan_handler| { match event { MessageSendEvent::SendPeerStorage { ref node_id, ref msg } => { - log_debug!(self.logger, "Handling SendPeerStorage event in peer_handler for {}", log_pubkey!(node_id)); + log_debug!( + self.logger, + "Handling SendPeerStorage event in peer_handler for {}", + log_pubkey!(node_id) + ); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendPeerStorageRetrieval { ref node_id, ref msg } => { - log_debug!(self.logger, "Handling SendPeerStorageRetrieval event in peer_handler for {}", log_pubkey!(node_id)); + log_debug!( + self.logger, + "Handling SendPeerStorageRetrieval event in peer_handler for {}", + log_pubkey!(node_id) + ); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); }, MessageSendEvent::SendAcceptChannel { ref node_id, ref msg } => { @@ -2405,34 +2879,54 @@ impl { - let logger = WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None); + MessageSendEvent::SendStfu { ref node_id, ref msg } => { + let logger = WithContext::from( + &self.logger, + Some(*node_id), + Some(msg.channel_id), + None, + ); log_debug!(logger, "Handling SendStfu event in peer_handler for node {} for channel {}", log_pubkey!(node_id), &msg.channel_id); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - } - MessageSendEvent::SendSpliceInit { ref node_id, ref msg} => { - let logger = WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None); + }, + MessageSendEvent::SendSpliceInit { ref node_id, ref msg } => { + let logger = WithContext::from( + &self.logger, + Some(*node_id), + Some(msg.channel_id), + None, + ); log_debug!(logger, "Handling SendSpliceInit event in peer_handler for node {} for channel {}", log_pubkey!(node_id), &msg.channel_id); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - } - MessageSendEvent::SendSpliceAck { ref node_id, ref msg} => { - let logger = WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None); + }, + MessageSendEvent::SendSpliceAck { ref node_id, ref msg } => { + let logger = WithContext::from( + &self.logger, + Some(*node_id), + Some(msg.channel_id), + None, + ); log_debug!(logger, "Handling SendSpliceAck event in peer_handler for node {} for channel {}", log_pubkey!(node_id), &msg.channel_id); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - } - MessageSendEvent::SendSpliceLocked { ref node_id, ref msg} => { - let logger = WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None); + }, + MessageSendEvent::SendSpliceLocked { ref node_id, ref msg } => { + let logger = WithContext::from( + &self.logger, + Some(*node_id), + Some(msg.channel_id), + None, + ); log_debug!(logger, "Handling SendSpliceLocked event in peer_handler for node {} for channel {}", log_pubkey!(node_id), &msg.channel_id); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - } + }, MessageSendEvent::SendTxAddInput { ref node_id, ref msg } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(msg.channel_id), None), "Handling SendTxAddInput event in peer_handler for node {} for channel {}", log_pubkey!(node_id), @@ -2493,7 +2987,19 @@ impl { + MessageSendEvent::UpdateHTLCs { + ref node_id, + ref channel_id, + updates: + msgs::CommitmentUpdate { + ref update_add_htlcs, + ref update_fulfill_htlcs, + ref update_fail_htlcs, + ref update_fail_malformed_htlcs, + ref update_fee, + ref commitment_signed, + }, + } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), Some(*channel_id), None), "Handling UpdateHTLCs event in peer_handler for node {} with {} adds, {} fulfills, {} fails, {} commits for channel {}", log_pubkey!(node_id), update_add_htlcs.len(), @@ -2553,27 +3059,52 @@ impl { + MessageSendEvent::SendChannelAnnouncement { + ref node_id, + ref msg, + ref update_msg, + } => { log_debug!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendChannelAnnouncement event in peer_handler for node {} for short channel id {}", log_pubkey!(node_id), msg.contents.short_channel_id); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, update_msg); + self.enqueue_message( + &mut *get_peer_for_forwarding!(node_id)?, + update_msg, + ); }, MessageSendEvent::BroadcastChannelAnnouncement { msg, update_msg } => { log_debug!(self.logger, "Handling BroadcastChannelAnnouncement event in peer_handler for short channel id {}", msg.contents.short_channel_id); - match self.message_handler.route_handler.handle_channel_announcement(None, &msg) { - Ok(_) | Err(LightningError { action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { + match route_handler.handle_channel_announcement(None, &msg) { + Ok(_) + | Err(LightningError { + action: msgs::ErrorAction::IgnoreDuplicateGossip, + .. + }) => { let forward = wire::Message::ChannelAnnouncement(msg); - self.forward_broadcast_msg(peers, &forward, None, from_chan_handler); + self.forward_broadcast_msg( + peers, + &forward, + None, + from_chan_handler, + ); }, _ => {}, } if let Some(msg) = update_msg { - match self.message_handler.route_handler.handle_channel_update(None, &msg) { - Ok(_) | Err(LightningError { action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { + match route_handler.handle_channel_update(None, &msg) { + Ok(_) + | Err(LightningError { + action: msgs::ErrorAction::IgnoreDuplicateGossip, + .. + }) => { let forward = wire::Message::ChannelUpdate(msg); - self.forward_broadcast_msg(peers, &forward, None, from_chan_handler); + self.forward_broadcast_msg( + peers, + &forward, + None, + from_chan_handler, + ); }, _ => {}, } @@ -2581,20 +3112,38 @@ impl { log_debug!(self.logger, "Handling BroadcastChannelUpdate event in peer_handler for contents {:?}", msg.contents); - match self.message_handler.route_handler.handle_channel_update(None, &msg) { - Ok(_) | Err(LightningError { action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { + match route_handler.handle_channel_update(None, &msg) { + Ok(_) + | Err(LightningError { + action: msgs::ErrorAction::IgnoreDuplicateGossip, + .. + }) => { let forward = wire::Message::ChannelUpdate(msg); - self.forward_broadcast_msg(peers, &forward, None, from_chan_handler); + self.forward_broadcast_msg( + peers, + &forward, + None, + from_chan_handler, + ); }, _ => {}, } }, MessageSendEvent::BroadcastNodeAnnouncement { msg } => { log_debug!(self.logger, "Handling BroadcastNodeAnnouncement event in peer_handler for node {}", msg.contents.node_id); - match self.message_handler.route_handler.handle_node_announcement(None, &msg) { - Ok(_) | Err(LightningError { action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { + match route_handler.handle_node_announcement(None, &msg) { + Ok(_) + | Err(LightningError { + action: msgs::ErrorAction::IgnoreDuplicateGossip, + .. + }) => { let forward = wire::Message::NodeAnnouncement(msg); - self.forward_broadcast_msg(peers, &forward, None, from_chan_handler); + self.forward_broadcast_msg( + peers, + &forward, + None, + from_chan_handler, + ); }, _ => {}, } @@ -2618,7 +3167,9 @@ impl::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg)); + let msg = msg.map(|msg| { + wire::Message::<<::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg) + }); peers_to_disconnect.insert(node_id, msg); }, msgs::ErrorAction::DisconnectPeerWithWarning { msg } => { @@ -2627,26 +3178,45 @@ impl { - log_given_level!(logger, level, "Received a HandleError event to be ignored for node {}", log_pubkey!(node_id)); + log_given_level!( + logger, + level, + "Received a HandleError event to be ignored for node {}", + log_pubkey!(node_id) + ); }, msgs::ErrorAction::IgnoreDuplicateGossip => {}, msgs::ErrorAction::IgnoreError => { - log_debug!(logger, "Received a HandleError event to be ignored for node {}", log_pubkey!(node_id)); - }, + log_debug!( + logger, + "Received a HandleError event to be ignored for node {}", + log_pubkey!(node_id) + ); + }, msgs::ErrorAction::SendErrorMessage { ref msg } => { log_trace!(logger, "Handling SendErrorMessage HandleError event in peer_handler for node {} with message {}", log_pubkey!(node_id), msg.data); - self.enqueue_message(&mut *get_peer_for_forwarding!(&node_id)?, msg); + self.enqueue_message( + &mut *get_peer_for_forwarding!(&node_id)?, + msg, + ); }, - msgs::ErrorAction::SendWarningMessage { ref msg, ref log_level } => { + msgs::ErrorAction::SendWarningMessage { + ref msg, + ref log_level, + } => { log_given_level!(logger, *log_level, "Handling SendWarningMessage HandleError event in peer_handler for node {} with message {}", log_pubkey!(node_id), msg.data); - self.enqueue_message(&mut *get_peer_for_forwarding!(&node_id)?, msg); + self.enqueue_message( + &mut *get_peer_for_forwarding!(&node_id)?, + msg, + ); }, } }, @@ -2662,7 +3232,7 @@ impl { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendReplyChannelRange event in peer_handler for node {} with num_scids={} first_blocknum={} number_of_blocks={}, sync_complete={}", log_pubkey!(node_id), @@ -2671,47 +3241,60 @@ impl { log_gossip!(WithContext::from(&self.logger, Some(*node_id), None, None), "Handling SendGossipTimestampFilter event in peer_handler for node {} with first_timestamp={}, timestamp_range={}", log_pubkey!(node_id), msg.first_timestamp, msg.timestamp_range); self.enqueue_message(&mut *get_peer_for_forwarding!(node_id)?, msg); - } + }, } Some(()) }; - let chan_events = self.message_handler.chan_handler.get_and_clear_pending_msg_events(); + let chan_events = chan_handler.get_and_clear_pending_msg_events(); for event in chan_events { handle_event(event, true); } - let route_events = self.message_handler.route_handler.get_and_clear_pending_msg_events(); + let route_events = route_handler.get_and_clear_pending_msg_events(); for event in route_events { handle_event(event, false); } - let send_only_events = self.message_handler.send_only_message_handler.get_and_clear_pending_msg_events(); + let send_only_events = send_only_message_handler.get_and_clear_pending_msg_events(); for event in send_only_events { handle_event(event, false); } - let onion_msg_events = self.message_handler.onion_message_handler.get_and_clear_pending_msg_events(); + let onion_msg_events = onion_message_handler.get_and_clear_pending_msg_events(); for event in onion_msg_events { handle_event(event, false); } - for (node_id, msg) in self.message_handler.custom_message_handler.get_and_clear_pending_msg() { - if peers_to_disconnect.get(&node_id).is_some() { continue; } - self.enqueue_message(&mut *if let Some(peer) = get_peer_for_forwarding!(&node_id) { peer } else { continue; }, &msg); + for (node_id, msg) in custom_message_handler.get_and_clear_pending_msg() { + if peers_to_disconnect.get(&node_id).is_some() { + continue; + } + let mut peer = if let Some(peer) = get_peer_for_forwarding!(&node_id) { + peer + } else { + continue; + }; + self.enqueue_message(&mut peer, &msg); } for (descriptor, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if flush_read_disabled { peer.received_channel_announce_since_backlogged = false; } - self.do_attempt_write_data(&mut (*descriptor).clone(), &mut *peer, flush_read_disabled); + if flush_read_disabled { + peer.received_channel_announce_since_backlogged = false; + } + self.do_attempt_write_data( + &mut (*descriptor).clone(), + &mut *peer, + flush_read_disabled, + ); } } if !peers_to_disconnect.is_empty() { @@ -2723,7 +3306,8 @@ impl 0 && !peer.received_message_since_timer_tick) - || peer.awaiting_pong_timer_tick_intervals as u64 > - MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER as u64 * peers_lock.len() as u64 - { + let not_recently_active = peer.awaiting_pong_timer_tick_intervals > 0 + && !peer.received_message_since_timer_tick; + let reached_threshold_intervals = peer.awaiting_pong_timer_tick_intervals + as u64 + > MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER as u64 * peers_lock.len() as u64; + if not_recently_active || reached_threshold_intervals { descriptors_needing_disconnect.push(descriptor.clone()); break; } @@ -2901,14 +3511,15 @@ impl) { + pub fn broadcast_node_announcement( + &self, rgb: [u8; 3], alias: [u8; 32], mut addresses: Vec, + ) { if addresses.len() > 100 { panic!("More than half the message size was taken up by public addresses!"); } @@ -2977,9 +3591,10 @@ impl sig, Err(_) => { log_error!(self.logger, "Failed to generate signature for node_announcement"); @@ -2987,27 +3602,32 @@ impl bool { match type_id { - msgs::ChannelAnnouncement::TYPE | - msgs::ChannelUpdate::TYPE | - msgs::NodeAnnouncement::TYPE | - msgs::QueryChannelRange::TYPE | - msgs::ReplyChannelRange::TYPE | - msgs::QueryShortChannelIds::TYPE | - msgs::ReplyShortChannelIdsEnd::TYPE => true, - _ => false + msgs::ChannelAnnouncement::TYPE + | msgs::ChannelUpdate::TYPE + | msgs::NodeAnnouncement::TYPE + | msgs::QueryChannelRange::TYPE + | msgs::ReplyChannelRange::TYPE + | msgs::QueryShortChannelIds::TYPE + | msgs::ReplyShortChannelIdsEnd::TYPE => true, + _ => false, } } @@ -3015,18 +3635,18 @@ fn is_gossip_msg(type_id: u16) -> bool { mod tests { use super::*; - use crate::sign::{NodeSigner, Recipient}; use crate::io; - use crate::ln::types::ChannelId; - use crate::types::features::{InitFeatures, NodeFeatures}; + use crate::ln::msgs::{Init, LightningError, SocketAddress}; use crate::ln::peer_channel_encryptor::PeerChannelEncryptor; + use crate::ln::types::ChannelId; use crate::ln::{msgs, wire}; - use crate::ln::msgs::{Init, LightningError, SocketAddress}; + use crate::sign::{NodeSigner, Recipient}; + use crate::types::features::{InitFeatures, NodeFeatures}; use crate::util::test_utils; - use bitcoin::Network; use bitcoin::constants::ChainHash; - use bitcoin::secp256k1::{PublicKey, SecretKey, Secp256k1}; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + use bitcoin::Network; use crate::sync::{Arc, Mutex}; use core::convert::Infallible; @@ -3047,7 +3667,7 @@ mod tests { self.fd == other.fd } } - impl Eq for FileDescriptor { } + impl Eq for FileDescriptor {} impl core::hash::Hash for FileDescriptor { fn hash(&self, hasher: &mut H) { self.fd.hash(hasher) @@ -3064,7 +3684,9 @@ mod tests { } } - fn disconnect_socket(&mut self) { self.disconnect.store(true, Ordering::Release); } + fn disconnect_socket(&mut self) { + self.disconnect.store(true, Ordering::Release); + } } impl FileDescriptor { @@ -3093,16 +3715,15 @@ mod tests { impl TestCustomMessageHandler { fn new(features: InitFeatures) -> Self { - Self { - features, - conn_tracker: test_utils::ConnectionTracker::new(), - } + Self { features, conn_tracker: test_utils::ConnectionTracker::new() } } } impl wire::CustomMessageReader for TestCustomMessageHandler { type CustomMessage = Infallible; - fn read(&self, _: u16, _: &mut R) -> Result, msgs::DecodeError> { + fn read( + &self, _: u16, _: &mut R, + ) -> Result, msgs::DecodeError> { Ok(None) } } @@ -3112,17 +3733,23 @@ mod tests { unreachable!(); } - fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() } + fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { + Vec::new() + } fn peer_disconnected(&self, their_node_id: PublicKey) { self.conn_tracker.peer_disconnected(their_node_id); } - fn peer_connected(&self, their_node_id: PublicKey, _msg: &Init, _inbound: bool) -> Result<(), ()> { + fn peer_connected( + &self, their_node_id: PublicKey, _msg: &Init, _inbound: bool, + ) -> Result<(), ()> { self.conn_tracker.peer_connected(their_node_id) } - fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } + fn provided_node_features(&self) -> NodeFeatures { + NodeFeatures::empty() + } fn provided_init_features(&self, _: PublicKey) -> InitFeatures { self.features.clone() @@ -3138,15 +3765,15 @@ mod tests { feature_bits[32] = 0b00000001; InitFeatures::from_le_bytes(feature_bits) }; - cfgs.push( - PeerManagerCfg{ - chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), - logger: test_utils::TestLogger::with_id(i.to_string()), - routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler::new(features), - node_signer: test_utils::TestNodeSigner::new(node_secret), - } - ); + cfgs.push(PeerManagerCfg { + chan_handler: test_utils::TestChannelMessageHandler::new( + ChainHash::using_genesis_block(Network::Testnet), + ), + logger: test_utils::TestLogger::with_id(i.to_string()), + routing_handler: test_utils::TestRoutingMessageHandler::new(), + custom_handler: TestCustomMessageHandler::new(features), + node_signer: test_utils::TestNodeSigner::new(node_secret), + }); } cfgs @@ -3161,15 +3788,15 @@ mod tests { feature_bits[33 + i] = 0b00000001; InitFeatures::from_le_bytes(feature_bits) }; - cfgs.push( - PeerManagerCfg{ - chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)), - logger: test_utils::TestLogger::new(), - routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler::new(features), - node_signer: test_utils::TestNodeSigner::new(node_secret), - } - ); + cfgs.push(PeerManagerCfg { + chan_handler: test_utils::TestChannelMessageHandler::new( + ChainHash::using_genesis_block(Network::Testnet), + ), + logger: test_utils::TestLogger::new(), + routing_handler: test_utils::TestRoutingMessageHandler::new(), + custom_handler: TestCustomMessageHandler::new(features), + node_signer: test_utils::TestNodeSigner::new(node_secret), + }); } cfgs @@ -3181,38 +3808,76 @@ mod tests { let node_secret = SecretKey::from_slice(&[42 + i as u8; 32]).unwrap(); let features = InitFeatures::from_le_bytes(vec![0u8; 33]); let network = ChainHash::from(&[i as u8; 32]); - cfgs.push( - PeerManagerCfg{ - chan_handler: test_utils::TestChannelMessageHandler::new(network), - logger: test_utils::TestLogger::new(), - routing_handler: test_utils::TestRoutingMessageHandler::new(), - custom_handler: TestCustomMessageHandler::new(features), - node_signer: test_utils::TestNodeSigner::new(node_secret), - } - ); + cfgs.push(PeerManagerCfg { + chan_handler: test_utils::TestChannelMessageHandler::new(network), + logger: test_utils::TestLogger::new(), + routing_handler: test_utils::TestRoutingMessageHandler::new(), + custom_handler: TestCustomMessageHandler::new(features), + node_signer: test_utils::TestNodeSigner::new(node_secret), + }); } cfgs } - fn create_network<'a>(peer_count: usize, cfgs: &'a Vec) -> Vec> { + fn create_network<'a>( + peer_count: usize, cfgs: &'a Vec, + ) -> Vec< + PeerManager< + FileDescriptor, + &'a test_utils::TestChannelMessageHandler, + &'a test_utils::TestRoutingMessageHandler, + IgnoringMessageHandler, + &'a test_utils::TestLogger, + &'a TestCustomMessageHandler, + &'a test_utils::TestNodeSigner, + IgnoringMessageHandler, + >, + > { let mut peers = Vec::new(); for i in 0..peer_count { let ephemeral_bytes = [i as u8; 32]; let msg_handler = MessageHandler { - chan_handler: &cfgs[i].chan_handler, route_handler: &cfgs[i].routing_handler, - onion_message_handler: IgnoringMessageHandler {}, custom_message_handler: &cfgs[i].custom_handler, send_only_message_handler: IgnoringMessageHandler {}, + chan_handler: &cfgs[i].chan_handler, + route_handler: &cfgs[i].routing_handler, + onion_message_handler: IgnoringMessageHandler {}, + custom_message_handler: &cfgs[i].custom_handler, + send_only_message_handler: IgnoringMessageHandler {}, }; - let peer = PeerManager::new(msg_handler, 0, &ephemeral_bytes, &cfgs[i].logger, &cfgs[i].node_signer); + let peer = PeerManager::new( + msg_handler, + 0, + &ephemeral_bytes, + &cfgs[i].logger, + &cfgs[i].node_signer, + ); peers.push(peer); } peers } - fn try_establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor, Result, Result) { - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + type TestPeer<'a> = PeerManager< + FileDescriptor, + &'a test_utils::TestChannelMessageHandler, + &'a test_utils::TestRoutingMessageHandler, + IgnoringMessageHandler, + &'a test_utils::TestLogger, + &'a TestCustomMessageHandler, + &'a test_utils::TestNodeSigner, + IgnoringMessageHandler, + >; + + fn try_establish_connection<'a>( + peer_a: &TestPeer<'a>, peer_b: &TestPeer<'a>, + ) -> ( + FileDescriptor, + FileDescriptor, + Result, + Result, + ) { + let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; + let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; static FD_COUNTER: AtomicUsize = AtomicUsize::new(0); let fd = FD_COUNTER.fetch_add(1, Ordering::Relaxed) as u16; @@ -3221,7 +3886,8 @@ mod tests { let mut fd_a = FileDescriptor::new(fd); let mut fd_b = FileDescriptor::new(fd); - let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); + let initial_data = + peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); peer_a.process_events(); @@ -3240,10 +3906,11 @@ mod tests { (fd_a, fd_b, a_refused, b_refused) } - - fn establish_connection<'a>(peer_a: &PeerManager, peer_b: &PeerManager) -> (FileDescriptor, FileDescriptor) { - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; + fn establish_connection<'a>( + peer_a: &TestPeer<'a>, peer_b: &TestPeer<'a>, + ) -> (FileDescriptor, FileDescriptor) { + let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; + let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); let id_b = peer_b.node_signer.get_node_id(Recipient::Node).unwrap(); @@ -3276,62 +3943,83 @@ mod tests { let peers = Arc::new(create_network(2, unsafe { &*(&*cfgs as *const _) as &'static _ })); let start_time = std::time::Instant::now(); - macro_rules! spawn_thread { ($id: expr) => { { - let peers = Arc::clone(&peers); - let cfgs = Arc::clone(&cfgs); - std::thread::spawn(move || { - let mut ctr = 0; - while start_time.elapsed() < std::time::Duration::from_secs(1) { - let id_a = peers[0].node_signer.get_node_id(Recipient::Node).unwrap(); - let mut fd_a = FileDescriptor::new($id + ctr * 3); - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; - let mut fd_b = FileDescriptor::new($id + ctr * 3); - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; - let initial_data = peers[1].new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); - peers[0].new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); - if peers[0].read_event(&mut fd_a, &initial_data).is_err() { break; } - + macro_rules! spawn_thread { + ($id: expr) => {{ + let peers = Arc::clone(&peers); + let cfgs = Arc::clone(&cfgs); + std::thread::spawn(move || { + let mut ctr = 0; while start_time.elapsed() < std::time::Duration::from_secs(1) { - peers[0].process_events(); - if fd_a.disconnect.load(Ordering::Acquire) { break; } - let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); - if peers[1].read_event(&mut fd_b, &a_data).is_err() { break; } - - peers[1].process_events(); - if fd_b.disconnect.load(Ordering::Acquire) { break; } - let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); - if peers[0].read_event(&mut fd_a, &b_data).is_err() { break; } - - cfgs[0].chan_handler.pending_events.lock().unwrap() - .push(MessageSendEvent::SendShutdown { - node_id: peers[1].node_signer.get_node_id(Recipient::Node).unwrap(), + let id_a = peers[0].node_signer.get_node_id(Recipient::Node).unwrap(); + let mut fd_a = FileDescriptor::new($id + ctr * 3); + let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; + let mut fd_b = FileDescriptor::new($id + ctr * 3); + let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; + let initial_data = peers[1] + .new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())) + .unwrap(); + peers[0] + .new_inbound_connection(fd_a.clone(), Some(addr_b.clone())) + .unwrap(); + if peers[0].read_event(&mut fd_a, &initial_data).is_err() { + break; + } + + while start_time.elapsed() < std::time::Duration::from_secs(1) { + peers[0].process_events(); + if fd_a.disconnect.load(Ordering::Acquire) { + break; + } + let a_data = fd_a.outbound_data.lock().unwrap().split_off(0); + if peers[1].read_event(&mut fd_b, &a_data).is_err() { + break; + } + + peers[1].process_events(); + if fd_b.disconnect.load(Ordering::Acquire) { + break; + } + let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); + if peers[0].read_event(&mut fd_a, &b_data).is_err() { + break; + } + + let node_id_1 = + peers[1].node_signer.get_node_id(Recipient::Node).unwrap(); + let msg_event_1 = MessageSendEvent::SendShutdown { + node_id: node_id_1, msg: msgs::Shutdown { channel_id: ChannelId::new_zero(), scriptpubkey: bitcoin::ScriptBuf::new(), }, - }); - cfgs[1].chan_handler.pending_events.lock().unwrap() - .push(MessageSendEvent::SendShutdown { - node_id: peers[0].node_signer.get_node_id(Recipient::Node).unwrap(), + }; + cfgs[0].chan_handler.pending_events.lock().unwrap().push(msg_event_1); + + let node_id_0 = + peers[0].node_signer.get_node_id(Recipient::Node).unwrap(); + let msg_event_0 = MessageSendEvent::SendShutdown { + node_id: node_id_0, msg: msgs::Shutdown { channel_id: ChannelId::new_zero(), scriptpubkey: bitcoin::ScriptBuf::new(), }, - }); + }; + cfgs[1].chan_handler.pending_events.lock().unwrap().push(msg_event_0); - if ctr % 2 == 0 { - peers[0].timer_tick_occurred(); - peers[1].timer_tick_occurred(); + if ctr % 2 == 0 { + peers[0].timer_tick_occurred(); + peers[1].timer_tick_occurred(); + } } - } - peers[0].socket_disconnected(&fd_a); - peers[1].socket_disconnected(&fd_b); - ctr += 1; - std::thread::sleep(std::time::Duration::from_micros(1)); - } - }) - } } } + peers[0].socket_disconnected(&fd_a); + peers[1].socket_disconnected(&fd_b); + ctr += 1; + std::thread::sleep(std::time::Duration::from_micros(1)); + } + }) + }}; + } let thrd_a = spawn_thread!(1); let thrd_b = spawn_thread!(2); @@ -3350,10 +4038,11 @@ mod tests { for (peer_a, peer_b) in peer_pairs.iter() { let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); let mut fd_a = FileDescriptor::new(1); - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; let mut fd_b = FileDescriptor::new(1); - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; - let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); + let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; + let initial_data = + peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); peer_a.process_events(); @@ -3380,10 +4069,11 @@ mod tests { for (peer_a, peer_b) in peer_pairs.iter() { let id_a = peer_a.node_signer.get_node_id(Recipient::Node).unwrap(); let mut fd_a = FileDescriptor::new(1); - let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000}; + let addr_a = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1000 }; let mut fd_b = FileDescriptor::new(1); - let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001}; - let initial_data = peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); + let addr_b = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1001 }; + let initial_data = + peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap(); peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap(); assert_eq!(peer_a.read_event(&mut fd_a, &initial_data).unwrap(), false); peer_a.process_events(); @@ -3406,7 +4096,10 @@ mod tests { let cfgs = create_peermgr_cfgs(2); let peers = create_network(2, &cfgs); establish_connection(&peers[0], &peers[1]); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } let their_id = peers[1].node_signer.get_node_id(Recipient::Node).unwrap(); cfgs[0].chan_handler.pending_events.lock().unwrap().push(MessageSendEvent::HandleError { @@ -3415,7 +4108,10 @@ mod tests { }); peers[0].process_events(); - assert_eq!(peers[0].peers.read().unwrap().len(), 0); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 0); + } } #[test] @@ -3423,18 +4119,30 @@ mod tests { // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and // push a message from one peer to another. let cfgs = create_peermgr_cfgs(2); - let a_chan_handler = test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)); - let b_chan_handler = test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)); + let a_chan_handler = test_utils::TestChannelMessageHandler::new( + ChainHash::using_genesis_block(Network::Testnet), + ); + let b_chan_handler = test_utils::TestChannelMessageHandler::new( + ChainHash::using_genesis_block(Network::Testnet), + ); let mut peers = create_network(2, &cfgs); let (fd_a, mut fd_b) = establish_connection(&peers[0], &peers[1]); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } let their_id = peers[1].node_signer.get_node_id(Recipient::Node).unwrap(); - let msg = msgs::Shutdown { channel_id: ChannelId::from_bytes([42; 32]), scriptpubkey: bitcoin::ScriptBuf::new() }; - a_chan_handler.pending_events.lock().unwrap().push(MessageSendEvent::SendShutdown { - node_id: their_id, msg: msg.clone() - }); + let msg = msgs::Shutdown { + channel_id: ChannelId::from_bytes([42; 32]), + scriptpubkey: bitcoin::ScriptBuf::new(), + }; + a_chan_handler + .pending_events + .lock() + .unwrap() + .push(MessageSendEvent::SendShutdown { node_id: their_id, msg: msg.clone() }); peers[0].message_handler.chan_handler = &a_chan_handler; b_chan_handler.expect_receive_msg(wire::Message::Shutdown(msg)); @@ -3456,11 +4164,12 @@ mod tests { let peers = create_network(2, &cfgs); let mut fd_dup = FileDescriptor::new(3); - let addr_dup = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1003}; + let addr_dup = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 1], port: 1003 }; let id_a = cfgs[0].node_signer.get_node_id(Recipient::Node).unwrap(); peers[0].new_inbound_connection(fd_dup.clone(), Some(addr_dup.clone())).unwrap(); - let mut dup_encryptor = PeerChannelEncryptor::new_outbound(id_a, SecretKey::from_slice(&[42; 32]).unwrap()); + let mut dup_encryptor = + PeerChannelEncryptor::new_outbound(id_a, SecretKey::from_slice(&[42; 32]).unwrap()); let initial_data = dup_encryptor.get_act_one(&peers[1].secp_ctx); assert_eq!(peers[0].read_event(&mut fd_dup, &initial_data).unwrap(), false); peers[0].process_events(); @@ -3482,10 +4191,16 @@ mod tests { let cfgs = create_peermgr_cfgs(2); let peers = create_network(2, &cfgs); establish_connection(&peers[0], &peers[1]); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } peers[0].disconnect_all_peers(); - assert_eq!(peers[0].peers.read().unwrap().len(), 0); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 0); + } } #[test] @@ -3494,17 +4209,26 @@ mod tests { let cfgs = create_peermgr_cfgs(2); let peers = create_network(2, &cfgs); establish_connection(&peers[0], &peers[1]); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } // peers[0] awaiting_pong is set to true, but the Peer is still connected peers[0].timer_tick_occurred(); peers[0].process_events(); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } // Since timer_tick_occurred() is called again when awaiting_pong is true, all Peers are disconnected peers[0].timer_tick_occurred(); peers[0].process_events(); - assert_eq!(peers[0].peers.read().unwrap().len(), 0); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 0); + } } fn do_test_peer_connected_error_disconnects(handler: usize) { @@ -3514,16 +4238,20 @@ mod tests { let cfgs = create_peermgr_cfgs(2); let peers = create_network(2, &cfgs); + let chan_handler = peers[handler & 1].message_handler.chan_handler; + let route_handler = peers[handler & 1].message_handler.route_handler; + let custom_message_handler = peers[handler & 1].message_handler.custom_message_handler; + match handler & !1 { 0 => { - peers[handler & 1].message_handler.chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release); - } + chan_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + }, 2 => { - peers[handler & 1].message_handler.route_handler.conn_tracker.fail_connections.store(true, Ordering::Release); - } + route_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + }, 4 => { - peers[handler & 1].message_handler.custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release); - } + custom_message_handler.conn_tracker.fail_connections.store(true, Ordering::Release); + }, _ => panic!(), } let (_sd1, _sd2, a_refused, b_refused) = try_establish_connection(&peers[0], &peers[1]); @@ -3535,13 +4263,15 @@ mod tests { assert!(peers[1].list_peers().is_empty()); } // At least one message handler should have seen the connection. - assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) || - peers[handler & 1].message_handler.route_handler.conn_tracker.had_peers.load(Ordering::Acquire) || - peers[handler & 1].message_handler.custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire)); + assert!( + chan_handler.conn_tracker.had_peers.load(Ordering::Acquire) + || route_handler.conn_tracker.had_peers.load(Ordering::Acquire) + || custom_message_handler.conn_tracker.had_peers.load(Ordering::Acquire) + ); // And both message handlers doing tracking should see the disconnection - assert!(peers[handler & 1].message_handler.chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); - assert!(peers[handler & 1].message_handler.route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); - assert!(peers[handler & 1].message_handler.custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(chan_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(route_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); + assert!(custom_message_handler.conn_tracker.connected_peers.lock().unwrap().is_empty()); } #[test] @@ -3571,7 +4301,7 @@ mod tests { // Make each peer to read the messages that the other peer just wrote to them. Note that // due to the max-message-before-ping limits this may take a few iterations to complete. - for _ in 0..150/super::BUFFER_DRAIN_MSGS_PER_TICK + 1 { + for _ in 0..150 / super::BUFFER_DRAIN_MSGS_PER_TICK + 1 { peers[1].process_events(); let a_read_data = fd_b.outbound_data.lock().unwrap().split_off(0); assert!(!a_read_data.is_empty()); @@ -3584,7 +4314,11 @@ mod tests { peers[1].read_event(&mut fd_b, &b_read_data).unwrap(); peers[0].process_events(); - assert_eq!(fd_a.outbound_data.lock().unwrap().len(), 0, "Until A receives data, it shouldn't send more messages"); + assert_eq!( + fd_a.outbound_data.lock().unwrap().len(), + 0, + "Until A receives data, it shouldn't send more messages" + ); } // Check that each peer has received the expected number of channel updates and channel @@ -3611,9 +4345,15 @@ mod tests { peers[0].new_inbound_connection(fd_a.clone(), None).unwrap(); // If we get a single timer tick before completion, that's fine - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } peers[0].timer_tick_occurred(); - assert_eq!(peers[0].peers.read().unwrap().len(), 1); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } assert_eq!(peers[0].read_event(&mut fd_a, &initial_data).unwrap(), false); peers[0].process_events(); @@ -3623,7 +4363,10 @@ mod tests { // ...but if we get a second timer tick, we should disconnect the peer peers[0].timer_tick_occurred(); - assert_eq!(peers[0].peers.read().unwrap().len(), 0); + { + let peers_len = peers[0].peers.read().unwrap().len(); + assert_eq!(peers_len, 0); + } let b_data = fd_b.outbound_data.lock().unwrap().split_off(0); assert!(peers[0].read_event(&mut fd_a, &b_data).is_err()); @@ -3636,22 +4379,26 @@ mod tests { // two of the noise handshake along with our init message but before we receive their init // message. let logger = test_utils::TestLogger::new(); - let node_signer_a = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[42; 32]).unwrap()); - let node_signer_b = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[43; 32]).unwrap()); - let peer_a = PeerManager::new(MessageHandler { + let node_signer_a = + test_utils::TestNodeSigner::new(SecretKey::from_slice(&[42; 32]).unwrap()); + let node_signer_b = + test_utils::TestNodeSigner::new(SecretKey::from_slice(&[43; 32]).unwrap()); + let message_handler_a = MessageHandler { chan_handler: ErroringMessageHandler::new(), route_handler: IgnoringMessageHandler {}, onion_message_handler: IgnoringMessageHandler {}, custom_message_handler: IgnoringMessageHandler {}, send_only_message_handler: IgnoringMessageHandler {}, - }, 0, &[0; 32], &logger, &node_signer_a); - let peer_b = PeerManager::new(MessageHandler { + }; + let message_handler_b = MessageHandler { chan_handler: ErroringMessageHandler::new(), route_handler: IgnoringMessageHandler {}, onion_message_handler: IgnoringMessageHandler {}, custom_message_handler: IgnoringMessageHandler {}, send_only_message_handler: IgnoringMessageHandler {}, - }, 0, &[1; 32], &logger, &node_signer_b); + }; + let peer_a = PeerManager::new(message_handler_a, 0, &[0; 32], &logger, &node_signer_a); + let peer_b = PeerManager::new(message_handler_b, 0, &[1; 32], &logger, &node_signer_b); let a_id = node_signer_a.get_node_id(Recipient::Node).unwrap(); let mut fd_a = FileDescriptor::new(1); @@ -3672,21 +4419,48 @@ mod tests { peer_b.timer_tick_occurred(); let act_three_with_init_b = fd_b.outbound_data.lock().unwrap().split_off(0); - assert!(!peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete()); + { + let peer_a_lock = peer_a.peers.read().unwrap(); + let handshake_complete = + peer_a_lock.get(&fd_a).unwrap().lock().unwrap().handshake_complete(); + assert!(!handshake_complete); + } + assert_eq!(peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(), false); peer_a.process_events(); - assert!(peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete()); + + { + let peer_a_lock = peer_a.peers.read().unwrap(); + let handshake_complete = + peer_a_lock.get(&fd_a).unwrap().lock().unwrap().handshake_complete(); + assert!(handshake_complete); + } let init_a = fd_a.outbound_data.lock().unwrap().split_off(0); assert!(!init_a.is_empty()); - assert!(!peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete()); + { + let peer_b_lock = peer_b.peers.read().unwrap(); + let handshake_complete = + peer_b_lock.get(&fd_b).unwrap().lock().unwrap().handshake_complete(); + assert!(!handshake_complete); + } + assert_eq!(peer_b.read_event(&mut fd_b, &init_a).unwrap(), false); peer_b.process_events(); - assert!(peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete()); + + { + let peer_b_lock = peer_b.peers.read().unwrap(); + let handshake_complete = + peer_b_lock.get(&fd_b).unwrap().lock().unwrap().handshake_complete(); + assert!(handshake_complete); + } // Make sure we're still connected. - assert_eq!(peer_b.peers.read().unwrap().len(), 1); + { + let peers_len = peer_b.peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } // B should send a ping on the first timer tick after `handshake_complete`. assert!(fd_b.outbound_data.lock().unwrap().split_off(0).is_empty()); @@ -3698,10 +4472,13 @@ mod tests { { let peers = peer_a.peers.read().unwrap(); let mut peer_b = peers.get(&fd_a).unwrap().lock().unwrap(); - peer_a.enqueue_message(&mut peer_b, &msgs::WarningMessage { - channel_id: ChannelId([0; 32]), - data: "no disconnect plz".to_string(), - }); + peer_a.enqueue_message( + &mut peer_b, + &msgs::WarningMessage { + channel_id: ChannelId([0; 32]), + data: "no disconnect plz".to_string(), + }, + ); } peer_a.process_events(); let msg = fd_a.outbound_data.lock().unwrap().split_off(0); @@ -3717,11 +4494,17 @@ mod tests { peer_b.timer_tick_occurred(); send_warning(); } - assert_eq!(peer_b.peers.read().unwrap().len(), 1); + { + let peers_len = peer_b.peers.read().unwrap().len(); + assert_eq!(peers_len, 1); + } // One more tick should enforce the pong timeout. peer_b.timer_tick_occurred(); - assert_eq!(peer_b.peers.read().unwrap().len(), 0); + { + let peers_len = peer_b.peers.read().unwrap().len(); + assert_eq!(peers_len, 0); + } } #[test] @@ -3735,24 +4518,26 @@ mod tests { let peers = create_network(2, &cfgs); let (mut fd_a, mut fd_b) = establish_connection(&peers[0], &peers[1]); - macro_rules! drain_queues { () => { - loop { - peers[0].process_events(); - peers[1].process_events(); + macro_rules! drain_queues { + () => { + loop { + peers[0].process_events(); + peers[1].process_events(); - let msg = fd_a.outbound_data.lock().unwrap().split_off(0); - if !msg.is_empty() { - assert_eq!(peers[1].read_event(&mut fd_b, &msg).unwrap(), false); - continue; - } - let msg = fd_b.outbound_data.lock().unwrap().split_off(0); - if !msg.is_empty() { - assert_eq!(peers[0].read_event(&mut fd_a, &msg).unwrap(), false); - continue; + let msg = fd_a.outbound_data.lock().unwrap().split_off(0); + if !msg.is_empty() { + assert_eq!(peers[1].read_event(&mut fd_b, &msg).unwrap(), false); + continue; + } + let msg = fd_b.outbound_data.lock().unwrap().split_off(0); + if !msg.is_empty() { + assert_eq!(peers[0].read_event(&mut fd_a, &msg).unwrap(), false); + continue; + } + break; } - break; - } - } } + }; + } // First, make sure all pending messages have been processed and queues drained. drain_queues!(); @@ -3760,10 +4545,7 @@ mod tests { let secp_ctx = Secp256k1::new(); let key = SecretKey::from_slice(&[1; 32]).unwrap(); let msg = channel_announcement(&key, &key, ChannelFeatures::empty(), 42, &secp_ctx); - let msg_ev = MessageSendEvent::BroadcastChannelAnnouncement { - msg, - update_msg: None, - }; + let msg_ev = MessageSendEvent::BroadcastChannelAnnouncement { msg, update_msg: None }; fd_a.hang_writes.store(true, Ordering::Relaxed); @@ -3774,15 +4556,24 @@ mod tests { peers[0].process_events(); } - assert_eq!(peers[0].peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.len(), - OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP); + { + let peer_a_lock = peers[0].peers.read().unwrap(); + let buf_len = + peer_a_lock.get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.len(); + assert_eq!(buf_len, OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP); + } // Check that if a broadcast message comes in from the channel handler (i.e. it is an // announcement for our own channel), it gets queued anyway. cfgs[0].chan_handler.pending_events.lock().unwrap().push(msg_ev); peers[0].process_events(); - assert_eq!(peers[0].peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.len(), - OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP + 1); + + { + let peer_a_lock = peers[0].peers.read().unwrap(); + let buf_len = + peer_a_lock.get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.len(); + assert_eq!(buf_len, OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP + 1); + } // Finally, deliver all the messages and make sure we got the right count. Note that there // was an extra message that had already moved from the broadcast queue to the encrypted @@ -3792,101 +4583,127 @@ mod tests { peers[0].write_buffer_space_avail(&mut fd_a).unwrap(); drain_queues!(); - assert!(peers[0].peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.is_empty()); - assert_eq!(cfgs[1].routing_handler.chan_anns_recvd.load(Ordering::Relaxed), - OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP + 2); + { + let peer_a_lock = peers[0].peers.read().unwrap(); + let empty = + peer_a_lock.get(&fd_a).unwrap().lock().unwrap().gossip_broadcast_buffer.is_empty(); + assert!(empty); + } + + assert_eq!( + cfgs[1].routing_handler.chan_anns_recvd.load(Ordering::Relaxed), + OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP + 2 + ); } #[test] - fn test_filter_addresses(){ + fn test_filter_addresses() { // Tests the filter_addresses function. // For (10/8) - let ip_address = SocketAddress::TcpIpV4{addr: [10, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [10, 0, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [10, 0, 255, 201], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [10, 0, 255, 201], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [10, 255, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [10, 255, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (0/8) - let ip_address = SocketAddress::TcpIpV4{addr: [0, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [0, 0, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [0, 0, 255, 187], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [0, 0, 255, 187], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [0, 255, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [0, 255, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (100.64/10) - let ip_address = SocketAddress::TcpIpV4{addr: [100, 64, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [100, 64, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [100, 78, 255, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [100, 78, 255, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [100, 127, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [100, 127, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (127/8) - let ip_address = SocketAddress::TcpIpV4{addr: [127, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [127, 0, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [127, 65, 73, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [127, 65, 73, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [127, 255, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [127, 255, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (169.254/16) - let ip_address = SocketAddress::TcpIpV4{addr: [169, 254, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [169, 254, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [169, 254, 221, 101], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [169, 254, 221, 101], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [169, 254, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [169, 254, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (172.16/12) - let ip_address = SocketAddress::TcpIpV4{addr: [172, 16, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [172, 16, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [172, 27, 101, 23], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [172, 27, 101, 23], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [172, 31, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [172, 31, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (192.168/16) - let ip_address = SocketAddress::TcpIpV4{addr: [192, 168, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 168, 0, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [192, 168, 205, 159], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 168, 205, 159], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [192, 168, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 168, 255, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (192.88.99/24) - let ip_address = SocketAddress::TcpIpV4{addr: [192, 88, 99, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 88, 99, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [192, 88, 99, 140], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 88, 99, 140], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV4{addr: [192, 88, 99, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [192, 88, 99, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For other IPv4 addresses - let ip_address = SocketAddress::TcpIpV4{addr: [188, 255, 99, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [188, 255, 99, 0], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); - let ip_address = SocketAddress::TcpIpV4{addr: [123, 8, 129, 14], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [123, 8, 129, 14], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); - let ip_address = SocketAddress::TcpIpV4{addr: [2, 88, 9, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV4 { addr: [2, 88, 9, 255], port: 1000 }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); // For (2000::/3) - let ip_address = SocketAddress::TcpIpV6{addr: [32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); - let ip_address = SocketAddress::TcpIpV6{addr: [45, 34, 209, 190, 0, 123, 55, 34, 0, 0, 3, 27, 201, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [45, 34, 209, 190, 0, 123, 55, 34, 0, 0, 3, 27, 201, 0, 0, 0], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); - let ip_address = SocketAddress::TcpIpV6{addr: [63, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [63, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), Some(ip_address.clone())); // For other IPv6 addresses - let ip_address = SocketAddress::TcpIpV6{addr: [24, 240, 12, 32, 0, 0, 0, 0, 20, 97, 0, 32, 121, 254, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [24, 240, 12, 32, 0, 0, 0, 0, 20, 97, 0, 32, 121, 254, 0, 0], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV6{addr: [68, 23, 56, 63, 0, 0, 2, 7, 75, 109, 0, 39, 0, 0, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [68, 23, 56, 63, 0, 0, 2, 7, 75, 109, 0, 39, 0, 0, 0, 0], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); - let ip_address = SocketAddress::TcpIpV6{addr: [101, 38, 140, 230, 100, 0, 30, 98, 0, 26, 0, 0, 57, 96, 0, 0], port: 1000}; + let ip_address = SocketAddress::TcpIpV6 { + addr: [101, 38, 140, 230, 100, 0, 30, 98, 0, 26, 0, 0, 57, 96, 0, 0], + port: 1000, + }; assert_eq!(filter_addresses(Some(ip_address.clone())), None); // For (None) @@ -3908,7 +4725,9 @@ mod tests { // sure we observe a value greater than one at least once. let cfg = Arc::new(create_peermgr_cfgs(1)); // Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }. - let peer = Arc::new(create_network(1, unsafe { &*(&*cfg as *const _) as &'static _ }).pop().unwrap()); + let peer = Arc::new( + create_network(1, unsafe { &*(&*cfg as *const _) as &'static _ }).pop().unwrap(), + ); let end_time = Instant::now() + Duration::from_millis(100); let observed_loop = Arc::new(AtomicBool::new(false)); @@ -3917,9 +4736,13 @@ mod tests { let thread_observed_loop = Arc::clone(&observed_loop); move || { while Instant::now() < end_time || !thread_observed_loop.load(Ordering::Acquire) { - test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER.with(|val| val.store(0, Ordering::Relaxed)); + test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER + .with(|val| val.store(0, Ordering::Relaxed)); thread_peer.process_events(); - if test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER.with(|val| val.load(Ordering::Relaxed)) > 1 { + if test_utils::TestChannelMessageHandler::MESSAGE_FETCH_COUNTER + .with(|val| val.load(Ordering::Relaxed)) + > 1 + { thread_observed_loop.store(true, Ordering::Release); return; }