Skip to content

Commit

Permalink
NATNEG: Optimize connect process
Browse files Browse the repository at this point in the history
  • Loading branch information
mkwcat committed Feb 2, 2024
1 parent 0edfa96 commit 4c98fd2
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 76 deletions.
233 changes: 162 additions & 71 deletions natneg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,28 @@ const (
)

type NATNEGSession struct {
Version byte
Cookie uint32
Mutex sync.RWMutex
Clients map[byte]*NATNEGClient
}

type NATNEGClient struct {
Cookie uint32
Connected bool
NegotiateIP string
LocalIP string
ServerIP string
GameName string
Cookie uint32
Index byte
ConnectingIndex byte
ConnectAck bool
Connected map[byte]bool
NegotiateIP string
LocalIP string
ServerIP string
GameName string
}

var (
sessions = map[uint32]*NATNEGSession{}
mutex = sync.RWMutex{}
sessions = map[uint32]*NATNEGSession{}
mutex = sync.RWMutex{}
natnegConn net.PacketConn
)

func StartServer() {
Expand All @@ -85,6 +90,8 @@ func StartServer() {
panic(err)
}

natnegConn = conn

// Close the listener when the application closes.
defer conn.Close()
logging.Notice("NATNEG", "Listening on", address)
Expand Down Expand Up @@ -119,105 +126,116 @@ func handleConnection(conn net.PacketConn, addr net.Addr, buffer []byte) {

moduleName := "NATNEG:" + fmt.Sprintf("%08x/", cookie) + addr.String()

mutex.Lock()
session, exists := sessions[cookie]
if !exists {
logging.Info(moduleName, "Creating session")
session = &NATNEGSession{
Cookie: cookie,
Mutex: sync.RWMutex{},
Clients: map[byte]*NATNEGClient{},
var session *NATNEGSession

if command != NNNatifyRequest && command != NNAddressCheckRequest {
mutex.Lock()
var exists bool
session, exists = sessions[cookie]
if !exists {
logging.Info(moduleName, "Creating session")
session = &NATNEGSession{
Version: version,
Cookie: cookie,
Mutex: sync.RWMutex{},
Clients: map[byte]*NATNEGClient{},
}
sessions[cookie] = session

// Session has TTL of 30 seconds
time.AfterFunc(30*time.Second, func() {
mutex.Lock()
delete(sessions, cookie)
mutex.Unlock()

logging.Info(moduleName, "Deleted session")
})
}
sessions[cookie] = session
mutex.Unlock()

// Session has TTL of 30 seconds
time.AfterFunc(30*time.Second, func() {
mutex.Lock()
delete(sessions, cookie)
mutex.Unlock()
if session.Version != version {
logging.Error(moduleName, "Version mismatch")
return
}

logging.Info(moduleName, "Deleted session")
})
session.Mutex.Lock()
defer session.Mutex.Unlock()
}
mutex.Unlock()

session.Mutex.Lock()
defer session.Mutex.Unlock()

switch command {
default:
logging.Error(moduleName, "Received unknown command type:", aurora.Cyan(command))
break

case NNInitRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNInitRequest"))
// logging.Info(moduleName, "Command:", aurora.Yellow("NN_INIT"))
session.handleInit(conn, addr, buffer[12:], moduleName, version)
break

case NNInitReply:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNInitReply"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_INITACK"))
break

case NNErtTestRequest:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNErtTestRequest"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_ERTTEST"))
break

case NNErtTestReply:
logging.Info(moduleName, "Command:", aurora.Yellow("NNErtReply"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_ERTACK"))
break

case NNStateUpdate:
logging.Info(moduleName, "Command:", aurora.Yellow("NNStateUpdate"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_STATEUPDATE"))
break

case NNConnectRequest:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNConnectRequest"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_CONNECT"))
break

case NNConnectReply:
logging.Info(moduleName, "Command:", aurora.Yellow("NNConnectReply"))
// TODO: Set the client Connected value to true here
// logging.Info(moduleName, "Command:", aurora.Yellow("NN_CONNECT_ACK"))
session.handleConnectReply(conn, addr, buffer[12:], moduleName, version)
break

case NNConnectPing:
logging.Info(moduleName, "Command:", aurora.Yellow("NNConnectPing"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_CONNECT_PING"))
break

case NNBackupTestRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNBackupTestRequest"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_BACKUP_TEST"))
break

case NNBackupTestReply:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNBackupTestReply"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_BACKUP_ACK"))
break

case NNAddressCheckRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNAddressCheckRequest"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_ADDRESS_CHECK"))
break

case NNAddressCheckReply:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNAddressCheckReply"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_ADDRESS_REPLY"))
break

case NNNatifyRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNNatifyRequest"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_NATIFY_REQUEST"))
break

case NNReportRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNReportRequest"))
// logging.Info(moduleName, "Command:", aurora.Yellow("NN_REPORT"))
session.handleReport(conn, addr, buffer[12:], moduleName, version)
break

case NNReportReply:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNReportReply"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_REPORT_ACK"))
break

case NNPreInitRequest:
logging.Info(moduleName, "Command:", aurora.Yellow("NNPreInitRequest"))
logging.Info(moduleName, "Command:", aurora.Yellow("NN_PREINIT"))
break

case NNPreInitReply:
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NNPreInitReply"))
logging.Warn(moduleName, "Received server command:", aurora.Yellow("NN_PREINIT_ACK"))
break
}
}
Expand Down Expand Up @@ -265,9 +283,6 @@ func (session *NATNEGSession) handleInit(conn net.PacketConn, addr net.Addr, buf

localIPStr := fmt.Sprintf("%d.%d.%d.%d:%d", localIPBytes[0], localIPBytes[1], localIPBytes[2], localIPBytes[3], localPort)

logging.Info(moduleName, "Game Name:", aurora.Cyan(gameName), "Version:", aurora.Cyan(version), "Port Type:", aurora.Yellow(getPortTypeName(portType)), "Client Index:", aurora.Cyan(clientIndex), "Use Game Port:", aurora.Cyan(useGamePort))
logging.Info(moduleName, "Local IP:", aurora.Cyan(localIPStr))

if portType > 0x03 {
logging.Error(moduleName, "Invalid port type")
return
Expand All @@ -290,18 +305,27 @@ func (session *NATNEGSession) handleInit(conn net.PacketConn, addr net.Addr, buf
sender, exists := session.Clients[clientIndex]
if !exists {
logging.Notice(moduleName, "Creating client index", aurora.Cyan(clientIndex))

for _, other := range session.Clients {
if other.GameName != gameName {
logging.Error(moduleName, "Game name mismatch", aurora.Cyan(other.GameName), "!=", aurora.Cyan(gameName))
return
}
}

sender = &NATNEGClient{
Cookie: session.Cookie,
Connected: false,
NegotiateIP: "",
LocalIP: "",
ServerIP: "",
GameName: "",
Cookie: session.Cookie,
Index: clientIndex,
ConnectingIndex: clientIndex,
Connected: map[byte]bool{},
NegotiateIP: "",
LocalIP: "",
ServerIP: "",
GameName: "",
}
session.Clients[clientIndex] = sender
}

sender.Connected = false
sender.GameName = gameName

if portType != PortTypeGamePort {
Expand All @@ -317,24 +341,14 @@ func (session *NATNEGSession) handleInit(conn net.PacketConn, addr net.Addr, buf
if !sender.isMapped() {
return
}
logging.Info(moduleName, "Mapped", aurora.BrightCyan(sender.NegotiateIP), aurora.BrightCyan(sender.LocalIP), aurora.BrightCyan(sender.ServerIP))
// logging.Info(moduleName, "Mapped", aurora.BrightCyan(sender.NegotiateIP), aurora.BrightCyan(sender.LocalIP), aurora.BrightCyan(sender.ServerIP))

for id, destination := range session.Clients {
if id == clientIndex || destination.Connected || !destination.isMapped() {
continue
}

logging.Notice(moduleName, "Exchange connect requests")

// Send the requests back and forth
// TODO: Send again if no reply received from client
sender.sendConnectRequest(conn, destination, version)
destination.sendConnectRequest(conn, sender, version)
}
// Send the connect requests
session.sendConnectRequests(moduleName)
}

func (client *NATNEGClient) isMapped() bool {
if client.NegotiateIP == "" || client.LocalIP == "" || client.ServerIP == "" {
if client.NegotiateIP == "" || client.ServerIP == "" {
return false
}

Expand All @@ -346,7 +360,50 @@ func createPacketHeader(version byte, command byte, cookie uint32) []byte {
return binary.BigEndian.AppendUint32(header, cookie)
}

func (client *NATNEGClient) sendConnectRequest(conn net.PacketConn, destination *NATNEGClient, version byte) {
func (session *NATNEGSession) sendConnectRequests(moduleName string) {
for id, sender := range session.Clients {
if !sender.isMapped() || sender.ConnectingIndex != id {
continue
}

for destID, destination := range session.Clients {
if id == destID || !destination.isMapped() || destination.ConnectingIndex != destID || destination.Connected[id] {
continue
}

logging.Notice(moduleName, "Exchange connect requests between", aurora.BrightCyan(id), "and", aurora.BrightCyan(destID))
sender.ConnectingIndex = destID
sender.ConnectAck = false
destination.ConnectingIndex = id
destination.ConnectAck = false

go func(session *NATNEGSession, sender *NATNEGClient, destination *NATNEGClient) {
for {
check := false

if !destination.ConnectAck && destination.ConnectingIndex == sender.Index {
check = true
sender.sendConnectRequestPacket(natnegConn, destination, session.Version)
}

if !sender.ConnectAck && sender.ConnectingIndex == destination.Index {
check = true
destination.sendConnectRequestPacket(natnegConn, sender, session.Version)
}

if !check {
logging.Notice(moduleName, "No connect requests to send")
return
}

time.Sleep(1 * time.Second)
}
}(session, sender, destination)
}
}
}

func (client *NATNEGClient) sendConnectRequestPacket(conn net.PacketConn, destination *NATNEGClient, version byte) {
connectHeader := createPacketHeader(version, NNConnectRequest, destination.Cookie)
connectHeader = append(connectHeader, common.IPFormatBytes(client.ServerIP)...)
_, port := common.IPFormatToInt(client.ServerIP)
Expand All @@ -361,9 +418,43 @@ func (client *NATNEGClient) sendConnectRequest(conn net.PacketConn, destination
conn.WriteTo(connectHeader, destIPAddr)
}

func (session *NATNEGSession) handleConnectReply(conn net.PacketConn, addr net.Addr, buffer []byte, moduleName string, version byte) {
// portType := buffer[0]
clientIndex := buffer[1]
// useGamePort := buffer[2]
// localIPBytes := buffer[3:7]

session.Mutex.Lock()
defer session.Mutex.Unlock()

if client, exists := session.Clients[clientIndex]; exists {
client.ConnectAck = true
}
}

func (session *NATNEGSession) handleReport(conn net.PacketConn, addr net.Addr, buffer []byte, _ string, version byte) {
response := createPacketHeader(version, NNReportReply, session.Cookie)
response = append(response, buffer[:9]...)
response[14] = 0
conn.WriteTo(response, addr)

// portType := buffer[0]
clientIndex := buffer[1]
result := buffer[2]
// natType := buffer[3]
// mappingScheme := buffer[7]
// gameName, err := common.GetString(buffer[11:])

moduleName := "NATNEG:" + fmt.Sprintf("%08x/", session.Cookie) + addr.String()

if client, exists := session.Clients[clientIndex]; exists {
logging.Notice(moduleName, "Report from", aurora.BrightCyan(clientIndex), "result:", aurora.Cyan(result))

client.Connected[client.ConnectingIndex] = true
client.ConnectingIndex = clientIndex
client.ConnectAck = false
}

// Send remaining requests
session.sendConnectRequests(moduleName)
}
Loading

0 comments on commit 4c98fd2

Please sign in to comment.