Skip to content

Commit

Permalink
Connect tests, initial RunCommand
Browse files Browse the repository at this point in the history
  • Loading branch information
dsnidr committed Feb 1, 2022
1 parent 8829bb0 commit 6b6d0ad
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 12 deletions.
87 changes: 83 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"encoding/binary"
"fmt"
"net"
"sync"
"time"

"github.com/pkg/errors"
"github.com/refractorgscm/rcon2/packet"
)

type Client struct {
Expand All @@ -15,6 +17,8 @@ type Client struct {
conn *net.TCPConn

log Logger

cmdMutex sync.Mutex
}

type Config struct {
Expand All @@ -33,6 +37,8 @@ type Config struct {

ReadDeadline time.Duration
WriteDeadline time.Duration

RestrictedPacketIDs []int32
}


Expand All @@ -41,6 +47,7 @@ var DefaultConfig = &Config{
EndianMode: binary.LittleEndian,
ReadDeadline: time.Second*2,
WriteDeadline: time.Second*2,
RestrictedPacketIDs: []int32{},
}

func NewClient(host string, port uint16, password string) *Client {
Expand All @@ -49,15 +56,35 @@ func NewClient(host string, port uint16, password string) *Client {
config.Port = port
config.Password = password

return &Client{
config: config,
}
return NewClientFromConfig(DefaultConfig)
}

func NewClientFromConfig(config *Config) *Client {
return &Client{
c := &Client{
config: config,
}

if c.log == nil {
c.log = &DefaultLogger{}
}

if c.config.EndianMode == nil {
c.config.EndianMode = binary.LittleEndian
}

if c.config.ConnTimeout == 0 {
c.config.ConnTimeout = DefaultConfig.ConnTimeout
}

if c.config.ReadDeadline == 0 {
c.config.ConnTimeout = DefaultConfig.ReadDeadline
}

if c.config.WriteDeadline == 0 {
c.config.ConnTimeout = DefaultConfig.WriteDeadline
}

return c
}

func (c *Client) SetLogger(logger Logger) {
Expand All @@ -77,5 +104,57 @@ func (c *Client) Connect() error {
return errors.Wrap(err, "tcp dial error")
}

if err := c.authenticate(); err != nil {
c.log.Debug("Authentication failed", err)
return err
}

return nil
}

func (c *Client) authenticate() error {
p := c.newPacket(packet.TypeAuth, c.config.Password)

if err := c.sendPacket(p); err != nil {
return errors.Wrap(err, "could not send packet")
}

res, err := c.readPacketTimeout()
if err != nil {
return errors.Wrap(err, "could not get auth response")
}

if res.Type != packet.TypeAuthRes {
return errors.New("packet was not of the type auth response")
}

if res.ID == packet.AuthFailedID {
return errors.Wrap(ErrAuthentication, "authentication failed")
}

c.log.Debug("Authenticated")

return nil
}

func (c *Client) RunCommand(cmd string) (string, error) {
c.cmdMutex.Lock()
defer c.cmdMutex.Unlock()

p := c.newPacket(packet.TypeCommand, cmd)

if err := c.sendPacket(p); err != nil {
return "", err
}

res, err := c.readPacket()
if err != nil {
return "", err
}

return string(res.Body), nil
}

func (c *Client) newPacket(pType packet.PacketType, body string) *packet.Packet {
return packet.New(c.config.EndianMode, pType, []byte(body), c.config.RestrictedPacketIDs)
}
24 changes: 24 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package rcon

import (
"encoding/binary"
"testing"

"github.com/refractorgscm/rcon2/fakeserver"
"github.com/stretchr/testify/assert"
)

Expand All @@ -26,4 +28,26 @@ func TestNewClientFromConfig(t *testing.T) {
assert.Equal(t, "localhost", client.config.Host)
assert.Equal(t, uint16(1234), client.config.Port)
assert.Equal(t, "suPerSecure123", client.config.Password)
}

func TestClient_Connect(t *testing.T) {
fakeServer := fakeserver.New(9898, binary.LittleEndian)
go fakeServer.Listen()

client := NewClient("localhost", 9898, "suPerSecure123")
err := client.Connect()
assert.Nil(t, err)
}

func TestClient_RunCommand(t *testing.T) {
fakeServer := fakeserver.New(9899, binary.LittleEndian)
go fakeServer.Listen()

client := NewClient("localhost", 9899, "suPerSecure123")
err := client.Connect()
assert.Nil(t, err)

res, err := client.RunCommand("help")
assert.Nil(t, err)
assert.Equal(t, "firstplayer\notherplayer\nlastplayer", res)
}
84 changes: 84 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package rcon

import (
"bufio"
"strings"
"time"

"github.com/pkg/errors"
"github.com/refractorgscm/rcon2/packet"
)

func (c *Client) sendPacket(p *packet.Packet) error {
out, err := p.Encode()
if err != nil {
return errors.Wrap(err, "could not encode packet")
}

if err := c.writeToConn(out); err != nil {
return errors.Wrap(err, "could not send packet")
}

return nil
}

func (c *Client) readPacket() (*packet.Packet, error) {
if c.conn == nil {
return nil, ErrNotConnected
}

if err := c.conn.SetDeadline(time.Time{}); err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil, ErrNotConnected
}

return nil, errors.Wrap(err, "could not set connection deadline")
}

reader := bufio.NewReader(c.conn)

res, err := packet.Decode(c.config.EndianMode, reader)
if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil, ErrNotConnected
}

return nil, errors.Wrap(err, "could not read packet")
}

c.log.Debug("Read packet ID: ", res.ID, ", Body: ", string(res.Body))

return res, nil
}

func (c *Client) readPacketTimeout() (*packet.Packet, error) {
if c.conn == nil {
return nil, ErrNotConnected
}

if err := c.conn.SetDeadline(time.Now().Add(c.config.ReadDeadline)); err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil, ErrNotConnected
}

return nil, errors.Wrap(err, "could not set connection deadline")
}

reader := bufio.NewReader(c.conn)

res, err := packet.Decode(c.config.EndianMode, reader)
if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil, ErrNotConnected
}

return nil, errors.Wrap(err, "could not read packet")
}

return res, nil
}

func (c *Client) writeToConn(data []byte) error {
_, err := c.conn.Write(data)
return err
}
8 changes: 8 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package rcon

import "errors"

var ErrNotConnected = errors.New("not connected")
var ErrAuthentication = errors.New("authentication failed")
var ErrQueueTimeout = errors.New("queue timeout")
var ErrReadTimeout = errors.New("read timeout")
109 changes: 109 additions & 0 deletions fakeserver/fakeserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package fakeserver

import (
"bufio"
"context"
"encoding/binary"
"fmt"
"net"

"github.com/refractorgscm/rcon2/packet"
)

var AuthShouldFail = false

type Server struct {
port string
mode binary.ByteOrder
ctx context.Context
terminate bool
}

func New(port uint16, mode binary.ByteOrder) *Server {
portStr := fmt.Sprintf(":%d", port)

return &Server{
port: portStr,
mode: mode,
}
}

func (s *Server) Listen() {
l, err := net.Listen("tcp4", s.port)
if err != nil {
panic(err)
}
defer l.Close()

fmt.Printf("Fake server listening on port %s\n", s.port)

for {
if s.terminate {
break
}

c, err := l.Accept()
if err != nil {
panic(err)
}

go s.handleConnection(c)
}
}

func (s *Server) Stop() {
s.terminate = true
}

func (s *Server) handleConnection(conn net.Conn) {
for {
reader := bufio.NewReader(conn)

p, err := packet.Decode(s.mode, reader)
if err != nil {
panic(err)
}

fmt.Println(p)

var res *packet.Packet

switch p.Type {
case packet.TypeAuth:
res = s.handleAuthReq(p)
case packet.TypeCommand:
switch string(p.Body) {
case "help":
res = packet.New(p.Mode, packet.TypeCommandRes, []byte("firstplayer\notherplayer\nlastplayer"), nil)
default:
res = packet.New(p.Mode, packet.TypeCommandRes, []byte("Unknown command"), nil)
}
}

out, _ := res.Encode()

printBytes(out)

conn.Write(out)
}
}

func (s *Server) handleAuthReq(p *packet.Packet) *packet.Packet {
res := packet.New(s.mode, packet.TypeAuthRes, []byte{}, []int32{})
res.ID = p.ID

if AuthShouldFail {
res.ID = packet.AuthFailedID
return res
}

return res
}

func printBytes(arr []byte) {
fmt.Printf("Bytes (%d): ", len(arr))
for _, b := range arr {
fmt.Printf("%x ", b)
}
fmt.Print("\n")
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module github.com/refractorgscm/rcon
module github.com/refractorgscm/rcon2

go 1.17

Expand Down
Loading

0 comments on commit 6b6d0ad

Please sign in to comment.