diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index 47d9833f0..0a5ccac36 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -344,9 +344,6 @@ impl InitialReceiveTransition { } } -#[derive(Clone, Debug, uniffi::Object)] -pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); - /// Primitive representation of a transaction output for the FFI boundary. #[derive(Clone, Debug, serde::Serialize, serde::Deserialize, uniffi::Record)] pub struct TxOut { @@ -458,6 +455,9 @@ impl From for Weight { fn from(value: payjoin::bitcoin::Weight) -> Self { Weight { weight_units: value.to_wu() } } } +#[derive(Clone, Debug, uniffi::Object)] +pub struct ReceiverBuilder(payjoin::receive::v2::ReceiverBuilder); + #[uniffi::export] impl ReceiverBuilder { /// Creates a new [`Initialized`] with the provided parameters. @@ -728,6 +728,23 @@ impl UncheckedOriginalPayload { ))))) } + pub fn extract_tx_to_check_broadcast_suitability(&self) -> Vec { + payjoin::bitcoin::consensus::encode::serialize( + &self.0.clone().extract_tx_to_check_broadcast_suitability(), + ) + } + + pub fn apply_broadcast_suitability( + &self, + min_fee_rate_sat_per_kwu: Option, + can_broadcast: bool, + ) -> Result { + let min_fee_rate = validate_fee_rate_sat_per_kwu_opt(min_fee_rate_sat_per_kwu)?; + Ok(UncheckedOriginalPayloadTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_broadcast_suitability(min_fee_rate, can_broadcast), + ))))) + } + /// Call this method if the only way to initiate a Payjoin with this receiver /// requires manual intervention, as in most consumer wallets. /// @@ -740,6 +757,34 @@ impl UncheckedOriginalPayload { } } +#[derive(Debug, uniffi::Object)] +pub struct InputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputOwnedTaggedReference(self.0.mark(result))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct InputOwnedTaggedReference( + payjoin::receive::TaggedReference, +); + +#[uniffi::export] +impl InputOwnedTaggedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + + pub fn get_result(&self) -> bool { self.0.get_result() } + + pub fn get_index(&self) -> u64 { self.0.get_index() as u64 } +} + #[derive(Clone, uniffi::Object)] pub struct MaybeInputsOwned(payjoin::receive::v2::Receiver); @@ -783,6 +828,7 @@ impl MaybeInputsOwned { &self.0.clone().extract_tx_to_schedule_broadcast(), ) } + pub fn check_inputs_not_owned( &self, is_owned: Arc, @@ -793,6 +839,58 @@ impl MaybeInputsOwned { }), )))) } + + pub fn get_input_script_refs(&self) -> Result>, ReceiverError> { + self.0 + .clone() + .get_input_script_refs() + .map(|iter| { + iter.map(|input_script_ref| Arc::new(InputOwnedReference(input_script_ref))) + .collect::>() + }) + .map_err(ReceiverError::from) + } + + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: Vec>, + ) -> MaybeInputsOwnedTransition { + MaybeInputsOwnedTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_owned_checks(checked_input_scripts.into_iter().map(|r| { + Arc::try_unwrap(r) + .expect("InputOwnedTaggedReference Arc should have a single owner") + .0 + })), + )))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct InputSeenReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl InputSeenReference { + pub fn get_value(&self) -> OutPoint { self.0.get_value().into() } + + pub fn mark(&self, result: bool) -> Arc { + Arc::new(InputSeenTaggedReference(self.0.mark(result))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct InputSeenTaggedReference( + payjoin::receive::TaggedReference, +); + +#[uniffi::export] +impl InputSeenTaggedReference { + pub fn get_value(&self) -> OutPoint { self.0.get_value().into() } + + pub fn get_result(&self) -> bool { self.0.get_result() } + + pub fn get_index(&self) -> u64 { self.0.get_index() as u64 } } #[derive(Clone, uniffi::Object)] @@ -844,6 +942,58 @@ impl MaybeInputsSeen { }), )))) } + + pub fn get_input_outpoint_refs(&self) -> Vec> { + self.0 + .clone() + .get_input_outpoint_refs() + .map(|input_outpoint_ref| Arc::new(InputSeenReference(input_outpoint_ref))) + .collect::>() + } + + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: Vec>, + ) -> MaybeInputsSeenTransition { + MaybeInputsSeenTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_input_seen_checks(checked_input_outpoints.into_iter().map(|r| { + Arc::try_unwrap(r) + .expect("InputSeenTaggedReference Arc should have a single owner") + .0 + })), + )))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct OutputOwnedReference( + payjoin::receive::Reference, +); + +#[uniffi::export] +impl OutputOwnedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + + pub fn mark(&self, result: bool) -> Arc { + Arc::new(OutputOwnedTaggedReference(self.0.mark(result))) + } +} + +#[derive(Debug, uniffi::Object)] +pub struct OutputOwnedTaggedReference( + payjoin::receive::TaggedReference< + payjoin::bitcoin::ScriptBuf, + payjoin::receive::OutputOwnedTag, + >, +); + +#[uniffi::export] +impl OutputOwnedTaggedReference { + pub fn get_value(&self) -> Vec { self.0.get_value().to_bytes() } + + pub fn get_result(&self) -> bool { self.0.get_result() } + + pub fn get_index(&self) -> u64 { self.0.get_index() as u64 } } /// The receiver has not yet identified which outputs belong to the receiver. @@ -893,6 +1043,27 @@ impl OutputsUnknown { }), )))) } + + pub fn get_output_script_refs(&self) -> Vec> { + self.0 + .clone() + .get_output_script_refs() + .map(|output_script_ref| Arc::new(OutputOwnedReference(output_script_ref))) + .collect::>() + } + + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: Vec>, + ) -> OutputsUnknownTransition { + OutputsUnknownTransition(Arc::new(RwLock::new(Some( + self.0.clone().apply_output_owned_checks(checked_output_scripts.into_iter().map(|r| { + Arc::try_unwrap(r) + .expect("OutputOwnedTaggedReference Arc should have a single owner") + .0 + })), + )))) + } } #[derive(uniffi::Object)] @@ -1177,6 +1348,14 @@ impl ProvisionalProposal { } pub fn psbt_to_sign(&self) -> String { self.0.clone().psbt_to_sign().to_string() } + + pub fn finalize_signed_proposal(&self, signed_psbt: String) -> ProvisionalProposalTransition { + ProvisionalProposalTransition(Arc::new(RwLock::new(Some( + self.0.clone().finalize_proposal(|_| { + Ok(Psbt::from_str(&signed_psbt).map_err(ImplementationError::new)?) + }), + )))) + } } #[derive(Clone, uniffi::Object)] @@ -1433,6 +1612,29 @@ impl Monitor { .map_err(|e| ImplementationError::new(e).into()) }))))) } + pub fn extract_fallback_txid(&self) -> String { + self.0.clone().extract_fallback_txid().to_string() + } + + pub fn extract_payjoin_proposal_txid(&self) -> String { + self.0.clone().extract_payjoin_proposal_txid().to_string() + } + + pub fn check_fallback_monitorable(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().check_fallback_monitorable())))) + } + + pub fn fallback_tx_exists(&self) -> MonitorTransition { + MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().fallback_tx_exists())))) + } + + pub fn payjoin_tx_exists( + &self, + payjoin_tx: Vec, + ) -> Result { + let tx = try_deserialize_tx(payjoin_tx)?; + Ok(MonitorTransition(Arc::new(RwLock::new(Some(self.0.clone().payjoin_tx_exists(tx)))))) + } } /// Session persister that should save and load events as JSON strings. diff --git a/payjoin/src/core/receive/common/mod.rs b/payjoin/src/core/receive/common/mod.rs index 9a93fb8f0..635c96159 100644 --- a/payjoin/src/core/receive/common/mod.rs +++ b/payjoin/src/core/receive/common/mod.rs @@ -863,7 +863,7 @@ mod tests { .calculate_psbt_context_with_fee_range(None, None) .expect("Contributed inputs should allow for valid fee contributions"); let payjoin_proposal = - psbt_context.finalize_proposal(|_| Ok(processed_psbt.clone())).expect("Valid psbt"); + psbt_context.finalize_signed_proposal(processed_psbt.clone()).expect("Valid psbt"); assert!(payjoin_proposal.xpub.is_empty()); diff --git a/payjoin/src/core/receive/error.rs b/payjoin/src/core/receive/error.rs index d7bd23e39..bed31cec1 100644 --- a/payjoin/src/core/receive/error.rs +++ b/payjoin/src/core/receive/error.rs @@ -3,6 +3,7 @@ use std::{error, fmt}; use crate::error_codes::ErrorCode::{ self, NotEnoughMoney, OriginalPsbtRejected, Unavailable, VersionUnsupported, }; +use crate::ImplementationError; /// The top-level error type for the payjoin receiver #[derive(Debug)] @@ -29,6 +30,10 @@ impl From for Error { fn from(e: ProtocolError) -> Self { Error::Protocol(e) } } +impl From for Error { + fn from(e: ImplementationError) -> Self { Error::Implementation(e) } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/payjoin/src/core/receive/mod.rs b/payjoin/src/core/receive/mod.rs index f102d743a..5e9e97604 100644 --- a/payjoin/src/core/receive/mod.rs +++ b/payjoin/src/core/receive/mod.rs @@ -10,6 +10,7 @@ //! version 1, refer to the `receive::v1` module documentation after enabling the `v1` feature. use std::collections::BTreeMap; +use std::marker::PhantomData; use std::str::FromStr; use bitcoin::transaction::InputWeightPrediction; @@ -228,6 +229,142 @@ impl<'a> From<&'a InputPair> for InternalInputPair<'a> { fn from(pair: &'a InputPair) -> Self { Self { psbtin: &pair.psbtin, txin: &pair.txin } } } +/// Holds a value that requires some form of boolean check. +#[derive(Debug)] +pub struct Reference { + value: V, + index: usize, + /// The final index in the set of [`Reference`]s to be checked + final_index: usize, + _tag: PhantomData, +} + +impl Reference +where + V: Clone, + T: Tag, +{ + fn new(value: V, index: usize, final_index: usize) -> Self { + Reference { value, index, final_index, _tag: PhantomData } + } + + /// Returns a [`TaggedReference`] that has been marked with the result of the boolean + /// check. + pub fn mark(&self, result: bool) -> TaggedReference { + TaggedReference { + value: self.value.clone(), + index: self.index, + final_index: self.final_index, + tag: T::new(result), + } + } + + /// Extracts the value to to be checked + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Holds the result of a checked [`Reference`]. Can only be created with [`Reference::mark`]. +#[derive(Debug)] +pub struct TaggedReference { + value: V, + index: usize, + final_index: usize, + tag: T, +} + +impl TaggedReference +where + V: Clone, + T: Tag, +{ + pub fn get_result(&self) -> bool { self.tag.result() } + pub fn get_index(&self) -> usize { self.index } + pub fn get_value(&self) -> V { self.value.clone() } +} + +/// Trait used to distinguish different types of validation +pub trait Tag { + fn new(result: bool) -> Self; + fn result(&self) -> bool; +} + +#[derive(Debug)] +pub struct InputOwnedTag { + is_owned: bool, +} + +impl Tag for InputOwnedTag { + fn new(result: bool) -> InputOwnedTag { InputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +#[derive(Debug)] +pub struct InputSeenTag { + is_seen: bool, +} + +impl Tag for InputSeenTag { + fn new(result: bool) -> InputSeenTag { InputSeenTag { is_seen: result } } + fn result(&self) -> bool { self.is_seen } +} + +#[derive(Debug)] +pub struct OutputOwnedTag { + is_owned: bool, +} + +impl Tag for OutputOwnedTag { + fn new(result: bool) -> OutputOwnedTag { OutputOwnedTag { is_owned: result } } + fn result(&self) -> bool { self.is_owned } +} + +/// Helper function to run validation callback over a list of [`Reference`]s +pub fn check_references( + references: impl Iterator>, + check: &mut impl FnMut(&V) -> Result, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut checked_references: Vec> = vec![]; + for reference in references { + let result = check(&reference.get_value())?; + checked_references.push(reference.mark(result)); + } + Ok(checked_references.into_iter()) +} + +/// Validate that the [`TaggedReference`]s are in the correct order and are a complete set. +fn validate_checks( + checked_references: impl IntoIterator>, +) -> Result>, ImplementationError> +where + V: Clone, + T: Tag, +{ + let mut current_index = 0; + let mut is_complete = false; + let mut validated_refs: Vec> = vec![]; + for reference in checked_references { + if reference.get_index() != current_index { + return Err(ImplementationError::from( + "Missing reference check at index {current_index}", + )); + } + if reference.get_index() == reference.final_index { + is_complete = true + } else { + current_index += 1; + } + validated_refs.push(reference); + } + if !is_complete { + return Err(ImplementationError::from("Missing reference check at index {current_index}")); + } + Ok(validated_refs.into_iter()) +} + /// Validate the payload of a Payjoin request for PSBT and Params sanity pub(crate) fn parse_payload( base64: &str, @@ -254,7 +391,7 @@ pub struct PsbtContext { impl PsbtContext { /// Prepare the PSBT by creating a new PSBT and copying only the fields allowed by the [spec](https://github.com/bitcoin/bips/blob/master/bip-0078.mediawiki#senders-payjoin-proposal-checklist) - fn prepare_psbt(self, processed_psbt: Psbt) -> Psbt { + fn prepare_psbt(&self, processed_psbt: Psbt) -> Psbt { tracing::trace!("Original PSBT from callback: {processed_psbt:#?}"); // Create a new PSBT and copy only the allowed fields @@ -329,18 +466,11 @@ impl PsbtContext { psbt } - /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before - /// they sign the transaction and broadcast it to the network. + /// Finalizes the signed proposal PSBT /// - /// Finalization consists of two steps: - /// 1. Remove all sender signatures which were received with the original PSBT as these signatures are now invalid. - /// 2. Sign and finalize the resulting PSBT using the passed `wallet_process_psbt` signing function. - fn finalize_proposal( - self, - wallet_process_psbt: impl Fn(&Psbt) -> Result, - ) -> Result { - let psbt = self.psbt_to_sign(); - let signed_psbt = wallet_process_psbt(&psbt)?; + /// Verifies that signed PSBT is for the same transaction as the payjoin proposal PSBT and prepares + /// the signed PSBT to be sent to the sender + fn finalize_signed_proposal(&self, signed_psbt: Psbt) -> Result { let expected_ntxid = self.payjoin_psbt.unsigned_tx.compute_ntxid(); let actual_ntxid = signed_psbt.unsigned_tx.compute_ntxid(); if expected_ntxid != actual_ntxid { @@ -372,6 +502,17 @@ impl OriginalPayload { &self, min_fee_rate: Option, can_broadcast: impl Fn(&bitcoin::Transaction) -> Result, + ) -> Result<(), Error> { + self.apply_broadcast_suitability( + min_fee_rate, + can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate())?, + ) + } + + pub fn apply_broadcast_suitability( + &self, + min_fee_rate: Option, + can_broadcast: bool, ) -> Result<(), Error> { let original_psbt_fee_rate = self.psbt_fee_rate()?; if let Some(min_fee_rate) = min_fee_rate { @@ -383,80 +524,144 @@ impl OriginalPayload { .into()); } } - if can_broadcast(&self.psbt.clone().extract_tx_unchecked_fee_rate()) - .map_err(Error::Implementation)? - { + if can_broadcast { Ok(()) } else { Err(InternalPayloadError::OriginalPsbtNotBroadcastable.into()) } } - /// Check that the original PSBT has no receiver-owned inputs. + /// Check that the original PSBT has no receiver owned inputs. /// /// An attacker can try to spend the receiver's own inputs. This check prevents that. pub fn check_inputs_not_owned( &self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result<(), Error> { - let mut err: Result<(), Error> = Ok(()); - if let Some(e) = self + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + let final_index = self.psbt.input_pairs().count() - 1; + let script_references = self .psbt .input_pairs() - .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => Some(txout.script_pubkey.to_owned()), - Err(e) => { - **err = Err(InternalPayloadError::PrevTxOut(e).into()); - None - } - }) - .find_map(|script| match is_owned(&script) { - Ok(false) => None, - Ok(true) => Some(InternalPayloadError::InputOwned(script).into()), - Err(e) => Some(Error::Implementation(e)), + .enumerate() + .map(|(index, input)| match input.previous_txout() { + Ok(txout) => Ok(Reference::::new( + txout.script_pubkey.to_owned(), + index, + final_index, + )), + Err(e) => Err(InternalPayloadError::PrevTxOut(e)), }) - { - return Err(e); + .collect::>, InternalPayloadError>>()?; + Ok(script_references.into_iter()) + } + + pub fn apply_input_owned_checks( + &self, + checked_input_scripts: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_scripts)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => + Err(InternalPayloadError::InputOwned(checked_input.get_value()).into()), + None => Ok(()), } - err?; - Ok(()) } pub fn check_no_inputs_seen_before( &self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result<(), Error> { - self.psbt.input_pairs().try_for_each(|input| { - match is_known(&input.txin.previous_output) { - Ok(false) => Ok::<(), Error>(()), - Ok(true) => { - tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", input.txin.previous_output); - Err(InternalPayloadError::InputSeen(input.txin.previous_output))? - }, - Err(e) => Err(Error::Implementation(e))?, + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.input_pairs().count() - 1; + let outpoint_references = self + .psbt + .input_pairs() + .enumerate() + .map(|(index, input)| { + Reference::::new( + input.txin.previous_output, + index, + final_index, + ) + }) + .collect::>(); + outpoint_references.into_iter() + } + + pub fn apply_input_seen_checks( + &self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result<(), Error> { + let validated_checks = validate_checks(checked_input_outpoints)?; + match validated_checks.into_iter().find(|checked_input| checked_input.get_result()) { + Some(checked_input) => { + tracing::warn!("Request contains an input we've seen before: {}. Preventing possible probing attack.", checked_input.get_value()); + Err(InternalPayloadError::InputSeen(checked_input.get_value()))? } - })?; - Ok(()) + None => Ok(()), + } } pub fn identify_receiver_outputs( self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - let owned_vouts: Vec = self + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + let final_index = self.psbt.unsigned_tx.output.len() - 1; + let script_references = self .psbt .unsigned_tx .output .iter() .enumerate() - .filter_map(|(vout, txo)| match is_receiver_output(&txo.script_pubkey) { - Ok(true) => Some(Ok(vout)), - Ok(false) => None, - Err(e) => Some(Err(e)), + .map(|(index, output)| { + Reference::::new( + output.script_pubkey.clone(), + index, + final_index, + ) }) - .collect::, _>>() - .map_err(Error::Implementation)?; + .collect::>(); + script_references.into_iter() + } + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + let validated_checks = validate_checks(checked_output_scripts)?; + let owned_vouts = validated_checks + .into_iter() + .filter_map(|checked_output| match checked_output.get_result() { + true => Some(checked_output.get_index()), + false => None, + }) + .collect::>(); if owned_vouts.is_empty() { return Err(InternalPayloadError::MissingPayment.into()); } @@ -486,7 +691,9 @@ pub(crate) mod tests { witness, Amount, PubkeyHash, ScriptBuf, ScriptHash, Sequence, Txid, WScriptHash, XOnlyPublicKey, }; - use payjoin_test_utils::{DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, QUERY_PARAMS}; + use payjoin_test_utils::{ + DUMMY20, DUMMY32, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS, + }; use super::*; use crate::psbt::InternalPsbtInputError::InvalidScriptPubKey; @@ -498,6 +705,24 @@ pub(crate) mod tests { OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params } } + pub(crate) fn original_missing_prevtxout_from_test_vector() -> OriginalPayload { + let params = Params::from_query_str(QUERY_PARAMS, &[Version::One]) + .expect("Could not parse params from query str"); + let mut psbt: Psbt = PARSED_ORIGINAL_PSBT.clone(); + for psbtin in psbt.inputs_mut() { + psbtin.non_witness_utxo = None; + psbtin.witness_utxo = None; + } + OriginalPayload { psbt: psbt.clone(), params } + } + + pub(crate) fn psbt_context_from_test_vector() -> PsbtContext { + PsbtContext { + payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), + original_psbt: PARSED_ORIGINAL_PSBT.clone(), + } + } + #[test] fn input_pair_with_expected_weight() { let p2wsh_txout = TxOut { @@ -832,6 +1057,141 @@ pub(crate) mod tests { assert_eq!(err, PsbtInputError::from(InternalPsbtInputError::ProvidedUnnecessaryWeight)); } + #[test] + fn test_check_broadcast_suitability() { + let original = original_from_test_vector(); + + // Outcome 1: min_fee_rate too high → PsbtBelowFeeRate error + let err = original + .clone() + .check_broadcast_suitability(Some(FeeRate::MAX), |_| Ok(true)) + .expect_err("Should fail when fee rate is below minimum"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PsbtBelowFeeRate(original_fee_rate, min_fee_rate), + ))) => { + assert_eq!(original_fee_rate, original.psbt_fee_rate().unwrap()); + assert_eq!(min_fee_rate, FeeRate::MAX); + } + _ => panic!("Expected PsbtBelowFeeRate error, got: {err:?}"), + } + + // Outcome 2: can_broadcast returns false → OriginalPsbtNotBroadcastable error + let err = original + .clone() + .check_broadcast_suitability(None, |_| Ok(false)) + .expect_err("Should fail when can_broadcast returns false"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::OriginalPsbtNotBroadcastable, + ))) => {} + _ => panic!("Expected OriginalPsbtNotBroadcastable error, got: {err:?}"), + } + + // Outcome 3: can_broadcast returns an implementation error → Error::Implementation + let err = original + .clone() + .check_broadcast_suitability(None, |_| { + Err(ImplementationError::from("broadcast check failed")) + }) + .expect_err("Should fail when can_broadcast returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "broadcast check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: success + original + .check_broadcast_suitability(None, |_| Ok(true)) + .expect("Should succeed when fee rate is acceptable and can_broadcast returns true"); + } + + #[test] + fn test_check_inputs_not_owned() { + let original = original_from_test_vector(); + let original_missing_prevtxout = original_missing_prevtxout_from_test_vector(); + + // Outcome 1: input_scripts returns a PrevTxOut error → Protocol error + let err = original_missing_prevtxout + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect_err("Should fail when previous txout is missing"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::PrevTxOut(_), + ))) => {} + _ => panic!("Expected PrevTxOut error, got: {err:?}"), + } + + // Outcome 2: is_owned returns true → InputOwned error + let err = original + .clone() + .check_inputs_not_owned(&mut |_| Ok(true)) + .expect_err("Should fail when input is owned"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputOwned(_), + ))) => {} + _ => panic!("Expected InputOwned error, got: {err:?}"), + } + + // Outcome 3: is_owned returns an implementation error → Error::Implementation + let err = original + .clone() + .check_inputs_not_owned(&mut |_| { + Err(ImplementationError::from("ownership check failed")) + }) + .expect_err("Should fail when is_owned returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "ownership check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 4: is_owned returns false → success + original + .check_inputs_not_owned(&mut |_| Ok(false)) + .expect("Should succeed when no inputs are owned"); + } + + #[test] + fn test_check_no_inputs_seen_before() { + let original = original_from_test_vector(); + + // Outcome 1: is_known returns true → InputSeen error + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| Ok(true)) + .expect_err("Should fail when input has been seen before"); + match err { + Error::Protocol(ProtocolError::OriginalPayload(PayloadError( + InternalPayloadError::InputSeen(_), + ))) => {} + _ => panic!("Expected InputSeen error, got: {err:?}"), + } + + // Outcome 2: is_known returns an implementation error → Error::Implementation + let err = original + .clone() + .check_no_inputs_seen_before(&mut |_| { + Err(ImplementationError::from("input seen check failed")) + }) + .expect_err("Should fail when is_known returns an implementation error"); + match err { + Error::Implementation(error_message) => { + assert_eq!(error_message.to_string(), "input seen check failed".to_string()) + } + _ => panic!("Expected Error::Implementation, got: {err:?}"), + } + + // Outcome 3: is_known returns false → success + original + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .expect("Should succeed when no inputs have been seen before"); + } + #[test] fn test_identify_receiver_outputs() { let original = original_from_test_vector(); @@ -866,4 +1226,23 @@ pub(crate) mod tests { assert_eq!(wants_outputs.owned_vouts, vec![0, 1]); assert_eq!(wants_outputs.params.additional_fee_contribution, None); } + + #[test] + fn test_finalize_proposal() { + // Outcome 1: wallet_process_psbt returns a psbt with mismatched ntxid → ImplementationError + let psbt_context = psbt_context_from_test_vector(); + let err = psbt_context + .clone() + .finalize_signed_proposal( + // return a totally different psbt to trigger ntxid mismatch + PARSED_ORIGINAL_PSBT.clone(), + ) + .expect_err("Should fail when ntxid mismatches"); + assert!(err.to_string().contains("Ntxid mismatch")); + + // Outcome 2: wallet_process_psbt succeeds → Ok(Psbt) + let _psbt = psbt_context + .finalize_signed_proposal(PARSED_PAYJOIN_PROPOSAL.clone()) + .expect("Should succeed when wallet_process_psbt returns a valid signed psbt"); + } } diff --git a/payjoin/src/core/receive/v1/mod.rs b/payjoin/src/core/receive/v1/mod.rs index 8c659564e..61f739bc8 100644 --- a/payjoin/src/core/receive/v1/mod.rs +++ b/payjoin/src/core/receive/v1/mod.rs @@ -113,6 +113,42 @@ impl UncheckedOriginalPayload { Ok(MaybeInputsOwned { original: self.original }) } + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Self::process_broadcast_suitability_result`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Self::extract_tx_to_check_broadcast_suitability`] first to acquire the tx + /// to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn process_broadcast_suitability_result( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> Result { + self.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable)?; + Ok(MaybeInputsOwned { original: self.original }) + } + /// Moves on to the next typestate without any of the current typestate's validations. /// /// Use this for interactive payment receivers, where there is no risk of a probing attack since the @@ -151,7 +187,35 @@ impl MaybeInputsOwned { self, is_owned: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.check_inputs_not_owned(is_owned)?; + let checked_inputs = + check_references(self.get_input_script_refs()?, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + })?; + self.apply_input_owned_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Self::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_input_owned_checks(checked_input_scripts)?; Ok(MaybeInputsSeen { original: self.original }) } } @@ -176,7 +240,42 @@ impl MaybeInputsSeen { self, is_known: &mut impl FnMut(&OutPoint) -> Result, ) -> Result { - self.original.check_no_inputs_seen_before(is_known)?; + let checked_inputs = check_references(self.get_input_outpoint_refs(), is_known)?; + self.apply_input_seen_checks(checked_inputs) + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Self::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> Result { + self.original.apply_input_seen_checks(checked_input_outpoints)?; Ok(OutputsUnknown { original: self.original }) } } @@ -208,7 +307,49 @@ impl OutputsUnknown { self, is_receiver_output: &mut impl FnMut(&Script) -> Result, ) -> Result { - self.original.identify_receiver_outputs(is_receiver_output) + let checked_outputs = + check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + })?; + self.apply_output_owned_checks(checked_outputs) + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to [`Self::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Self::get_output_script_refs`] to obtain the references that need to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + #[cfg_attr(not(feature = "v1"), allow(dead_code))] + pub fn apply_output_owned_checks( + &self, + checked_output_scripts: impl IntoIterator>, + ) -> Result { + self.original.apply_output_owned_checks(checked_output_scripts) } } @@ -292,11 +433,9 @@ impl ProvisionalProposal { self, wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> Result { - let finalized_psbt = self - .psbt_context - .finalize_proposal(wallet_process_psbt) - .map_err(|e| Error::Implementation(ImplementationError::new(e)))?; - Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt)?; + self.finalize_signed_proposal(&signed_psbt) } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -305,6 +444,17 @@ impl ProvisionalProposal { /// is different from the entity that has access to the private keys, /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.psbt_context.psbt_to_sign() } + + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Self::psbt_to_sign`] to obtain the payjoin proposal's unsigned + /// PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal(self, signed_psbt: &Psbt) -> Result { + let finalized_psbt = self.psbt_context.finalize_signed_proposal(signed_psbt.clone())?; + Ok(PayjoinProposal { payjoin_psbt: finalized_psbt }) + } } /// A finalized Payjoin proposal, complete with fees and receiver signatures, that the sender diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 7ff08f0ee..e123d4c86 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -30,7 +30,7 @@ use std::time::Duration; use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; -use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, TxOut, Txid}; +use bitcoin::{Address, Amount, FeeRate, OutPoint, Script, ScriptBuf, Transaction, TxOut, Txid}; pub(crate) use error::InternalSessionError; pub use error::SessionError; use serde::de::Deserializer; @@ -57,7 +57,10 @@ use crate::persist::{ MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults, MaybeSuccessTransition, MaybeTransientTransition, NextStateTransition, TerminalTransition, }; -use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext}; +use crate::receive::{ + check_references, parse_payload, InputOwnedTag, InputPair, InputSeenTag, OriginalPayload, + OutputOwnedTag, PsbtContext, Reference, TaggedReference, +}; use crate::time::Time; use crate::uri::ShortId; use crate::{ImplementationError, IntoUrl, IntoUrlError, Request, Version}; @@ -640,7 +643,68 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_broadcast_suitability(min_fee_rate, can_broadcast) { + let tx = self.extract_tx_to_check_broadcast_suitability(); + match can_broadcast(&tx) { + Ok(is_broadcast_suitable) => + self.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Moves on to the next typestate without any of the current typestate's validations. + /// + /// Use this for interactive payment receivers, where there is no risk of a probing attack since the + /// receiver needs to manually create payjoin URIs. + pub fn assume_interactive_receiver( + self, + ) -> NextStateTransition> { + NextStateTransition::success( + SessionEvent::CheckedBroadcastSuitability(), + Receiver { + state: MaybeInputsOwned { original: self.original.clone() }, + session_context: self.session_context, + }, + ) + } + + /// Extracts the original PSBT so caller can check that the proposal can be broadcasted. + /// + /// Result of the broadcastibility check should then be returned to + /// [`Receiver::apply_broadcast_suitability`]. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + pub fn extract_tx_to_check_broadcast_suitability(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() + } + + /// Processes the result of whether the original PSBT in the proposal can be broadcasted. + /// + /// Call [`Receiver::extract_tx_to_check_broadcast_suitability`] first to + /// acquire the tx to be checked for broadcastibility. + /// + /// If the receiver is a non-interactive payment processor (ex. a donation page which generates + /// a new QR code for each visit), then it should make sure that the original PSBT is broadcastable + /// as a fallback mechanism in case the payjoin fails. This validation would be equivalent to + /// `testmempoolaccept` Bitcoin Core RPC call returning `{"allowed": true,...}`. + /// + /// Receiver can optionally set a minimum fee rate which will be enforced on the original PSBT in the proposal. + /// This can be used to further prevent probing attacks since the attacker would now need to probe the receiver + /// with transactions which are both broadcastable and pay high fee. Unrelated to the probing attack scenario, + /// this parameter also makes operating in a high fee environment easier for the receiver. + pub fn apply_broadcast_suitability( + self, + min_fee_rate: Option, + is_broadcast_suitable: bool, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_broadcast_suitability(min_fee_rate, is_broadcast_suitable) { Ok(()) => MaybeFatalTransition::success( SessionEvent::CheckedBroadcastSuitability(), Receiver { @@ -661,22 +725,6 @@ impl Receiver { } } - /// Moves on to the next typestate without any of the current typestate's validations. - /// - /// Use this for interactive payment receivers, where there is no risk of a probing attack since the - /// receiver needs to manually create payjoin URIs. - pub fn assume_interactive_receiver( - self, - ) -> NextStateTransition> { - NextStateTransition::success( - SessionEvent::CheckedBroadcastSuitability(), - Receiver { - state: MaybeInputsOwned { original: self.original.clone() }, - session_context: self.session_context, - }, - ) - } - pub(crate) fn apply_checked_broadcast_suitability(self) -> ReceiveSession { let new_state = Receiver { state: MaybeInputsOwned { original: self.original.clone() }, @@ -720,31 +768,74 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_inputs_not_owned(is_owned) { - Ok(inner) => inner, + match self.get_input_script_refs() { + Ok(input_scripts) => match check_references(input_scripts, &mut |script: &ScriptBuf| { + is_owned(script.as_script()) + }) { + Ok(checked_input_scripts) => self.apply_input_owned_checks(checked_input_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + }, Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedInputsNotOwned(), - Receiver { - state: MaybeInputsSeen { original: self.original.clone() }, - session_context: self.session_context, + } + } + + /// Get [`Reference`]s that hold the input scripts that need to be checked for ownership by the + /// reciever. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_owned_checks`]. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn get_input_script_refs( + &self, + ) -> Result>, Error> { + self.state.original.get_input_script_refs() + } + + /// Applies the input ownership checks to advance the state machine. + /// + /// Use [`Receiver::get_input_script_refs`] to obtain the references that need to be checked. + /// + /// An attacker can try to spend the receiver's own inputs. This check prevents that. + pub fn apply_input_owned_checks( + self, + checked_input_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_owned_checks(checked_input_scripts) { + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedInputsNotOwned(), + Receiver { + state: MaybeInputsSeen { original: self.original.clone() }, + session_context: self.session_context, + }, + ), + Err(e) => match e { + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_inputs_not_owned(self) -> ReceiveSession { @@ -782,31 +873,69 @@ impl Receiver { Error, Receiver, > { - match self.state.original.check_no_inputs_seen_before(is_known) { - Ok(inner) => inner, + match check_references(self.get_input_outpoint_refs(), is_known) { + Ok(checked_input_outpoints) => self.apply_input_seen_checks(checked_input_outpoints), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the input outpoints that need to be checked for whether they + /// have already been seen by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_input_seen_checks`]. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn get_input_outpoint_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_input_outpoint_refs() + } + + /// Applies the input seen checks to advance the state machine. + /// + /// Use [`Receiver::get_input_outpoint_refs`] to obtain the references that need to be checked. + /// + /// This check prevents the following attacks: + /// 1. Probing attacks, where the sender can use the exact same proposal (or with minimal change) + /// to have the receiver reveal their UTXO set by contributing to all proposals with different inputs + /// and sending them back to the receiver. + /// 2. Re-entrant payjoin, where the sender uses the payjoin PSBT of a previous payjoin as the + /// original proposal PSBT of the current, new payjoin. + pub fn apply_input_seen_checks( + self, + checked_input_outpoints: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_input_seen_checks(checked_input_outpoints) { + Ok(()) => MaybeFatalTransition::success( + SessionEvent::CheckedNoInputsSeenBefore(), + Receiver { + state: OutputsUnknown { original: self.original.clone() }, + session_context: self.session_context, + }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } - }, - }; - MaybeFatalTransition::success( - SessionEvent::CheckedNoInputsSeenBefore(), - Receiver { - state: OutputsUnknown { original: self.original.clone() }, - session_context: self.session_context, + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - ) + } } pub(crate) fn apply_checked_no_inputs_seen_before(self) -> ReceiveSession { @@ -849,28 +978,71 @@ impl Receiver { Error, Receiver, > { - let inner = match self.state.original.identify_receiver_outputs(is_receiver_output) { - Ok(inner) => inner, + match check_references(self.get_output_script_refs(), &mut |script: &ScriptBuf| { + is_receiver_output(script.as_script()) + }) { + Ok(checked_output_scripts) => self.apply_output_owned_checks(checked_output_scripts), + Err(e) => MaybeFatalTransition::transient(e.into()), + } + } + + /// Get [`Reference`]s that hold the output scripts that need to be checked for ownership + /// by the receiver. + /// + /// Once completed, these checks should be submitted to + /// [`Receiver::apply_output_owned_checks`]. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn get_output_script_refs( + &self, + ) -> impl Iterator> { + self.state.original.get_output_script_refs() + } + + /// Applies the output owned checks to advance the state machine. + /// + /// Use [`Receiver::get_output_script_refs`] to obtain the references that need + /// to be checked. + /// + /// Additionally, this function also protects the receiver from accidentally subtracting fees + /// from their own outputs: when a sender is sending a proposal, + /// they can select an output which they want the receiver to subtract fees from to account for + /// the increased transaction size. If a sender specifies a receiver output for this purpose, this + /// function sets that parameter to None so that it is ignored in subsequent steps of the + /// receiver flow. This protects the receiver from accidentally subtracting fees from their own + /// outputs. + pub fn apply_output_owned_checks( + self, + checked_output_scripts: impl IntoIterator>, + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + Error, + Receiver, + > { + match self.state.original.apply_output_owned_checks(checked_output_scripts) { + Ok(inner) => MaybeFatalTransition::success( + SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), + Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, + ), Err(e) => match e { - Error::Implementation(_) => { - return MaybeFatalTransition::transient(e); - } - _ => { - return MaybeFatalTransition::replyable_error( - SessionEvent::GotReplyableError((&e).into()), - Receiver { - state: HasReplyableError { error_reply: (&e).into() }, - session_context: self.session_context, - }, - e, - ); - } + Error::Implementation(_) => MaybeFatalTransition::transient(e), + _ => MaybeFatalTransition::replyable_error( + SessionEvent::GotReplyableError((&e).into()), + Receiver { + state: HasReplyableError { error_reply: (&e).into() }, + session_context: self.session_context, + }, + e, + ), }, - }; - MaybeFatalTransition::success( - SessionEvent::IdentifiedReceiverOutputs(inner.owned_vouts.clone()), - Receiver { state: WantsOutputs { inner }, session_context: self.session_context }, - ) + } } pub(crate) fn apply_identified_receiver_outputs( @@ -1054,23 +1226,20 @@ impl Receiver { ) -> MaybeFatalTransition, ProtocolError> { let max_effective_fee_rate = max_effective_fee_rate.or(Some(self.session_context.max_fee_rate)); - let psbt_context = match self + match self .state .inner .calculate_psbt_context_with_fee_range(min_fee_rate, max_effective_fee_rate) { - Ok(inner) => inner, - Err(e) => { - return MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())); - } - }; - MaybeFatalTransition::success( - SessionEvent::AppliedFeeRange(psbt_context.clone()), - Receiver { - state: ProvisionalProposal { psbt_context }, - session_context: self.session_context, - }, - ) + Ok(psbt_context) => MaybeFatalTransition::success( + SessionEvent::AppliedFeeRange(psbt_context.clone()), + Receiver { + state: ProvisionalProposal { psbt_context }, + session_context: self.session_context, + }, + ), + Err(e) => MaybeFatalTransition::transient(ProtocolError::OriginalPayload(e.into())), + } } pub(crate) fn apply_applied_fee_range(self, psbt_context: PsbtContext) -> ReceiveSession { @@ -1104,19 +1273,12 @@ impl Receiver { wallet_process_psbt: impl Fn(&Psbt) -> Result, ) -> MaybeTransientTransition, ImplementationError> { - let original_psbt = self.state.psbt_context.original_psbt.clone(); - let inner = match self.state.psbt_context.finalize_proposal(wallet_process_psbt) { - Ok(inner) => inner, - Err(e) => { - return MaybeTransientTransition::transient(e); - } - }; - let psbt_context = PsbtContext { payjoin_psbt: inner.clone(), original_psbt }; - let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; - MaybeTransientTransition::success( - SessionEvent::FinalizedProposal(inner), - Receiver { state: payjoin_proposal, session_context: self.session_context }, - ) + let psbt = self.psbt_to_sign(); + let signed_psbt = wallet_process_psbt(&psbt); + match signed_psbt { + Ok(signed_psbt) => self.finalize_signed_proposal(&signed_psbt), + Err(e) => MaybeTransientTransition::transient(e), + } } /// The Payjoin proposal PSBT that the receiver needs to sign @@ -1126,6 +1288,33 @@ impl Receiver { /// so the PSBT to sign must be accessible to such implementers. pub fn psbt_to_sign(&self) -> Psbt { self.state.psbt_context.psbt_to_sign() } + /// Finalizes the Payjoin proposal into a PSBT which the sender will find acceptable before + /// they sign the transaction and broadcast it to the network. + /// + /// This takes a receiver signed PSBT payjoin proposal and finalizes it for broadcast to + /// the sender. Use [`Receiver::psbt_to_sign`] to obtain the payjoin + /// proposal's unsigned PSBT for receiver to sign and return here. + pub fn finalize_signed_proposal( + self, + signed_psbt: &Psbt, + ) -> MaybeTransientTransition, ImplementationError> + { + let original_psbt = self.state.psbt_context.original_psbt.clone(); + let payjoin_psbt = + match self.state.psbt_context.finalize_signed_proposal(signed_psbt.clone()) { + Ok(payjoin_psbt) => payjoin_psbt, + Err(e) => { + return MaybeTransientTransition::transient(e); + } + }; + let psbt_context = PsbtContext { payjoin_psbt: payjoin_psbt.clone(), original_psbt }; + let payjoin_proposal = PayjoinProposal { psbt_context: psbt_context.clone() }; + MaybeTransientTransition::success( + SessionEvent::FinalizedProposal(payjoin_psbt), + Receiver { state: payjoin_proposal, session_context: self.session_context }, + ) + } + pub(crate) fn apply_payjoin_proposal(self, payjoin_psbt: Psbt) -> ReceiveSession { let psbt_context = PsbtContext { payjoin_psbt, @@ -1342,7 +1531,7 @@ impl Receiver { // If the fallback transaction included any non-SegWit inputs, then the transaction ID of // the Payjoin proposal is going to change when the sender signs their non-SegWit address // one more time. The receiver cannot monitor the transaction, and should conclude the session. - if fallback_tx.input.iter().any(|txin| txin.witness.is_empty()) { + if has_empty_witness(&fallback_tx) { return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( SessionOutcome::PayjoinProposalSent, )); @@ -1361,18 +1550,7 @@ impl Receiver { ImplementationError::from(format!("Payjoin transaction ID mismatch. Expected: {payjoin_txid}, Got: {tx_id}").as_str()), )); } - // TODO: should we check for witness and scriptsig on the tx? - let mut sender_witnesses = vec![]; - - for i in self.state.psbt_context.sender_input_indexes() { - let input = - tx.input.get(i).expect("sender_input_indexes should return valid indices"); - sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); - } - // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::Success(sender_witnesses), - )); + return self.payjoin_tx_exists(tx); } Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), @@ -1381,16 +1559,72 @@ impl Receiver { // If the Payjoin proposal was not found, check the fallback transaction, as it is // the second of two transactions whose IDs the receiver is aware of. match transaction_exists(fallback_tx.compute_txid()) { - Ok(Some(_)) => - return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( - SessionOutcome::FallbackBroadcasted, - )), + Ok(Some(_)) => return self.fallback_tx_exists(), Ok(None) => {} Err(e) => return MaybeFatalOrSuccessTransition::transient(Error::Implementation(e)), } MaybeFatalOrSuccessTransition::no_results(self.clone()) } + + pub fn extract_fallback_txid(&self) -> Txid { + self.state.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate().compute_txid() + } + + pub fn extract_payjoin_proposal_txid(&self) -> Txid { + self.state.psbt_context.payjoin_psbt.clone().extract_tx_unchecked_fee_rate().compute_txid() + } + + pub fn check_fallback_monitorable( + &self, + ) -> MaybeFatalOrSuccessTransition { + let fallback_tx = self + .state + .psbt_context + .original_psbt + .clone() + .extract_tx_fee_rate_limit() + .expect("fallback transaction should be in the receiver context"); + + // If the fallback transaction included any non-SegWit inputs, then the transaction ID of + // the Payjoin proposal is going to change when the sender signs their non-SegWit address + // one more time. The receiver cannot monitor the transaction, and should conclude the session. + if has_empty_witness(&fallback_tx) { + return MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::PayjoinProposalSent, + )); + } + + MaybeFatalOrSuccessTransition::no_results(self.clone()) + } + + pub fn fallback_tx_exists(&self) -> MaybeFatalOrSuccessTransition { + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed( + SessionOutcome::FallbackBroadcasted, + )) + } + + pub fn payjoin_tx_exists( + &self, + payjoin_tx: Transaction, + ) -> MaybeFatalOrSuccessTransition { + // TODO: should we check for witness and scriptsig on the tx? + let mut sender_witnesses = vec![]; + + for i in self.state.psbt_context.sender_input_indexes() { + let input = + payjoin_tx.input.get(i).expect("sender_input_indexes should return valid indices"); + sender_witnesses.push((input.script_sig.clone(), input.witness.clone())); + } + // Payjoin transaction with SegWit inputs was detected. Log the signatures and complete the session. + MaybeFatalOrSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Success( + sender_witnesses, + ))) + } +} + +fn has_empty_witness(tx: &Transaction) -> bool { + tx.input.iter().any(|txin| txin.witness.is_empty()) } /// Derive a mailbox endpoint on a directory given a [`ShortId`]. @@ -1601,20 +1835,22 @@ pub mod test { Ok(ret) } - let maybe_inputs_seen = - receiver.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)); + let maybe_inputs_seen = receiver + .check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false)) + .save(&persister) + .expect("Persister shouldn't fail"); assert_eq!(call_count, 1); let outputs_unknown = maybe_inputs_seen - .save(&persister) - .expect("Persister shouldn't fail") .check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false)) .save(&persister) .expect("Persister shouldn't fail"); assert_eq!(call_count, 2); let _wants_outputs = outputs_unknown - .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)); + .identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true)) + .save(&persister) + .expect("Persister shouldn't fail"); // there are 2 receiver outputs so we should expect this callback to run twice incrementing // call count twice assert_eq!(call_count, 4);