Skip to content

Commit df4d9a0

Browse files
committed
ci: add, improve unit tests
1 parent 4499102 commit df4d9a0

5 files changed

Lines changed: 299 additions & 8 deletions

File tree

.github/workflows/coverage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Get code coverage
2424
run: |
2525
curl -sSfL https://git.io/GeoLite2-City.mmdb -o /tmp/GeoLite2-City.mmdb && \
26-
go test -coverprofile=coverage.out -covermode=atomic ./...
26+
go test -coverprofile=coverage.out -covermode=atomic ./internal/...
2727
- name: Upload coverage reports to Codecov
2828
uses: codecov/codecov-action@v5
2929
with:

internal/geoip/geoip_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,20 @@ func newReturnsErrorIfDatabaseIsInvalid(t *testing.T) {
2121
}
2222

2323
func newReturnsInstanceIfDatabaseIsValid(t *testing.T) {
24-
geoIP, err := NewGeoIP(geoDatabasePath)
24+
geo, err := NewGeoIP(geoDatabasePath)
2525
assert.NoError(t, err)
26-
assert.NotNil(t, geoIP)
27-
assert.IsType(t, &GeoIP{}, geoIP)
28-
assert.Equal(t, geoIP.Database, geoDatabasePath)
29-
assert.IsType(t, geoIP.Reader, &geoip2.Reader{})
26+
assert.NotNil(t, geo)
27+
defer geo.Close()
28+
assert.IsType(t, &GeoIP{}, geo)
29+
assert.Equal(t, geo.Database, geoDatabasePath)
30+
assert.IsType(t, geo.Reader, &geoip2.Reader{})
3031
}
3132

3233
func locationReturnsNilIfAddressIsUnresolved(t *testing.T) {
3334
geo, err := NewGeoIP(geoDatabasePath)
3435
assert.NoError(t, err)
3536
assert.NotNil(t, geo)
37+
defer geo.Close()
3638

3739
for ipStr, desc := range map[string]string{
3840
"::1": "local address",
@@ -50,6 +52,7 @@ func locationReturnsLocationIfAddressIsResolved(t *testing.T) {
5052
geo, err := NewGeoIP(geoDatabasePath)
5153
assert.NoError(t, err)
5254
assert.NotNil(t, geo)
55+
defer geo.Close()
5356

5457
ip, _ := netip.ParseAddr("63.176.75.230")
5558
location := geo.Location(ip)

internal/profiler/profiler_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
Copyright (c) 2025 Tobias Schäfer. All rights reserved.
3+
Licensed under the MIT License, see LICENSE file in the project root for details.
4+
*/
5+
package profiler
6+
7+
import (
8+
"fmt"
9+
"log/slog"
10+
"net"
11+
"net/url"
12+
"testing"
13+
14+
"github.com/grafana/pyroscope-go"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/tschaefer/conntrackd/internal/logger"
17+
)
18+
19+
func __setupLogger(t *testing.T, level string) {
20+
logger, err := logger.NewLogger(level)
21+
if err != nil {
22+
t.Fatalf("failed to create logger: %v", err)
23+
}
24+
slog.SetDefault(logger.Logger)
25+
}
26+
27+
func __address(t *testing.T) string {
28+
for port := 4096; port <= 65535; port++ {
29+
c, _ := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
30+
if c != nil {
31+
c.Close()
32+
continue
33+
}
34+
return fmt.Sprintf("http://localhost:%d", port)
35+
}
36+
t.Fatalf("failed to find free port")
37+
38+
return ""
39+
}
40+
41+
func newReturnsProfiler(t *testing.T) {
42+
__setupLogger(t, "info")
43+
address := __address(t)
44+
45+
profiler := NewProfiler(address)
46+
assert.NotNil(t, profiler)
47+
assert.Equal(t, address, profiler.Config.ServerAddress)
48+
assert.Nil(t, profiler.Config.Logger)
49+
assert.Nil(t, profiler.Instance)
50+
assert.Equal(t, "github.com/tschaefer/conntrackd", profiler.Config.ApplicationName)
51+
}
52+
53+
func newSetsLoggerIfLogLevelIsDebug(t *testing.T) {
54+
__setupLogger(t, "debug")
55+
address := __address(t)
56+
57+
profiler := NewProfiler(address)
58+
assert.NotNil(t, profiler)
59+
assert.Equal(t, address, profiler.Config.ServerAddress)
60+
assert.IsType(t, pyroscope.StandardLogger, profiler.Config.Logger)
61+
assert.Nil(t, profiler.Instance)
62+
assert.Equal(t, "github.com/tschaefer/conntrackd", profiler.Config.ApplicationName)
63+
}
64+
65+
func startSetsInstance(t *testing.T) {
66+
__setupLogger(t, "info")
67+
address := __address(t)
68+
69+
profiler := NewProfiler(address)
70+
err := profiler.Start()
71+
if err != nil {
72+
t.Fatalf("failed to start profiler: %v", err)
73+
}
74+
assert.NotNil(t, profiler.Instance)
75+
assert.IsType(t, &pyroscope.Profiler{}, profiler.Instance)
76+
defer profiler.Stop()
77+
}
78+
79+
func startReturnsErrorIfAddressIsInvalid(t *testing.T) {
80+
__setupLogger(t, "info")
81+
address := "http://invalid:address"
82+
83+
profiler := NewProfiler(address)
84+
err := profiler.Start()
85+
assert.NotNil(t, err)
86+
assert.Nil(t, profiler.Instance)
87+
assert.Error(t, err)
88+
assert.IsType(t, &url.Error{}, err)
89+
}
90+
91+
func TestProfiler(t *testing.T) {
92+
t.Run("profiler.New returns Profiler", newReturnsProfiler)
93+
t.Run("profiler.New sets logger if log level is debug", newSetsLoggerIfLogLevelIsDebug)
94+
t.Run("profiler.Start sets Instance", startSetsInstance)
95+
t.Run("profiler.Start returns error if address is invalid", startReturnsErrorIfAddressIsInvalid)
96+
}

internal/service/service_test.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
Copyright (c) 2025 Tobias Schäfer. All rights reserved.
3+
Licensed under the MIT license, see LICENSE in the project root for details.
4+
*/
5+
package service
6+
7+
import (
8+
"bytes"
9+
"context"
10+
"log/slog"
11+
"net/netip"
12+
"syscall"
13+
"testing"
14+
"time"
15+
16+
"github.com/stretchr/testify/assert"
17+
"github.com/ti-mo/conntrack"
18+
"github.com/tschaefer/conntrackd/internal/filter"
19+
"github.com/tschaefer/conntrackd/internal/logger"
20+
"github.com/tschaefer/conntrackd/internal/sink"
21+
)
22+
23+
func __setupSinkAndLogger(t *testing.T) (*sink.Sink, *logger.Logger, *bytes.Buffer) {
24+
var record bytes.Buffer
25+
target := slog.New(slog.NewTextHandler(&record, &slog.HandlerOptions{
26+
Level: slog.LevelInfo,
27+
}))
28+
29+
logger, err := logger.NewLogger("info")
30+
if err != nil {
31+
t.Fatalf("failed to create logger: %v", err)
32+
}
33+
34+
sink := &sink.Sink{Logger: target}
35+
36+
return sink, logger, &record
37+
}
38+
39+
func __createEvent(proto uint8) conntrack.Event {
40+
flow := conntrack.NewFlow(
41+
proto,
42+
conntrack.StatusAssured,
43+
netip.MustParseAddr("9.0.0.1"),
44+
netip.MustParseAddr("7.8.8.8"),
45+
12344, 80,
46+
59, 0,
47+
)
48+
return conntrack.Event{Flow: &flow}
49+
}
50+
51+
func newReturnsService(t *testing.T) {
52+
logger, err := logger.NewLogger("debug")
53+
if err != nil {
54+
t.Fatalf("failed to create logger: %v", err)
55+
}
56+
57+
svc, err := NewService(logger, nil, nil, nil)
58+
assert.NoError(t, err)
59+
assert.NotNil(t, svc)
60+
}
61+
62+
func processEventDoesNotRecordIfEventNotTCPorUDP(t *testing.T) {
63+
sink, logger, record := __setupSinkAndLogger(t)
64+
svc, err := NewService(logger, nil, nil, sink)
65+
if err != nil {
66+
t.Fatalf("failed to create service: %v", err)
67+
}
68+
69+
event := __createEvent(syscall.IPPROTO_ICMP)
70+
svc.processEvent(event)
71+
assert.Len(t, record.String(), 0, "No log output expected for non-TCP/UDP event")
72+
}
73+
74+
func processEventDoesRecordIfEventTCPorUDP(t *testing.T) {
75+
sink, logger, record := __setupSinkAndLogger(t)
76+
svc, err := NewService(logger, nil, nil, sink)
77+
if err != nil {
78+
t.Fatalf("failed to create service: %v", err)
79+
}
80+
81+
event := __createEvent(syscall.IPPROTO_TCP)
82+
svc.processEvent(event)
83+
assert.Greater(t, len(record.String()), 0, "Log output expected for TCP event")
84+
85+
record.Reset()
86+
event = __createEvent(syscall.IPPROTO_UDP)
87+
svc.processEvent(event)
88+
assert.Greater(t, len(record.String()), 0, "Log output expected for UDP event")
89+
}
90+
91+
func processEventDoesNotRecordIfFilteredOut(t *testing.T) {
92+
sink, logger, record := __setupSinkAndLogger(t)
93+
filter, err := filter.NewFilter([]string{"drop any"})
94+
if err != nil {
95+
t.Fatalf("failed to create filter: %v", err)
96+
}
97+
svc, err := NewService(logger, nil, filter, sink)
98+
if err != nil {
99+
t.Fatalf("failed to create service: %v", err)
100+
}
101+
102+
event := __createEvent(syscall.IPPROTO_TCP)
103+
svc.processEvent(event)
104+
assert.Len(t, record.String(), 0, "No log output expected for filtered out event")
105+
}
106+
107+
func startEventProcessorStartsGoroutine(t *testing.T) {
108+
sink, logger, _ := __setupSinkAndLogger(t)
109+
svc, err := NewService(logger, nil, nil, sink)
110+
if err != nil {
111+
t.Fatalf("failed to create service: %v", err)
112+
}
113+
114+
evCh := make(chan conntrack.Event)
115+
ctx, cancel := context.WithCancel(context.Background())
116+
defer cancel()
117+
g := svc.startEventProcessor(ctx, evCh)
118+
assert.NotNil(t, g, "Errgroup expected to be returned")
119+
}
120+
121+
func startEventProcessorDoesRecordOnEvent(t *testing.T) {
122+
sink, logger, record := __setupSinkAndLogger(t)
123+
svc, err := NewService(logger, nil, nil, sink)
124+
if err != nil {
125+
t.Fatalf("failed to create service: %v", err)
126+
}
127+
128+
evCh := make(chan conntrack.Event)
129+
ctx, cancel := context.WithCancel(context.Background())
130+
defer cancel()
131+
g := svc.startEventProcessor(ctx, evCh)
132+
133+
event := __createEvent(syscall.IPPROTO_TCP)
134+
evCh <- event
135+
136+
time.Sleep(100 * time.Millisecond)
137+
cancel()
138+
g.Wait()
139+
140+
assert.Greater(t, len(record.String()), 0, "Log output expected for processed event")
141+
}
142+
143+
func TestService(t *testing.T) {
144+
t.Run("service.New returns service", newReturnsService)
145+
t.Run("service.processEvent does not record if event not TCP or UDP", processEventDoesNotRecordIfEventNotTCPorUDP)
146+
t.Run("service.processEvent does record if event TCP or UDP", processEventDoesRecordIfEventTCPorUDP)
147+
t.Run("service.processEvent does not record if filtered out", processEventDoesNotRecordIfFilteredOut)
148+
t.Run("service.startEventProcessor starts goroutine", startEventProcessorStartsGoroutine)
149+
t.Run("service.startEventProcessor does record on event", startEventProcessorDoesRecordOnEvent)
150+
}

internal/version/version_test.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,28 @@ Licensed under the MIT license, see LICENSE in the project root for details.
55
package version
66

77
import (
8+
"os"
89
"testing"
910

1011
"github.com/stretchr/testify/assert"
1112
)
1213

14+
func __capture(f func()) string {
15+
originalStdout := os.Stdout
16+
17+
r, w, _ := os.Pipe()
18+
os.Stdout = w
19+
20+
f()
21+
22+
_ = w.Close()
23+
os.Stdout = originalStdout
24+
25+
var buf = make([]byte, 5096)
26+
n, _ := r.Read(buf)
27+
return string(buf[:n])
28+
}
29+
1330
func releaseReturnsDevWhenVersionIsEmpty(t *testing.T) {
1431
Version = ""
1532
expected := "dev"
@@ -23,8 +40,8 @@ func commitReturnsEmptyStringWhenGitCommitIsEmpty(t *testing.T) {
2340
}
2441

2542
func releaseReturnsVersionWhenVersionIsSet(t *testing.T) {
26-
Version = "1.0.0"
27-
expected := "1.0.0"
43+
Version = "v1.0.0"
44+
expected := "v1.0.0"
2845
assert.Equal(t, expected, Release())
2946
}
3047

@@ -34,9 +51,34 @@ func commitReturnsCommitHashWhenGitCommitIsSet(t *testing.T) {
3451
assert.Equal(t, expected, Commit())
3552
}
3653

54+
func bannerReturnsLogo(t *testing.T) {
55+
expected := `
56+
_ _ _
57+
___ ___ _ __ _ __ | |_ _ __ ____ ___| | ____| |
58+
/ __/ _ \| '_ \| '_ \| __| '__/ _ |/ __| |/ / _ |
59+
| (_| (_) | | | | | | | |_| | | (_| | (__| < (_| |
60+
\___\___/|_| |_|_| |_|\__|_| \__,_|\___|_|\_\__,_|
61+
`
62+
assert.Equal(t, expected, Banner())
63+
assert.Len(t, Banner(), 266)
64+
}
65+
66+
func printReturnsVersionAndCommit(t *testing.T) {
67+
Version = "v1.0.0"
68+
GitCommit = "f98352c5101f5097c183cb667401a4f459dc7221"
69+
70+
output := __capture(func() {
71+
Print()
72+
})
73+
assert.Contains(t, output, "Release: "+Release())
74+
assert.Contains(t, output, "Commit: "+Commit())
75+
}
76+
3777
func TestVersion(t *testing.T) {
3878
t.Run("version.Release returns 'dev' when Version is empty", releaseReturnsDevWhenVersionIsEmpty)
3979
t.Run("version.Commit returns empty string when GitCommit is empty", commitReturnsEmptyStringWhenGitCommitIsEmpty)
4080
t.Run("version.Release returns Version when Version is set", releaseReturnsVersionWhenVersionIsSet)
4181
t.Run("version.Commit returns commit hash when GitCommit is set", commitReturnsCommitHashWhenGitCommitIsSet)
82+
t.Run("version.Banner returns the correct logo", bannerReturnsLogo)
83+
t.Run("version.Print prints logo, version and commit", printReturnsVersionAndCommit)
4284
}

0 commit comments

Comments
 (0)