Skip to content

Commit c382222

Browse files
committed
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent b41f4cc commit c382222

File tree

2 files changed

+82
-72
lines changed

2 files changed

+82
-72
lines changed

device/allowedips.go

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
8585
}
8686
}
8787

88-
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
89-
if node == nil {
90-
return node
91-
}
92-
93-
// walk recursively
94-
95-
node.child[0] = node.child[0].removeByPeer(p)
96-
node.child[1] = node.child[1].removeByPeer(p)
97-
98-
if node.peer != p {
99-
return node
100-
}
101-
102-
// remove peer & merge
103-
104-
node.removeFromPeerEntries()
105-
node.peer = nil
106-
if node.child[0] == nil {
107-
return node.child[1]
108-
}
109-
return node.child[0]
110-
}
111-
11288
func (node *trieEntry) choose(ip net.IP) byte {
11389
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
11490
}
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
261237
table.mutex.Lock()
262238
defer table.mutex.Unlock()
263239

264-
table.IPv4 = table.IPv4.removeByPeer(peer)
265-
table.IPv6 = table.IPv6.removeByPeer(peer)
240+
var next *list.Element
241+
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
242+
next = elem.Next()
243+
node := elem.Value.(*trieEntry)
244+
245+
node.removeFromPeerEntries()
246+
node.peer = nil
247+
if node.child[0] != nil && node.child[1] != nil {
248+
continue
249+
}
250+
bit := 0
251+
if node.child[0] == nil {
252+
bit = 1
253+
}
254+
child := node.child[bit]
255+
if child != nil {
256+
child.parent = node.parent
257+
}
258+
*node.parent.parentBit = child
259+
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
260+
continue
261+
}
262+
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
263+
if parent.peer != nil {
264+
continue
265+
}
266+
child = parent.child[node.parent.parentBitType^1]
267+
if child != nil {
268+
child.parent = parent.parent
269+
}
270+
*parent.parent.parentBit = child
271+
}
266272
}
267273

268274
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {

device/allowedips_rand_test.go

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package device
77

88
import (
99
"math/rand"
10+
"net"
1011
"sort"
1112
"testing"
1213
)
@@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
6465
return nil
6566
}
6667

67-
func TestTrieRandomIPv4(t *testing.T) {
68-
var slow SlowRouter
69-
var peers []*Peer
70-
var allowedIPs AllowedIPs
71-
72-
rand.Seed(1)
73-
74-
const AddressLength = 4
75-
76-
for n := 0; n < NumberOfPeers; n++ {
77-
peers = append(peers, &Peer{})
78-
}
79-
80-
for n := 0; n < NumberOfAddresses; n++ {
81-
var addr [AddressLength]byte
82-
rand.Read(addr[:])
83-
cidr := uint8(rand.Uint32() % (AddressLength * 8))
84-
index := rand.Int() % NumberOfPeers
85-
allowedIPs.Insert(addr[:], cidr, peers[index])
86-
slow = slow.Insert(addr[:], cidr, peers[index])
87-
}
88-
89-
for n := 0; n < NumberOfTests; n++ {
90-
var addr [AddressLength]byte
91-
rand.Read(addr[:])
92-
peer1 := slow.Lookup(addr[:])
93-
peer2 := allowedIPs.LookupIPv4(addr[:])
94-
if peer1 != peer2 {
95-
t.Error("Trie did not match naive implementation, for:", addr)
68+
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
69+
n := 0
70+
for _, x := range r {
71+
if x.peer != peer {
72+
r[n] = x
73+
n++
9674
}
9775
}
76+
return r[:n]
9877
}
9978

100-
func TestTrieRandomIPv6(t *testing.T) {
101-
var slow SlowRouter
79+
func TestTrieRandom(t *testing.T) {
80+
var slow4, slow6 SlowRouter
10281
var peers []*Peer
10382
var allowedIPs AllowedIPs
10483

10584
rand.Seed(1)
10685

107-
const AddressLength = 16
108-
10986
for n := 0; n < NumberOfPeers; n++ {
11087
peers = append(peers, &Peer{})
11188
}
11289

11390
for n := 0; n < NumberOfAddresses; n++ {
114-
var addr [AddressLength]byte
115-
rand.Read(addr[:])
116-
cidr := uint8(rand.Uint32() % (AddressLength * 8))
117-
index := rand.Int() % NumberOfPeers
118-
allowedIPs.Insert(addr[:], cidr, peers[index])
119-
slow = slow.Insert(addr[:], cidr, peers[index])
91+
var addr4 [4]byte
92+
rand.Read(addr4[:])
93+
cidr := uint8(rand.Intn(32) + 1)
94+
index := rand.Intn(NumberOfPeers)
95+
allowedIPs.Insert(addr4[:], cidr, peers[index])
96+
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
97+
98+
var addr6 [16]byte
99+
rand.Read(addr6[:])
100+
cidr = uint8(rand.Intn(128) + 1)
101+
index = rand.Intn(NumberOfPeers)
102+
allowedIPs.Insert(addr6[:], cidr, peers[index])
103+
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
120104
}
121105

122-
for n := 0; n < NumberOfTests; n++ {
123-
var addr [AddressLength]byte
124-
rand.Read(addr[:])
125-
peer1 := slow.Lookup(addr[:])
126-
peer2 := allowedIPs.LookupIPv6(addr[:])
127-
if peer1 != peer2 {
128-
t.Error("Trie did not match naive implementation, for:", addr)
106+
for p := 0; ; p++ {
107+
for n := 0; n < NumberOfTests; n++ {
108+
var addr4 [4]byte
109+
rand.Read(addr4[:])
110+
peer1 := slow4.Lookup(addr4[:])
111+
peer2 := allowedIPs.LookupIPv4(addr4[:])
112+
if peer1 != peer2 {
113+
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
114+
}
115+
116+
var addr6 [16]byte
117+
rand.Read(addr6[:])
118+
peer1 = slow6.Lookup(addr6[:])
119+
peer2 = allowedIPs.LookupIPv6(addr6[:])
120+
if peer1 != peer2 {
121+
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
122+
}
123+
}
124+
if p >= len(peers) {
125+
break
129126
}
127+
allowedIPs.RemoveByPeer(peers[p])
128+
slow4 = slow4.RemoveByPeer(peers[p])
129+
slow6 = slow6.RemoveByPeer(peers[p])
130+
}
131+
132+
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
133+
t.Error("Failed to remove all nodes from trie by peer")
130134
}
131135
}

0 commit comments

Comments
 (0)