diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 79299d2..c81226e 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -32,4 +32,7 @@ jobs:
- name: Run linters
run: golangci-lint run
- name: Run tests
- run: go test -v ./...
+ run: |
+ curl -sSfL https://git.io/GeoLite2-City.mmdb -o /tmp/GeoLite2-City.mmdb && \
+ export KEEP_GEOIP_DB=true && \
+ go test -v ./...
diff --git a/README.md b/README.md
index 13fc43d..f855643 100644
--- a/README.md
+++ b/README.md
@@ -41,43 +41,74 @@ sudo conntrackd run --sink.journal.enable
```
For further configuration, see the command-line options below.
-| Flag | Description | Options |
-|--------------------------------|------------------------------------|----------------------------------------|
-| `filter.destinations` | Filter by destination networks | PUBLIC,PRIVATE,LOCAL,MULTICAST |
-| `filter.sourcess` | Filter by source networks | PUBLIC,PRIVATE,LOCAL,MULTICAST |
-| `filter.protocols` | Filter by protocols | TCP,UDP |
-| `filter.types` | Filter by event types | NEW,UPDATE,DESTROY |
-| `filter.destination.addresses` | Filter by destination IP addresses | |
-| `filter.source.addresses` | Filter by source IP addresses | |
-| `filter.destination.ports` | Filter by destination ports | |
-| `filter.source.ports` | Filter by source ports | |
-| `geoip.database` | Path to GeoIP database | |
-| `service.log.format` | Log format | json,text; default: text |
-| `service.log.level` | Log level | trace,debug,info,error; default: info |
-| `sink.journal.enable` | Enable journald sink | |
-| `sink.syslog.enable` | Enable syslog sink | |
-| `sink.enable.loki` | Enable Loki sink | |
-| `sink.stream.enable` | Enable stream sink | |
-| `sink.syslog.address` | Syslog address | default: udp://localhost:514 |
-| `sink.loki.address` | Loki address | default: http://localhost:3100 |
-| `sink.loki.labels` | Loki labels | comma seperated key=value pairs |
-| `sink.stream.writer` | Stream writer type | stdout,stderr,discard; default: stdout |
-
-All filters are exclusive; if any filter is not set, all related events are processed.
-
-Example run:
+## Filtering
+
+conntrackd logs conntrack events to various sinks.
+
+**Protocol Support:** Only TCP and UDP events are processed. All other protocols
+(ICMP, IGMP, etc.) are automatically ignored and never logged, regardless of filter rules.
+
+You can use filters to control which TCP/UDP events are logged using a
+Domain-Specific Language (DSL).
+The `--filter` flag lets you specify filter rules:
```bash
sudo conntrackd run \
- --geoip.database /usr/local/share/GeoLite2-City.mmdb \
- --filter.destination PRIVATE \
- --filter.protocol UDP \
- --filter.destination.addresses 142.250.186.163,2a00:1450:4001:82b::2003
- --sink.journal.enable \
- --service.log.format json \
- --service.log.level debug
+ --filter "drop destination address 8.8.8.8" \
+ --filter "log protocol TCP and destination network PUBLIC" \
+ --filter "drop any" \
+ --sink.journal.enable
+```
+
+**Filter Rules:**
+- Rules are evaluated in order (first-match wins)
+- Events are **logged by default** when no rule matches
+- `--filter` flag can be repeated for multiple rules
+- Use `drop any` as a final rule to block all non-matching events from being logged
+
+**Important:** Filters control which conntrack events are **logged**,
+not network traffic. Traffic always flows normally; filters only affect logging.
+
+**Common Filter Examples:**
+
+```bash
+# Don't log events to a specific IP
+--filter "drop destination address 8.8.8.8"
+
+# Log only NEW TCP connections (deny everything else)
+--filter "log type NEW and protocol TCP"
+--filter "drop any"
+
+# Don't log DNS to specific server
+--filter "drop destination address 10.19.80.100 on port 53"
+
+# Don't log any traffic to private networks
+--filter "drop destination network PRIVATE"
+
+# Log only traffic from public IPs using TCP
+--filter "log source network PUBLIC and protocol TCP"
+--filter "drop any"
```
+See [docs/filter.md](docs/filter.md) for complete DSL documentation,
+including grammar, operators, and advanced examples.
+
+## Configuration Flags
+
+| Flag | Description | Default |
+|-------------------------|---------------------------------------------------|--------------------------|
+| `--filter` | Filter rule in DSL format (repeatable) | |
+| `--geoip.database` | Path to GeoIP database | |
+| `--service.log.level` | Log level (debug, info, warn, error) | info |
+| `--sink.journal.enable` | Enable journald sink | |
+| `--sink.syslog.enable` | Enable syslog sink | |
+| `--sink.loki.enable` | Enable Loki sink | |
+| `--sink.stream.enable` | Enable stream sink | |
+| `--sink.syslog.address` | Syslog address | udp://localhost:514 |
+| `--sink.loki.address` | Loki address | http://localhost:3100 |
+| `--sink.loki.labels` | Loki labels (comma-separated key=value pairs) | |
+| `--sink.stream.writer` | Stream writer (stdout, stderr, discard) | stdout |
+
## Logging format
conntrackd emits structured logs for each conntrack event. A typical log entry
@@ -100,7 +131,8 @@ GEO location fields:
- lat (latitude)
- lon (longitude)
-Example log entry recorded by sink `syslog`:
+
+Example log entry recorded by sink `syslog`
```json
{
@@ -117,16 +149,18 @@ Example log entry recorded by sink `syslog`:
"level": "INFO",
"logger.name": "samber/slog-syslog",
"logger.version": "v2.5.2",
- "message": "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/41348...",
+ "message": "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd...",
"timestamp": "2025-11-15T09:55:25.647544937Z"
}
```
+
-Example log entry recorded by sink `journal`:
+
+Example log entry recorded by sink `journal`
```json
{
- "__CURSOR" : "s=b3c7821dbfce47a59b06797aea9028ca;i=6772d3;b=100da27bd8...",
+ "__CURSOR" : "s=b3c7821dbfce47a59b06797aea9028ca;i=6772d3;b=100da27bd...",
"_CAP_EFFECTIVE" : "1ffffffffff",
"EVENT_SPORT" : "39790",
"_SOURCE_REALTIME_TIMESTAMP" : "1763200187611509",
@@ -154,7 +188,7 @@ Example log entry recorded by sink `journal`:
"_SYSTEMD_INVOCATION_ID" : "021760b3373342b98aaeabf9d12d8d74",
"EVENT_FLOW" : "3478798157",
"_PID" : "3794900",
- "_CMDLINE" : "conntrackd run --service.log.level debug --service.log.format ...",
+ "_CMDLINE" : "conntrackd run --service.log.level debug --service.log....",
"EVENT_PROT" : "TCP",
"_AUDIT_SESSION" : "1",
"_BOOT_ID" : "100da27bd8b94096b5c80cdac34d6063",
@@ -164,12 +198,14 @@ Example log entry recorded by sink `journal`:
"_AUDIT_LOGINUID" : "1000",
"_UID" : "0",
"EVENT_TYPE" : "UPDATE",
- "MESSAGE" : "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/39790..."
+ "MESSAGE" : "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fe..."
}
```
+
-Example log entry recorded by sink `loki`:
+
+Example log entry recorded by sink `loki`
```json
{
@@ -194,11 +230,33 @@ Example log entry recorded by sink `loki`:
"values": [
[
"1763537351540294198",
- "UPDATE TCP connection from 2003:cf:1716:7b64:d6e9:8aff:fe4f:7a59/44950..."
+ "UPDATE TCP connection from 2003:cf:1716:7b64:d6e9:8aff:fe4f:7a59/44..."
]
]
}
```
+
+
+
+Example log entry recorded by sink `stream`
+
+```json
+{
+ "time": "2025-11-22T11:34:43.181432081+01:00",
+ "level": "INFO",
+ "msg": "NEW TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/4...",
+ "type": "NEW",
+ "flow": 2899284024,
+ "prot": "TCP",
+ "src": "2003:cf:1716:7b64:da80:83ff:fecd:da51",
+ "dst": "2a01:4f8:160:5372::2",
+ "sport": 41220,
+ "dport": 80,
+ "state": "SYN_SENT"
+}
+```
+
+
## Security Notes
@@ -212,7 +270,8 @@ Example log entry recorded by sink `loki`:
Contributions are welcome! Please fork the repository and submit a pull request.
For major changes, open an issue first to discuss what you would like to change.
-Ensure that your code adheres to the existing style and includes appropriate tests.
+Ensure that your code adheres to the existing style and includes appropriate
+tests.
## License
diff --git a/cmd/run.go b/cmd/run.go
index c313bd6..3ace724 100644
--- a/cmd/run.go
+++ b/cmd/run.go
@@ -5,81 +5,70 @@ Licensed under the MIT License, see LICENSE file in the project root for details
package cmd
import (
+ "context"
"fmt"
- "net/netip"
- "net/url"
"os"
- "slices"
+ "os/signal"
"strings"
+ "syscall"
"github.com/spf13/cobra"
+ "github.com/tschaefer/conntrackd/internal/filter"
+ "github.com/tschaefer/conntrackd/internal/geoip"
+ "github.com/tschaefer/conntrackd/internal/logger"
"github.com/tschaefer/conntrackd/internal/service"
+ "github.com/tschaefer/conntrackd/internal/sink"
)
-var srv = service.Service{}
-
-var (
- validEventTypes = []string{"NEW", "UPDATE", "DESTROY"}
- validProtocols = []string{"TCP", "UDP"}
- validNetworks = []string{"PUBLIC", "PRIVATE", "LOCAL", "MULTICAST"}
- validLogLevels = []string{"trace", "debug", "info", "error"}
- validLogFormats = []string{"text", "json"}
- validSyslogSchemes = []string{"udp", "tcp", "unix", "unixgram", "unixpacket"}
- validLokiSchemes = []string{"http", "https"}
- validStreamWriters = []string{"stdout", "stderr", "discard"}
-)
+type Options struct {
+ logLevel string
+ geoipDatabase string
+ filterRules []string
+ sink sink.Config
+}
+
+var options = Options{}
var runCmd = &cobra.Command{
Use: "run",
Short: "Run the conntrackd service",
Run: func(cmd *cobra.Command, args []string) {
- if !srv.Sink.Journal.Enable &&
- !srv.Sink.Syslog.Enable &&
- !srv.Sink.Loki.Enable &&
- !srv.Sink.Stream.Enable {
- cobra.CheckErr(fmt.Errorf("at least one sink must be enabled"))
+ l, err := logger.NewLogger(options.logLevel)
+ if err != nil {
+ cobra.CheckErr(fmt.Sprintf("Failed to create logger: %v", err))
}
- err := validateStringFlag("sink.syslog.address", srv.Sink.Syslog.Address, []string{})
- cobra.CheckErr(err)
-
- err = validateStringFlag("sink.loki.address", srv.Sink.Loki.Address, []string{})
- cobra.CheckErr(err)
-
- err = validateStringFlag("sink.stream.writer", srv.Sink.Stream.Writer, validStreamWriters)
- cobra.CheckErr(err)
-
- err = validateStringSliceFlag("filter.types", srv.Filter.EventTypes, validEventTypes)
- cobra.CheckErr(err)
-
- err = validateStringSliceFlag("filter.protocols", srv.Filter.Protocols, validProtocols)
- cobra.CheckErr(err)
-
- err = validateStringSliceFlag("filter.destination.networks", srv.Filter.Networks.Destinations, validNetworks)
- cobra.CheckErr(err)
-
- err = validateStringSliceFlag("filter.source.networks", srv.Filter.Networks.Sources, validNetworks)
- cobra.CheckErr(err)
-
- err = validateStringSliceFlag("filter.destination.addresses", srv.Filter.Addresses.Destinations, []string{})
- cobra.CheckErr(err)
+ var g *geoip.GeoIP
+ if options.geoipDatabase != "" {
+ g, err = geoip.NewGeoIP(options.geoipDatabase)
+ if err != nil {
+ cobra.CheckErr(fmt.Sprintf("Failed to open geoip database: %v", err))
+ }
+ defer func() {
+ _ = g.Close()
+ }()
+ }
- err = validateStringSliceFlag("filter.source.addresses", srv.Filter.Addresses.Sources, []string{})
- cobra.CheckErr(err)
+ var f *filter.Filter
+ if len(options.filterRules) > 0 {
+ f, err = filter.NewFilter(options.filterRules)
+ if err != nil {
+ cobra.CheckErr(fmt.Sprintf("failed to compile filter rules: %v", err))
+ }
+ }
- err = validateStringFlag("service.log.level", srv.Logger.Level, validLogLevels)
- cobra.CheckErr(err)
+ s, err := sink.NewSink(&options.sink)
+ if err != nil {
+ cobra.CheckErr(fmt.Sprintf("failed to initialize sink: %v", err))
+ }
- err = validateStringFlag("service.log.format", srv.Logger.Format, validLogFormats)
+ service, err := service.NewService(l, g, f, s)
cobra.CheckErr(err)
- if srv.GeoIP.Database != "" {
- if _, err := os.Stat(srv.GeoIP.Database); os.IsNotExist(err) {
- cobra.CheckErr(fmt.Errorf("GeoIP database file does not exist: %s", srv.GeoIP.Database))
- }
- }
+ ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer stop()
- if err := srv.Run(); err != nil {
+ if tranquil := service.Run(ctx); !tranquil {
os.Exit(1)
}
},
@@ -88,99 +77,28 @@ var runCmd = &cobra.Command{
func init() {
runCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp)
- runCmd.Flags().StringSliceVar(&srv.Filter.EventTypes, "filter.types", nil, "Filter by event type (NEW,UPDATE,DESTROY)")
- runCmd.Flags().StringSliceVar(&srv.Filter.Protocols, "filter.protocols", nil, "Filter by protocol (TCP,UDP)")
- runCmd.Flags().StringSliceVar(&srv.Filter.Networks.Destinations, "filter.destination.networks", nil, "Filter by destination networks (PUBLIC,PRIVATE,LOCAL,MULTICAST)")
- runCmd.Flags().StringSliceVar(&srv.Filter.Networks.Sources, "filter.source.networks", nil, "Filter by sources networks (PUBLIC,PRIVATE,LOCAL,MULTICAST)")
- runCmd.Flags().StringSliceVar(&srv.Filter.Addresses.Destinations, "filter.destination.addresses", nil, "Filter by destination IP addresses")
- runCmd.Flags().StringSliceVar(&srv.Filter.Addresses.Sources, "filter.source.addresses", nil, "Filter by source IP addresses")
- runCmd.Flags().UintSliceVar(&srv.Filter.Ports.Destinations, "filter.destination.ports", nil, "Filter by destination ports")
- runCmd.Flags().UintSliceVar(&srv.Filter.Ports.Sources, "filter.source.ports", nil, "Filter by source ports")
-
- runCmd.Flags().StringVar(&srv.Logger.Format, "service.log.format", "", "Log format (text,json)")
- runCmd.Flags().StringVar(&srv.Logger.Level, "service.log.level", "", "Log level (debug,info)")
-
- runCmd.Flags().StringVar(&srv.GeoIP.Database, "geoip.database", "", "Path to GeoIP database")
-
- runCmd.Flags().BoolVar(&srv.Sink.Journal.Enable, "sink.journal.enable", false, "Enable journald sink")
- runCmd.Flags().BoolVar(&srv.Sink.Syslog.Enable, "sink.syslog.enable", false, "Enable syslog sink")
- runCmd.Flags().StringVar(&srv.Sink.Syslog.Address, "sink.syslog.address", "udp://localhost:514", "Syslog address")
- runCmd.Flags().BoolVar(&srv.Sink.Loki.Enable, "sink.loki.enable", false, "Enable Loki sink")
- runCmd.Flags().StringVar(&srv.Sink.Loki.Address, "sink.loki.address", "http://localhost:3100", "Loki address")
- runCmd.Flags().StringSliceVar(&srv.Sink.Loki.Labels, "sink.loki.labels", nil, "Additional labels for Loki sink in key=value format")
- runCmd.Flags().BoolVar(&srv.Sink.Stream.Enable, "sink.stream.enable", false, "Enable stream sink")
- runCmd.Flags().StringVar(&srv.Sink.Stream.Writer, "sink.stream.writer", "stdout", "Stream writer (stdout,stderr,discard)")
+ runCmd.Flags().StringArrayVar(&options.filterRules, "filter", nil, "Filter rules in DSL format (repeatable, first-match wins)")
+
+ runCmd.Flags().StringVar(&options.logLevel, "log.level", "info", fmt.Sprintf("Log level (%s)", strings.Join(logger.Levels, ", ")))
+ _ = runCmd.RegisterFlagCompletionFunc("log.level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
+ return logger.Levels, cobra.ShellCompDirectiveNoFileComp
+ })
+ runCmd.Flags().StringVar(&options.geoipDatabase, "geoip.database", "", "Path to GeoIP database")
_ = runCmd.RegisterFlagCompletionFunc("geoip.database", cobra.FixedCompletions(nil, cobra.ShellCompDirectiveDefault))
- _ = runCmd.RegisterFlagCompletionFunc("service.log.level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
- return validLogLevels, cobra.ShellCompDirectiveNoFileComp
- })
+ runCmd.Flags().BoolVar(&options.sink.Journal.Enable, "sink.journal.enable", false, "Enable journald sink")
+ runCmd.Flags().BoolVar(&options.sink.Syslog.Enable, "sink.syslog.enable", false, "Enable syslog sink")
+ runCmd.Flags().StringVar(&options.sink.Syslog.Address, "sink.syslog.address", "udp://localhost:514", "Syslog address")
- _ = runCmd.RegisterFlagCompletionFunc("service.log.format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
- return validLogFormats, cobra.ShellCompDirectiveNoFileComp
- })
+ runCmd.Flags().BoolVar(&options.sink.Loki.Enable, "sink.loki.enable", false, "Enable Loki sink")
+ runCmd.Flags().StringVar(&options.sink.Loki.Address, "sink.loki.address", "http://localhost:3100", "Loki address")
+ runCmd.Flags().StringSliceVar(&options.sink.Loki.Labels, "sink.loki.labels", nil, "Additional labels for Loki sink in key=value format")
+ runCmd.Flags().BoolVar(&options.sink.Stream.Enable, "sink.stream.enable", false, "Enable stream sink")
+ runCmd.Flags().StringVar(&options.sink.Stream.Writer, "sink.stream.writer", "stdout", fmt.Sprintf("Stream writer (%s)", strings.Join(sink.StreamWriters, ", ")))
_ = runCmd.RegisterFlagCompletionFunc("sink.stream.writer", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
- return validStreamWriters, cobra.ShellCompDirectiveNoFileComp
+ return sink.StreamWriters, cobra.ShellCompDirectiveNoFileComp
})
-}
-
-func validateStringSliceFlag(flag string, values []string, validValues []string) error {
- if flag == "filter.source.addresses" || flag == "filter.destination.addresses" {
- for _, v := range values {
- if _, err := netip.ParseAddr(v); err != nil {
- return fmt.Errorf("invalid IP address '%s' for '--%s'", v, flag)
- }
- }
- return nil
- }
-
- for _, v := range values {
- if !slices.Contains(validValues, v) {
- return fmt.Errorf("invalid value '%s' for '--%s' . Valid values are: %s", v, flag, validValues)
- }
- }
- return nil
-}
-
-func validateStringFlag(flag string, value string, validValues []string) error {
- if value == "" {
- return nil
- }
-
- if flag == "sink.syslog.address" || flag == "sink.loki.address" {
- url, err := url.Parse(value)
- if err != nil {
- return fmt.Errorf("invalid URL '%s' for '--%s'", value, flag)
- }
-
- if flag == "sink.syslog.address" {
- if !slices.Contains(validSyslogSchemes, url.Scheme) {
- return fmt.Errorf("invalid URL scheme '%s' for '--%s'. Valid schemes are: udp, tcp, unix, unixgram unixpacket", url.Scheme, flag)
- }
- if url.Host == "" && !strings.HasPrefix(url.Scheme, "unix") {
- return fmt.Errorf("invalid URL '%s' for '--%s'. Host is missing", value, flag)
- }
- if url.Path == "" && strings.HasPrefix(url.Scheme, "unix") {
- return fmt.Errorf("invalid URL '%s' for '--%s'. Path is missing", value, flag)
- }
- }
-
- if flag == "sink.loki.address" {
- if !slices.Contains(validLokiSchemes, url.Scheme) {
- return fmt.Errorf("invalid URL scheme '%s' for '--%s'. Valid schemes are: http, https", url.Scheme, flag)
- }
- if url.Host == "" {
- return fmt.Errorf("invalid URL '%s' for '--%s'. Host is missing", value, flag)
- }
- }
-
- return nil
- }
- if !slices.Contains(validValues, value) {
- return fmt.Errorf("invalid value '%s' for '--%s' . Valid values are: %s", value, flag, validValues)
- }
- return nil
}
diff --git a/cmd/run_test.go b/cmd/run_test.go
deleted file mode 100644
index bff4566..0000000
--- a/cmd/run_test.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package cmd
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestValidateStringSliceFlag_Valid(t *testing.T) {
- assert.NoError(t, validateStringSliceFlag("filter.types", []string{"NEW", "UPDATE"}, validEventTypes))
- assert.NoError(t, validateStringSliceFlag("filter.protocols", []string{"TCP"}, validProtocols))
- assert.NoError(t, validateStringSliceFlag("filter.destination.addresses", []string{"127.0.0.1", "::1"}, []string{}))
-}
-
-func TestValidateStringSliceFlag_Invalid(t *testing.T) {
- assert.Error(t, validateStringSliceFlag("filter.types", []string{"BAD"}, validEventTypes))
- assert.Error(t, validateStringSliceFlag("filter.addresses", []string{"not-an-ip"}, []string{}))
-}
-
-func TestValidateStringFlag_LogLevelsAndFormats(t *testing.T) {
- assert.NoError(t, validateStringFlag("service.log.level", "debug", validLogLevels))
- assert.NoError(t, validateStringFlag("service.log.format", "json", validLogFormats))
-
- assert.Error(t, validateStringFlag("service.log.level", "verbose", validLogLevels))
- assert.Error(t, validateStringFlag("service.log.format", "xml", validLogFormats))
-}
-
-func TestValidateStringFlag_SyslogAddress_Valid(t *testing.T) {
- valids := []string{
- "udp://localhost:514",
- "tcp://127.0.0.1:514",
- "unix:///var/run/syslog.sock",
- "unixgram:///var/run/syslog.sock",
- "unixpacket:///var/run/syslog.sock",
- }
- for _, v := range valids {
- assert.NoErrorf(t, validateStringFlag("sink.syslog.address", v, []string{}), "valid syslog address %q should not error", v)
- }
-}
-
-func TestValidateStringFlag_SyslogAddress_Invalid(t *testing.T) {
- assert.Error(t, validateStringFlag("sink.syslog.address", "http://localhost:514", []string{}))
- assert.Error(t, validateStringFlag("sink.syslog.address", "tcp:///nohost", []string{}))
- assert.Error(t, validateStringFlag("sink.syslog.address", "unix://", []string{}))
-}
-
-func TestValidateStringFlag_LokiAddress_Valid(t *testing.T) {
- assert.NoError(t, validateStringFlag("sink.loki.address", "http://localhost:3100", []string{}))
- assert.NoError(t, validateStringFlag("sink.loki.address", "https://example.com", []string{}))
-}
-
-func TestValidateStringFlag_LokiAddress_Invalid(t *testing.T) {
- assert.Error(t, validateStringFlag("sink.loki.address", "tcp://localhost:3100", []string{}))
- assert.Error(t, validateStringFlag("sink.loki.address", "http:///path", []string{}))
-}
-
-func TestValidSlicesAreExplicit(t *testing.T) {
- expectedEvents := []string{"NEW", "UPDATE", "DESTROY"}
- assert.Equal(t, expectedEvents, validEventTypes, "validEventTypes mismatch")
-
- expectedProtocols := []string{"TCP", "UDP"}
- assert.Equal(t, expectedProtocols, validProtocols, "validProtocols mismatch")
-
- expectedDest := []string{"PUBLIC", "PRIVATE", "LOCAL", "MULTICAST"}
- assert.Equal(t, expectedDest, validNetworks, "validNetworks mismatch")
-}
diff --git a/cmd/version.go b/cmd/version.go
index 7fa8562..ed9a30c 100644
--- a/cmd/version.go
+++ b/cmd/version.go
@@ -18,5 +18,5 @@ var versionCmd = &cobra.Command{
}
func init() {
- runCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp)
+ versionCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp)
}
diff --git a/docs/filter.md b/docs/filter.md
new file mode 100644
index 0000000..9056826
--- /dev/null
+++ b/docs/filter.md
@@ -0,0 +1,357 @@
+# Filter DSL Documentation
+
+## Overview
+
+The conntrackd filter DSL (Domain-Specific Language) allows you to control
+which conntrack events are **logged** to your configured sinks
+(journal, syslog, Loki, etc.).
+
+**Protocol Support:** Only TCP and UDP events are processed. All other protocols
+(ICMP, IGMP, etc.) are automatically ignored and never logged, regardless of
+filter rules.
+
+**Important:** Filters do not affect network traffic - they only control which
+conntrack events are logged. All network traffic flows normally regardless of
+filter rules.
+
+Rules are evaluated in order (first-match wins), and events are
+**logged by default** when no rule matches.
+
+## Command-Line Usage
+
+Use the `--filter` flag to specify filter rules. This flag can be repeated
+multiple times:
+
+```bash
+conntrackd run \
+ --filter "drop destination address 8.8.8.8" \
+ --filter "log protocol TCP" \
+ --filter "drop ANY"
+```
+
+## Understanding Allow-by-Default
+
+By default, conntrackd logs all conntrack events. This means:
+
+- If no filters match an event, it **is logged** (allow-by-default)
+- An `log` rule means "log this event"
+- A `drop` rule means "don't log this event"
+
+To log **only** specific events, use an `log` rule followed by `drop ANY`:
+
+```bash
+# Log ONLY NEW TCP connections
+--filter "log type NEW and protocol TCP"
+--filter "drop ANY"
+```
+
+Without the `drop ANY`, all non-matching events would still be logged.
+
+## Grammar
+
+The filter DSL follows this grammar (EBNF notation):
+
+```ebnf
+rule ::= action expression
+action ::= "log" | "drop"
+expression ::= orExpr
+orExpr ::= andExpr { "or" andExpr }
+andExpr ::= notExpr { "and" notExpr }
+notExpr ::= [ "not" | "!" ] primary
+primary ::= predicate | "(" expression ")"
+
+predicate ::= eventPred | protoPred | addrPred | networkPred | portPred | anyPred
+
+eventPred ::= "type" identList
+protoPred ::= "protocol" identList
+addrPred ::= direction "address" addrList [ "on" "port" portSpec ]
+networkPred::= direction "network" identList
+portPred ::= [ direction ] "port" portSpec
+ | "on" "port" portSpec
+anyPred ::= "ANY"
+
+direction ::= "source" | "src" | "destination" | "dst" | "dest"
+identList ::= IDENT { "," IDENT }
+addrList ::= ADDRESS { "," ADDRESS }
+portSpec ::= NUMBER | NUMBER "-" NUMBER | NUMBER { "," NUMBER }
+```
+
+## Predicates
+
+### Any (Catch-All)
+
+The `ANY` predicate matches all events. It's typically used with `drop` to
+block all non-matching events:
+
+```bash
+# Log only NEW TCP connections (deny everything else)
+log type NEW and protocol TCP
+drop ANY
+
+# Log only traffic to specific IPs (deny everything else)
+log destination address 1.2.3.4,5.6.7.8
+drop ANY
+```
+
+### Event Type
+
+Match on conntrack event types:
+
+```bash
+# Deny NEW events
+drop type NEW
+
+# Allow UPDATE or DESTROY events
+log type UPDATE,DESTROY
+```
+
+Valid types: `NEW`, `UPDATE`, `DESTROY`
+
+### Protocol
+
+Match on protocol:
+
+```bash
+# Deny TCP events
+drop protocol TCP
+
+# Allow TCP or UDP
+log protocol TCP,UDP
+```
+
+Valid protocols: `TCP`, `UDP`
+
+### Network Classification
+
+Match on network categories:
+
+```bash
+# Deny traffic to private networks
+drop destination network PRIVATE
+
+# Allow traffic from public networks
+log source network PUBLIC
+```
+
+Valid network types:
+- `LOCAL` - Loopback addresses (127.0.0.0/8, ::1) and link-local addresses
+- `PRIVATE` - RFC1918 private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) and IPv6 ULA (fc00::/7)
+- `PUBLIC` - Public addresses (not LOCAL, PRIVATE, or MULTICAST)
+- `MULTICAST` - Multicast addresses (224.0.0.0/4, ff00::/8)
+
+### IP Address
+
+Match on specific IP addresses or CIDR ranges:
+
+```bash
+# Deny traffic to specific IP
+drop destination address 8.8.8.8
+
+# Allow traffic from a CIDR range
+log source address 192.168.1.0/24
+
+# Match IPv6 addresses
+drop destination address 2001:4860:4860::8888
+
+# Multiple addresses
+drop destination address 1.1.1.1,8.8.8.8
+```
+
+### Port
+
+Match on ports:
+
+```bash
+# Deny traffic to port 22
+drop destination port 22
+
+# Allow traffic from ports 1024-65535
+log source port 1024-65535
+
+# Match multiple ports
+drop destination port 22,23,3389
+
+# Match traffic on any port (source or destination)
+deny on port 53
+```
+
+### Address with Port
+
+Combine address and port matching:
+
+```bash
+# Deny traffic to specific IP on specific port
+drop destination address 10.19.80.100 on port 53
+```
+
+## Operators
+
+### AND
+
+Both conditions must be true:
+
+```bash
+log type NEW and protocol TCP
+```
+
+### OR
+
+At least one condition must be true:
+
+```bash
+log type NEW or type UPDATE
+```
+
+### NOT
+
+Negate a condition:
+
+```bash
+log not type DESTROY
+```
+
+### Parentheses
+
+Group expressions:
+
+```bash
+log (type NEW and protocol TCP) or (type UPDATE and protocol UDP)
+```
+
+## Operator Precedence
+
+1. NOT (highest)
+2. AND
+3. OR (lowest)
+
+Example: `log type NEW or type UPDATE and protocol TCP` parses as `log type NEW or (type UPDATE and protocol TCP)`
+
+## Evaluation Semantics
+
+### First-Match Wins
+
+Rules are evaluated in the order they are specified. The first rule whose predicate matches determines the action (allow or deny).
+
+```bash
+# First rule matches TCP traffic - allows it
+# Second rule never evaluated for TCP to 8.8.8.8
+conntrackd run \
+ --filter "log protocol TCP" \
+ --filter "drop destination address 8.8.8.8"
+```
+
+### Allow-by-Default
+
+If no rule matches an event, it is **logged** (allowed by default):
+
+```bash
+# Don't log events to 8.8.8.8
+# All other events ARE logged
+conntrackd run --filter "drop destination address 8.8.8.8"
+```
+
+To change this behavior and log **only** specific events, use `drop ANY` as the final rule:
+
+```bash
+# Log ONLY events to 8.8.8.8
+# All other events are NOT logged
+conntrackd run \
+ --filter "log destination address 8.8.8.8" \
+ --filter "drop ANY"
+```
+
+## Examples
+
+### Example 1: Don't Log Specific Destination
+
+Don't log events to a specific IP address, but log all TCP traffic to public networks:
+
+```bash
+conntrackd run \
+ --filter "drop destination address 8.8.8.8" \
+ --filter "log protocol TCP and destination network PUBLIC"
+```
+
+**Evaluation:**
+- Traffic to 8.8.8.8: Matches first rule → **NOT LOGGED**
+- TCP to public IP (not 8.8.8.8): Matches second rule → **LOGGED**
+- UDP to private network: No match → **LOGGED** (default)
+
+### Example 2: Don't Log DNS to Specific Server
+
+Don't log DNS traffic to a specific IP, log all other TCP/UDP:
+
+```bash
+conntrackd run \
+ --filter "drop destination address 10.19.80.100 on port 53" \
+ --filter "log protocol TCP,UDP"
+```
+
+**Evaluation:**
+- DNS to 10.19.80.100: Matches first rule → **NOT LOGGED**
+- TCP/UDP to other destinations: Matches second rule → **LOGGED**
+- Other protocols: No match → **LOGGED** (default)
+
+### Example 3: Log Only Specific Traffic
+
+Log only NEW TCP connections (don't log anything else):
+
+```bash
+conntrackd run \
+ --filter "log type NEW and protocol TCP" \
+ --filter "drop ANY"
+```
+
+**Evaluation:**
+- NEW TCP: Matches first rule → **LOGGED**
+- NEW UDP: Matches second rule → **NOT LOGGED**
+- UPDATE/DESTROY: Matches second rule → **NOT LOGGED**
+
+**Note:** Without `drop ANY`, all non-matching events would still be logged.
+
+### Example 4: Complex Filtering
+
+Don't log outbound traffic to private networks on specific ports:
+
+```bash
+conntrackd run \
+ --filter "drop destination network PRIVATE and destination port 22,23,3389" \
+ --filter "log source network PRIVATE"
+```
+
+**Evaluation:**
+- Private network destination on port 22: Matches first rule → **NOT LOGGED**
+- Private network source: Matches second rule → **LOGGED**
+- Other traffic: No match → **LOGGED** (default)
+
+## Best Practices
+
+1. **Order Matters**: Place more specific rules before general rules
+2. **Use `drop ANY` for Exclusive Logging**: When you want to log ONLY specific events, end with `drop ANY`
+3. **Use AND for Precision**: Combine multiple conditions to create precise filters
+4. **Test Incrementally**: Start with simple rules and add complexity
+5. **Document Complex Rules**: Add comments in your deployment scripts
+6. **Use Parentheses**: Make precedence explicit in complex expressions
+
+## Case Insensitivity
+
+Keywords and identifiers are case-insensitive:
+
+```bash
+# These are all equivalent
+log type NEW
+ALLOW TYPE NEW
+Allow Type New
+```
+
+## Abbreviations
+
+Supported abbreviations:
+- `src` = `source`
+- `dst` = `dest` = `destination`
+
+```bash
+# These are equivalent
+deny src network PRIVATE
+drop source network PRIVATE
+```
diff --git a/go.mod b/go.mod
index 2aca6d6..fea67d0 100644
--- a/go.mod
+++ b/go.mod
@@ -7,6 +7,7 @@ require (
github.com/grafana/loki-client-go v0.0.0-20251015150631-c42bbddc310a
github.com/mdlayher/netlink v1.8.0
github.com/oschwald/geoip2-golang/v2 v2.0.0
+ github.com/prometheus/common v0.67.2
github.com/samber/slog-loki/v3 v3.6.0
github.com/samber/slog-multi v1.6.0
github.com/samber/slog-syslog/v2 v2.5.2
@@ -47,7 +48,6 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.20.4 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
- github.com/prometheus/common v0.67.2 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/prometheus/prometheus v0.35.0 // indirect
github.com/samber/lo v1.52.0 // indirect
diff --git a/internal/filter/ast.go b/internal/filter/ast.go
new file mode 100644
index 0000000..9d1daa2
--- /dev/null
+++ b/internal/filter/ast.go
@@ -0,0 +1,121 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+// Action represents the action to take when a rule matches
+type Action int
+
+const (
+ ActionLog Action = iota
+ ActionDrop
+)
+
+func (a Action) String() string {
+ switch a {
+ case ActionLog:
+ return "log"
+ case ActionDrop:
+ return "drop"
+ default:
+ return "unknown"
+ }
+}
+
+// ExprNode represents a node in the expression AST
+type ExprNode interface {
+ isExprNode()
+}
+
+// BinaryExpr represents a binary expression (AND, OR)
+type BinaryExpr struct {
+ Op BinaryOp
+ Left ExprNode
+ Right ExprNode
+}
+
+func (BinaryExpr) isExprNode() {}
+
+type BinaryOp int
+
+const (
+ OpAnd BinaryOp = iota
+ OpOr
+)
+
+// UnaryExpr represents a unary expression (NOT)
+type UnaryExpr struct {
+ Op UnaryOp
+ Expr ExprNode
+}
+
+func (UnaryExpr) isExprNode() {}
+
+type UnaryOp int
+
+const (
+ OpNot UnaryOp = iota
+)
+
+// Predicate represents a base predicate
+type Predicate interface {
+ ExprNode
+ isPredicate()
+}
+
+// TypePredicate matches event types
+type TypePredicate struct {
+ Types []string // NEW, UPDATE, DESTROY
+}
+
+func (TypePredicate) isExprNode() {}
+func (TypePredicate) isPredicate() {}
+
+// ProtocolPredicate matches protocols
+type ProtocolPredicate struct {
+ Protocols []string // TCP, UDP
+}
+
+func (ProtocolPredicate) isExprNode() {}
+func (ProtocolPredicate) isPredicate() {}
+
+// NetworkPredicate matches network types
+type NetworkPredicate struct {
+ Direction string // source, destination
+ Networks []string // LOCAL, PRIVATE, PUBLIC, MULTICAST
+}
+
+func (NetworkPredicate) isExprNode() {}
+func (NetworkPredicate) isPredicate() {}
+
+// AddressPredicate matches IP addresses or CIDR ranges
+type AddressPredicate struct {
+ Direction string // source, destination
+ Addresses []string // IP addresses or CIDR
+ Ports []uint16 // optional ports
+}
+
+func (AddressPredicate) isExprNode() {}
+func (AddressPredicate) isPredicate() {}
+
+// PortPredicate matches ports
+type PortPredicate struct {
+ Direction string // source, destination
+ Ports []uint16 // port numbers or ranges
+}
+
+func (PortPredicate) isExprNode() {}
+func (PortPredicate) isPredicate() {}
+
+// AnyPredicate matches any event (catch-all)
+type AnyPredicate struct{}
+
+func (AnyPredicate) isExprNode() {}
+func (AnyPredicate) isPredicate() {}
+
+// Rule represents a complete filter rule
+type Rule struct {
+ Action Action
+ Expr ExprNode
+}
diff --git a/internal/filter/eval.go b/internal/filter/eval.go
new file mode 100644
index 0000000..3a99dd8
--- /dev/null
+++ b/internal/filter/eval.go
@@ -0,0 +1,249 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+import (
+ "fmt"
+ "net/netip"
+ "slices"
+ "strings"
+ "syscall"
+
+ "github.com/ti-mo/conntrack"
+)
+
+// PredicateFunc is a function that evaluates a predicate against an event
+type PredicateFunc func(event conntrack.Event) bool
+
+// Compile compiles an expression AST into a predicate function
+func Compile(expr ExprNode) (PredicateFunc, error) {
+ switch e := expr.(type) {
+ case BinaryExpr:
+ return compileBinaryExpr(e)
+ case UnaryExpr:
+ return compileUnaryExpr(e)
+ case TypePredicate:
+ return compileTypePredicate(e), nil
+ case ProtocolPredicate:
+ return compileProtocolPredicate(e), nil
+ case NetworkPredicate:
+ return compileNetworkPredicate(e), nil
+ case AddressPredicate:
+ return compileAddressPredicate(e)
+ case PortPredicate:
+ return compilePortPredicate(e), nil
+ case AnyPredicate:
+ return compileAnyPredicate(e), nil
+ default:
+ return nil, fmt.Errorf("unknown expression type: %T", expr)
+ }
+}
+
+func compileBinaryExpr(expr BinaryExpr) (PredicateFunc, error) {
+ left, err := Compile(expr.Left)
+ if err != nil {
+ return nil, err
+ }
+ right, err := Compile(expr.Right)
+ if err != nil {
+ return nil, err
+ }
+
+ switch expr.Op {
+ case OpAnd:
+ return func(event conntrack.Event) bool {
+ return left(event) && right(event)
+ }, nil
+ case OpOr:
+ return func(event conntrack.Event) bool {
+ return left(event) || right(event)
+ }, nil
+ default:
+ return nil, fmt.Errorf("unknown binary operator: %v", expr.Op)
+ }
+}
+
+func compileUnaryExpr(expr UnaryExpr) (PredicateFunc, error) {
+ inner, err := Compile(expr.Expr)
+ if err != nil {
+ return nil, err
+ }
+
+ switch expr.Op {
+ case OpNot:
+ return func(event conntrack.Event) bool {
+ return !inner(event)
+ }, nil
+ default:
+ return nil, fmt.Errorf("unknown unary operator: %v", expr.Op)
+ }
+}
+
+func compileTypePredicate(pred TypePredicate) PredicateFunc {
+ return func(event conntrack.Event) bool {
+ var eventType string
+ switch event.Type {
+ case conntrack.EventNew:
+ eventType = "NEW"
+ case conntrack.EventUpdate:
+ eventType = "UPDATE"
+ case conntrack.EventDestroy:
+ eventType = "DESTROY"
+ default:
+ return false
+ }
+ return slices.Contains(pred.Types, eventType)
+ }
+}
+
+func compileProtocolPredicate(pred ProtocolPredicate) PredicateFunc {
+ return func(event conntrack.Event) bool {
+ var protocol string
+ switch event.Flow.TupleOrig.Proto.Protocol {
+ case syscall.IPPROTO_TCP:
+ protocol = "TCP"
+ case syscall.IPPROTO_UDP:
+ protocol = "UDP"
+ default:
+ return false
+ }
+ return slices.Contains(pred.Protocols, protocol)
+ }
+}
+
+func compileNetworkPredicate(pred NetworkPredicate) PredicateFunc {
+ return func(event conntrack.Event) bool {
+ var ip netip.Addr
+ switch strings.ToLower(pred.Direction) {
+ case "source", "src":
+ ip = event.Flow.TupleOrig.IP.SourceAddress
+ case "destination", "dst", "dest":
+ ip = event.Flow.TupleOrig.IP.DestinationAddress
+ default:
+ return false
+ }
+
+ isLocal := ip.IsLoopback() || ip.IsLinkLocalUnicast()
+ isPrivate := ip.IsPrivate()
+ isMulticast := ip.IsMulticast()
+ isPublic := !isLocal && !isPrivate && !isMulticast
+
+ for _, network := range pred.Networks {
+ switch network {
+ case "LOCAL":
+ if isLocal {
+ return true
+ }
+ case "PRIVATE":
+ if isPrivate {
+ return true
+ }
+ case "MULTICAST":
+ if isMulticast {
+ return true
+ }
+ case "PUBLIC":
+ if isPublic {
+ return true
+ }
+ }
+ }
+ return false
+ }
+}
+
+func compileAddressPredicate(pred AddressPredicate) (PredicateFunc, error) {
+ // Pre-compile address matchers
+ var matchers []func(netip.Addr) bool
+ for _, addrStr := range pred.Addresses {
+ // Try to parse as CIDR first
+ if strings.Contains(addrStr, "/") {
+ prefix, err := netip.ParsePrefix(addrStr)
+ if err == nil {
+ // Capture prefix in a local variable to avoid loop variable capture
+ pfx := prefix
+ matchers = append(matchers, func(ip netip.Addr) bool {
+ return pfx.Contains(ip)
+ })
+ continue
+ }
+ }
+ // Try to parse as IP
+ addr, err := netip.ParseAddr(addrStr)
+ if err == nil {
+ // Capture addr in a local variable to avoid loop variable capture
+ a := addr
+ matchers = append(matchers, func(ip netip.Addr) bool {
+ return ip == a
+ })
+ }
+ }
+
+ // If no addresses parsed successfully, return error
+ if len(matchers) == 0 {
+ return nil, fmt.Errorf("no valid addresses in predicate")
+ }
+
+ return func(event conntrack.Event) bool {
+ var ip netip.Addr
+ var port uint16
+
+ switch strings.ToLower(pred.Direction) {
+ case "source", "src":
+ ip = event.Flow.TupleOrig.IP.SourceAddress
+ port = event.Flow.TupleOrig.Proto.SourcePort
+ case "destination", "dst", "dest":
+ ip = event.Flow.TupleOrig.IP.DestinationAddress
+ port = event.Flow.TupleOrig.Proto.DestinationPort
+ default:
+ return false
+ }
+
+ // Check if IP matches
+ matched := false
+ for _, matcher := range matchers {
+ if matcher(ip) {
+ matched = true
+ break
+ }
+ }
+
+ if !matched {
+ return false
+ }
+
+ // If ports are specified, check them too
+ if len(pred.Ports) > 0 {
+ return slices.Contains(pred.Ports, port)
+ }
+
+ return true
+ }, nil
+}
+
+func compilePortPredicate(pred PortPredicate) PredicateFunc {
+ return func(event conntrack.Event) bool {
+ srcPort := event.Flow.TupleOrig.Proto.SourcePort
+ dstPort := event.Flow.TupleOrig.Proto.DestinationPort
+
+ switch strings.ToLower(pred.Direction) {
+ case "source", "src":
+ return slices.Contains(pred.Ports, srcPort)
+ case "destination", "dst", "dest":
+ return slices.Contains(pred.Ports, dstPort)
+ case "both":
+ return slices.Contains(pred.Ports, srcPort) || slices.Contains(pred.Ports, dstPort)
+ default:
+ return false
+ }
+ }
+}
+
+func compileAnyPredicate(pred AnyPredicate) PredicateFunc {
+ // AnyPredicate always matches
+ return func(event conntrack.Event) bool {
+ return true
+ }
+}
diff --git a/internal/filter/eval_test.go b/internal/filter/eval_test.go
new file mode 100644
index 0000000..1be5761
--- /dev/null
+++ b/internal/filter/eval_test.go
@@ -0,0 +1,433 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+import (
+ "net/netip"
+ "syscall"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/ti-mo/conntrack"
+)
+
+func createEvent(eventTypeVal, proto uint8) conntrack.Event {
+ flow := conntrack.NewFlow(
+ proto,
+ conntrack.StatusAssured,
+ netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("8.8.8.8"),
+ 1234, 80,
+ 60, 0,
+ )
+ // Create event and then set Type using constants to satisfy type system
+ event := conntrack.Event{Flow: &flow}
+ switch eventTypeVal {
+ case 1: // NEW
+ event.Type = conntrack.EventNew
+ case 2: // UPDATE
+ event.Type = conntrack.EventUpdate
+ case 3: // DESTROY
+ event.Type = conntrack.EventDestroy
+ }
+ return event
+}
+
+func createEventWithAddrs(eventTypeVal, proto uint8, srcIP, dstIP string, srcPort, dstPort uint16) conntrack.Event {
+ flow := conntrack.NewFlow(
+ proto,
+ conntrack.StatusAssured,
+ netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP),
+ srcPort, dstPort,
+ 60, 0,
+ )
+ // Create event and then set Type using constants to satisfy type system
+ event := conntrack.Event{Flow: &flow}
+ switch eventTypeVal {
+ case 1: // NEW
+ event.Type = conntrack.EventNew
+ case 2: // UPDATE
+ event.Type = conntrack.EventUpdate
+ case 3: // DESTROY
+ event.Type = conntrack.EventDestroy
+ }
+ return event
+}
+
+func TestEval_TypePredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ event conntrack.Event
+ expected bool
+ }{
+ {"match NEW", "log type NEW", createEvent(1, syscall.IPPROTO_TCP), true},
+ {"match UPDATE", "log type UPDATE", createEvent(2, syscall.IPPROTO_TCP), true},
+ {"match DESTROY", "log type DESTROY", createEvent(3, syscall.IPPROTO_TCP), true},
+ {"no match NEW", "log type UPDATE", createEvent(1, syscall.IPPROTO_TCP), false},
+ {"match multiple", "log type NEW,UPDATE", createEvent(1, syscall.IPPROTO_TCP), true},
+ {"match multiple 2", "log type NEW,UPDATE", createEvent(2, syscall.IPPROTO_TCP), true},
+ {"no match multiple", "log type NEW,UPDATE", createEvent(3, syscall.IPPROTO_TCP), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.expected, pred(tt.event))
+ })
+ }
+}
+
+func TestEval_ProtocolPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ proto uint8
+ expected bool
+ }{
+ {"match TCP", "log protocol TCP", syscall.IPPROTO_TCP, true},
+ {"match UDP", "log protocol UDP", syscall.IPPROTO_UDP, true},
+ {"no match TCP", "log protocol UDP", syscall.IPPROTO_TCP, false},
+ {"match multiple", "log protocol TCP,UDP", syscall.IPPROTO_TCP, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ event := createEvent(1, tt.proto)
+ assert.Equal(t, tt.expected, pred(event))
+ })
+ }
+}
+
+func TestEval_NetworkPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ srcIP string
+ dstIP string
+ expected bool
+ }{
+ {"src private", "log source network PRIVATE", "10.0.0.1", "8.8.8.8", true},
+ {"src public", "log source network PUBLIC", "8.8.8.8", "10.0.0.1", true},
+ {"dst public", "log destination network PUBLIC", "10.0.0.1", "8.8.8.8", true},
+ {"dst private", "log destination network PRIVATE", "8.8.8.8", "10.0.0.1", true},
+ {"src local loopback", "log source network LOCAL", "127.0.0.1", "8.8.8.8", true},
+ {"dst multicast", "log destination network MULTICAST", "10.0.0.1", "224.0.0.1", true},
+ {"no match", "log source network PUBLIC", "10.0.0.1", "8.8.8.8", false},
+ {"ipv6 private", "log source network PRIVATE", "fc00::1", "2001:db8::1", true},
+ {"ipv6 public", "log destination network PUBLIC", "10.0.0.1", "2001:4860:4860::8888", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ event := createEventWithAddrs(1, syscall.IPPROTO_TCP, tt.srcIP, tt.dstIP, 1234, 80)
+ assert.Equal(t, tt.expected, pred(event))
+ })
+ }
+}
+
+func TestEval_AddressPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ srcIP string
+ dstIP string
+ srcPort uint16
+ dstPort uint16
+ expected bool
+ }{
+ {"dst exact match", "log destination address 8.8.8.8", "10.0.0.1", "8.8.8.8", 1234, 80, true},
+ {"dst no match", "log destination address 8.8.8.8", "10.0.0.1", "8.8.4.4", 1234, 80, false},
+ {"src exact match", "log source address 10.0.0.1", "10.0.0.1", "8.8.8.8", 1234, 80, true},
+ {"cidr match", "log destination address 8.8.8.0/24", "10.0.0.1", "8.8.8.100", 1234, 80, true},
+ {"cidr no match", "log destination address 8.8.8.0/24", "10.0.0.1", "8.8.9.1", 1234, 80, false},
+ {"with port match", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.100", 1234, 53, true},
+ {"with port no match addr", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.101", 1234, 53, false},
+ {"with port no match port", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.100", 1234, 80, false},
+ {"ipv6 match", "log destination address 2001:4860:4860::8888", "10.0.0.1", "2001:4860:4860::8888", 1234, 80, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ event := createEventWithAddrs(1, syscall.IPPROTO_TCP, tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
+ assert.Equal(t, tt.expected, pred(event))
+ })
+ }
+}
+
+func TestEval_PortPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ srcPort uint16
+ dstPort uint16
+ expected bool
+ }{
+ {"dst port match", "log destination port 80", 1234, 80, true},
+ {"dst port no match", "log destination port 80", 1234, 443, false},
+ {"src port match", "log source port 1234", 1234, 80, true},
+ {"on port dst match", "log on port 80", 1234, 80, true},
+ {"on port src match", "log on port 1234", 1234, 80, true},
+ {"on port no match", "log on port 53", 1234, 80, false},
+ {"port range", "log destination port 80,443,8080", 1234, 443, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", tt.srcPort, tt.dstPort)
+ assert.Equal(t, tt.expected, pred(event))
+ })
+ }
+}
+
+func TestEval_AnyPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ event conntrack.Event
+ }{
+ {
+ "any matches NEW TCP",
+ "log any",
+ createEvent(1, syscall.IPPROTO_TCP),
+ },
+ {
+ "any matches UPDATE UDP",
+ "drop any",
+ createEvent(2, syscall.IPPROTO_UDP),
+ },
+ {
+ "any matches DESTROY",
+ "log any",
+ createEvent(3, syscall.IPPROTO_TCP),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+ result := pred(tt.event)
+ assert.True(t, result, "any predicate should always match")
+ })
+ }
+}
+
+func TestEval_BinaryExpressions(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ event conntrack.Event
+ expected bool
+ }{
+ {
+ "and both match",
+ "log type NEW and protocol TCP",
+ createEvent(1, syscall.IPPROTO_TCP),
+ true,
+ },
+ {
+ "and first no match",
+ "log type UPDATE and protocol TCP",
+ createEvent(1, syscall.IPPROTO_TCP),
+ false,
+ },
+ {
+ "and second no match",
+ "log type NEW and protocol UDP",
+ createEvent(1, syscall.IPPROTO_TCP),
+ false,
+ },
+ {
+ "or first match",
+ "log type NEW or type UPDATE",
+ createEvent(1, syscall.IPPROTO_TCP),
+ true,
+ },
+ {
+ "or second match",
+ "log type UPDATE or type NEW",
+ createEvent(1, syscall.IPPROTO_TCP),
+ true,
+ },
+ {
+ "or both no match",
+ "log type UPDATE or type DESTROY",
+ createEvent(1, syscall.IPPROTO_TCP),
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.expected, pred(tt.event))
+ })
+ }
+}
+
+func TestEval_UnaryExpression(t *testing.T) {
+ tests := []struct {
+ name string
+ rule string
+ event conntrack.Event
+ expected bool
+ }{
+ {
+ "not match becomes false",
+ "log not type NEW",
+ createEvent(1, syscall.IPPROTO_TCP),
+ false,
+ },
+ {
+ "not no-match becomes true",
+ "log not type UPDATE",
+ createEvent(1, syscall.IPPROTO_TCP),
+ true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.rule)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ pred, err := Compile(rule.Expr)
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.expected, pred(tt.event))
+ })
+ }
+}
+
+func TestFilter_Evaluate(t *testing.T) {
+ rules := []string{
+ "drop destination address 8.8.8.8",
+ "log protocol TCP and destination network PUBLIC",
+ }
+
+ filter, err := NewFilter(rules)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ event conntrack.Event
+ matched bool
+ allow bool
+ matchedIndex int
+ }{
+ {
+ "first rule denies",
+ createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80),
+ true,
+ false,
+ 0,
+ },
+ {
+ "second rule allows",
+ createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "1.1.1.1", 1234, 80),
+ true,
+ true,
+ 1,
+ },
+ {
+ "no match allows by default",
+ createEventWithAddrs(1, syscall.IPPROTO_UDP, "10.0.0.1", "192.168.1.1", 1234, 80),
+ false,
+ true,
+ -1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ matched, allow, matchedIndex := filter.Evaluate(tt.event)
+ assert.Equal(t, tt.matched, matched)
+ assert.Equal(t, tt.allow, allow)
+ assert.Equal(t, tt.matchedIndex, matchedIndex)
+ })
+ }
+}
+
+func TestFilter_EmptyAllowsByDefault(t *testing.T) {
+ filter, err := NewFilter([]string{})
+ require.NoError(t, err)
+
+ event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80)
+ matched, allow, matchedIndex := filter.Evaluate(event)
+ assert.False(t, matched)
+ assert.True(t, allow)
+ assert.Equal(t, -1, matchedIndex)
+}
+
+func TestFilter_FirstMatchWins(t *testing.T) {
+ rules := []string{
+ "log protocol TCP",
+ "drop destination address 8.8.8.8",
+ }
+
+ filter, err := NewFilter(rules)
+ require.NoError(t, err)
+
+ // TCP to 8.8.8.8 should be allowed by first rule (first match wins)
+ event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80)
+ matched, allow, matchedIndex := filter.Evaluate(event)
+ assert.True(t, matched)
+ assert.True(t, allow)
+ assert.Equal(t, 0, matchedIndex)
+}
diff --git a/internal/filter/filter.go b/internal/filter/filter.go
index 5e1befa..e736d5c 100644
--- a/internal/filter/filter.go
+++ b/internal/filter/filter.go
@@ -5,212 +5,72 @@ Licensed under the MIT License, see LICENSE file in the project root for details
package filter
import (
- "log/slog"
- "slices"
- "syscall"
+ "fmt"
"github.com/ti-mo/conntrack"
)
-type FilterAddresses struct {
- Destinations []string
- Sources []string
-}
-
-type FilterNetworks struct {
- Destinations []string
- Sources []string
-}
-
-type FilterPorts struct {
- Destinations []uint
- Sources []uint
-}
-
+// Filter represents a compiled set of filter rules
type Filter struct {
- EventTypes []string
- Protocols []string
- Networks FilterNetworks
- Addresses FilterAddresses
- Ports FilterPorts
+ Rules []CompiledRule
}
-func (f *Filter) eventType(event conntrack.Event) bool {
- if len(f.EventTypes) == 0 {
- return false
- }
-
- types := map[any]string{
- conntrack.EventNew: "NEW",
- conntrack.EventUpdate: "UPDATE",
- conntrack.EventDestroy: "DESTROY",
- }
- eventTypeStr, ok := types[event.Type]
- if !ok {
- return false
- }
-
- return slices.Contains(f.EventTypes, eventTypeStr)
+// CompiledRule represents a parsed and compiled filter rule
+type CompiledRule struct {
+ Rule *Rule
+ Predicate PredicateFunc
+ RuleText string
}
-func (f *Filter) eventProtocol(event conntrack.Event) bool {
- if len(f.Protocols) == 0 {
- switch event.Flow.TupleOrig.Proto.Protocol {
- case syscall.IPPROTO_TCP, syscall.IPPROTO_UDP:
- return false
- default:
- return true
- }
+// NewFilter creates a new DSL-based filter from rule strings
+func NewFilter(ruleStrings []string) (*Filter, error) {
+ filter := &Filter{
+ Rules: make([]CompiledRule, 0, len(ruleStrings)),
}
- protocols := map[int]string{
- syscall.IPPROTO_TCP: "TCP",
- syscall.IPPROTO_UDP: "UDP",
- }
- protocolStr, ok := protocols[int(event.Flow.TupleOrig.Proto.Protocol)]
- if !ok {
- return false
- }
-
- return slices.Contains(f.Protocols, protocolStr)
-}
-
-func (f *Filter) eventSource(event conntrack.Event) bool {
- if len(f.Networks.Sources) == 0 {
- return false
- }
-
- src := event.Flow.TupleOrig.IP.SourceAddress
- slog.Info("Source Address", "src", src.String())
- isLocal := src.IsLoopback()
- slog.Info("Is Local", "isLocal", isLocal)
- isPrivate := src.IsPrivate()
- isMulticast := src.IsMulticast()
- isPublic := !isLocal && !isPrivate && !isMulticast
-
- for _, filterSource := range f.Networks.Sources {
- switch filterSource {
- case "LOCAL":
- if isLocal {
- return true
- }
- case "PRIVATE":
- if isPrivate {
- return true
- }
- case "MULTICAST":
- if isMulticast {
- return true
- }
- case "PUBLIC":
- if isPublic {
- return true
- }
+ for i, ruleStr := range ruleStrings {
+ parser, err := NewParser(ruleStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initialize parser for rule %d: %w", i, err)
}
- }
- return false
-}
-
-func (f *Filter) eventDestination(event conntrack.Event) bool {
- if len(f.Networks.Destinations) == 0 {
- return false
- }
-
- dest := event.Flow.TupleOrig.IP.DestinationAddress
- isLocal := dest.IsLoopback()
- isPrivate := dest.IsPrivate()
- isMulticast := dest.IsMulticast()
- isPublic := !isLocal && !isPrivate && !isMulticast
-
- for _, filterDest := range f.Networks.Destinations {
- switch filterDest {
- case "LOCAL":
- if isLocal {
- return true
- }
- case "PRIVATE":
- if isPrivate {
- return true
- }
- case "MULTICAST":
- if isMulticast {
- return true
- }
- case "PUBLIC":
- if isPublic {
- return true
- }
+ rule, err := parser.ParseRule()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse rule %d (%s): %w", i, ruleStr, err)
}
- }
- return false
-}
-
-func (f *Filter) eventAddressDestination(event conntrack.Event) bool {
- if len(f.Addresses.Destinations) == 0 {
- return false
- }
-
- return slices.Contains(f.Addresses.Destinations, event.Flow.TupleOrig.IP.DestinationAddress.String())
-}
-
-func (f *Filter) eventAddressSource(event conntrack.Event) bool {
- if len(f.Addresses.Sources) == 0 {
- return false
- }
-
- return slices.Contains(f.Addresses.Sources, event.Flow.TupleOrig.IP.SourceAddress.String())
-}
-
-func (f *Filter) eventPortDestination(event conntrack.Event) bool {
- if len(f.Ports.Destinations) == 0 {
- return false
- }
-
- return slices.Contains(f.Ports.Destinations, uint(event.Flow.TupleOrig.Proto.DestinationPort))
-}
+ predicate, err := Compile(rule.Expr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to compile rule %d (%s): %w", i, ruleStr, err)
+ }
-func (f *Filter) eventPortSource(event conntrack.Event) bool {
- if len(f.Ports.Sources) == 0 {
- return false
+ filter.Rules = append(filter.Rules, CompiledRule{
+ Rule: rule,
+ Predicate: predicate,
+ RuleText: ruleStr,
+ })
}
- return slices.Contains(f.Ports.Sources, uint(event.Flow.TupleOrig.Proto.SourcePort))
+ return filter, nil
}
-func (f *Filter) Apply(event conntrack.Event) bool {
- if f.eventType(event) {
- return true
+// Evaluate evaluates the filter against an event
+// Returns: (matched bool, shouldLog bool, matchedRuleIndex int)
+// If no rule matches, returns (false, true, -1) for log-by-default policy
+func (f *Filter) Evaluate(event conntrack.Event) (bool, bool, int) {
+ if f == nil || len(f.Rules) == 0 {
+ // Log by default when no rules
+ return false, true, -1
}
- if f.eventProtocol(event) {
- return true
- }
-
- if f.eventSource(event) {
- return true
- }
-
- if f.eventDestination(event) {
- return true
- }
-
- if f.eventAddressDestination(event) {
- return true
- }
-
- if f.eventAddressSource(event) {
- return true
- }
-
- if f.eventPortDestination(event) {
- return true
- }
-
- if f.eventPortSource(event) {
- return true
+ // First-match wins
+ for i, compiledRule := range f.Rules {
+ if compiledRule.Predicate(event) {
+ shouldLog := compiledRule.Rule.Action == ActionLog
+ return true, shouldLog, i
+ }
}
- return false
+ // Log by default when no rule matches
+ return false, true, -1
}
diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go
deleted file mode 100644
index d6c9664..0000000
--- a/internal/filter/filter_test.go
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
-Copyright (c) 2025 Tobias Schäfer. All rights reserved.
-Licensed under the MIT License, see LICENSE file in the project root for details.
-*/
-package filter
-
-import (
- "io"
- "net/http"
- "net/netip"
- "os"
- "syscall"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/ti-mo/conntrack"
-)
-
-const geoDatabasePath = "/tmp/GeoLite2-City.mmdb"
-const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb"
-
-func setup() {
- setupGeoDatabase()
-}
-
-func setupGeoDatabase() {
- if _, err := os.Stat(geoDatabasePath); os.IsNotExist(err) {
- resp, err := http.Get(geoDatabaseUrl)
- if err != nil {
- panic(err)
- }
- defer func() {
- _ = resp.Body.Close()
- }()
-
- out, err := os.Create(geoDatabasePath)
- if err != nil {
- panic(err)
- }
- defer func() {
- _ = out.Close()
- }()
-
- _, err = io.Copy(out, resp.Body)
- if err != nil {
- panic(err)
- }
- }
-}
-
-func Test_Filter(t *testing.T) {
- setup()
-
- flow := conntrack.NewFlow(
- syscall.IPPROTO_TCP,
- conntrack.StatusAssured,
- netip.MustParseAddr("10.19.80.100"), netip.MustParseAddr("78.47.60.169"),
- 4711, 443,
- 60, 0,
- )
-
- event := conntrack.Event{
- Type: conntrack.EventNew,
- Flow: &flow,
- }
-
- f := &Filter{}
- matched := f.Apply(event)
- assert.False(t, matched, "no filters")
-
- f = &Filter{
- Protocols: []string{"TCP"},
- }
- matched = f.Apply(event)
- assert.True(t, matched, "protocol filter TCP")
-
- f = &Filter{
- EventTypes: []string{"NEW"},
- }
- matched = f.Apply(event)
- assert.True(t, matched, "event type filter NEW")
-
- f = &Filter{
- EventTypes: []string{"UPDATE"},
- }
- matched = f.Apply(event)
- assert.False(t, matched, "event type filter UPDATE")
-
- f = &Filter{
- EventTypes: []string{"DESTROY"},
- }
- matched = f.Apply(event)
- assert.False(t, matched, "event type filter DESTROY")
-
- f = &Filter{
- EventTypes: []string{"NEW", "DESTROY"},
- }
- matched = f.Apply(event)
- assert.True(t, matched, "event type filter NEW, DESTROY")
-
- f = &Filter{
- Protocols: []string{"UDP"},
- }
- matched = f.Apply(event)
- assert.False(t, matched, "protocol filter UDP")
-
- f = &Filter{
- Protocols: []string{"UDP", "TCP"},
- }
- matched = f.Apply(event)
- assert.True(t, matched, "protocol filter UDP, TCP")
-
- f = &Filter{
- Protocols: []string{"ICMP"},
- }
- matched = f.Apply(event)
- assert.False(t, matched, "bad protocol filter")
-
- f = &Filter{
- Networks: FilterNetworks{
- Destinations: []string{"PUBLIC"},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "destination network filter PUBLIC")
-
- f = &Filter{
- Networks: FilterNetworks{
- Destinations: []string{"PRIVATE"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "destination network filter PRIVATE")
-
- f = &Filter{
- Networks: FilterNetworks{
- Destinations: []string{"LOCAL"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "destination network filter LOCAL")
-
- f = &Filter{
- Networks: FilterNetworks{
- Destinations: []string{"MULTICAST"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source network filter MULTICAST")
-
- f = &Filter{
- Networks: FilterNetworks{
- Sources: []string{"PUBLIC"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source network filter PUBLIC")
-
- f = &Filter{
- Networks: FilterNetworks{
- Sources: []string{"PRIVATE"},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "source network filter PRIVATE")
-
- f = &Filter{
- Networks: FilterNetworks{
- Sources: []string{"LOCAL"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source network filter LOCAL")
-
- f = &Filter{
- Networks: FilterNetworks{
- Sources: []string{"MULTICAST"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source network filter MULTICAST")
-
- f = &Filter{
- Addresses: FilterAddresses{
- Destinations: []string{"78.47.60.169"},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "destination address filter match")
-
- f = &Filter{
- Addresses: FilterAddresses{
- Destinations: []string{"78.47.60.170"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "destination address filter no match")
-
- f = &Filter{
- Addresses: FilterAddresses{
- Sources: []string{"10.19.80.100"},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "source address filter match")
-
- f = &Filter{
- Addresses: FilterAddresses{
- Sources: []string{"10.19.80.200"},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source address filter no match")
-
- f = &Filter{
- Ports: FilterPorts{
- Destinations: []uint{443},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "destination port filter match")
-
- f = &Filter{
- Ports: FilterPorts{
- Destinations: []uint{80},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "destination port filter no match")
-
- f = &Filter{
- Ports: FilterPorts{
- Sources: []uint{4711},
- },
- }
- matched = f.Apply(event)
- assert.True(t, matched, "source port filter match")
-
- f = &Filter{
- Ports: FilterPorts{
- Sources: []uint{1234},
- },
- }
- matched = f.Apply(event)
- assert.False(t, matched, "source port filter no match")
-}
diff --git a/internal/filter/lexer.go b/internal/filter/lexer.go
new file mode 100644
index 0000000..4f780d0
--- /dev/null
+++ b/internal/filter/lexer.go
@@ -0,0 +1,172 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+import (
+ "fmt"
+ "strings"
+ "unicode"
+)
+
+type TokenType int
+
+const (
+ TokenEOF TokenType = iota
+ TokenIdent
+ TokenNumber
+ TokenComma
+ TokenDash
+ TokenSlash
+ TokenColon
+ TokenDot
+ TokenLParen
+ TokenRParen
+ TokenAnd
+ TokenOr
+ TokenNot
+ TokenLog
+ TokenDrop
+ TokenEventType
+ TokenProtocol
+ TokenSource
+ TokenDestination
+ TokenAddress
+ TokenNetwork
+ TokenPort
+ TokenOn
+ TokenAny
+)
+
+type Token struct {
+ Type TokenType
+ Value string
+ Pos int
+}
+
+type Lexer struct {
+ input string
+ pos int
+ ch rune
+}
+
+func NewLexer(input string) *Lexer {
+ l := &Lexer{input: input}
+ l.readChar()
+ return l
+}
+
+func (l *Lexer) readChar() {
+ if l.pos >= len(l.input) {
+ l.ch = 0
+ } else {
+ l.ch = rune(l.input[l.pos])
+ }
+ l.pos++
+}
+
+func (l *Lexer) skipWhitespace() {
+ for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' {
+ l.readChar()
+ }
+}
+
+func (l *Lexer) readIdentifier() string {
+ start := l.pos - 1
+ for unicode.IsLetter(l.ch) || unicode.IsDigit(l.ch) || l.ch == '_' {
+ l.readChar()
+ }
+ return l.input[start : l.pos-1]
+}
+
+func (l *Lexer) readNumber() string {
+ start := l.pos - 1
+ for unicode.IsDigit(l.ch) {
+ l.readChar()
+ }
+ return l.input[start : l.pos-1]
+}
+
+func (l *Lexer) NextToken() (Token, error) {
+ l.skipWhitespace()
+
+ tok := Token{Pos: l.pos - 1}
+
+ switch l.ch {
+ case 0:
+ tok.Type = TokenEOF
+ case ',':
+ tok.Type = TokenComma
+ tok.Value = ","
+ l.readChar()
+ case '-':
+ tok.Type = TokenDash
+ tok.Value = "-"
+ l.readChar()
+ case '/':
+ tok.Type = TokenSlash
+ tok.Value = "/"
+ l.readChar()
+ case ':':
+ tok.Type = TokenColon
+ tok.Value = ":"
+ l.readChar()
+ case '.':
+ tok.Type = TokenDot
+ tok.Value = "."
+ l.readChar()
+ case '(':
+ tok.Type = TokenLParen
+ tok.Value = "("
+ l.readChar()
+ case ')':
+ tok.Type = TokenRParen
+ tok.Value = ")"
+ l.readChar()
+ case '!':
+ tok.Type = TokenNot
+ tok.Value = "!"
+ l.readChar()
+ default:
+ if unicode.IsLetter(l.ch) {
+ tok.Value = l.readIdentifier()
+ tok.Type = l.lookupKeyword(tok.Value)
+ } else if unicode.IsDigit(l.ch) {
+ tok.Value = l.readNumber()
+ tok.Type = TokenNumber
+ } else {
+ return tok, fmt.Errorf("unexpected character: %c at position %d", l.ch, l.pos-1)
+ }
+ }
+
+ return tok, nil
+}
+
+func (l *Lexer) lookupKeyword(ident string) TokenType {
+ keywords := map[string]TokenType{
+ "and": TokenAnd,
+ "or": TokenOr,
+ "not": TokenNot,
+ "log": TokenLog,
+ "drop": TokenDrop,
+ "type": TokenEventType,
+ "protocol": TokenProtocol,
+ "source": TokenSource,
+ "src": TokenSource,
+ "destination": TokenDestination,
+ "dst": TokenDestination,
+ "dest": TokenDestination,
+ "address": TokenAddress,
+ "network": TokenNetwork,
+ "port": TokenPort,
+ "on": TokenOn,
+ "any": TokenAny,
+ }
+
+ lower := strings.ToLower(ident)
+ if tokType, ok := keywords[lower]; ok {
+ return tokType
+ }
+ return TokenIdent
+}
diff --git a/internal/filter/parser.go b/internal/filter/parser.go
new file mode 100644
index 0000000..b6847ed
--- /dev/null
+++ b/internal/filter/parser.go
@@ -0,0 +1,504 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+type Parser struct {
+ lexer *Lexer
+ current Token
+ peek Token
+}
+
+func NewParser(input string) (*Parser, error) {
+ p := &Parser{lexer: NewLexer(input)}
+ // Read two tokens to initialize current and peek
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
+func (p *Parser) nextToken() error {
+ p.current = p.peek
+ tok, err := p.lexer.NextToken()
+ if err != nil {
+ return err
+ }
+ p.peek = tok
+ return nil
+}
+
+func (p *Parser) expect(tokType TokenType) error {
+ if p.current.Type != tokType {
+ return fmt.Errorf("expected token %v, got %v (%s) at position %d",
+ tokType, p.current.Type, p.current.Value, p.current.Pos)
+ }
+ return p.nextToken()
+}
+
+// ParseRule parses a complete rule: action expression
+func (p *Parser) ParseRule() (*Rule, error) {
+ rule := &Rule{}
+
+ // Parse action
+ switch p.current.Type {
+ case TokenLog:
+ rule.Action = ActionLog
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ case TokenDrop:
+ rule.Action = ActionDrop
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ default:
+ return nil, fmt.Errorf("expected 'log' or 'drop', got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+
+ // Parse expression
+ expr, err := p.parseExpression()
+ if err != nil {
+ return nil, err
+ }
+ rule.Expr = expr
+
+ // Expect EOF
+ if p.current.Type != TokenEOF {
+ return nil, fmt.Errorf("unexpected token after rule: %s at position %d",
+ p.current.Value, p.current.Pos)
+ }
+
+ return rule, nil
+}
+
+// parseExpression parses: orExpr
+func (p *Parser) parseExpression() (ExprNode, error) {
+ return p.parseOrExpr()
+}
+
+// parseOrExpr parses: andExpr { ("," | "or") andExpr }
+func (p *Parser) parseOrExpr() (ExprNode, error) {
+ left, err := p.parseAndExpr()
+ if err != nil {
+ return nil, err
+ }
+
+ for p.current.Type == TokenOr || p.current.Type == TokenComma {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ right, err := p.parseAndExpr()
+ if err != nil {
+ return nil, err
+ }
+ left = BinaryExpr{Op: OpOr, Left: left, Right: right}
+ }
+
+ return left, nil
+}
+
+// parseAndExpr parses: notExpr { "and" notExpr }
+func (p *Parser) parseAndExpr() (ExprNode, error) {
+ left, err := p.parseNotExpr()
+ if err != nil {
+ return nil, err
+ }
+
+ for p.current.Type == TokenAnd {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ right, err := p.parseNotExpr()
+ if err != nil {
+ return nil, err
+ }
+ left = BinaryExpr{Op: OpAnd, Left: left, Right: right}
+ }
+
+ return left, nil
+}
+
+// parseNotExpr parses: [ "not" | "!" ] primary
+func (p *Parser) parseNotExpr() (ExprNode, error) {
+ if p.current.Type == TokenNot {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ expr, err := p.parsePrimary()
+ if err != nil {
+ return nil, err
+ }
+ return UnaryExpr{Op: OpNot, Expr: expr}, nil
+ }
+ return p.parsePrimary()
+}
+
+// parsePrimary parses: predicate | "(" expression ")"
+func (p *Parser) parsePrimary() (ExprNode, error) {
+ if p.current.Type == TokenLParen {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ expr, err := p.parseExpression()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.expect(TokenRParen); err != nil {
+ return nil, err
+ }
+ return expr, nil
+ }
+ return p.parsePredicate()
+}
+
+// parsePredicate parses various predicate types
+func (p *Parser) parsePredicate() (ExprNode, error) {
+ switch p.current.Type {
+ case TokenEventType:
+ return p.parseTypePredicate()
+ case TokenProtocol:
+ return p.parseProtocolPredicate()
+ case TokenSource, TokenDestination:
+ return p.parseDirectionalPredicate()
+ case TokenOn:
+ return p.parsePortPredicate()
+ case TokenAny:
+ // Parse "any" - matches everything
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ return AnyPredicate{}, nil
+ default:
+ return nil, fmt.Errorf("expected predicate keyword, got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+}
+
+// parseTypePredicate parses: "type" IDENT_LIST
+func (p *Parser) parseTypePredicate() (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ types, err := p.parseIdentList()
+ if err != nil {
+ return nil, err
+ }
+
+ // Normalize and validate type names
+ validTypes := map[string]bool{
+ "NEW": true,
+ "UPDATE": true,
+ "DESTROY": true,
+ }
+ normalized := make([]string, len(types))
+ for i, t := range types {
+ normalized[i] = strings.ToUpper(t)
+ if !validTypes[normalized[i]] {
+ return nil, fmt.Errorf("invalid event type '%s' at position %d, valid types are: NEW, UPDATE, DESTROY", t, p.current.Pos)
+ }
+ }
+
+ return TypePredicate{Types: normalized}, nil
+}
+
+// parseProtocolPredicate parses: "protocol" IDENT_LIST
+func (p *Parser) parseProtocolPredicate() (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ protocols, err := p.parseIdentList()
+ if err != nil {
+ return nil, err
+ }
+
+ // Normalize and validate protocol names
+ validProtocols := map[string]bool{
+ "TCP": true,
+ "UDP": true,
+ }
+ normalized := make([]string, len(protocols))
+ for i, proto := range protocols {
+ normalized[i] = strings.ToUpper(proto)
+ if !validProtocols[normalized[i]] {
+ return nil, fmt.Errorf("invalid protocol '%s' at position %d, valid protocols are: TCP, UDP", proto, p.current.Pos)
+ }
+ }
+
+ return ProtocolPredicate{Protocols: normalized}, nil
+}
+
+// parseDirectionalPredicate handles source/destination address, network, or port
+func (p *Parser) parseDirectionalPredicate() (ExprNode, error) {
+ direction := p.current.Value
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ switch p.current.Type {
+ case TokenNetwork:
+ return p.parseNetworkPredicate(direction)
+ case TokenAddress:
+ return p.parseAddressPredicate(direction)
+ case TokenPort:
+ return p.parsePortPredicateWithDirection(direction)
+ default:
+ return nil, fmt.Errorf("expected 'network', 'address', or 'port' after '%s', got '%s' at position %d",
+ direction, p.current.Value, p.current.Pos)
+ }
+}
+
+// parseNetworkPredicate parses: direction "network" IDENT
+func (p *Parser) parseNetworkPredicate(direction string) (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ networks, err := p.parseIdentList()
+ if err != nil {
+ return nil, err
+ }
+
+ // Normalize and validate network names
+ validNetworks := map[string]bool{
+ "LOCAL": true,
+ "PRIVATE": true,
+ "PUBLIC": true,
+ "MULTICAST": true,
+ }
+ normalized := make([]string, len(networks))
+ for i, net := range networks {
+ normalized[i] = strings.ToUpper(net)
+ if !validNetworks[normalized[i]] {
+ return nil, fmt.Errorf("invalid network type '%s' at position %d, valid networks are: LOCAL, PRIVATE, PUBLIC, MULTICAST", net, p.current.Pos)
+ }
+ }
+
+ return NetworkPredicate{Direction: direction, Networks: normalized}, nil
+}
+
+// parseAddressPredicate parses: direction "address" (IP | CIDR) ["on" "port" PORT_SPEC]
+func (p *Parser) parseAddressPredicate(direction string) (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ addresses, err := p.parseAddressList()
+ if err != nil {
+ return nil, err
+ }
+
+ var ports []uint16
+ // Check for optional "on port"
+ if p.current.Type == TokenOn {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if err := p.expect(TokenPort); err != nil {
+ return nil, err
+ }
+ ports, err = p.parsePortSpec()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return AddressPredicate{Direction: direction, Addresses: addresses, Ports: ports}, nil
+}
+
+// parsePortPredicate parses: "on" "port" PORT_SPEC (no direction)
+func (p *Parser) parsePortPredicate() (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if err := p.expect(TokenPort); err != nil {
+ return nil, err
+ }
+
+ ports, err := p.parsePortSpec()
+ if err != nil {
+ return nil, err
+ }
+
+ // "on port" without direction means both source and destination
+ return PortPredicate{Direction: "both", Ports: ports}, nil
+}
+
+// parsePortPredicateWithDirection parses: direction "port" PORT_SPEC
+func (p *Parser) parsePortPredicateWithDirection(direction string) (ExprNode, error) {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ ports, err := p.parsePortSpec()
+ if err != nil {
+ return nil, err
+ }
+
+ return PortPredicate{Direction: direction, Ports: ports}, nil
+}
+
+// parseIdentList parses: IDENT { "," IDENT }
+func (p *Parser) parseIdentList() ([]string, error) {
+ var idents []string
+
+ if p.current.Type != TokenIdent {
+ return nil, fmt.Errorf("expected identifier, got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+
+ idents = append(idents, p.current.Value)
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ for p.current.Type == TokenComma {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if p.current.Type != TokenIdent {
+ return nil, fmt.Errorf("expected identifier after comma, got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ idents = append(idents, p.current.Value)
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ }
+
+ return idents, nil
+}
+
+// parseAddressList parses IP addresses or CIDR ranges
+func (p *Parser) parseAddressList() ([]string, error) {
+ var addresses []string
+
+ addr, err := p.parseAddress()
+ if err != nil {
+ return nil, err
+ }
+ addresses = append(addresses, addr)
+
+ for p.current.Type == TokenComma {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ addr, err := p.parseAddress()
+ if err != nil {
+ return nil, err
+ }
+ addresses = append(addresses, addr)
+ }
+
+ return addresses, nil
+}
+
+// parseAddress parses an IP address or CIDR range
+func (p *Parser) parseAddress() (string, error) {
+ var parts []string
+
+ // Parse IPv4, IPv6, or CIDR
+ for {
+ switch p.current.Type {
+ case TokenNumber, TokenIdent:
+ parts = append(parts, p.current.Value)
+ if err := p.nextToken(); err != nil {
+ return "", err
+ }
+ case TokenDot, TokenColon, TokenSlash:
+ parts = append(parts, p.current.Value)
+ if err := p.nextToken(); err != nil {
+ return "", err
+ }
+ default:
+ goto done
+ }
+ }
+done:
+
+ if len(parts) == 0 {
+ return "", fmt.Errorf("expected IP address at position %d", p.current.Pos)
+ }
+
+ return strings.Join(parts, ""), nil
+}
+
+// parsePortSpec parses: NUMBER | NUMBER "-" NUMBER | NUMBER { "," NUMBER }
+func (p *Parser) parsePortSpec() ([]uint16, error) {
+ var ports []uint16
+
+ if p.current.Type != TokenNumber {
+ return nil, fmt.Errorf("expected port number, got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+
+ port, err := strconv.ParseUint(p.current.Value, 10, 16)
+ if err != nil {
+ return nil, fmt.Errorf("invalid port number '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ ports = append(ports, uint16(port))
+
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+
+ // Check for range (e.g., 80-90)
+ if p.current.Type == TokenDash {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if p.current.Type != TokenNumber {
+ return nil, fmt.Errorf("expected port number after '-', got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ endPort, err := strconv.ParseUint(p.current.Value, 10, 16)
+ if err != nil {
+ return nil, fmt.Errorf("invalid port number '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ // Expand range
+ for p := ports[0] + 1; p <= uint16(endPort); p++ {
+ ports = append(ports, p)
+ }
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ return ports, nil
+ }
+
+ // Check for comma-separated list
+ for p.current.Type == TokenComma {
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ if p.current.Type != TokenNumber {
+ return nil, fmt.Errorf("expected port number after comma, got '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ port, err := strconv.ParseUint(p.current.Value, 10, 16)
+ if err != nil {
+ return nil, fmt.Errorf("invalid port number '%s' at position %d",
+ p.current.Value, p.current.Pos)
+ }
+ ports = append(ports, uint16(port))
+ if err := p.nextToken(); err != nil {
+ return nil, err
+ }
+ }
+
+ return ports, nil
+}
diff --git a/internal/filter/parser_test.go b/internal/filter/parser_test.go
new file mode 100644
index 0000000..c1c1fbc
--- /dev/null
+++ b/internal/filter/parser_test.go
@@ -0,0 +1,428 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package filter
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParser_BasicActions(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ action Action
+ valid bool
+ }{
+ {"log type", "log type NEW", ActionLog, true},
+ {"drop type", "drop type NEW", ActionDrop, true},
+ {"LOG uppercase", "LOG type NEW", ActionLog, true},
+ {"DROP uppercase", "DROP type NEW", ActionDrop, true},
+ {"missing action", "type NEW", ActionLog, false},
+ {"invalid action", "permit type NEW", ActionLog, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ if !tt.valid {
+ if err == nil {
+ _, err = parser.ParseRule()
+ }
+ assert.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+ assert.Equal(t, tt.action, rule.Action)
+ })
+ }
+}
+
+func TestParser_TypePredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected []string
+ }{
+ {"single type", "log type NEW", []string{"NEW"}},
+ {"multiple types", "log type NEW,UPDATE", []string{"NEW", "UPDATE"}},
+ {"all types", "log type NEW,UPDATE,DESTROY", []string{"NEW", "UPDATE", "DESTROY"}},
+ {"lowercase", "log type new,update", []string{"NEW", "UPDATE"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ typePred, ok := rule.Expr.(TypePredicate)
+ require.True(t, ok, "expected TypePredicate")
+ assert.ElementsMatch(t, tt.expected, typePred.Types)
+ })
+ }
+}
+
+func TestParser_ProtocolPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected []string
+ }{
+ {"single protocol", "log protocol TCP", []string{"TCP"}},
+ {"multiple protocols", "log protocol TCP,UDP", []string{"TCP", "UDP"}},
+ {"lowercase", "log protocol tcp", []string{"TCP"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ protoPred, ok := rule.Expr.(ProtocolPredicate)
+ require.True(t, ok, "expected ProtocolPredicate")
+ assert.ElementsMatch(t, tt.expected, protoPred.Protocols)
+ })
+ }
+}
+
+func TestParser_NetworkPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ direction string
+ networks []string
+ }{
+ {"source private", "log source network PRIVATE", "source", []string{"PRIVATE"}},
+ {"destination public", "log destination network PUBLIC", "destination", []string{"PUBLIC"}},
+ {"dst abbreviation", "log dst network LOCAL", "dst", []string{"LOCAL"}},
+ {"src abbreviation", "log src network MULTICAST", "src", []string{"MULTICAST"}},
+ {"multiple networks", "log source network PRIVATE,LOCAL", "source", []string{"PRIVATE", "LOCAL"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ netPred, ok := rule.Expr.(NetworkPredicate)
+ require.True(t, ok, "expected NetworkPredicate")
+ assert.Equal(t, tt.direction, netPred.Direction)
+ assert.ElementsMatch(t, tt.networks, netPred.Networks)
+ })
+ }
+}
+
+func TestParser_AddressPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ direction string
+ addresses []string
+ ports []uint16
+ }{
+ {"ipv4 address", "log destination address 8.8.8.8", "destination", []string{"8.8.8.8"}, nil},
+ {"ipv6 address", "log source address 2001:db8::1", "source", []string{"2001:db8::1"}, nil},
+ {"cidr", "log destination address 192.168.1.0/24", "destination", []string{"192.168.1.0/24"}, nil},
+ {"address with port", "log destination address 10.19.80.100 on port 53", "destination", []string{"10.19.80.100"}, []uint16{53}},
+ {"multiple addresses", "log source address 1.1.1.1,8.8.8.8", "source", []string{"1.1.1.1", "8.8.8.8"}, nil},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ addrPred, ok := rule.Expr.(AddressPredicate)
+ require.True(t, ok, "expected AddressPredicate")
+ assert.Equal(t, tt.direction, addrPred.Direction)
+ assert.ElementsMatch(t, tt.addresses, addrPred.Addresses)
+ if tt.ports != nil {
+ assert.ElementsMatch(t, tt.ports, addrPred.Ports)
+ }
+ })
+ }
+}
+
+func TestParser_PortPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ direction string
+ ports []uint16
+ }{
+ {"single port", "log destination port 80", "destination", []uint16{80}},
+ {"multiple ports", "log source port 80,443", "source", []uint16{80, 443}},
+ {"port range", "log destination port 8000-8005", "destination", []uint16{8000, 8001, 8002, 8003, 8004, 8005}},
+ {"on port both", "log on port 53", "both", []uint16{53}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ portPred, ok := rule.Expr.(PortPredicate)
+ require.True(t, ok, "expected PortPredicate")
+ assert.Equal(t, tt.direction, portPred.Direction)
+ assert.ElementsMatch(t, tt.ports, portPred.Ports)
+ })
+ }
+}
+
+func TestParser_AnyPredicate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"log any", "log any"},
+ {"drop any", "drop any"},
+ {"ANY uppercase", "log ANY"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ _, ok := rule.Expr.(AnyPredicate)
+ assert.True(t, ok, "expected AnyPredicate")
+ })
+ }
+}
+
+func TestParser_BinaryExpressions(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ op BinaryOp
+ }{
+ {"and operator", "log type NEW and protocol TCP", OpAnd},
+ {"or operator", "log type NEW or type UPDATE", OpOr},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ binExpr, ok := rule.Expr.(BinaryExpr)
+ require.True(t, ok, "expected BinaryExpr")
+ assert.Equal(t, tt.op, binExpr.Op)
+ })
+ }
+}
+
+func TestParser_UnaryExpression(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"not keyword", "log not type NEW"},
+ {"exclamation", "log ! protocol TCP"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ unaryExpr, ok := rule.Expr.(UnaryExpr)
+ require.True(t, ok, "expected UnaryExpr")
+ assert.Equal(t, OpNot, unaryExpr.Op)
+ })
+ }
+}
+
+func TestParser_Parentheses(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"simple grouping", "log (type NEW)"},
+ {"complex grouping", "log (type NEW and protocol TCP) or type UPDATE"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ _, err = parser.ParseRule()
+ require.NoError(t, err)
+ })
+ }
+}
+
+func TestParser_Precedence(t *testing.T) {
+ // Test that AND binds tighter than OR
+ input := "log type NEW or type UPDATE and protocol TCP"
+ parser, err := NewParser(input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+
+ // Should parse as: (type NEW) OR (type UPDATE AND protocol TCP)
+ orExpr, ok := rule.Expr.(BinaryExpr)
+ require.True(t, ok)
+ assert.Equal(t, OpOr, orExpr.Op)
+
+ // Right side should be AND
+ andExpr, ok := orExpr.Right.(BinaryExpr)
+ require.True(t, ok)
+ assert.Equal(t, OpAnd, andExpr.Op)
+}
+
+func TestParser_ComplexExamples(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {
+ "example 1",
+ "drop destination address 8.8.8.8",
+ },
+ {
+ "example 2",
+ "log protocol TCP and destination network PUBLIC",
+ },
+ {
+ "example 3",
+ "drop destination address 10.19.80.100 on port 53",
+ },
+ {
+ "example 4",
+ "log protocol TCP,UDP",
+ },
+ {
+ "complex with negation",
+ "log not (type DESTROY and destination network PRIVATE)",
+ },
+ {
+ "multiple conditions",
+ "drop source network LOCAL and destination port 22,23,3389",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ require.NoError(t, err)
+ rule, err := parser.ParseRule()
+ require.NoError(t, err)
+ assert.NotNil(t, rule)
+ })
+ }
+}
+
+func TestParser_InvalidInputs(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"empty", ""},
+ {"only action", "log"},
+ {"missing action", "type NEW"},
+ {"invalid keyword", "log typo NEW"},
+ {"unclosed paren", "log (type NEW"},
+ {"extra tokens", "log type NEW extra stuff"},
+ {"invalid port", "log destination port abc"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ if err == nil {
+ _, err = parser.ParseRule()
+ }
+ assert.Error(t, err)
+ })
+ }
+}
+
+func TestParser_InvalidTypeValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"invalid type single", "log type FOO"},
+ {"invalid type in list", "log type NEW,BAR"},
+ {"invalid type complex", "log (destination address 78.47.60.169/32) and type NE"},
+ {"typo in type", "log type NWE"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ if err == nil {
+ _, err = parser.ParseRule()
+ }
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid event type")
+ })
+ }
+}
+
+func TestParser_InvalidProtocolValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"invalid protocol single", "log protocol ICMP"},
+ {"invalid protocol in list", "log protocol TCP,ICMP"},
+ {"invalid protocol complex", "log destination address 1.2.3.4 and protocol FOO"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ if err == nil {
+ _, err = parser.ParseRule()
+ }
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid protocol")
+ })
+ }
+}
+
+func TestParser_InvalidNetworkValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ }{
+ {"invalid network single", "log destination network INVALID"},
+ {"invalid network in list", "log source network LOCAL,BOGUS"},
+ {"invalid network complex", "log type NEW and destination network FOO"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser, err := NewParser(tt.input)
+ if err == nil {
+ _, err = parser.ParseRule()
+ }
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid network type")
+ })
+ }
+}
diff --git a/internal/geoip/geoip.go b/internal/geoip/geoip.go
index 746c7b8..99ad1b2 100644
--- a/internal/geoip/geoip.go
+++ b/internal/geoip/geoip.go
@@ -5,17 +5,15 @@ Licensed under the MIT License, see LICENSE file in the project root for details
package geoip
import (
- "log/slog"
+ "fmt"
"net/netip"
+ "strings"
"github.com/oschwald/geoip2-golang/v2"
)
-type Reader struct {
- reader *geoip2.Reader
-}
-
type GeoIP struct {
+ Reader *geoip2.Reader
Database string
}
@@ -26,23 +24,30 @@ type Location struct {
Lon float64
}
-func Open(path string) (*Reader, error) {
- slog.Debug("Opening GeoIP2 database", "path", path)
-
- reader, err := geoip2.Open(path)
+func NewGeoIP(database string) (*GeoIP, error) {
+ reader, err := geoip2.Open(database)
if err != nil {
return nil, err
}
- return &Reader{reader: reader}, nil
+ metadata := reader.Metadata()
+ if !strings.HasSuffix(metadata.DatabaseType, "City") {
+ _ = reader.Close()
+ return nil, fmt.Errorf("invalid GeoIP2 database type: %s, expected City", metadata.DatabaseType)
+ }
+
+ return &GeoIP{
+ Reader: reader,
+ Database: database,
+ }, nil
}
-func (r *Reader) Close() error {
- return r.reader.Close()
+func (g *GeoIP) Close() error {
+ return g.Reader.Close()
}
-func (r *Reader) Location(ip netip.Addr) *Location {
- record, err := r.reader.City(ip)
+func (g *GeoIP) Location(ip netip.Addr) *Location {
+ record, err := g.Reader.City(ip)
if err != nil {
return nil
}
diff --git a/internal/geoip/geoip_test.go b/internal/geoip/geoip_test.go
index 0e40a66..4022af7 100644
--- a/internal/geoip/geoip_test.go
+++ b/internal/geoip/geoip_test.go
@@ -11,15 +11,12 @@ import (
"os"
"testing"
+ "github.com/oschwald/geoip2-golang/v2"
"github.com/stretchr/testify/assert"
)
const geoDatabasePath = "/tmp/GeoLite2-City.mmdb"
-const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb"
-
-func setup() {
- setupGeoDatabase()
-}
+const geoDatabaseUrl = "https://git.io/GeoLite2-City.mmdb"
func setupGeoDatabase() {
if _, err := os.Stat(geoDatabasePath); os.IsNotExist(err) {
@@ -46,35 +43,59 @@ func setupGeoDatabase() {
}
}
-func Test_GeoIP(t *testing.T) {
- setup()
+func newReturnsError_InvalidDatabase(t *testing.T) {
+ _, err := NewGeoIP("../../README.md")
+ assert.EqualError(t, err, "error opening database: invalid MaxMind DB file")
+}
- geo, err := Open("invalid-path.mmdb")
- assert.Error(t, err, "open invalid path")
- assert.Nil(t, geo, "no geoip instance")
+func newReturnsInstance_ValidDatabase(t *testing.T) {
+ geoIP, err := NewGeoIP(geoDatabasePath)
+ assert.NoError(t, err)
+ assert.NotNil(t, geoIP)
+ assert.IsType(t, &GeoIP{}, geoIP)
+ assert.Equal(t, geoIP.Database, geoDatabasePath)
+ assert.IsType(t, geoIP.Reader, &geoip2.Reader{})
+}
- geo, err = Open(geoDatabasePath)
- assert.NoError(t, err, "open valid path")
- assert.NotNil(t, geo, "geoip instance")
- defer func() {
- _ = geo.Close()
- }()
+func locationReturnsNil_UnresolvedIP(t *testing.T) {
+ geo, err := NewGeoIP(geoDatabasePath)
+ assert.NoError(t, err)
+ assert.NotNil(t, geo)
- testees := map[string]string{
+ for ipStr, desc := range map[string]string{
"::1": "local address",
"10.19.80.12": "private address",
"224.0.1.1": "multicast address",
"172.66.43.195": "unresolved address",
- }
-
- for ipStr, desc := range testees {
+ } {
ip, _ := netip.ParseAddr(ipStr)
location := geo.Location(ip)
- assert.Nil(t, location, desc, "no location")
+ assert.Nil(t, location, desc)
}
+}
+
+func locationReturnsLocation_ResolvedIP(t *testing.T) {
+ geo, err := NewGeoIP(geoDatabasePath)
+ assert.NoError(t, err)
+ assert.NotNil(t, geo)
ip, _ := netip.ParseAddr("63.176.75.230")
location := geo.Location(ip)
assert.NotNil(t, location, "resolved address")
assert.IsType(t, &Location{}, location, "location type")
}
+
+func TestGeoIP(t *testing.T) {
+ setupGeoDatabase()
+
+ t.Run("New returns error for invalid database", newReturnsError_InvalidDatabase)
+ t.Run("New returns instance for valid database", newReturnsInstance_ValidDatabase)
+ t.Run("Location returns nil for unresolved IPs", locationReturnsNil_UnresolvedIP)
+ t.Run("Location returns location for resolved IPs", locationReturnsLocation_ResolvedIP)
+
+ skip, ok := os.LookupEnv("KEEP_GEOIP_DB")
+ if ok && skip == "1" || skip == "true" {
+ return
+ }
+ _ = os.Remove(geoDatabasePath)
+}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
index a856d5d..95fe99e 100644
--- a/internal/logger/logger.go
+++ b/internal/logger/logger.go
@@ -5,79 +5,36 @@ Licensed under the MIT License, see LICENSE file in the project root for details
package logger
import (
- "context"
"fmt"
"log/slog"
"os"
)
type Logger struct {
- Format string
- Level string
+ Logger *slog.Logger
+ Level slog.Level
}
-const (
- LevelTrace = slog.Level(-8)
+var (
+ Levels = []string{"debug", "info", "warn", "error"}
+ level slog.Level
)
-var format string
-var level slog.Level
-
-func (l *Logger) Initialize() error {
- var logLevel slog.Level
- switch l.Level {
- case "trace":
- logLevel = LevelTrace
- case "debug":
- logLevel = slog.LevelDebug
- case "info":
- logLevel = slog.LevelInfo
- case "error":
- logLevel = slog.LevelError
- case "":
- logLevel = slog.LevelInfo
- default:
- return fmt.Errorf("unknown log level: %q", l.Level)
+func NewLogger(levelStr string) (*Logger, error) {
+ err := level.UnmarshalText([]byte(levelStr))
+ if err != nil {
+ return nil, fmt.Errorf("unknown log level: %q", levelStr)
}
- level = logLevel
- loggerOptions := &slog.HandlerOptions{
- Level: logLevel,
- ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
- if a.Key == slog.LevelKey {
- if a.Value.String() == "DEBUG-4" {
- a.Value = slog.StringValue("TRACE")
- }
- }
- return a
- },
- }
- if logLevel == LevelTrace {
- loggerOptions.AddSource = true
+ o := &slog.HandlerOptions{Level: level}
+ if level == slog.LevelDebug {
+ o.AddSource = true
}
- var logger *slog.Logger
- switch l.Format {
- case "json":
- logger = slog.New(slog.NewJSONHandler(os.Stderr, loggerOptions))
- case "text", "":
- logger = slog.New(slog.NewTextHandler(os.Stderr, loggerOptions))
- default:
- return fmt.Errorf("unknown log format: %q", l.Format)
- }
- format = l.Format
-
- slog.SetDefault(logger)
-
- return nil
-}
-
-func Trace(msg string, args ...any) {
- slog.Log(context.Background(), LevelTrace, msg, args...)
-}
-
-func Format() string {
- return format
+ return &Logger{
+ Logger: slog.New(slog.NewJSONHandler(os.Stderr, o)),
+ Level: level,
+ }, nil
}
func Level() slog.Level {
diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go
index e5972a8..4658091 100644
--- a/internal/logger/logger_test.go
+++ b/internal/logger/logger_test.go
@@ -5,43 +5,45 @@ Licensed under the MIT License, see LICENSE file in the project root for details
package logger
import (
+ "log/slog"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
)
-func Test_Logger(t *testing.T) {
- testees := []Logger{
- {Format: "json", Level: "trace"},
- {Format: "json", Level: "debug"},
- {Format: "json", Level: "info"},
- {Format: "json", Level: "error"},
- {Format: "text", Level: "trace"},
- {Format: "text", Level: "debug"},
- {Format: "text", Level: "info"},
- {Format: "text", Level: "error"},
- }
+func newReturnsError_UnknownLogLevel(t *testing.T) {
+ _, err := NewLogger("unknown")
+ assert.Error(t, err)
+ assert.EqualError(t, err, `unknown log level: "unknown"`)
+}
- for _, logger := range testees {
- err := logger.Initialize()
- assert.NoError(t, err, "valid logger config")
+func newReturnsLogger_KnownLogLevels(t *testing.T) {
+ for _, level := range []string{"debug", "info", "warn", "error"} {
+ logger, err := NewLogger(level)
+ assert.NoError(t, err)
+ assert.NotNil(t, logger)
+ assert.IsType(t, logger.Logger, &slog.Logger{})
+ assert.Equal(t, strings.ToUpper(level), logger.Level.String())
}
+}
- logger := Logger{}
- err := logger.Initialize()
- assert.NoError(t, err, "default logger config")
-
- logger = Logger{Format: "xml", Level: "info"}
- err = logger.Initialize()
- assert.Errorf(t, err, "unknown log format: %q", "xml")
-
- logger = Logger{Format: "json", Level: "panic"}
- err = logger.Initialize()
- assert.Errorf(t, err, "unknown log level: %q", "error")
+func levelReturnsCorrectLevel(t *testing.T) {
+ for str, level := range map[string]slog.Level{
+ "debug": slog.LevelDebug,
+ "info": slog.LevelInfo,
+ "warn": slog.LevelWarn,
+ "error": slog.LevelError,
+ } {
+ _, err := NewLogger(str)
+ assert.NoError(t, err)
+
+ assert.Equal(t, level, Level())
+ }
+}
- logger = Logger{Format: "json", Level: "debug"}
- err = logger.Initialize()
- assert.NoError(t, err, "valid logger config")
- assert.Equal(t, "json", logger.Format, "logger format set by Initialize args")
- assert.Equal(t, "debug", logger.Level, "logger level set by Initialize args")
+func TestLogger(t *testing.T) {
+ t.Run("New returns error for unknown log level", newReturnsError_UnknownLogLevel)
+ t.Run("New returns logger for known log levels", newReturnsLogger_KnownLogLevels)
+ t.Run("Level returns correct level", levelReturnsCorrectLevel)
}
diff --git a/internal/record/record.go b/internal/record/record.go
index 3c3c8fd..b64f3da 100644
--- a/internal/record/record.go
+++ b/internal/record/record.go
@@ -11,11 +11,10 @@ import (
"github.com/ti-mo/conntrack"
"github.com/tschaefer/conntrackd/internal/geoip"
- "github.com/tschaefer/conntrackd/internal/logger"
)
-func Record(event conntrack.Event, geo *geoip.Reader, sink *slog.Logger) {
- logger.Trace("Conntrack Event", "data", event)
+func Record(event conntrack.Event, geo *geoip.GeoIP, logger *slog.Logger) {
+ slog.Debug("Conntrack Event", "data", event)
protocols := map[int]string{
syscall.IPPROTO_TCP: "TCP",
@@ -80,5 +79,5 @@ func Record(event conntrack.Event, geo *geoip.Reader, sink *slog.Logger) {
event.Flow.TupleOrig.Proto.DestinationPort,
)
- sink.Info(msg, append(established, location...)...)
+ logger.Info(msg, append(established, location...)...)
}
diff --git a/internal/record/record_test.go b/internal/record/record_test.go
index 3308229..ef7aebc 100644
--- a/internal/record/record_test.go
+++ b/internal/record/record_test.go
@@ -23,7 +23,7 @@ import (
)
const geoDatabasePath = "/tmp/GeoLite2-City.mmdb"
-const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb"
+const geoDatabaseUrl = "https://git.io/GeoLite2-City.mmdb"
var log bytes.Buffer
@@ -65,9 +65,8 @@ func setupGeoDatabase() {
}
}
-func Test_Record(t *testing.T) {
+func recordLogsBasicData(t *testing.T) {
logger := setupLogger()
- setupGeoDatabase()
flow := conntrack.NewFlow(
syscall.IPPROTO_TCP,
@@ -82,7 +81,7 @@ func Test_Record(t *testing.T) {
Flow: &flow,
}
- var geo *geoip.Reader
+ var geo *geoip.GeoIP
Record(event, geo, logger)
var result map[string]any
err := json.Unmarshal(log.Bytes(), &result)
@@ -91,21 +90,56 @@ func Test_Record(t *testing.T) {
wanted := []string{"level", "time",
"type", "flow", "prot", "src", "dst", "sport", "dport"}
got := slices.Sorted(maps.Keys(result))
- assert.ElementsMatch(t, wanted, got, "record keys without geoip")
+ assert.ElementsMatch(t, wanted, got, "record basic keys")
+
+ log.Reset()
+}
+
+func recordLogsWithGeoIPData(t *testing.T) {
+ logger := setupLogger()
+
+ flow := conntrack.NewFlow(
+ syscall.IPPROTO_TCP,
+ conntrack.StatusAssured,
+ netip.MustParseAddr("10.19.80.100"), netip.MustParseAddr("78.47.60.169"),
+ 4711, 443,
+ 60, 0,
+ )
+
+ event := conntrack.Event{
+ Type: conntrack.EventNew,
+ Flow: &flow,
+ }
- geo, err = geoip.Open(geoDatabasePath)
+ geo, err := geoip.NewGeoIP(geoDatabasePath)
assert.NoError(t, err)
defer func() {
_ = geo.Close()
}()
- log.Reset()
Record(event, geo, logger)
+ var result map[string]any
err = json.Unmarshal(log.Bytes(), &result)
assert.NoError(t, err)
- wanted = append(wanted, []string{"city", "country", "lat", "lon"}...)
- got = slices.Sorted(maps.Keys(result))
+ wanted := []string{"level", "time",
+ "type", "flow", "prot", "src", "dst", "sport", "dport",
+ "city", "country", "lat", "lon"}
+ got := slices.Sorted(maps.Keys(result))
assert.ElementsMatch(t, wanted, got, "record keys with geoip")
+ log.Reset()
+}
+
+func TestRecord(t *testing.T) {
+ setupGeoDatabase()
+
+ t.Run("Record logs basic data", recordLogsBasicData)
+ t.Run("Record logs with GeoIP data", recordLogsWithGeoIPData)
+
+ skip, ok := os.LookupEnv("KEEP_GEOIP_DB")
+ if ok && skip == "1" || skip == "true" {
+ return
+ }
+ _ = os.Remove(geoDatabasePath)
}
diff --git a/internal/service/service.go b/internal/service/service.go
index 38faaa6..3ee43c4 100644
--- a/internal/service/service.go
+++ b/internal/service/service.go
@@ -7,7 +7,6 @@ package service
import (
"context"
"log/slog"
- "os/signal"
"syscall"
"github.com/mdlayher/netlink"
@@ -23,39 +22,75 @@ import (
)
type Service struct {
- Filter filter.Filter
- Logger logger.Logger
- GeoIP geoip.GeoIP
- Sink sink.Sink
+ Filter *filter.Filter
+ GeoIP *geoip.GeoIP
+ Sink *sink.Sink
+ Logger *slog.Logger
}
-func (s *Service) handler(geo *geoip.Reader, sink *slog.Logger) error {
- con, err := conntrack.Dial(nil)
+func NewService(logger *logger.Logger, geoip *geoip.GeoIP, filter *filter.Filter, sink *sink.Sink) (*Service, error) {
+ slog.SetDefault(logger.Logger)
+
+ return &Service{
+ Filter: filter,
+ GeoIP: geoip,
+ Sink: sink,
+ Logger: logger.Logger,
+ }, nil
+}
+
+func (s *Service) Run(ctx context.Context) bool {
+ slog.Info("Starting conntrack listener.",
+ "release", version.Release(), "commit", version.Commit(),
+ )
+
+ con, err := s.setupConntrack()
if err != nil {
- slog.Error("Failed to dial conntrack.", "error", err)
- return err
+ return false
}
+ defer func() {
+ _ = con.Close()
+ }()
- evCh := make(chan conntrack.Event, 1024)
- errCh, err := con.Listen(evCh, 4, netfilter.GroupsCT)
+ evCh, errCh, err := s.startEventListener(con)
if err != nil {
_ = con.Close()
- slog.Error("Failed to listen to conntrack events.", "error", err)
- return err
+ return false
+ }
+
+ g := s.startEventProcessor(ctx, evCh)
+
+ return s.handleShutdown(ctx, con, g, errCh)
+}
+
+func (s *Service) setupConntrack() (*conntrack.Conn, error) {
+ con, err := conntrack.Dial(nil)
+ if err != nil {
+ slog.Error("Failed to dial conntrack.", "error", err)
+ return nil, err
}
if err := con.SetOption(netlink.ListenAllNSID|netlink.NoENOBUFS, true); err != nil {
_ = con.Close()
slog.Error("Failed to set conntrack listen options.", "error", err)
- return err
+ return nil, err
}
- defer func() {
- _ = con.Close()
- }()
- ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
- defer stop()
+ return con, nil
+}
+func (s *Service) startEventListener(con *conntrack.Conn) (chan conntrack.Event, chan error, error) {
+ evCh := make(chan conntrack.Event, 1024)
+ errCh, err := con.Listen(evCh, 4, netfilter.GroupsCT)
+ if err != nil {
+ slog.Error("Failed to listen to conntrack events.", "error", err)
+ return nil, nil, err
+ }
+
+ return evCh, errCh, nil
+}
+
+func (s *Service) startEventProcessor(ctx context.Context, evCh chan conntrack.Event) *errgroup.Group {
var g errgroup.Group
g.Go(func() error {
for {
@@ -66,73 +101,53 @@ func (s *Service) handler(geo *geoip.Reader, sink *slog.Logger) error {
if !ok {
return nil
}
- go func() {
- if !s.Filter.Apply(event) {
- record.Record(event, geo, sink)
- }
- }()
+ go s.processEvent(event)
}
}
})
+ return &g
+}
+
+func (s *Service) processEvent(event conntrack.Event) {
+ // Only process TCP and UDP events, ignore all other protocols (ICMP, etc.)
+ protocol := event.Flow.TupleOrig.Proto.Protocol
+ if protocol != syscall.IPPROTO_TCP && protocol != syscall.IPPROTO_UDP {
+ return
+ }
+ shouldRecord := true
+ if s.Filter != nil {
+ _, shouldLog, _ := s.Filter.Evaluate(event)
+ shouldRecord = shouldLog
+ }
+
+ if shouldRecord {
+ record.Record(event, s.GeoIP, s.Sink.Logger)
+ }
+}
+
+func (s *Service) handleShutdown(ctx context.Context, con *conntrack.Conn, g *errgroup.Group, errCh chan error) bool {
select {
case err := <-errCh:
if err != nil {
slog.Error("Conntrack listener error.", "error", err)
- stop()
_ = con.Close()
if gErr := g.Wait(); gErr != nil {
slog.Error("Event loop returned error during shutdown.", "error", gErr)
}
- return err
+ return false
}
- stop()
_ = con.Close()
if gErr := g.Wait(); gErr != nil {
slog.Error("Event loop returned error during shutdown.", "error", gErr)
}
- return nil
+ return true
case <-ctx.Done():
slog.Info("Shutting down conntrack listener.")
_ = con.Close()
if gErr := g.Wait(); gErr != nil {
slog.Error("Event loop returned error during shutdown.", "error", gErr)
}
- return nil
+ return true
}
}
-
-func (s *Service) Run() error {
- if err := s.Logger.Initialize(); err != nil {
- slog.Error("Failed to initialize logger.", "error", err)
- return err
- }
-
- slog.Debug("Running Service.", "data", s)
-
- slog.Info("Starting conntrack listener.",
- "release", version.Release(), "commit", version.Commit(),
- "filter", s.Filter,
- "geoip", s.GeoIP.Database,
- )
-
- sink, err := s.Sink.Initialize()
- if err != nil {
- slog.Error("Failed to initialize sink.", "error", err)
- return err
- }
-
- var geo *geoip.Reader
- if s.GeoIP.Database != "" {
- geo, err = geoip.Open(s.GeoIP.Database)
- if err != nil {
- slog.Error("Failed to open geoip database.", "error", err)
- return err
- }
- defer func() {
- _ = geo.Close()
- }()
- }
-
- return s.handler(geo, sink)
-}
diff --git a/internal/sink/journal.go b/internal/sink/journal.go
index 474db64..cb10a3d 100644
--- a/internal/sink/journal.go
+++ b/internal/sink/journal.go
@@ -15,8 +15,6 @@ type Journal struct {
}
func (j *Journal) TargetJournal(options *slog.HandlerOptions) (slog.Handler, error) {
- slog.Debug("Initializing systemd journal sink.")
-
slogjournal.FieldPrefix = "EVENT"
o := &slogjournal.Option{
Level: options.Level,
diff --git a/internal/sink/loki.go b/internal/sink/loki.go
index 2939763..81d0860 100644
--- a/internal/sink/loki.go
+++ b/internal/sink/loki.go
@@ -12,7 +12,8 @@ import (
"os"
"strings"
- klog "github.com/go-kit/log"
+ kitlog "github.com/go-kit/log"
+ kitlevel "github.com/go-kit/log/level"
"github.com/grafana/loki-client-go/loki"
"github.com/grafana/loki-client-go/pkg/labelutil"
"github.com/prometheus/common/model"
@@ -20,100 +21,104 @@ import (
"github.com/tschaefer/conntrackd/internal/logger"
)
+const (
+ readyPath = "/ready"
+ pushPath = "/loki/api/v1/push"
+)
+
type Loki struct {
Enable bool
Address string
Labels []string
}
-const (
- readyPath = "/ready"
- pushPath = "/loki/api/v1/push"
-)
+var LokiProtocols = []string{"http", "https"}
-func (l *Loki) isReady() error {
- uri, err := url.Parse(l.Address)
- if err != nil {
- return err
- }
- uri.Path = uri.Path + readyPath
-
- response, err := http.Get(uri.String())
+func (l *Loki) TargetLoki(options *slog.HandlerOptions) (slog.Handler, error) {
+ url, err := url.Parse(l.Address)
if err != nil {
- return err
+ return nil, err
}
- defer func() {
- _ = response.Body.Close()
- }()
- if response.StatusCode != http.StatusOK {
- return errors.New(response.Status)
+ if err := l.isReady(*url); err != nil {
+ return nil, err
}
- return nil
-}
-
-func (l *Loki) TargetLoki(options *slog.HandlerOptions) (slog.Handler, error) {
- slog.Debug("Initializing Grafana Loki sink.", "data", l)
-
- if err := l.isReady(); err != nil {
+ url.Path = url.Path + pushPath
+ config, err := loki.NewDefaultConfig(url.String())
+ if err != nil {
return nil, err
}
- uri, err := url.Parse(l.Address)
+ hostname, err := os.Hostname()
if err != nil {
- return nil, err
+ hostname = "unknown"
}
- uri.Path = uri.Path + pushPath
+ config.ExternalLabels = l.setLabels(hostname)
- config, err := loki.NewDefaultConfig(uri.String())
+ klogger := l.createLogger()
+ client, err := loki.NewWithLogger(config, klogger)
if err != nil {
return nil, err
}
- hostname, err := os.Hostname()
+
+ o := &slogloki.Option{
+ Client: client,
+ Level: options.Level,
+ }
+ return o.NewLokiHandler(), nil
+}
+
+func (l *Loki) isReady(url url.URL) error {
+ url.Path = url.Path + readyPath
+
+ response, err := http.Get(url.String())
if err != nil {
- hostname = "unknown"
+ return err
+ }
+ defer func() {
+ _ = response.Body.Close()
+ }()
+
+ if response.StatusCode != http.StatusOK {
+ return errors.New(response.Status)
}
- config.ExternalLabels = labelutil.LabelSet{
+ return nil
+}
+
+func (l *Loki) setLabels(hostname string) labelutil.LabelSet {
+ labels := labelutil.LabelSet{
LabelSet: model.LabelSet{
model.LabelName("service_name"): model.LabelValue("conntrackd"),
model.LabelName("host"): model.LabelValue(hostname),
},
}
- if len(l.Labels) > 0 {
- for _, label := range l.Labels {
- if !strings.Contains(label, "=") {
- continue
- }
- parts := strings.SplitN(label, "=", 2)
- key := parts[0]
- value := parts[1]
- config.ExternalLabels.LabelSet[model.LabelName(key)] = model.LabelValue(value)
- }
+ if len(l.Labels) == 0 {
+ return labels
}
- sw := klog.NewSyncWriter(os.Stderr)
- var klogger klog.Logger
- switch logger.Format() {
- case "json":
- klogger = klog.NewJSONLogger(sw)
- case "text":
- fallthrough
- default:
- klogger = klog.NewLogfmtLogger(sw)
+ for _, label := range l.Labels {
+ if !strings.Contains(label, "=") {
+ continue
+ }
+ parts := strings.SplitN(label, "=", 2)
+ key := parts[0]
+ value := parts[1]
+ labels.LabelSet[model.LabelName(key)] = model.LabelValue(value)
}
- klogger = klog.With(klogger, "time", klog.DefaultTimestamp, "sink", "loki")
- client, err := loki.NewWithLogger(config, klogger)
- if err != nil {
- return nil, err
- }
+ return labels
+}
- o := &slogloki.Option{
- Client: client,
- Level: options.Level,
- }
- return o.NewLokiHandler(), nil
+func (l *Loki) createLogger() kitlog.Logger {
+ level := logger.Level().String()
+ klevel := kitlevel.ParseDefault(level, kitlevel.InfoValue())
+
+ klogger := kitlog.NewJSONLogger(kitlog.NewSyncWriter(os.Stderr))
+ klogger = kitlevel.NewFilter(klogger, kitlevel.Allow(klevel))
+ klogger = kitlog.With(klogger, "time", kitlog.DefaultTimestamp, "sink", "loki")
+
+ return klogger
}
diff --git a/internal/sink/sink.go b/internal/sink/sink.go
index 66c2efb..85190c5 100644
--- a/internal/sink/sink.go
+++ b/internal/sink/sink.go
@@ -6,12 +6,22 @@ package sink
import (
"errors"
+ "fmt"
"log/slog"
+ "os"
slogmulti "github.com/samber/slog-multi"
)
+const (
+ ExitOnWarningEnv string = "CONNTRACKD_SINK_EXIT_ON_WARNING"
+)
+
type Sink struct {
+ Logger *slog.Logger
+}
+
+type Config struct {
Journal Journal
Syslog Syslog
Loki Loki
@@ -20,53 +30,48 @@ type Sink struct {
type SinkTarget func(*slog.HandlerOptions) (slog.Handler, error)
-func (s *Sink) Initialize() (*slog.Logger, error) {
- slog.Debug("Initializing sink targets.", "data", s)
-
+func NewSink(config *Config) (*Sink, error) {
options := &slog.HandlerOptions{
Level: slog.LevelInfo,
}
- var handlers []slog.Handler
- if s.Journal.Enable {
- handler, err := s.Journal.TargetJournal(options)
- if err != nil {
- slog.Warn("Failed to initialize journal sink", "error", err)
- } else {
- handlers = append(handlers, handler)
- }
+ exitOnWarning := false
+ envExitOnWarning, ok := os.LookupEnv(ExitOnWarningEnv)
+ if ok && envExitOnWarning == "1" || envExitOnWarning == "true" {
+ exitOnWarning = true
}
- if s.Syslog.Enable {
- handler, err := s.Syslog.TargetSyslog(options)
- if err != nil {
- slog.Warn("Failed to initialize syslog sink", "error", err)
- } else {
- handlers = append(handlers, handler)
- }
- }
+ var handlers []slog.Handler
- if s.Loki.Enable {
- handler, err := s.Loki.TargetLoki(options)
- if err != nil {
- slog.Warn("Failed to initialize loki sink", "error", err)
- } else {
- handlers = append(handlers, handler)
- }
+ targets := []struct {
+ name string
+ enabled bool
+ init SinkTarget
+ }{
+ {"journal", config.Journal.Enable, config.Journal.TargetJournal},
+ {"syslog", config.Syslog.Enable, config.Syslog.TargetSyslog},
+ {"loki", config.Loki.Enable, config.Loki.TargetLoki},
+ {"stream", config.Stream.Enable, config.Stream.TargetStream},
}
- if s.Stream.Enable {
- handler, err := s.Stream.TargetStream(options)
+ for _, t := range targets {
+ if !t.enabled {
+ continue
+ }
+ handler, err := t.init(options)
if err != nil {
- slog.Warn("Failed to initialize stream sink", "error", err)
- } else {
- handlers = append(handlers, handler)
+ fmt.Fprintf(os.Stderr, "Warning: Failed to initialize sink %q: %v\n", t.name, err)
+ if exitOnWarning {
+ os.Exit(1)
+ }
+ continue
}
+ handlers = append(handlers, handler)
}
if len(handlers) == 0 {
return nil, errors.New("no target sink available")
}
- return slog.New(slogmulti.Fanout(handlers...)), nil
+ return &Sink{Logger: slog.New(slogmulti.Fanout(handlers...))}, nil
}
diff --git a/internal/sink/sink_test.go b/internal/sink/sink_test.go
new file mode 100644
index 0000000..047247a
--- /dev/null
+++ b/internal/sink/sink_test.go
@@ -0,0 +1,76 @@
+/*
+Copyright (c) 2025 Tobias Schäfer. All rights reserved.
+Licensed under the MIT License, see LICENSE file in the project root for details.
+*/
+package sink
+
+import (
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func capture(f func()) string {
+ originalStderr := os.Stderr
+
+ r, w, _ := os.Pipe()
+ os.Stderr = w
+
+ f()
+
+ _ = w.Close()
+ os.Stderr = originalStderr
+
+ var buf = make([]byte, 5096)
+ n, _ := r.Read(buf)
+ return string(buf[:n])
+}
+
+func newReturnsError_NoTargetsEnabled(t *testing.T) {
+ config := &Config{
+ Journal: Journal{Enable: false},
+ Syslog: Syslog{Enable: false},
+ Loki: Loki{Enable: false},
+ Stream: Stream{Enable: false},
+ }
+
+ sink, err := NewSink(config)
+
+ assert.Nil(t, sink)
+ assert.NotNil(t, err)
+ assert.EqualError(t, err, "no target sink available")
+}
+
+func newReturnsSink_TargetsEnabled(t *testing.T) {
+ config := &Config{
+ Journal: Journal{Enable: false},
+ Syslog: Syslog{Enable: false},
+ Loki: Loki{Enable: false},
+ Stream: Stream{Enable: true, Writer: "discard"},
+ }
+
+ sink, err := NewSink(config)
+ assert.NotNil(t, sink)
+ assert.Nil(t, err)
+ assert.IsType(t, &Sink{}, sink)
+}
+
+func newPrintsWarning_TargetInitFails(t *testing.T) {
+ config := &Config{
+ Journal: Journal{Enable: false},
+ Syslog: Syslog{Enable: false},
+ Loki: Loki{Enable: true, Address: "http://invalid-address"},
+ Stream: Stream{Enable: true, Writer: "discard"},
+ }
+ warning := capture(func() {
+ _, _ = NewSink(config)
+ })
+ assert.Contains(t, warning, "Warning: Failed to initialize sink \"loki\"")
+}
+
+func TestSink(t *testing.T) {
+ t.Run("NewSink returns error if no targets are enabled", newReturnsError_NoTargetsEnabled)
+ t.Run("NewSink returns sink if targets enabled", newReturnsSink_TargetsEnabled)
+ t.Run("NewSink prints warning if target init fails", newPrintsWarning_TargetInitFails)
+}
diff --git a/internal/sink/stream.go b/internal/sink/stream.go
index b953937..155a1ab 100644
--- a/internal/sink/stream.go
+++ b/internal/sink/stream.go
@@ -16,16 +16,17 @@ type Stream struct {
Writer string
}
+var StreamWriters = []string{"stdout", "stderr", "discard"}
+
func (s *Stream) TargetStream(options *slog.HandlerOptions) (slog.Handler, error) {
- slog.Debug("Initializing stream sink.")
+ writer := map[string]io.Writer{
+ "stdout": os.Stdout,
+ "stderr": os.Stderr,
+ "discard": io.Discard,
+ }
- switch s.Writer {
- case "stdout":
- return slog.NewJSONHandler(os.Stdout, options), nil
- case "stderr":
- return slog.NewJSONHandler(os.Stderr, options), nil
- case "discard":
- return slog.NewJSONHandler(io.Discard, options), nil
+ if w, ok := writer[s.Writer]; ok {
+ return slog.NewJSONHandler(w, options), nil
}
return nil, fmt.Errorf("invalid stream writer specified: %q", s.Writer)
diff --git a/internal/sink/syslog.go b/internal/sink/syslog.go
index d98bc10..267c103 100644
--- a/internal/sink/syslog.go
+++ b/internal/sink/syslog.go
@@ -18,18 +18,18 @@ type Syslog struct {
Address string
}
-func (s *Syslog) TargetSyslog(options *slog.HandlerOptions) (slog.Handler, error) {
- slog.Debug("Initializing syslog sink", "data", s)
+var SyslogProtocols = []string{"udp", "tcp", "unix", "unixgram", "unixpacket"}
- uri, err := url.Parse(s.Address)
+func (s *Syslog) TargetSyslog(options *slog.HandlerOptions) (slog.Handler, error) {
+ url, err := url.Parse(s.Address)
if err != nil {
return nil, err
}
- network := uri.Scheme
- address := uri.Host
+ network := url.Scheme
+ address := url.Host
if strings.HasPrefix(network, "unix") {
- address = uri.Path
+ address = url.Path
}
writer, err := net.Dial(network, address)