1
1
// Copyright 2023-, GraphOps and Semiotic Labs.
2
2
// SPDX-License-Identifier: Apache-2.0
3
3
4
+ use std:: sync:: Mutex as StdMutex ;
4
5
use std:: {
5
6
cmp:: max,
6
7
collections:: { HashMap , HashSet } ,
7
- sync:: { Arc , Mutex } ,
8
+ sync:: Arc ,
8
9
time:: Duration ,
9
10
} ;
10
11
@@ -15,6 +16,7 @@ use eventuals::Eventual;
15
16
use indexer_common:: { escrow_accounts:: EscrowAccounts , prelude:: SubgraphClient } ;
16
17
use sqlx:: PgPool ;
17
18
use thegraph:: types:: Address ;
19
+ use tokio:: sync:: Mutex as TokioMutex ;
18
20
use tokio:: { select, sync:: Notify , time} ;
19
21
use tracing:: { error, warn} ;
20
22
@@ -34,10 +36,11 @@ enum AllocationState {
34
36
pub struct Inner {
35
37
config : & ' static config:: Cli ,
36
38
pgpool : PgPool ,
37
- allocations : Arc < Mutex < HashMap < Address , AllocationState > > > ,
39
+ allocations : Arc < StdMutex < HashMap < Address , AllocationState > > > ,
38
40
sender : Address ,
39
41
sender_aggregator_endpoint : String ,
40
- unaggregated_fees : Arc < Mutex < UnaggregatedReceipts > > ,
42
+ unaggregated_fees : Arc < StdMutex < UnaggregatedReceipts > > ,
43
+ unaggregated_receipts_guard : Arc < TokioMutex < ( ) > > ,
41
44
}
42
45
43
46
impl Inner {
@@ -136,7 +139,7 @@ impl Inner {
136
139
return ;
137
140
} ;
138
141
139
- if let Err ( e) = self . recompute_unaggregated_fees_static ( ) {
142
+ if let Err ( e) = self . recompute_unaggregated_fees ( ) . await {
140
143
error ! (
141
144
"Error while recomputing unaggregated fees for sender {}: {}" ,
142
145
self . sender, e
@@ -171,7 +174,11 @@ impl Inner {
171
174
}
172
175
173
176
/// Recompute the sender's total unaggregated fees value and last receipt ID.
174
- fn recompute_unaggregated_fees_static ( & self ) -> Result < ( ) > {
177
+ async fn recompute_unaggregated_fees ( & self ) -> Result < ( ) > {
178
+ // Make sure to pause the handling of receipt notifications while we update the unaggregated
179
+ // fees.
180
+ let _guard = self . unaggregated_receipts_guard . lock ( ) . await ;
181
+
175
182
// Similar pattern to get_heaviest_allocation().
176
183
let allocations: Vec < _ > = self . allocations . lock ( ) . unwrap ( ) . values ( ) . cloned ( ) . collect ( ) ;
177
184
@@ -236,6 +243,7 @@ pub struct SenderAccount {
236
243
rav_requester_notify : Arc < Notify > ,
237
244
rav_requester_finalize_task : tokio:: task:: JoinHandle < ( ) > ,
238
245
rav_requester_finalize_notify : Arc < Notify > ,
246
+ unaggregated_receipts_guard : Arc < TokioMutex < ( ) > > ,
239
247
}
240
248
241
249
impl SenderAccount {
@@ -250,13 +258,16 @@ impl SenderAccount {
250
258
tap_eip712_domain_separator : Eip712Domain ,
251
259
sender_aggregator_endpoint : String ,
252
260
) -> Self {
261
+ let unaggregated_receipts_guard = Arc :: new ( TokioMutex :: new ( ( ) ) ) ;
262
+
253
263
let inner = Arc :: new ( Inner {
254
264
config,
255
265
pgpool,
256
- allocations : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
266
+ allocations : Arc :: new ( StdMutex :: new ( HashMap :: new ( ) ) ) ,
257
267
sender : sender_id,
258
268
sender_aggregator_endpoint,
259
- unaggregated_fees : Arc :: new ( Mutex :: new ( UnaggregatedReceipts :: default ( ) ) ) ,
269
+ unaggregated_fees : Arc :: new ( StdMutex :: new ( UnaggregatedReceipts :: default ( ) ) ) ,
270
+ unaggregated_receipts_guard : unaggregated_receipts_guard. clone ( ) ,
260
271
} ) ;
261
272
262
273
let rav_requester_notify = Arc :: new ( Notify :: new ( ) ) ;
@@ -289,6 +300,7 @@ impl SenderAccount {
289
300
rav_requester_notify,
290
301
rav_requester_finalize_task,
291
302
rav_requester_finalize_notify,
303
+ unaggregated_receipts_guard,
292
304
}
293
305
}
294
306
@@ -346,21 +358,32 @@ impl SenderAccount {
346
358
& self ,
347
359
new_receipt_notification : NewReceiptNotification ,
348
360
) {
349
- let mut unaggregated_fees = self . inner . unaggregated_fees . lock ( ) . unwrap ( ) ;
350
-
351
- // Else we already processed that receipt, most likely from pulling the receipts
352
- // from the database.
353
- if new_receipt_notification. id > unaggregated_fees. last_id {
354
- if let Some ( AllocationState :: Active ( allocation) ) = self
355
- . inner
356
- . allocations
357
- . lock ( )
358
- . unwrap ( )
359
- . get ( & new_receipt_notification. allocation_id )
361
+ // Make sure to pause the handling of receipt notifications while we update the unaggregated
362
+ // fees.
363
+ let _guard = self . unaggregated_receipts_guard . lock ( ) . await ;
364
+
365
+ let allocation_state = self
366
+ . inner
367
+ . allocations
368
+ . lock ( )
369
+ . unwrap ( )
370
+ . get ( & new_receipt_notification. allocation_id )
371
+ . cloned ( ) ;
372
+
373
+ if let Some ( AllocationState :: Active ( allocation) ) = allocation_state {
374
+ // Try to add the receipt value to the allocation's unaggregated fees value.
375
+ // If the fees were not added, it means the receipt was already processed, so we
376
+ // don't need to do anything.
377
+ if allocation
378
+ . fees_add ( new_receipt_notification. value , new_receipt_notification. id )
379
+ . await
360
380
{
361
381
// Add the receipt value to the allocation's unaggregated fees value.
362
- allocation. fees_add ( new_receipt_notification. value ) ;
382
+ allocation
383
+ . fees_add ( new_receipt_notification. value , new_receipt_notification. id )
384
+ . await ;
363
385
// Add the receipt value to the sender's unaggregated fees value.
386
+ let mut unaggregated_fees = self . inner . unaggregated_fees . lock ( ) . unwrap ( ) ;
364
387
* unaggregated_fees = UnaggregatedReceipts {
365
388
value : self
366
389
. inner
@@ -373,18 +396,18 @@ impl SenderAccount {
373
396
{
374
397
self . rav_requester_notify . notify_waiters ( ) ;
375
398
}
376
- } else {
377
- error ! (
378
- "Received a new receipt notification for allocation {} that doesn't exist \
379
- or is ineligible for sender {}.",
380
- new_receipt_notification. allocation_id, self . inner. sender
381
- ) ;
382
399
}
400
+ } else {
401
+ error ! (
402
+ "Received a new receipt notification for allocation {} that doesn't exist \
403
+ or is ineligible for sender {}.",
404
+ new_receipt_notification. allocation_id, self . inner. sender
405
+ ) ;
383
406
}
384
407
}
385
408
386
- pub fn recompute_unaggregated_fees ( & self ) -> Result < ( ) > {
387
- self . inner . recompute_unaggregated_fees_static ( )
409
+ pub async fn recompute_unaggregated_fees ( & self ) -> Result < ( ) > {
410
+ self . inner . recompute_unaggregated_fees ( ) . await
388
411
}
389
412
}
390
413
@@ -507,7 +530,7 @@ mod tests {
507
530
* ALLOCATION_ID_2 ,
508
531
] ) )
509
532
. await ;
510
- sender. recompute_unaggregated_fees ( ) . unwrap ( ) ;
533
+ sender. recompute_unaggregated_fees ( ) . await . unwrap ( ) ;
511
534
512
535
sender
513
536
}
0 commit comments