Skip to content

Commit 17a5449

Browse files
committed
try-out: atomicbool
1 parent 87e4bc6 commit 17a5449

File tree

1 file changed

+50
-47
lines changed

1 file changed

+50
-47
lines changed

lightning/src/util/sweep.rs

+50-47
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid};
3636
use crate::sync::Arc;
3737
use core::future::Future;
3838
use core::ops::Deref;
39+
use core::sync::atomic::{AtomicBool, Ordering};
3940
use core::task;
4041

4142
use super::async_poll::dummy_waker;
@@ -350,7 +351,8 @@ where
350351
L::Target: Logger,
351352
O::Target: OutputSpender,
352353
{
353-
sweeper_state: Mutex<RuntimeSweeperState>,
354+
sweeper_state: Mutex<SweeperState>,
355+
pending_sweep: AtomicBool,
354356
broadcaster: B,
355357
fee_estimator: E,
356358
chain_data_source: Option<F>,
@@ -380,12 +382,10 @@ where
380382
output_spender: O, change_destination_source: D, kv_store: K, logger: L,
381383
) -> Self {
382384
let outputs = Vec::new();
383-
let sweeper_state = Mutex::new(RuntimeSweeperState {
384-
persistent: SweeperState { outputs, best_block },
385-
sweep_pending: false,
386-
});
385+
let sweeper_state = Mutex::new(SweeperState { outputs, best_block });
387386
Self {
388387
sweeper_state,
388+
pending_sweep: AtomicBool::new(false),
389389
broadcaster,
390390
fee_estimator,
391391
chain_data_source,
@@ -427,7 +427,7 @@ where
427427
return Ok(());
428428
}
429429

430-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
430+
let mut state_lock = self.sweeper_state.lock().unwrap();
431431
for descriptor in relevant_descriptors {
432432
let output_info = TrackedSpendableOutput {
433433
descriptor,
@@ -444,20 +444,20 @@ where
444444

445445
state_lock.outputs.push(output_info);
446446
}
447-
self.persist_state(&state_lock).map_err(|e| {
447+
self.persist_state(&*state_lock).map_err(|e| {
448448
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
449449
})
450450
}
451451

452452
/// Returns a list of the currently tracked spendable outputs.
453453
pub fn tracked_spendable_outputs(&self) -> Vec<TrackedSpendableOutput> {
454-
self.sweeper_state.lock().unwrap().persistent.outputs.clone()
454+
self.sweeper_state.lock().unwrap().outputs.clone()
455455
}
456456

457457
/// Gets the latest best block which was connected either via the [`Listen`] or
458458
/// [`Confirm`] interfaces.
459459
pub fn current_best_block(&self) -> BestBlock {
460-
self.sweeper_state.lock().unwrap().persistent.best_block
460+
self.sweeper_state.lock().unwrap().best_block
461461
}
462462

463463
/// Regenerates and broadcasts the spending transaction for any outputs that are pending
@@ -481,24 +481,29 @@ where
481481
true
482482
};
483483

484+
// Prevent concurrent sweeps.
485+
if self.pending_sweep.load(Ordering::Relaxed) {
486+
return Ok(());
487+
}
488+
484489
// See if there is anything to sweep before requesting a change address.
485490
{
486-
let mut sweeper_state = self.sweeper_state.lock().unwrap();
491+
let sweeper_state = self.sweeper_state.lock().unwrap();
487492

488-
// Prevent concurrent sweeping.
489-
if sweeper_state.sweep_pending {
490-
return Ok(());
491-
}
492-
493-
let cur_height = sweeper_state.persistent.best_block.height;
494-
let has_respends =
495-
sweeper_state.persistent.outputs.iter().any(|o| filter_fn(o, cur_height));
493+
let cur_height = sweeper_state.best_block.height;
494+
let has_respends = sweeper_state.outputs.iter().any(|o| filter_fn(o, cur_height));
496495
if !has_respends {
497496
return Ok(());
498497
}
498+
}
499499

500-
// There is something to sweep. Block concurrent sweeps.
501-
sweeper_state.sweep_pending = true;
500+
// Mark sweep pending, if no other thread did so already.
501+
if self
502+
.pending_sweep
503+
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
504+
.is_err()
505+
{
506+
return Ok(());
502507
}
503508

504509
// Request a new change address outside of the mutex to avoid the mutex crossing await.
@@ -509,10 +514,7 @@ where
509514
{
510515
let mut runtime_sweeper_state = self.sweeper_state.lock().unwrap();
511516

512-
// Always allow a new sweep after this spend, also in the error case.
513-
runtime_sweeper_state.sweep_pending = false;
514-
515-
let sweeper_state = &mut runtime_sweeper_state.persistent;
517+
let sweeper_state = &mut runtime_sweeper_state;
516518

517519
let change_destination_script = change_destination_script_result?;
518520

@@ -527,6 +529,8 @@ where
527529
.collect();
528530

529531
if respend_descriptors.is_empty() {
532+
self.pending_sweep.store(false, Ordering::Release);
533+
530534
// It could be that a tx confirmed and there is now nothing to sweep anymore.
531535
return Ok(());
532536
}
@@ -545,6 +549,8 @@ where
545549
spending_tx
546550
},
547551
Err(e) => {
552+
self.pending_sweep.store(false, Ordering::Release);
553+
548554
log_error!(self.logger, "Error spending outputs: {:?}", e);
549555
return Ok(());
550556
},
@@ -570,6 +576,8 @@ where
570576
self.broadcaster.broadcast_transactions(&[&spending_tx]);
571577
}
572578

579+
self.pending_sweep.store(false, Ordering::Release);
580+
573581
Ok(())
574582
}
575583

@@ -668,22 +676,22 @@ where
668676
fn filtered_block_connected(
669677
&self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32,
670678
) {
671-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
679+
let mut state_lock = self.sweeper_state.lock().unwrap();
672680
assert_eq!(state_lock.best_block.block_hash, header.prev_blockhash,
673681
"Blocks must be connected in chain-order - the connected header must build on the last connected header");
674682
assert_eq!(state_lock.best_block.height, height - 1,
675683
"Blocks must be connected in chain-order - the connected block height must be one greater than the previous height");
676684

677-
self.transactions_confirmed_internal(state_lock, header, txdata, height);
678-
self.best_block_updated_internal(state_lock, header, height);
685+
self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
686+
self.best_block_updated_internal(&mut *state_lock, header, height);
679687

680-
let _ = self.persist_state(&state_lock).map_err(|e| {
688+
let _ = self.persist_state(&*state_lock).map_err(|e| {
681689
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
682690
});
683691
}
684692

685693
fn block_disconnected(&self, header: &Header, height: u32) {
686-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
694+
let mut state_lock = self.sweeper_state.lock().unwrap();
687695

688696
let new_height = height - 1;
689697
let block_hash = header.block_hash();
@@ -721,15 +729,15 @@ where
721729
fn transactions_confirmed(
722730
&self, header: &Header, txdata: &chain::transaction::TransactionData, height: u32,
723731
) {
724-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
725-
self.transactions_confirmed_internal(state_lock, header, txdata, height);
726-
self.persist_state(state_lock).unwrap_or_else(|e| {
732+
let mut state_lock = self.sweeper_state.lock().unwrap();
733+
self.transactions_confirmed_internal(&mut *state_lock, header, txdata, height);
734+
self.persist_state(&*state_lock).unwrap_or_else(|e| {
727735
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
728736
});
729737
}
730738

731739
fn transaction_unconfirmed(&self, txid: &Txid) {
732-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
740+
let mut state_lock = self.sweeper_state.lock().unwrap();
733741

734742
// Get what height was unconfirmed.
735743
let unconf_height = state_lock
@@ -746,22 +754,22 @@ where
746754
.filter(|o| o.status.confirmation_height() >= Some(unconf_height))
747755
.for_each(|o| o.status.unconfirmed());
748756

749-
self.persist_state(state_lock).unwrap_or_else(|e| {
757+
self.persist_state(&*state_lock).unwrap_or_else(|e| {
750758
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
751759
});
752760
}
753761
}
754762

755763
fn best_block_updated(&self, header: &Header, height: u32) {
756-
let state_lock = &mut self.sweeper_state.lock().unwrap().persistent;
757-
self.best_block_updated_internal(state_lock, header, height);
758-
let _ = self.persist_state(state_lock).map_err(|e| {
764+
let mut state_lock = self.sweeper_state.lock().unwrap();
765+
self.best_block_updated_internal(&mut *state_lock, header, height);
766+
let _ = self.persist_state(&*state_lock).map_err(|e| {
759767
log_error!(self.logger, "Error persisting OutputSweeper: {:?}", e);
760768
});
761769
}
762770

763771
fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option<BlockHash>)> {
764-
let state_lock = &self.sweeper_state.lock().unwrap().persistent;
772+
let state_lock = self.sweeper_state.lock().unwrap();
765773
state_lock
766774
.outputs
767775
.iter()
@@ -782,11 +790,6 @@ where
782790
}
783791
}
784792

785-
struct RuntimeSweeperState {
786-
persistent: SweeperState,
787-
sweep_pending: bool,
788-
}
789-
790793
#[derive(Debug, Clone)]
791794
struct SweeperState {
792795
outputs: Vec<TrackedSpendableOutput>,
@@ -849,10 +852,10 @@ where
849852
}
850853
}
851854

852-
let sweeper_state =
853-
Mutex::new(RuntimeSweeperState { persistent: state, sweep_pending: false });
855+
let sweeper_state = Mutex::new(state);
854856
Ok(Self {
855857
sweeper_state,
858+
pending_sweep: AtomicBool::new(false),
856859
broadcaster,
857860
fee_estimator,
858861
chain_data_source,
@@ -898,12 +901,12 @@ where
898901
}
899902
}
900903

901-
let sweeper_state =
902-
Mutex::new(RuntimeSweeperState { persistent: state, sweep_pending: false });
904+
let sweeper_state = Mutex::new(state);
903905
Ok((
904906
best_block,
905907
OutputSweeper {
906908
sweeper_state,
909+
pending_sweep: AtomicBool::new(false),
907910
broadcaster,
908911
fee_estimator,
909912
chain_data_source,

0 commit comments

Comments
 (0)