@@ -36,6 +36,7 @@ use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid};
36
36
use crate :: sync:: Arc ;
37
37
use core:: future:: Future ;
38
38
use core:: ops:: Deref ;
39
+ use core:: sync:: atomic:: { AtomicBool , Ordering } ;
39
40
use core:: task;
40
41
41
42
use super :: async_poll:: dummy_waker;
@@ -350,7 +351,8 @@ where
350
351
L :: Target : Logger ,
351
352
O :: Target : OutputSpender ,
352
353
{
353
- sweeper_state : Mutex < RuntimeSweeperState > ,
354
+ sweeper_state : Mutex < SweeperState > ,
355
+ pending_sweep : AtomicBool ,
354
356
broadcaster : B ,
355
357
fee_estimator : E ,
356
358
chain_data_source : Option < F > ,
@@ -380,12 +382,10 @@ where
380
382
output_spender : O , change_destination_source : D , kv_store : K , logger : L ,
381
383
) -> Self {
382
384
let outputs = Vec :: new ( ) ;
383
- let sweeper_state = Mutex :: new ( RuntimeSweeperState {
384
- persistent : SweeperState { outputs, best_block } ,
385
- sweep_pending : false ,
386
- } ) ;
385
+ let sweeper_state = Mutex :: new ( SweeperState { outputs, best_block } ) ;
387
386
Self {
388
387
sweeper_state,
388
+ pending_sweep : AtomicBool :: new ( false ) ,
389
389
broadcaster,
390
390
fee_estimator,
391
391
chain_data_source,
@@ -427,7 +427,7 @@ where
427
427
return Ok ( ( ) ) ;
428
428
}
429
429
430
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
430
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
431
431
for descriptor in relevant_descriptors {
432
432
let output_info = TrackedSpendableOutput {
433
433
descriptor,
@@ -444,20 +444,20 @@ where
444
444
445
445
state_lock. outputs . push ( output_info) ;
446
446
}
447
- self . persist_state ( & state_lock) . map_err ( |e| {
447
+ self . persist_state ( & * state_lock) . map_err ( |e| {
448
448
log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
449
449
} )
450
450
}
451
451
452
452
/// Returns a list of the currently tracked spendable outputs.
453
453
pub fn tracked_spendable_outputs ( & self ) -> Vec < TrackedSpendableOutput > {
454
- self . sweeper_state . lock ( ) . unwrap ( ) . persistent . outputs . clone ( )
454
+ self . sweeper_state . lock ( ) . unwrap ( ) . outputs . clone ( )
455
455
}
456
456
457
457
/// Gets the latest best block which was connected either via the [`Listen`] or
458
458
/// [`Confirm`] interfaces.
459
459
pub fn current_best_block ( & self ) -> BestBlock {
460
- self . sweeper_state . lock ( ) . unwrap ( ) . persistent . best_block
460
+ self . sweeper_state . lock ( ) . unwrap ( ) . best_block
461
461
}
462
462
463
463
/// Regenerates and broadcasts the spending transaction for any outputs that are pending
@@ -481,24 +481,29 @@ where
481
481
true
482
482
} ;
483
483
484
+ // Prevent concurrent sweeps.
485
+ if self . pending_sweep . load ( Ordering :: Relaxed ) {
486
+ return Ok ( ( ) ) ;
487
+ }
488
+
484
489
// See if there is anything to sweep before requesting a change address.
485
490
{
486
- let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
491
+ let sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
487
492
488
- // Prevent concurrent sweeping.
489
- if sweeper_state. sweep_pending {
490
- return Ok ( ( ) ) ;
491
- }
492
-
493
- let cur_height = sweeper_state. persistent . best_block . height ;
494
- let has_respends =
495
- sweeper_state. persistent . outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
493
+ let cur_height = sweeper_state. best_block . height ;
494
+ let has_respends = sweeper_state. outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
496
495
if !has_respends {
497
496
return Ok ( ( ) ) ;
498
497
}
498
+ }
499
499
500
- // There is something to sweep. Block concurrent sweeps.
501
- sweeper_state. sweep_pending = true ;
500
+ // Mark sweep pending, if no other thread did so already.
501
+ if self
502
+ . pending_sweep
503
+ . compare_exchange ( false , true , Ordering :: Acquire , Ordering :: Relaxed )
504
+ . is_err ( )
505
+ {
506
+ return Ok ( ( ) ) ;
502
507
}
503
508
504
509
// Request a new change address outside of the mutex to avoid the mutex crossing await.
@@ -509,10 +514,7 @@ where
509
514
{
510
515
let mut runtime_sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
511
516
512
- // Always allow a new sweep after this spend, also in the error case.
513
- runtime_sweeper_state. sweep_pending = false ;
514
-
515
- let sweeper_state = & mut runtime_sweeper_state. persistent ;
517
+ let sweeper_state = & mut runtime_sweeper_state;
516
518
517
519
let change_destination_script = change_destination_script_result?;
518
520
@@ -527,6 +529,8 @@ where
527
529
. collect ( ) ;
528
530
529
531
if respend_descriptors. is_empty ( ) {
532
+ self . pending_sweep . store ( false , Ordering :: Release ) ;
533
+
530
534
// It could be that a tx confirmed and there is now nothing to sweep anymore.
531
535
return Ok ( ( ) ) ;
532
536
}
@@ -545,6 +549,8 @@ where
545
549
spending_tx
546
550
} ,
547
551
Err ( e) => {
552
+ self . pending_sweep . store ( false , Ordering :: Release ) ;
553
+
548
554
log_error ! ( self . logger, "Error spending outputs: {:?}" , e) ;
549
555
return Ok ( ( ) ) ;
550
556
} ,
@@ -570,6 +576,8 @@ where
570
576
self . broadcaster . broadcast_transactions ( & [ & spending_tx] ) ;
571
577
}
572
578
579
+ self . pending_sweep . store ( false , Ordering :: Release ) ;
580
+
573
581
Ok ( ( ) )
574
582
}
575
583
@@ -668,22 +676,22 @@ where
668
676
fn filtered_block_connected (
669
677
& self , header : & Header , txdata : & chain:: transaction:: TransactionData , height : u32 ,
670
678
) {
671
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
679
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
672
680
assert_eq ! ( state_lock. best_block. block_hash, header. prev_blockhash,
673
681
"Blocks must be connected in chain-order - the connected header must build on the last connected header" ) ;
674
682
assert_eq ! ( state_lock. best_block. height, height - 1 ,
675
683
"Blocks must be connected in chain-order - the connected block height must be one greater than the previous height" ) ;
676
684
677
- self . transactions_confirmed_internal ( state_lock, header, txdata, height) ;
678
- self . best_block_updated_internal ( state_lock, header, height) ;
685
+ self . transactions_confirmed_internal ( & mut * state_lock, header, txdata, height) ;
686
+ self . best_block_updated_internal ( & mut * state_lock, header, height) ;
679
687
680
- let _ = self . persist_state ( & state_lock) . map_err ( |e| {
688
+ let _ = self . persist_state ( & * state_lock) . map_err ( |e| {
681
689
log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
682
690
} ) ;
683
691
}
684
692
685
693
fn block_disconnected ( & self , header : & Header , height : u32 ) {
686
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
694
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
687
695
688
696
let new_height = height - 1 ;
689
697
let block_hash = header. block_hash ( ) ;
@@ -721,15 +729,15 @@ where
721
729
fn transactions_confirmed (
722
730
& self , header : & Header , txdata : & chain:: transaction:: TransactionData , height : u32 ,
723
731
) {
724
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
725
- self . transactions_confirmed_internal ( state_lock, header, txdata, height) ;
726
- self . persist_state ( state_lock) . unwrap_or_else ( |e| {
732
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
733
+ self . transactions_confirmed_internal ( & mut * state_lock, header, txdata, height) ;
734
+ self . persist_state ( & * state_lock) . unwrap_or_else ( |e| {
727
735
log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
728
736
} ) ;
729
737
}
730
738
731
739
fn transaction_unconfirmed ( & self , txid : & Txid ) {
732
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
740
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
733
741
734
742
// Get what height was unconfirmed.
735
743
let unconf_height = state_lock
@@ -746,22 +754,22 @@ where
746
754
. filter ( |o| o. status . confirmation_height ( ) >= Some ( unconf_height) )
747
755
. for_each ( |o| o. status . unconfirmed ( ) ) ;
748
756
749
- self . persist_state ( state_lock) . unwrap_or_else ( |e| {
757
+ self . persist_state ( & * state_lock) . unwrap_or_else ( |e| {
750
758
log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
751
759
} ) ;
752
760
}
753
761
}
754
762
755
763
fn best_block_updated ( & self , header : & Header , height : u32 ) {
756
- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
757
- self . best_block_updated_internal ( state_lock, header, height) ;
758
- let _ = self . persist_state ( state_lock) . map_err ( |e| {
764
+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
765
+ self . best_block_updated_internal ( & mut * state_lock, header, height) ;
766
+ let _ = self . persist_state ( & * state_lock) . map_err ( |e| {
759
767
log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
760
768
} ) ;
761
769
}
762
770
763
771
fn get_relevant_txids ( & self ) -> Vec < ( Txid , u32 , Option < BlockHash > ) > {
764
- let state_lock = & self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
772
+ let state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
765
773
state_lock
766
774
. outputs
767
775
. iter ( )
@@ -782,11 +790,6 @@ where
782
790
}
783
791
}
784
792
785
- struct RuntimeSweeperState {
786
- persistent : SweeperState ,
787
- sweep_pending : bool ,
788
- }
789
-
790
793
#[ derive( Debug , Clone ) ]
791
794
struct SweeperState {
792
795
outputs : Vec < TrackedSpendableOutput > ,
@@ -849,10 +852,10 @@ where
849
852
}
850
853
}
851
854
852
- let sweeper_state =
853
- Mutex :: new ( RuntimeSweeperState { persistent : state, sweep_pending : false } ) ;
855
+ let sweeper_state = Mutex :: new ( state) ;
854
856
Ok ( Self {
855
857
sweeper_state,
858
+ pending_sweep : AtomicBool :: new ( false ) ,
856
859
broadcaster,
857
860
fee_estimator,
858
861
chain_data_source,
@@ -898,12 +901,12 @@ where
898
901
}
899
902
}
900
903
901
- let sweeper_state =
902
- Mutex :: new ( RuntimeSweeperState { persistent : state, sweep_pending : false } ) ;
904
+ let sweeper_state = Mutex :: new ( state) ;
903
905
Ok ( (
904
906
best_block,
905
907
OutputSweeper {
906
908
sweeper_state,
909
+ pending_sweep : AtomicBool :: new ( false ) ,
907
910
broadcaster,
908
911
fee_estimator,
909
912
chain_data_source,
0 commit comments