11use crate :: collections:: { HashMap , HashSet , VecDeque } ;
2- use crate :: tx_graph:: { TxAncestors , TxDescendants } ;
2+ use crate :: tx_graph:: { TxAncestors , TxDescendants , TxNode } ;
33use crate :: { Anchor , ChainOracle , TxGraph } ;
44use alloc:: boxed:: Box ;
55use alloc:: collections:: BTreeSet ;
@@ -36,6 +36,9 @@ pub struct CanonicalIter<'g, A, C> {
3636 canonical : CanonicalMap < A > ,
3737 not_canonical : NotCanonicalSet ,
3838
39+ canonical_ancestors : HashMap < Txid , Vec < Txid > > ,
40+ canonical_roots : VecDeque < Txid > ,
41+
3942 queue : VecDeque < Txid > ,
4043}
4144
@@ -75,6 +78,8 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
7578 unprocessed_leftover_txs : VecDeque :: new ( ) ,
7679 canonical : HashMap :: new ( ) ,
7780 not_canonical : HashSet :: new ( ) ,
81+ canonical_ancestors : HashMap :: new ( ) ,
82+ canonical_roots : VecDeque :: new ( ) ,
7883 queue : VecDeque :: new ( ) ,
7984 }
8085 }
@@ -160,7 +165,7 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
160165
161166 // Any conflicts with a canonical tx can be added to `not_canonical`. Descendants
162167 // of `not_canonical` txs can also be added to `not_canonical`.
163- for ( _, conflict_txid) in self . tx_graph . direct_conflicts ( & tx) {
168+ for ( _, conflict_txid) in self . tx_graph . direct_conflicts ( & tx. clone ( ) ) {
164169 TxDescendants :: new_include_root (
165170 self . tx_graph ,
166171 conflict_txid,
@@ -181,6 +186,18 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
181186 detected_self_double_spend = true ;
182187 return None ;
183188 }
189+
190+ // Calculates all the existing ancestors for the given Txid
191+ self . canonical_ancestors . insert (
192+ this_txid,
193+ tx. clone ( )
194+ . input
195+ . iter ( )
196+ . filter ( |txin| self . tx_graph . get_tx ( txin. previous_output . txid ) . is_some ( ) )
197+ . map ( |txin| txin. previous_output . txid )
198+ . collect ( ) ,
199+ ) ;
200+
184201 canonical_entry. insert ( ( tx, this_reason) ) ;
185202 Some ( this_txid)
186203 } ,
@@ -190,12 +207,29 @@ impl<'g, A: Anchor, C: ChainOracle> CanonicalIter<'g, A, C> {
190207 if detected_self_double_spend {
191208 for txid in staged_queue {
192209 self . canonical . remove ( & txid) ;
210+ self . canonical_ancestors . remove ( & txid) ;
193211 }
194212 for txid in undo_not_canonical {
195213 self . not_canonical . remove ( & txid) ;
196214 }
197215 } else {
198- self . queue . extend ( staged_queue) ;
216+ // TODO: (@oleonardolima) Can this be optimized somehow ?
217+ // Can we just do a simple lookup on the `canonical_ancestors` field ?
218+ for txid in staged_queue {
219+ let tx = self . tx_graph . get_tx ( txid) . expect ( "tx must exist" ) ;
220+ let ancestors = tx
221+ . input
222+ . iter ( )
223+ . map ( |txin| txin. previous_output . txid )
224+ . filter_map ( |prev_txid| self . tx_graph . get_tx ( prev_txid) )
225+ . collect :: < Vec < _ > > ( ) ;
226+
227+ // check if it's a root: it's either a coinbase transaction or has not known
228+ // ancestors in the tx_graph
229+ if tx. is_coinbase ( ) || ancestors. is_empty ( ) {
230+ self . canonical_roots . push_back ( txid) ;
231+ }
232+ }
199233 }
200234 }
201235}
@@ -204,52 +238,58 @@ impl<A: Anchor, C: ChainOracle> Iterator for CanonicalIter<'_, A, C> {
204238 type Item = Result < ( Txid , Arc < Transaction > , CanonicalReason < A > ) , C :: Error > ;
205239
206240 fn next ( & mut self ) -> Option < Self :: Item > {
207- loop {
208- if let Some ( txid) = self . queue . pop_front ( ) {
209- let ( tx, reason) = self
210- . canonical
211- . get ( & txid)
212- . cloned ( )
213- . expect ( "reason must exist" ) ;
214- return Some ( Ok ( ( txid, tx, reason) ) ) ;
241+ while let Some ( ( txid, tx) ) = self . unprocessed_assumed_txs . next ( ) {
242+ if !self . is_canonicalized ( txid) {
243+ self . mark_canonical ( txid, tx, CanonicalReason :: assumed ( ) ) ;
215244 }
245+ }
216246
217- if let Some ( ( txid, tx) ) = self . unprocessed_assumed_txs . next ( ) {
218- if !self . is_canonicalized ( txid) {
219- self . mark_canonical ( txid, tx, CanonicalReason :: assumed ( ) ) ;
247+ while let Some ( ( txid, tx, anchors) ) = self . unprocessed_anchored_txs . next ( ) {
248+ if !self . is_canonicalized ( txid) {
249+ if let Err ( err) = self . scan_anchors ( txid, tx, anchors) {
250+ return Some ( Err ( err) ) ;
220251 }
221252 }
253+ }
222254
223- if let Some ( ( txid, tx, anchors) ) = self . unprocessed_anchored_txs . next ( ) {
224- if !self . is_canonicalized ( txid) {
225- if let Err ( err) = self . scan_anchors ( txid, tx, anchors) {
226- return Some ( Err ( err) ) ;
227- }
228- }
229- continue ;
255+ while let Some ( ( txid, tx, last_seen) ) = self . unprocessed_seen_txs . next ( ) {
256+ debug_assert ! (
257+ !tx. is_coinbase( ) ,
258+ "Coinbase txs must not have `last_seen` (in mempool) value"
259+ ) ;
260+ if !self . is_canonicalized ( txid) {
261+ let observed_in = ObservedIn :: Mempool ( last_seen) ;
262+ self . mark_canonical ( txid, tx, CanonicalReason :: from_observed_in ( observed_in) ) ;
230263 }
264+ }
231265
232- if let Some ( ( txid, tx, last_seen) ) = self . unprocessed_seen_txs . next ( ) {
233- debug_assert ! (
234- !tx. is_coinbase( ) ,
235- "Coinbase txs must not have `last_seen` (in mempool) value"
236- ) ;
237- if !self . is_canonicalized ( txid) {
238- let observed_in = ObservedIn :: Mempool ( last_seen) ;
239- self . mark_canonical ( txid, tx, CanonicalReason :: from_observed_in ( observed_in) ) ;
240- }
241- continue ;
266+ while let Some ( ( txid, tx, height) ) = self . unprocessed_leftover_txs . pop_front ( ) {
267+ if !self . is_canonicalized ( txid) && !tx. is_coinbase ( ) {
268+ let observed_in = ObservedIn :: Block ( height) ;
269+ self . mark_canonical ( txid, tx, CanonicalReason :: from_observed_in ( observed_in) ) ;
242270 }
271+ }
243272
244- if let Some ( ( txid, tx, height) ) = self . unprocessed_leftover_txs . pop_front ( ) {
245- if !self . is_canonicalized ( txid) && !tx. is_coinbase ( ) {
246- let observed_in = ObservedIn :: Block ( height) ;
247- self . mark_canonical ( txid, tx, CanonicalReason :: from_observed_in ( observed_in) ) ;
248- }
249- continue ;
250- }
273+ if !self . canonical_roots . is_empty ( ) {
274+ let topological_iter = TopologicalIteratorWithLevels :: new (
275+ self . tx_graph ,
276+ self . chain ,
277+ self . chain_tip ,
278+ & self . canonical_ancestors ,
279+ self . canonical_roots . drain ( ..) . collect ( ) ,
280+ ) ;
281+ self . queue . extend ( topological_iter) ;
282+ }
251283
252- return None ;
284+ if let Some ( txid) = self . queue . pop_front ( ) {
285+ let ( tx, reason) = self
286+ . canonical
287+ . get ( & txid)
288+ . cloned ( )
289+ . expect ( "canonical reason must exist" ) ;
290+ Some ( Ok ( ( txid, tx, reason) ) )
291+ } else {
292+ None
253293 }
254294 }
255295}
@@ -342,3 +382,129 @@ impl<A: Clone> CanonicalReason<A> {
342382 }
343383 }
344384}
385+
386+ struct TopologicalIteratorWithLevels < ' a , A , C > {
387+ tx_graph : & ' a TxGraph < A > ,
388+ chain : & ' a C ,
389+ chain_tip : BlockId ,
390+
391+ current_level : Vec < Txid > ,
392+ next_level : Vec < Txid > ,
393+
394+ adj_list : HashMap < Txid , Vec < Txid > > ,
395+ parent_count : HashMap < Txid , usize > ,
396+
397+ current_index : usize ,
398+ }
399+
400+ impl < ' a , A : Anchor , C : ChainOracle > TopologicalIteratorWithLevels < ' a , A , C > {
401+ fn new (
402+ tx_graph : & ' a TxGraph < A > ,
403+ chain : & ' a C ,
404+ chain_tip : BlockId ,
405+ ancestors_by_txid : & HashMap < Txid , Vec < Txid > > ,
406+ roots : Vec < Txid > ,
407+ ) -> Self {
408+ let mut parent_count = HashMap :: new ( ) ;
409+ let mut adj_list: HashMap < Txid , Vec < Txid > > = HashMap :: new ( ) ;
410+
411+ for ( txid, ancestors) in ancestors_by_txid {
412+ for ancestor in ancestors {
413+ adj_list. entry ( * ancestor) . or_default ( ) . push ( * txid) ;
414+ * parent_count. entry ( * txid) . or_insert ( 0 ) += 1 ;
415+ }
416+ }
417+
418+ let mut current_level: Vec < Txid > = roots. to_vec ( ) ;
419+
420+ // Sort the initial level by confirmation height
421+ current_level. sort_by_key ( |& txid| {
422+ let tx_node = tx_graph. get_tx_node ( txid) . expect ( "tx should exist" ) ;
423+ Self :: find_direct_anchor ( & tx_node, chain, chain_tip)
424+ . expect ( "should not fail" )
425+ . map ( |anchor| anchor. confirmation_height_upper_bound ( ) )
426+ . unwrap_or ( u32:: MAX )
427+ } ) ;
428+
429+ Self {
430+ current_level,
431+ next_level : Vec :: new ( ) ,
432+ adj_list,
433+ parent_count,
434+ current_index : 0 ,
435+ tx_graph,
436+ chain,
437+ chain_tip,
438+ }
439+ }
440+
441+ fn find_direct_anchor (
442+ tx_node : & TxNode < ' _ , Arc < Transaction > , A > ,
443+ chain : & C ,
444+ chain_tip : BlockId ,
445+ ) -> Result < Option < A > , C :: Error > {
446+ tx_node
447+ . anchors
448+ . iter ( )
449+ . find_map ( |a| -> Option < Result < A , C :: Error > > {
450+ match chain. is_block_in_chain ( a. anchor_block ( ) , chain_tip) {
451+ Ok ( Some ( true ) ) => Some ( Ok ( a. clone ( ) ) ) ,
452+ Ok ( Some ( false ) ) | Ok ( None ) => None ,
453+ Err ( err) => Some ( Err ( err) ) ,
454+ }
455+ } )
456+ . transpose ( )
457+ }
458+
459+ fn advance_to_next_level ( & mut self ) {
460+ self . current_level = core:: mem:: take ( & mut self . next_level ) ;
461+
462+ // Sort by confirmation height
463+ self . current_level . sort_by_key ( |& txid| {
464+ let tx_node = self . tx_graph . get_tx_node ( txid) . expect ( "tx should exist" ) ;
465+
466+ Self :: find_direct_anchor ( & tx_node, self . chain , self . chain_tip )
467+ . expect ( "should not fail" )
468+ . map ( |anchor| anchor. confirmation_height_upper_bound ( ) )
469+ . unwrap_or ( u32:: MAX )
470+ } ) ;
471+
472+ self . current_index = 0 ;
473+ }
474+ }
475+
476+ impl < ' a , A : Anchor , C : ChainOracle > Iterator for TopologicalIteratorWithLevels < ' a , A , C > {
477+ type Item = Txid ;
478+
479+ fn next ( & mut self ) -> Option < Self :: Item > {
480+ // If we've exhausted the current level, move to next
481+ if self . current_index >= self . current_level . len ( ) {
482+ if self . next_level . is_empty ( ) {
483+ return None ;
484+ }
485+ self . advance_to_next_level ( ) ;
486+ }
487+
488+ let current = self . current_level [ self . current_index ] ;
489+ self . current_index += 1 ;
490+
491+ // If this is the last item in current level, prepare dependents for next level
492+ if self . current_index == self . current_level . len ( ) {
493+ // Process all dependents of all transactions in current level
494+ for & tx in & self . current_level {
495+ if let Some ( dependents) = self . adj_list . get ( & tx) {
496+ for & dependent in dependents {
497+ if let Some ( degree) = self . parent_count . get_mut ( & dependent) {
498+ * degree -= 1 ;
499+ if * degree == 0 {
500+ self . next_level . push ( dependent) ;
501+ }
502+ }
503+ }
504+ }
505+ }
506+ }
507+
508+ Some ( current)
509+ }
510+ }
0 commit comments