@@ -17,18 +17,26 @@ package state
1717import (
1818 "errors"
1919 "fmt"
20+ "sync"
2021
2122 "github.com/blinklabs-io/dingo/database"
23+ "github.com/blinklabs-io/dingo/event"
2224 ochainsync "github.com/blinklabs-io/gouroboros/protocol/chainsync"
2325 ocommon "github.com/blinklabs-io/gouroboros/protocol/common"
2426)
2527
2628var ErrIteratorChainTip = errors .New ("chain iterator is at chain tip" )
2729
2830type ChainIterator struct {
29- ls * LedgerState
30- startPoint ocommon.Point
31- blockNumber uint64
31+ mutex sync.Mutex
32+ ls * LedgerState
33+ startPoint ocommon.Point
34+ blockNumber uint64
35+ chainUpdateSubId event.EventSubscriberId
36+ chainUpdateChan <- chan event.Event
37+ needsRollback bool
38+ rollbackPoint ocommon.Point
39+ waitingChan chan event.Event
3240}
3341
3442type ChainIteratorResult struct {
@@ -42,9 +50,15 @@ func newChainIterator(
4250 startPoint ocommon.Point ,
4351 inclusive bool ,
4452) (* ChainIterator , error ) {
53+ // Subscribe to chain updates
54+ chainUpdateSubId , chainUpdateChan := ls .config .EventBus .Subscribe (
55+ ChainUpdateEventType ,
56+ )
4557 ci := & ChainIterator {
46- ls : ls ,
47- startPoint : startPoint ,
58+ ls : ls ,
59+ startPoint : startPoint ,
60+ chainUpdateSubId : chainUpdateSubId ,
61+ chainUpdateChan : chainUpdateChan ,
4862 }
4963 // Lookup start block in metadata DB if not origin
5064 if startPoint .Slot > 0 || len (startPoint .Hash ) > 0 {
@@ -61,15 +75,67 @@ func newChainIterator(
6175 ci .blockNumber ++
6276 }
6377 }
78+ go ci .handleChainUpdateEvents ()
6479 return ci , nil
6580}
6681
82+ func (ci * ChainIterator ) handleChainUpdateEvents () {
83+ for {
84+ evt , ok := <- ci .chainUpdateChan
85+ if ! ok {
86+ return
87+ }
88+ ci .mutex .Lock ()
89+ switch e := evt .Data .(type ) {
90+ case ChainBlockEvent :
91+ if ci .waitingChan != nil {
92+ // Send event without blocking
93+ select {
94+ case ci .waitingChan <- evt :
95+ default :
96+ }
97+ }
98+ case ChainRollbackEvent :
99+ if ci .blockNumber > 0 {
100+ ci .rollbackPoint = e .Point
101+ ci .needsRollback = true
102+ }
103+ if ci .waitingChan != nil {
104+ // Send event without blocking
105+ select {
106+ case ci .waitingChan <- evt :
107+ default :
108+ }
109+ }
110+ }
111+ ci .mutex .Unlock ()
112+ }
113+ }
114+
67115func (ci * ChainIterator ) Tip () (ochainsync.Tip , error ) {
68116 return ci .ls .chainTip (nil )
69117}
70118
71119func (ci * ChainIterator ) Next (blocking bool ) (* ChainIteratorResult , error ) {
72- ci .ls .RLock ()
120+ ci .mutex .Lock ()
121+ // Check for pending rollback
122+ if ci .needsRollback {
123+ ret := & ChainIteratorResult {}
124+ ret .Point = ci .rollbackPoint
125+ ret .Rollback = true
126+ ci .needsRollback = false
127+ if ci .rollbackPoint .Slot > 0 {
128+ // Lookup block number for rollback point
129+ tmpBlock , err := database .BlockByPoint (ci .ls .db , ci .rollbackPoint )
130+ if err != nil {
131+ ci .mutex .Unlock ()
132+ return nil , err
133+ }
134+ ci .blockNumber = tmpBlock .Number + 1
135+ }
136+ ci .mutex .Unlock ()
137+ return ret , nil
138+ }
73139 ret := & ChainIteratorResult {}
74140 // Lookup next block in metadata DB
75141 tmpBlock , err := database .BlockByNumber (ci .ls .db , ci .blockNumber )
@@ -78,42 +144,31 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
78144 ret .Point = ocommon .NewPoint (tmpBlock .Slot , tmpBlock .Hash )
79145 ret .Block = tmpBlock
80146 ci .blockNumber ++
81- ci .ls . RUnlock ()
147+ ci .mutex . Unlock ()
82148 return ret , nil
83149 }
84150 // Return any actual error
85151 if ! errors .Is (err , database .ErrBlockNotFound ) {
86- ci .ls . RUnlock ()
152+ ci .mutex . Unlock ()
87153 return ret , err
88154 }
89- // Check against current tip to see if it was rolled back
90- tip , err := ci .Tip ()
91- if err != nil {
92- return nil , err
93- }
94- if ci .blockNumber > 0 && ci .blockNumber - 1 > tip .BlockNumber {
95- ret .Point = tip .Point
96- ret .Rollback = true
97- ci .blockNumber = tip .BlockNumber + 1
98- ci .ls .RUnlock ()
99- return ret , nil
100- }
101155 // Return immediately if we're not blocking
102156 if ! blocking {
103- ci .ls . RUnlock ()
157+ ci .mutex . Unlock ()
104158 return nil , ErrIteratorChainTip
105159 }
106- // Wait for new block or a rollback
107- chainUpdateSubId , chainUpdateChan := ci .ls .config .EventBus .Subscribe (
108- ChainUpdateEventType ,
109- )
160+ // Wait for chain update
161+ ci .waitingChan = make (chan event.Event , 1 )
110162 // Release read lock while we wait for new event
111- ci .ls . RUnlock ()
112- evt , ok := <- chainUpdateChan
163+ ci .mutex . Unlock ()
164+ evt , ok := <- ci . waitingChan
113165 if ! ok {
114166 // TODO: return an actual error (#389)
115167 return nil , nil
116168 }
169+ ci .mutex .Lock ()
170+ defer ci .mutex .Unlock ()
171+ ci .waitingChan = nil
117172 switch e := evt .Data .(type ) {
118173 case ChainBlockEvent :
119174 ret .Point = e .Point
@@ -122,6 +177,7 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
122177 case ChainRollbackEvent :
123178 ret .Point = e .Point
124179 ret .Rollback = true
180+ ci .needsRollback = false
125181 if e .Point .Slot > 0 {
126182 // Lookup block number for rollback point
127183 tmpBlock , err := database .BlockByPoint (ci .ls .db , e .Point )
@@ -133,6 +189,5 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
133189 default :
134190 return nil , fmt .Errorf ("unexpected event type %T" , e )
135191 }
136- ci .ls .config .EventBus .Unsubscribe (ChainUpdateEventType , chainUpdateSubId )
137192 return ret , nil
138193}
0 commit comments