Skip to content

Commit

Permalink
Rehandshake
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Apr 14, 2023
1 parent 397fe5f commit a86fd5e
Show file tree
Hide file tree
Showing 14 changed files with 671 additions and 157 deletions.
242 changes: 219 additions & 23 deletions connection_manager.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
package nebula

import (
"bytes"
"context"
"sync"
"time"

"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)

type trafficDecision int

const (
doNothing trafficDecision = 0
deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote
swapPrimary trafficDecision = 3
migrateRelays trafficDecision = 4
)

type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex

out map[uint32]struct{}
outLock *sync.RWMutex

// relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex

hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
Expand All @@ -44,6 +60,8 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
Expand Down Expand Up @@ -84,6 +102,19 @@ func (n *connectionManager) Out(localIndex uint32) {
n.outLock.Unlock()
}

func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
}

// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
Expand Down Expand Up @@ -136,18 +167,130 @@ func (n *connectionManager) Run(ctx context.Context) {
}

func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
hostinfo, err := n.hostMap.QueryIndex(localIndex)
if err != nil {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)

switch decision {
case deleteTunnel:
n.hostMap.DeleteHostInfo(hostinfo)

case closeTunnel:
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)

case swapPrimary:
n.swapPrimary(hostinfo, primary)

case migrateRelays:
n.migrateRelayUsed(hostinfo, primary)
}

n.resetRelayTrafficCheck(hostinfo)
}

func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil {
n.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(n.relayUsed, idx)
}
}
}

func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor()

for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)

var index uint32
var relayFrom iputil.VpnIp
var relayTo iputil.VpnIp
switch {
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
continue
case ok && existing.State == Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayTo = existing.PeerIp
case ForwardingType:
relayFrom = existing.PeerIp
relayTo = newhostinfo.vpnIp
default:
// should never happen
}
case !ok:
n.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it.
n.relayUsedLock.RUnlock()
continue
}
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayTo = r.PeerIp
case ForwardingType:
relayFrom = r.PeerIp
relayTo = newhostinfo.vpnIp
default:
// should never happen
}
}

// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
RelayFromIp: uint32(relayFrom),
RelayToIp: uint32(relayTo),
}
msg, err := req.Marshal()
if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(req.RelayFromIp),
"relayTo": iputil.VpnIp(req.RelayToIp),
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
Info("send CreateRelayRequest")
}
}
}

func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()

hostinfo := n.hostMap.Indexes[localIndex]
if hostinfo == nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return
return doNothing, nil, nil
}

if n.handleInvalidCertificate(now, hostinfo) {
return
if n.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil
}

primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
primary := n.hostMap.Hosts[hostinfo.vpnIp]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
Expand All @@ -158,18 +301,22 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,

// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
delete(n.pendingDeletion, hostinfo.localIndexId)

if !mainHostInfo {
if hostinfo.vpnIp > n.intf.myVpnIp {
// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
// This the primary and prime the old primary hostinfo for testing
n.hostMap.MakePrimary(hostinfo)
if mainHostInfo {
n.tryRehandshake(hostinfo)
} else {
if n.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
decision = migrateRelays
}
}

Expand All @@ -180,7 +327,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
n.sendPunch(hostinfo)
}

return
return decision, hostinfo, primary
}

if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
Expand All @@ -189,9 +336,8 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")

n.hostMap.DeleteHostInfo(hostinfo)
delete(n.pendingDeletion, hostinfo.localIndexId)
return
return deleteTunnel, hostinfo, nil
}

hostinfo.logger(n.l).
Expand All @@ -204,7 +350,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
// Just maintain NAT state if configured to do so.
n.sendPunch(hostinfo)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return
return doNothing, nil, nil

}

Expand All @@ -218,22 +364,50 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return
return doNothing, nil, nil
}

// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)

} else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}

n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return doNothing, nil, nil
}

// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {

// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.

if current.vpnIp < n.intf.myVpnIp {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
return false
}

certState := n.intf.certState.Load()
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
}

func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary {
n.hostMap.unlockedMakePrimary(current)
}
n.hostMap.Unlock()
}

// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
if !n.intf.disconnectInvalid {
return false
}
Expand All @@ -253,10 +427,6 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
WithField("fingerprint", fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")

// Inform the remote and close the tunnel locally
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
delete(n.pendingDeletion, hostinfo.localIndexId)
return true
}

Expand All @@ -277,3 +447,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}

func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.certState.Load()
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
return
}

n.l.WithField("vpnIp", hostinfo.vpnIp).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")

//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}

//If this is a static host, we don't need to wait for the HostQueryReply
//We can trigger the handshake right now
if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
}
8 changes: 4 additions & 4 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
// Check if to disconnect with invalid certificate.
// Should be alive.
nextTick := now.Add(45 * time.Second)
destroyed := nc.handleInvalidCertificate(nextTick, hostinfo)
assert.False(t, destroyed)
invalid := nc.isInvalidCertificate(nextTick, hostinfo)
assert.False(t, invalid)

// Move ahead 61s.
// Check if to disconnect with invalid certificate.
// Should be disconnected.
nextTick = now.Add(61 * time.Second)
destroyed = nc.handleInvalidCertificate(nextTick, hostinfo)
assert.True(t, destroyed)
invalid = nc.isInvalidCertificate(nextTick, hostinfo)
assert.True(t, invalid)
}
14 changes: 14 additions & 0 deletions control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,17 @@ func (c *Control) GetHostmap() *HostMap {
func (c *Control) GetCert() *cert.NebulaCertificate {
return c.f.certState.Load().certificate
}

func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
ixHandshakeStage0(c.f, vpnIp, hostinfo)

// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
}
Loading

0 comments on commit a86fd5e

Please sign in to comment.