Skip to content

Commit 546b9d7

Browse files
authored
refactor: chain iterator rollback event handling (#571)
We watch for chain update events when starting a chain iterator to make sure that we don't miss a rollback event Signed-off-by: Aurora Gaffney <[email protected]>
1 parent 84d930a commit 546b9d7

File tree

1 file changed

+83
-28
lines changed

1 file changed

+83
-28
lines changed

state/chain.go

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,26 @@ package state
1717
import (
1818
"errors"
1919
"fmt"
20+
"sync"
2021

2122
"github.com/blinklabs-io/dingo/database"
23+
"github.com/blinklabs-io/dingo/event"
2224
ochainsync "github.com/blinklabs-io/gouroboros/protocol/chainsync"
2325
ocommon "github.com/blinklabs-io/gouroboros/protocol/common"
2426
)
2527

2628
var ErrIteratorChainTip = errors.New("chain iterator is at chain tip")
2729

2830
type ChainIterator struct {
29-
ls *LedgerState
30-
startPoint ocommon.Point
31-
blockNumber uint64
31+
mutex sync.Mutex
32+
ls *LedgerState
33+
startPoint ocommon.Point
34+
blockNumber uint64
35+
chainUpdateSubId event.EventSubscriberId
36+
chainUpdateChan <-chan event.Event
37+
needsRollback bool
38+
rollbackPoint ocommon.Point
39+
waitingChan chan event.Event
3240
}
3341

3442
type ChainIteratorResult struct {
@@ -42,9 +50,15 @@ func newChainIterator(
4250
startPoint ocommon.Point,
4351
inclusive bool,
4452
) (*ChainIterator, error) {
53+
// Subscribe to chain updates
54+
chainUpdateSubId, chainUpdateChan := ls.config.EventBus.Subscribe(
55+
ChainUpdateEventType,
56+
)
4557
ci := &ChainIterator{
46-
ls: ls,
47-
startPoint: startPoint,
58+
ls: ls,
59+
startPoint: startPoint,
60+
chainUpdateSubId: chainUpdateSubId,
61+
chainUpdateChan: chainUpdateChan,
4862
}
4963
// Lookup start block in metadata DB if not origin
5064
if startPoint.Slot > 0 || len(startPoint.Hash) > 0 {
@@ -61,15 +75,67 @@ func newChainIterator(
6175
ci.blockNumber++
6276
}
6377
}
78+
go ci.handleChainUpdateEvents()
6479
return ci, nil
6580
}
6681

82+
func (ci *ChainIterator) handleChainUpdateEvents() {
83+
for {
84+
evt, ok := <-ci.chainUpdateChan
85+
if !ok {
86+
return
87+
}
88+
ci.mutex.Lock()
89+
switch e := evt.Data.(type) {
90+
case ChainBlockEvent:
91+
if ci.waitingChan != nil {
92+
// Send event without blocking
93+
select {
94+
case ci.waitingChan <- evt:
95+
default:
96+
}
97+
}
98+
case ChainRollbackEvent:
99+
if ci.blockNumber > 0 {
100+
ci.rollbackPoint = e.Point
101+
ci.needsRollback = true
102+
}
103+
if ci.waitingChan != nil {
104+
// Send event without blocking
105+
select {
106+
case ci.waitingChan <- evt:
107+
default:
108+
}
109+
}
110+
}
111+
ci.mutex.Unlock()
112+
}
113+
}
114+
67115
func (ci *ChainIterator) Tip() (ochainsync.Tip, error) {
68116
return ci.ls.chainTip(nil)
69117
}
70118

71119
func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
72-
ci.ls.RLock()
120+
ci.mutex.Lock()
121+
// Check for pending rollback
122+
if ci.needsRollback {
123+
ret := &ChainIteratorResult{}
124+
ret.Point = ci.rollbackPoint
125+
ret.Rollback = true
126+
ci.needsRollback = false
127+
if ci.rollbackPoint.Slot > 0 {
128+
// Lookup block number for rollback point
129+
tmpBlock, err := database.BlockByPoint(ci.ls.db, ci.rollbackPoint)
130+
if err != nil {
131+
ci.mutex.Unlock()
132+
return nil, err
133+
}
134+
ci.blockNumber = tmpBlock.Number + 1
135+
}
136+
ci.mutex.Unlock()
137+
return ret, nil
138+
}
73139
ret := &ChainIteratorResult{}
74140
// Lookup next block in metadata DB
75141
tmpBlock, err := database.BlockByNumber(ci.ls.db, ci.blockNumber)
@@ -78,42 +144,31 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
78144
ret.Point = ocommon.NewPoint(tmpBlock.Slot, tmpBlock.Hash)
79145
ret.Block = tmpBlock
80146
ci.blockNumber++
81-
ci.ls.RUnlock()
147+
ci.mutex.Unlock()
82148
return ret, nil
83149
}
84150
// Return any actual error
85151
if !errors.Is(err, database.ErrBlockNotFound) {
86-
ci.ls.RUnlock()
152+
ci.mutex.Unlock()
87153
return ret, err
88154
}
89-
// Check against current tip to see if it was rolled back
90-
tip, err := ci.Tip()
91-
if err != nil {
92-
return nil, err
93-
}
94-
if ci.blockNumber > 0 && ci.blockNumber-1 > tip.BlockNumber {
95-
ret.Point = tip.Point
96-
ret.Rollback = true
97-
ci.blockNumber = tip.BlockNumber + 1
98-
ci.ls.RUnlock()
99-
return ret, nil
100-
}
101155
// Return immediately if we're not blocking
102156
if !blocking {
103-
ci.ls.RUnlock()
157+
ci.mutex.Unlock()
104158
return nil, ErrIteratorChainTip
105159
}
106-
// Wait for new block or a rollback
107-
chainUpdateSubId, chainUpdateChan := ci.ls.config.EventBus.Subscribe(
108-
ChainUpdateEventType,
109-
)
160+
// Wait for chain update
161+
ci.waitingChan = make(chan event.Event, 1)
110162
// Release read lock while we wait for new event
111-
ci.ls.RUnlock()
112-
evt, ok := <-chainUpdateChan
163+
ci.mutex.Unlock()
164+
evt, ok := <-ci.waitingChan
113165
if !ok {
114166
// TODO: return an actual error (#389)
115167
return nil, nil
116168
}
169+
ci.mutex.Lock()
170+
defer ci.mutex.Unlock()
171+
ci.waitingChan = nil
117172
switch e := evt.Data.(type) {
118173
case ChainBlockEvent:
119174
ret.Point = e.Point
@@ -122,6 +177,7 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
122177
case ChainRollbackEvent:
123178
ret.Point = e.Point
124179
ret.Rollback = true
180+
ci.needsRollback = false
125181
if e.Point.Slot > 0 {
126182
// Lookup block number for rollback point
127183
tmpBlock, err := database.BlockByPoint(ci.ls.db, e.Point)
@@ -133,6 +189,5 @@ func (ci *ChainIterator) Next(blocking bool) (*ChainIteratorResult, error) {
133189
default:
134190
return nil, fmt.Errorf("unexpected event type %T", e)
135191
}
136-
ci.ls.config.EventBus.Unsubscribe(ChainUpdateEventType, chainUpdateSubId)
137192
return ret, nil
138193
}

0 commit comments

Comments
 (0)