Skip to content

Commit 6a30136

Browse files
committed
fix: unaggregated fees update race condition with tokio mutex guards
Signed-off-by: Alexis Asseman <[email protected]>
1 parent 6cfa0a6 commit 6a30136

File tree

3 files changed

+106
-60
lines changed

3 files changed

+106
-60
lines changed

tap-agent/src/tap/sender_account.rs

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
// Copyright 2023-, GraphOps and Semiotic Labs.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use std::sync::Mutex as StdMutex;
45
use std::{
56
cmp::max,
67
collections::{HashMap, HashSet},
7-
sync::{Arc, Mutex},
8+
sync::Arc,
89
time::Duration,
910
};
1011

@@ -15,6 +16,7 @@ use eventuals::Eventual;
1516
use indexer_common::{escrow_accounts::EscrowAccounts, prelude::SubgraphClient};
1617
use sqlx::PgPool;
1718
use thegraph::types::Address;
19+
use tokio::sync::Mutex as TokioMutex;
1820
use tokio::{select, sync::Notify, time};
1921
use tracing::{error, warn};
2022

@@ -34,10 +36,11 @@ enum AllocationState {
3436
pub struct Inner {
3537
config: &'static config::Cli,
3638
pgpool: PgPool,
37-
allocations: Arc<Mutex<HashMap<Address, AllocationState>>>,
39+
allocations: Arc<StdMutex<HashMap<Address, AllocationState>>>,
3840
sender: Address,
3941
sender_aggregator_endpoint: String,
40-
unaggregated_fees: Arc<Mutex<UnaggregatedReceipts>>,
42+
unaggregated_fees: Arc<StdMutex<UnaggregatedReceipts>>,
43+
unaggregated_receipts_guard: Arc<TokioMutex<()>>,
4144
}
4245

4346
impl Inner {
@@ -136,7 +139,7 @@ impl Inner {
136139
return;
137140
};
138141

139-
if let Err(e) = self.recompute_unaggregated_fees_static() {
142+
if let Err(e) = self.recompute_unaggregated_fees().await {
140143
error!(
141144
"Error while recomputing unaggregated fees for sender {}: {}",
142145
self.sender, e
@@ -171,7 +174,11 @@ impl Inner {
171174
}
172175

173176
/// Recompute the sender's total unaggregated fees value and last receipt ID.
174-
fn recompute_unaggregated_fees_static(&self) -> Result<()> {
177+
async fn recompute_unaggregated_fees(&self) -> Result<()> {
178+
// Make sure to pause the handling of receipt notifications while we update the unaggregated
179+
// fees.
180+
let _guard = self.unaggregated_receipts_guard.lock().await;
181+
175182
// Similar pattern to get_heaviest_allocation().
176183
let allocations: Vec<_> = self.allocations.lock().unwrap().values().cloned().collect();
177184

@@ -236,6 +243,7 @@ pub struct SenderAccount {
236243
rav_requester_notify: Arc<Notify>,
237244
rav_requester_finalize_task: tokio::task::JoinHandle<()>,
238245
rav_requester_finalize_notify: Arc<Notify>,
246+
unaggregated_receipts_guard: Arc<TokioMutex<()>>,
239247
}
240248

241249
impl SenderAccount {
@@ -250,13 +258,16 @@ impl SenderAccount {
250258
tap_eip712_domain_separator: Eip712Domain,
251259
sender_aggregator_endpoint: String,
252260
) -> Self {
261+
let unaggregated_receipts_guard = Arc::new(TokioMutex::new(()));
262+
253263
let inner = Arc::new(Inner {
254264
config,
255265
pgpool,
256-
allocations: Arc::new(Mutex::new(HashMap::new())),
266+
allocations: Arc::new(StdMutex::new(HashMap::new())),
257267
sender: sender_id,
258268
sender_aggregator_endpoint,
259-
unaggregated_fees: Arc::new(Mutex::new(UnaggregatedReceipts::default())),
269+
unaggregated_fees: Arc::new(StdMutex::new(UnaggregatedReceipts::default())),
270+
unaggregated_receipts_guard: unaggregated_receipts_guard.clone(),
260271
});
261272

262273
let rav_requester_notify = Arc::new(Notify::new());
@@ -289,6 +300,7 @@ impl SenderAccount {
289300
rav_requester_notify,
290301
rav_requester_finalize_task,
291302
rav_requester_finalize_notify,
303+
unaggregated_receipts_guard,
292304
}
293305
}
294306

@@ -346,21 +358,32 @@ impl SenderAccount {
346358
&self,
347359
new_receipt_notification: NewReceiptNotification,
348360
) {
349-
let mut unaggregated_fees = self.inner.unaggregated_fees.lock().unwrap();
350-
351-
// Else we already processed that receipt, most likely from pulling the receipts
352-
// from the database.
353-
if new_receipt_notification.id > unaggregated_fees.last_id {
354-
if let Some(AllocationState::Active(allocation)) = self
355-
.inner
356-
.allocations
357-
.lock()
358-
.unwrap()
359-
.get(&new_receipt_notification.allocation_id)
361+
// Make sure to pause the handling of receipt notifications while we update the unaggregated
362+
// fees.
363+
let _guard = self.unaggregated_receipts_guard.lock().await;
364+
365+
let allocation_state = self
366+
.inner
367+
.allocations
368+
.lock()
369+
.unwrap()
370+
.get(&new_receipt_notification.allocation_id)
371+
.cloned();
372+
373+
if let Some(AllocationState::Active(allocation)) = allocation_state {
374+
// Try to add the receipt value to the allocation's unaggregated fees value.
375+
// If the fees were not added, it means the receipt was already processed, so we
376+
// don't need to do anything.
377+
if allocation
378+
.fees_add(new_receipt_notification.value, new_receipt_notification.id)
379+
.await
360380
{
361381
// Add the receipt value to the allocation's unaggregated fees value.
362-
allocation.fees_add(new_receipt_notification.value);
382+
allocation
383+
.fees_add(new_receipt_notification.value, new_receipt_notification.id)
384+
.await;
363385
// Add the receipt value to the sender's unaggregated fees value.
386+
let mut unaggregated_fees = self.inner.unaggregated_fees.lock().unwrap();
364387
*unaggregated_fees = UnaggregatedReceipts {
365388
value: self
366389
.inner
@@ -373,18 +396,18 @@ impl SenderAccount {
373396
{
374397
self.rav_requester_notify.notify_waiters();
375398
}
376-
} else {
377-
error!(
378-
"Received a new receipt notification for allocation {} that doesn't exist \
379-
or is ineligible for sender {}.",
380-
new_receipt_notification.allocation_id, self.inner.sender
381-
);
382399
}
400+
} else {
401+
error!(
402+
"Received a new receipt notification for allocation {} that doesn't exist \
403+
or is ineligible for sender {}.",
404+
new_receipt_notification.allocation_id, self.inner.sender
405+
);
383406
}
384407
}
385408

386-
pub fn recompute_unaggregated_fees(&self) -> Result<()> {
387-
self.inner.recompute_unaggregated_fees_static()
409+
pub async fn recompute_unaggregated_fees(&self) -> Result<()> {
410+
self.inner.recompute_unaggregated_fees().await
388411
}
389412
}
390413

@@ -507,7 +530,7 @@ mod tests {
507530
*ALLOCATION_ID_2,
508531
]))
509532
.await;
510-
sender.recompute_unaggregated_fees().unwrap();
533+
sender.recompute_unaggregated_fees().await.unwrap();
511534

512535
sender
513536
}

tap-agent/src/tap/sender_accounts_manager.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ impl SenderAccountsManager {
262262

263263
sender
264264
.recompute_unaggregated_fees()
265+
.await
265266
.expect("should be able to recompute unaggregated fees");
266267
}
267268
drop(sender_accounts_write_lock);

tap-agent/src/tap/sender_allocation.rs

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
// Copyright 2023-, GraphOps and Semiotic Labs.
22
// SPDX-License-Identifier: Apache-2.0
33

4-
use std::{
5-
str::FromStr,
6-
sync::{Arc, Mutex},
7-
time::Duration,
8-
};
4+
use std::sync::Mutex as StdMutex;
5+
use std::{str::FromStr, sync::Arc, time::Duration};
96

107
use alloy_primitives::hex::ToHex;
118
use alloy_sol_types::Eip712Domain;
@@ -22,6 +19,7 @@ use tap_core::{
2219
tap_receipt::{ReceiptCheck, ReceivedReceipt},
2320
};
2421
use thegraph::types::Address;
22+
use tokio::sync::Mutex as TokioMutex;
2523
use tracing::{error, warn};
2624

2725
use crate::{
@@ -48,10 +46,11 @@ pub struct SenderAllocation {
4846
allocation_id: Address,
4947
sender: Address,
5048
sender_aggregator_endpoint: String,
51-
unaggregated_fees: Arc<Mutex<UnaggregatedReceipts>>,
49+
unaggregated_fees: Arc<StdMutex<UnaggregatedReceipts>>,
5250
config: &'static config::Cli,
5351
escrow_accounts: Eventual<EscrowAccounts>,
54-
rav_request_guard: tokio::sync::Mutex<()>,
52+
rav_request_guard: TokioMutex<()>,
53+
unaggregated_receipts_guard: TokioMutex<()>,
5554
}
5655

5756
impl SenderAllocation {
@@ -110,10 +109,11 @@ impl SenderAllocation {
110109
allocation_id,
111110
sender,
112111
sender_aggregator_endpoint,
113-
unaggregated_fees: Arc::new(Mutex::new(UnaggregatedReceipts::default())),
112+
unaggregated_fees: Arc::new(StdMutex::new(UnaggregatedReceipts::default())),
114113
config,
115114
escrow_accounts,
116-
rav_request_guard: tokio::sync::Mutex::new(()),
115+
rav_request_guard: TokioMutex::new(()),
116+
unaggregated_receipts_guard: TokioMutex::new(()),
117117
};
118118

119119
sender_allocation
@@ -133,6 +133,10 @@ impl SenderAllocation {
133133
/// Delete obsolete receipts in the DB w.r.t. the last RAV in DB, then update the tap manager
134134
/// with the latest unaggregated fees from the database.
135135
async fn update_unaggregated_fees(&self) -> Result<()> {
136+
// Make sure to pause the handling of receipt notifications while we update the unaggregated
137+
// fees.
138+
let _guard = self.unaggregated_receipts_guard.lock().await;
139+
136140
self.tap_manager.remove_obsolete_receipts().await?;
137141

138142
let signers = signers_trimmed(&self.escrow_accounts, self.sender).await?;
@@ -176,19 +180,19 @@ impl SenderAllocation {
176180
.fetch_one(&self.pgpool)
177181
.await?;
178182

179-
let mut unaggregated_fees = self.unaggregated_fees.lock().unwrap();
180-
181183
ensure!(
182184
res.sum.is_none() == res.max.is_none(),
183185
"Exactly one of SUM(value) and MAX(id) is null. This should not happen."
184186
);
185187

186-
unaggregated_fees.last_id = res.max.unwrap_or(0).try_into()?;
187-
unaggregated_fees.value = res
188-
.sum
189-
.unwrap_or(BigDecimal::from(0))
190-
.to_string()
191-
.parse::<u128>()?;
188+
*self.unaggregated_fees.lock().unwrap() = UnaggregatedReceipts {
189+
last_id: res.max.unwrap_or(0).try_into()?,
190+
value: res
191+
.sum
192+
.unwrap_or(BigDecimal::from(0))
193+
.to_string()
194+
.parse::<u128>()?,
195+
};
192196

193197
// TODO: check if we need to run a RAV request here.
194198

@@ -360,22 +364,40 @@ impl SenderAllocation {
360364
Ok(())
361365
}
362366

363-
/// Safe add the fees to the unaggregated fees value, log an error if there is an overflow and
364-
/// set the unaggregated fees value to u128::MAX.
365-
pub fn fees_add(&self, fees: u128) {
367+
/// Safe add the fees to the unaggregated fees value if the receipt_id is greater than the
368+
/// last_id. If the addition would overflow u128, log an error and set the unaggregated fees
369+
/// value to u128::MAX.
370+
///
371+
/// Returns true if the fees were added, false otherwise.
372+
pub async fn fees_add(&self, fees: u128, receipt_id: u64) -> bool {
373+
// Make sure to pause the handling of receipt notifications while we update the unaggregated
374+
// fees.
375+
let _guard = self.unaggregated_receipts_guard.lock().await;
376+
377+
let mut fees_added = false;
366378
let mut unaggregated_fees = self.unaggregated_fees.lock().unwrap();
367-
unaggregated_fees.value = unaggregated_fees
368-
.value
369-
.checked_add(fees)
370-
.unwrap_or_else(|| {
371-
// This should never happen, but if it does, we want to know about it.
372-
error!(
373-
"Overflow when adding receipt value {} to total unaggregated fees {} for \
374-
allocation {} and sender {}. Setting total unaggregated fees to u128::MAX.",
375-
fees, unaggregated_fees.value, self.allocation_id, self.sender
376-
);
377-
u128::MAX
378-
});
379+
380+
if receipt_id > unaggregated_fees.last_id {
381+
*unaggregated_fees = UnaggregatedReceipts {
382+
last_id: receipt_id,
383+
value: unaggregated_fees
384+
.value
385+
.checked_add(fees)
386+
.unwrap_or_else(|| {
387+
// This should never happen, but if it does, we want to know about it.
388+
error!(
389+
"Overflow when adding receipt value {} to total unaggregated fees {} \
390+
for allocation {} and sender {}. Setting total unaggregated fees to \
391+
u128::MAX.",
392+
fees, unaggregated_fees.value, self.allocation_id, self.sender
393+
);
394+
u128::MAX
395+
}),
396+
};
397+
fees_added = true;
398+
}
399+
400+
fees_added
379401
}
380402

381403
pub fn get_unaggregated_fees(&self) -> UnaggregatedReceipts {

0 commit comments

Comments
 (0)