diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 79299d2..c81226e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,4 +32,7 @@ jobs: - name: Run linters run: golangci-lint run - name: Run tests - run: go test -v ./... + run: | + curl -sSfL https://git.io/GeoLite2-City.mmdb -o /tmp/GeoLite2-City.mmdb && \ + export KEEP_GEOIP_DB=true && \ + go test -v ./... diff --git a/README.md b/README.md index 13fc43d..f855643 100644 --- a/README.md +++ b/README.md @@ -41,43 +41,74 @@ sudo conntrackd run --sink.journal.enable ``` For further configuration, see the command-line options below. -| Flag | Description | Options | -|--------------------------------|------------------------------------|----------------------------------------| -| `filter.destinations` | Filter by destination networks | PUBLIC,PRIVATE,LOCAL,MULTICAST | -| `filter.sourcess` | Filter by source networks | PUBLIC,PRIVATE,LOCAL,MULTICAST | -| `filter.protocols` | Filter by protocols | TCP,UDP | -| `filter.types` | Filter by event types | NEW,UPDATE,DESTROY | -| `filter.destination.addresses` | Filter by destination IP addresses | | -| `filter.source.addresses` | Filter by source IP addresses | | -| `filter.destination.ports` | Filter by destination ports | | -| `filter.source.ports` | Filter by source ports | | -| `geoip.database` | Path to GeoIP database | | -| `service.log.format` | Log format | json,text; default: text | -| `service.log.level` | Log level | trace,debug,info,error; default: info | -| `sink.journal.enable` | Enable journald sink | | -| `sink.syslog.enable` | Enable syslog sink | | -| `sink.enable.loki` | Enable Loki sink | | -| `sink.stream.enable` | Enable stream sink | | -| `sink.syslog.address` | Syslog address | default: udp://localhost:514 | -| `sink.loki.address` | Loki address | default: http://localhost:3100 | -| `sink.loki.labels` | Loki labels | comma seperated key=value pairs | -| `sink.stream.writer` | Stream writer type | stdout,stderr,discard; default: stdout | - -All filters are exclusive; if any filter is not set, all related events are processed. - -Example run: +## Filtering + +conntrackd logs conntrack events to various sinks. + +**Protocol Support:** Only TCP and UDP events are processed. All other protocols +(ICMP, IGMP, etc.) are automatically ignored and never logged, regardless of filter rules. + +You can use filters to control which TCP/UDP events are logged using a +Domain-Specific Language (DSL). +The `--filter` flag lets you specify filter rules: ```bash sudo conntrackd run \ - --geoip.database /usr/local/share/GeoLite2-City.mmdb \ - --filter.destination PRIVATE \ - --filter.protocol UDP \ - --filter.destination.addresses 142.250.186.163,2a00:1450:4001:82b::2003 - --sink.journal.enable \ - --service.log.format json \ - --service.log.level debug + --filter "drop destination address 8.8.8.8" \ + --filter "log protocol TCP and destination network PUBLIC" \ + --filter "drop any" \ + --sink.journal.enable +``` + +**Filter Rules:** +- Rules are evaluated in order (first-match wins) +- Events are **logged by default** when no rule matches +- `--filter` flag can be repeated for multiple rules +- Use `drop any` as a final rule to block all non-matching events from being logged + +**Important:** Filters control which conntrack events are **logged**, +not network traffic. Traffic always flows normally; filters only affect logging. + +**Common Filter Examples:** + +```bash +# Don't log events to a specific IP +--filter "drop destination address 8.8.8.8" + +# Log only NEW TCP connections (deny everything else) +--filter "log type NEW and protocol TCP" +--filter "drop any" + +# Don't log DNS to specific server +--filter "drop destination address 10.19.80.100 on port 53" + +# Don't log any traffic to private networks +--filter "drop destination network PRIVATE" + +# Log only traffic from public IPs using TCP +--filter "log source network PUBLIC and protocol TCP" +--filter "drop any" ``` +See [docs/filter.md](docs/filter.md) for complete DSL documentation, +including grammar, operators, and advanced examples. + +## Configuration Flags + +| Flag | Description | Default | +|-------------------------|---------------------------------------------------|--------------------------| +| `--filter` | Filter rule in DSL format (repeatable) | | +| `--geoip.database` | Path to GeoIP database | | +| `--service.log.level` | Log level (debug, info, warn, error) | info | +| `--sink.journal.enable` | Enable journald sink | | +| `--sink.syslog.enable` | Enable syslog sink | | +| `--sink.loki.enable` | Enable Loki sink | | +| `--sink.stream.enable` | Enable stream sink | | +| `--sink.syslog.address` | Syslog address | udp://localhost:514 | +| `--sink.loki.address` | Loki address | http://localhost:3100 | +| `--sink.loki.labels` | Loki labels (comma-separated key=value pairs) | | +| `--sink.stream.writer` | Stream writer (stdout, stderr, discard) | stdout | + ## Logging format conntrackd emits structured logs for each conntrack event. A typical log entry @@ -100,7 +131,8 @@ GEO location fields: - lat (latitude) - lon (longitude) -Example log entry recorded by sink `syslog`: +
+Example log entry recorded by sink `syslog` ```json { @@ -117,16 +149,18 @@ Example log entry recorded by sink `syslog`: "level": "INFO", "logger.name": "samber/slog-syslog", "logger.version": "v2.5.2", - "message": "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/41348...", + "message": "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd...", "timestamp": "2025-11-15T09:55:25.647544937Z" } ``` +
-Example log entry recorded by sink `journal`: +
+Example log entry recorded by sink `journal` ```json { - "__CURSOR" : "s=b3c7821dbfce47a59b06797aea9028ca;i=6772d3;b=100da27bd8...", + "__CURSOR" : "s=b3c7821dbfce47a59b06797aea9028ca;i=6772d3;b=100da27bd...", "_CAP_EFFECTIVE" : "1ffffffffff", "EVENT_SPORT" : "39790", "_SOURCE_REALTIME_TIMESTAMP" : "1763200187611509", @@ -154,7 +188,7 @@ Example log entry recorded by sink `journal`: "_SYSTEMD_INVOCATION_ID" : "021760b3373342b98aaeabf9d12d8d74", "EVENT_FLOW" : "3478798157", "_PID" : "3794900", - "_CMDLINE" : "conntrackd run --service.log.level debug --service.log.format ...", + "_CMDLINE" : "conntrackd run --service.log.level debug --service.log....", "EVENT_PROT" : "TCP", "_AUDIT_SESSION" : "1", "_BOOT_ID" : "100da27bd8b94096b5c80cdac34d6063", @@ -164,12 +198,14 @@ Example log entry recorded by sink `journal`: "_AUDIT_LOGINUID" : "1000", "_UID" : "0", "EVENT_TYPE" : "UPDATE", - "MESSAGE" : "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/39790..." + "MESSAGE" : "UPDATE TCP connection from 2003:cf:1716:7b64:da80:83ff:fe..." } ``` +
-Example log entry recorded by sink `loki`: +
+Example log entry recorded by sink `loki` ```json { @@ -194,11 +230,33 @@ Example log entry recorded by sink `loki`: "values": [ [ "1763537351540294198", - "UPDATE TCP connection from 2003:cf:1716:7b64:d6e9:8aff:fe4f:7a59/44950..." + "UPDATE TCP connection from 2003:cf:1716:7b64:d6e9:8aff:fe4f:7a59/44..." ] ] } ``` +
+ +
+Example log entry recorded by sink `stream` + +```json +{ + "time": "2025-11-22T11:34:43.181432081+01:00", + "level": "INFO", + "msg": "NEW TCP connection from 2003:cf:1716:7b64:da80:83ff:fecd:da51/4...", + "type": "NEW", + "flow": 2899284024, + "prot": "TCP", + "src": "2003:cf:1716:7b64:da80:83ff:fecd:da51", + "dst": "2a01:4f8:160:5372::2", + "sport": 41220, + "dport": 80, + "state": "SYN_SENT" +} +``` +
+ ## Security Notes @@ -212,7 +270,8 @@ Example log entry recorded by sink `loki`: Contributions are welcome! Please fork the repository and submit a pull request. For major changes, open an issue first to discuss what you would like to change. -Ensure that your code adheres to the existing style and includes appropriate tests. +Ensure that your code adheres to the existing style and includes appropriate +tests. ## License diff --git a/cmd/run.go b/cmd/run.go index c313bd6..3ace724 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -5,81 +5,70 @@ Licensed under the MIT License, see LICENSE file in the project root for details package cmd import ( + "context" "fmt" - "net/netip" - "net/url" "os" - "slices" + "os/signal" "strings" + "syscall" "github.com/spf13/cobra" + "github.com/tschaefer/conntrackd/internal/filter" + "github.com/tschaefer/conntrackd/internal/geoip" + "github.com/tschaefer/conntrackd/internal/logger" "github.com/tschaefer/conntrackd/internal/service" + "github.com/tschaefer/conntrackd/internal/sink" ) -var srv = service.Service{} - -var ( - validEventTypes = []string{"NEW", "UPDATE", "DESTROY"} - validProtocols = []string{"TCP", "UDP"} - validNetworks = []string{"PUBLIC", "PRIVATE", "LOCAL", "MULTICAST"} - validLogLevels = []string{"trace", "debug", "info", "error"} - validLogFormats = []string{"text", "json"} - validSyslogSchemes = []string{"udp", "tcp", "unix", "unixgram", "unixpacket"} - validLokiSchemes = []string{"http", "https"} - validStreamWriters = []string{"stdout", "stderr", "discard"} -) +type Options struct { + logLevel string + geoipDatabase string + filterRules []string + sink sink.Config +} + +var options = Options{} var runCmd = &cobra.Command{ Use: "run", Short: "Run the conntrackd service", Run: func(cmd *cobra.Command, args []string) { - if !srv.Sink.Journal.Enable && - !srv.Sink.Syslog.Enable && - !srv.Sink.Loki.Enable && - !srv.Sink.Stream.Enable { - cobra.CheckErr(fmt.Errorf("at least one sink must be enabled")) + l, err := logger.NewLogger(options.logLevel) + if err != nil { + cobra.CheckErr(fmt.Sprintf("Failed to create logger: %v", err)) } - err := validateStringFlag("sink.syslog.address", srv.Sink.Syslog.Address, []string{}) - cobra.CheckErr(err) - - err = validateStringFlag("sink.loki.address", srv.Sink.Loki.Address, []string{}) - cobra.CheckErr(err) - - err = validateStringFlag("sink.stream.writer", srv.Sink.Stream.Writer, validStreamWriters) - cobra.CheckErr(err) - - err = validateStringSliceFlag("filter.types", srv.Filter.EventTypes, validEventTypes) - cobra.CheckErr(err) - - err = validateStringSliceFlag("filter.protocols", srv.Filter.Protocols, validProtocols) - cobra.CheckErr(err) - - err = validateStringSliceFlag("filter.destination.networks", srv.Filter.Networks.Destinations, validNetworks) - cobra.CheckErr(err) - - err = validateStringSliceFlag("filter.source.networks", srv.Filter.Networks.Sources, validNetworks) - cobra.CheckErr(err) - - err = validateStringSliceFlag("filter.destination.addresses", srv.Filter.Addresses.Destinations, []string{}) - cobra.CheckErr(err) + var g *geoip.GeoIP + if options.geoipDatabase != "" { + g, err = geoip.NewGeoIP(options.geoipDatabase) + if err != nil { + cobra.CheckErr(fmt.Sprintf("Failed to open geoip database: %v", err)) + } + defer func() { + _ = g.Close() + }() + } - err = validateStringSliceFlag("filter.source.addresses", srv.Filter.Addresses.Sources, []string{}) - cobra.CheckErr(err) + var f *filter.Filter + if len(options.filterRules) > 0 { + f, err = filter.NewFilter(options.filterRules) + if err != nil { + cobra.CheckErr(fmt.Sprintf("failed to compile filter rules: %v", err)) + } + } - err = validateStringFlag("service.log.level", srv.Logger.Level, validLogLevels) - cobra.CheckErr(err) + s, err := sink.NewSink(&options.sink) + if err != nil { + cobra.CheckErr(fmt.Sprintf("failed to initialize sink: %v", err)) + } - err = validateStringFlag("service.log.format", srv.Logger.Format, validLogFormats) + service, err := service.NewService(l, g, f, s) cobra.CheckErr(err) - if srv.GeoIP.Database != "" { - if _, err := os.Stat(srv.GeoIP.Database); os.IsNotExist(err) { - cobra.CheckErr(fmt.Errorf("GeoIP database file does not exist: %s", srv.GeoIP.Database)) - } - } + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() - if err := srv.Run(); err != nil { + if tranquil := service.Run(ctx); !tranquil { os.Exit(1) } }, @@ -88,99 +77,28 @@ var runCmd = &cobra.Command{ func init() { runCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp) - runCmd.Flags().StringSliceVar(&srv.Filter.EventTypes, "filter.types", nil, "Filter by event type (NEW,UPDATE,DESTROY)") - runCmd.Flags().StringSliceVar(&srv.Filter.Protocols, "filter.protocols", nil, "Filter by protocol (TCP,UDP)") - runCmd.Flags().StringSliceVar(&srv.Filter.Networks.Destinations, "filter.destination.networks", nil, "Filter by destination networks (PUBLIC,PRIVATE,LOCAL,MULTICAST)") - runCmd.Flags().StringSliceVar(&srv.Filter.Networks.Sources, "filter.source.networks", nil, "Filter by sources networks (PUBLIC,PRIVATE,LOCAL,MULTICAST)") - runCmd.Flags().StringSliceVar(&srv.Filter.Addresses.Destinations, "filter.destination.addresses", nil, "Filter by destination IP addresses") - runCmd.Flags().StringSliceVar(&srv.Filter.Addresses.Sources, "filter.source.addresses", nil, "Filter by source IP addresses") - runCmd.Flags().UintSliceVar(&srv.Filter.Ports.Destinations, "filter.destination.ports", nil, "Filter by destination ports") - runCmd.Flags().UintSliceVar(&srv.Filter.Ports.Sources, "filter.source.ports", nil, "Filter by source ports") - - runCmd.Flags().StringVar(&srv.Logger.Format, "service.log.format", "", "Log format (text,json)") - runCmd.Flags().StringVar(&srv.Logger.Level, "service.log.level", "", "Log level (debug,info)") - - runCmd.Flags().StringVar(&srv.GeoIP.Database, "geoip.database", "", "Path to GeoIP database") - - runCmd.Flags().BoolVar(&srv.Sink.Journal.Enable, "sink.journal.enable", false, "Enable journald sink") - runCmd.Flags().BoolVar(&srv.Sink.Syslog.Enable, "sink.syslog.enable", false, "Enable syslog sink") - runCmd.Flags().StringVar(&srv.Sink.Syslog.Address, "sink.syslog.address", "udp://localhost:514", "Syslog address") - runCmd.Flags().BoolVar(&srv.Sink.Loki.Enable, "sink.loki.enable", false, "Enable Loki sink") - runCmd.Flags().StringVar(&srv.Sink.Loki.Address, "sink.loki.address", "http://localhost:3100", "Loki address") - runCmd.Flags().StringSliceVar(&srv.Sink.Loki.Labels, "sink.loki.labels", nil, "Additional labels for Loki sink in key=value format") - runCmd.Flags().BoolVar(&srv.Sink.Stream.Enable, "sink.stream.enable", false, "Enable stream sink") - runCmd.Flags().StringVar(&srv.Sink.Stream.Writer, "sink.stream.writer", "stdout", "Stream writer (stdout,stderr,discard)") + runCmd.Flags().StringArrayVar(&options.filterRules, "filter", nil, "Filter rules in DSL format (repeatable, first-match wins)") + + runCmd.Flags().StringVar(&options.logLevel, "log.level", "info", fmt.Sprintf("Log level (%s)", strings.Join(logger.Levels, ", "))) + _ = runCmd.RegisterFlagCompletionFunc("log.level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + return logger.Levels, cobra.ShellCompDirectiveNoFileComp + }) + runCmd.Flags().StringVar(&options.geoipDatabase, "geoip.database", "", "Path to GeoIP database") _ = runCmd.RegisterFlagCompletionFunc("geoip.database", cobra.FixedCompletions(nil, cobra.ShellCompDirectiveDefault)) - _ = runCmd.RegisterFlagCompletionFunc("service.log.level", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return validLogLevels, cobra.ShellCompDirectiveNoFileComp - }) + runCmd.Flags().BoolVar(&options.sink.Journal.Enable, "sink.journal.enable", false, "Enable journald sink") + runCmd.Flags().BoolVar(&options.sink.Syslog.Enable, "sink.syslog.enable", false, "Enable syslog sink") + runCmd.Flags().StringVar(&options.sink.Syslog.Address, "sink.syslog.address", "udp://localhost:514", "Syslog address") - _ = runCmd.RegisterFlagCompletionFunc("service.log.format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return validLogFormats, cobra.ShellCompDirectiveNoFileComp - }) + runCmd.Flags().BoolVar(&options.sink.Loki.Enable, "sink.loki.enable", false, "Enable Loki sink") + runCmd.Flags().StringVar(&options.sink.Loki.Address, "sink.loki.address", "http://localhost:3100", "Loki address") + runCmd.Flags().StringSliceVar(&options.sink.Loki.Labels, "sink.loki.labels", nil, "Additional labels for Loki sink in key=value format") + runCmd.Flags().BoolVar(&options.sink.Stream.Enable, "sink.stream.enable", false, "Enable stream sink") + runCmd.Flags().StringVar(&options.sink.Stream.Writer, "sink.stream.writer", "stdout", fmt.Sprintf("Stream writer (%s)", strings.Join(sink.StreamWriters, ", "))) _ = runCmd.RegisterFlagCompletionFunc("sink.stream.writer", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - return validStreamWriters, cobra.ShellCompDirectiveNoFileComp + return sink.StreamWriters, cobra.ShellCompDirectiveNoFileComp }) -} - -func validateStringSliceFlag(flag string, values []string, validValues []string) error { - if flag == "filter.source.addresses" || flag == "filter.destination.addresses" { - for _, v := range values { - if _, err := netip.ParseAddr(v); err != nil { - return fmt.Errorf("invalid IP address '%s' for '--%s'", v, flag) - } - } - return nil - } - - for _, v := range values { - if !slices.Contains(validValues, v) { - return fmt.Errorf("invalid value '%s' for '--%s' . Valid values are: %s", v, flag, validValues) - } - } - return nil -} - -func validateStringFlag(flag string, value string, validValues []string) error { - if value == "" { - return nil - } - - if flag == "sink.syslog.address" || flag == "sink.loki.address" { - url, err := url.Parse(value) - if err != nil { - return fmt.Errorf("invalid URL '%s' for '--%s'", value, flag) - } - - if flag == "sink.syslog.address" { - if !slices.Contains(validSyslogSchemes, url.Scheme) { - return fmt.Errorf("invalid URL scheme '%s' for '--%s'. Valid schemes are: udp, tcp, unix, unixgram unixpacket", url.Scheme, flag) - } - if url.Host == "" && !strings.HasPrefix(url.Scheme, "unix") { - return fmt.Errorf("invalid URL '%s' for '--%s'. Host is missing", value, flag) - } - if url.Path == "" && strings.HasPrefix(url.Scheme, "unix") { - return fmt.Errorf("invalid URL '%s' for '--%s'. Path is missing", value, flag) - } - } - - if flag == "sink.loki.address" { - if !slices.Contains(validLokiSchemes, url.Scheme) { - return fmt.Errorf("invalid URL scheme '%s' for '--%s'. Valid schemes are: http, https", url.Scheme, flag) - } - if url.Host == "" { - return fmt.Errorf("invalid URL '%s' for '--%s'. Host is missing", value, flag) - } - } - - return nil - } - if !slices.Contains(validValues, value) { - return fmt.Errorf("invalid value '%s' for '--%s' . Valid values are: %s", value, flag, validValues) - } - return nil } diff --git a/cmd/run_test.go b/cmd/run_test.go deleted file mode 100644 index bff4566..0000000 --- a/cmd/run_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package cmd - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestValidateStringSliceFlag_Valid(t *testing.T) { - assert.NoError(t, validateStringSliceFlag("filter.types", []string{"NEW", "UPDATE"}, validEventTypes)) - assert.NoError(t, validateStringSliceFlag("filter.protocols", []string{"TCP"}, validProtocols)) - assert.NoError(t, validateStringSliceFlag("filter.destination.addresses", []string{"127.0.0.1", "::1"}, []string{})) -} - -func TestValidateStringSliceFlag_Invalid(t *testing.T) { - assert.Error(t, validateStringSliceFlag("filter.types", []string{"BAD"}, validEventTypes)) - assert.Error(t, validateStringSliceFlag("filter.addresses", []string{"not-an-ip"}, []string{})) -} - -func TestValidateStringFlag_LogLevelsAndFormats(t *testing.T) { - assert.NoError(t, validateStringFlag("service.log.level", "debug", validLogLevels)) - assert.NoError(t, validateStringFlag("service.log.format", "json", validLogFormats)) - - assert.Error(t, validateStringFlag("service.log.level", "verbose", validLogLevels)) - assert.Error(t, validateStringFlag("service.log.format", "xml", validLogFormats)) -} - -func TestValidateStringFlag_SyslogAddress_Valid(t *testing.T) { - valids := []string{ - "udp://localhost:514", - "tcp://127.0.0.1:514", - "unix:///var/run/syslog.sock", - "unixgram:///var/run/syslog.sock", - "unixpacket:///var/run/syslog.sock", - } - for _, v := range valids { - assert.NoErrorf(t, validateStringFlag("sink.syslog.address", v, []string{}), "valid syslog address %q should not error", v) - } -} - -func TestValidateStringFlag_SyslogAddress_Invalid(t *testing.T) { - assert.Error(t, validateStringFlag("sink.syslog.address", "http://localhost:514", []string{})) - assert.Error(t, validateStringFlag("sink.syslog.address", "tcp:///nohost", []string{})) - assert.Error(t, validateStringFlag("sink.syslog.address", "unix://", []string{})) -} - -func TestValidateStringFlag_LokiAddress_Valid(t *testing.T) { - assert.NoError(t, validateStringFlag("sink.loki.address", "http://localhost:3100", []string{})) - assert.NoError(t, validateStringFlag("sink.loki.address", "https://example.com", []string{})) -} - -func TestValidateStringFlag_LokiAddress_Invalid(t *testing.T) { - assert.Error(t, validateStringFlag("sink.loki.address", "tcp://localhost:3100", []string{})) - assert.Error(t, validateStringFlag("sink.loki.address", "http:///path", []string{})) -} - -func TestValidSlicesAreExplicit(t *testing.T) { - expectedEvents := []string{"NEW", "UPDATE", "DESTROY"} - assert.Equal(t, expectedEvents, validEventTypes, "validEventTypes mismatch") - - expectedProtocols := []string{"TCP", "UDP"} - assert.Equal(t, expectedProtocols, validProtocols, "validProtocols mismatch") - - expectedDest := []string{"PUBLIC", "PRIVATE", "LOCAL", "MULTICAST"} - assert.Equal(t, expectedDest, validNetworks, "validNetworks mismatch") -} diff --git a/cmd/version.go b/cmd/version.go index 7fa8562..ed9a30c 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -18,5 +18,5 @@ var versionCmd = &cobra.Command{ } func init() { - runCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp) + versionCmd.CompletionOptions.SetDefaultShellCompDirective(cobra.ShellCompDirectiveNoFileComp) } diff --git a/docs/filter.md b/docs/filter.md new file mode 100644 index 0000000..9056826 --- /dev/null +++ b/docs/filter.md @@ -0,0 +1,357 @@ +# Filter DSL Documentation + +## Overview + +The conntrackd filter DSL (Domain-Specific Language) allows you to control +which conntrack events are **logged** to your configured sinks +(journal, syslog, Loki, etc.). + +**Protocol Support:** Only TCP and UDP events are processed. All other protocols +(ICMP, IGMP, etc.) are automatically ignored and never logged, regardless of +filter rules. + +**Important:** Filters do not affect network traffic - they only control which +conntrack events are logged. All network traffic flows normally regardless of +filter rules. + +Rules are evaluated in order (first-match wins), and events are +**logged by default** when no rule matches. + +## Command-Line Usage + +Use the `--filter` flag to specify filter rules. This flag can be repeated +multiple times: + +```bash +conntrackd run \ + --filter "drop destination address 8.8.8.8" \ + --filter "log protocol TCP" \ + --filter "drop ANY" +``` + +## Understanding Allow-by-Default + +By default, conntrackd logs all conntrack events. This means: + +- If no filters match an event, it **is logged** (allow-by-default) +- An `log` rule means "log this event" +- A `drop` rule means "don't log this event" + +To log **only** specific events, use an `log` rule followed by `drop ANY`: + +```bash +# Log ONLY NEW TCP connections +--filter "log type NEW and protocol TCP" +--filter "drop ANY" +``` + +Without the `drop ANY`, all non-matching events would still be logged. + +## Grammar + +The filter DSL follows this grammar (EBNF notation): + +```ebnf +rule ::= action expression +action ::= "log" | "drop" +expression ::= orExpr +orExpr ::= andExpr { "or" andExpr } +andExpr ::= notExpr { "and" notExpr } +notExpr ::= [ "not" | "!" ] primary +primary ::= predicate | "(" expression ")" + +predicate ::= eventPred | protoPred | addrPred | networkPred | portPred | anyPred + +eventPred ::= "type" identList +protoPred ::= "protocol" identList +addrPred ::= direction "address" addrList [ "on" "port" portSpec ] +networkPred::= direction "network" identList +portPred ::= [ direction ] "port" portSpec + | "on" "port" portSpec +anyPred ::= "ANY" + +direction ::= "source" | "src" | "destination" | "dst" | "dest" +identList ::= IDENT { "," IDENT } +addrList ::= ADDRESS { "," ADDRESS } +portSpec ::= NUMBER | NUMBER "-" NUMBER | NUMBER { "," NUMBER } +``` + +## Predicates + +### Any (Catch-All) + +The `ANY` predicate matches all events. It's typically used with `drop` to +block all non-matching events: + +```bash +# Log only NEW TCP connections (deny everything else) +log type NEW and protocol TCP +drop ANY + +# Log only traffic to specific IPs (deny everything else) +log destination address 1.2.3.4,5.6.7.8 +drop ANY +``` + +### Event Type + +Match on conntrack event types: + +```bash +# Deny NEW events +drop type NEW + +# Allow UPDATE or DESTROY events +log type UPDATE,DESTROY +``` + +Valid types: `NEW`, `UPDATE`, `DESTROY` + +### Protocol + +Match on protocol: + +```bash +# Deny TCP events +drop protocol TCP + +# Allow TCP or UDP +log protocol TCP,UDP +``` + +Valid protocols: `TCP`, `UDP` + +### Network Classification + +Match on network categories: + +```bash +# Deny traffic to private networks +drop destination network PRIVATE + +# Allow traffic from public networks +log source network PUBLIC +``` + +Valid network types: +- `LOCAL` - Loopback addresses (127.0.0.0/8, ::1) and link-local addresses +- `PRIVATE` - RFC1918 private addresses (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) and IPv6 ULA (fc00::/7) +- `PUBLIC` - Public addresses (not LOCAL, PRIVATE, or MULTICAST) +- `MULTICAST` - Multicast addresses (224.0.0.0/4, ff00::/8) + +### IP Address + +Match on specific IP addresses or CIDR ranges: + +```bash +# Deny traffic to specific IP +drop destination address 8.8.8.8 + +# Allow traffic from a CIDR range +log source address 192.168.1.0/24 + +# Match IPv6 addresses +drop destination address 2001:4860:4860::8888 + +# Multiple addresses +drop destination address 1.1.1.1,8.8.8.8 +``` + +### Port + +Match on ports: + +```bash +# Deny traffic to port 22 +drop destination port 22 + +# Allow traffic from ports 1024-65535 +log source port 1024-65535 + +# Match multiple ports +drop destination port 22,23,3389 + +# Match traffic on any port (source or destination) +deny on port 53 +``` + +### Address with Port + +Combine address and port matching: + +```bash +# Deny traffic to specific IP on specific port +drop destination address 10.19.80.100 on port 53 +``` + +## Operators + +### AND + +Both conditions must be true: + +```bash +log type NEW and protocol TCP +``` + +### OR + +At least one condition must be true: + +```bash +log type NEW or type UPDATE +``` + +### NOT + +Negate a condition: + +```bash +log not type DESTROY +``` + +### Parentheses + +Group expressions: + +```bash +log (type NEW and protocol TCP) or (type UPDATE and protocol UDP) +``` + +## Operator Precedence + +1. NOT (highest) +2. AND +3. OR (lowest) + +Example: `log type NEW or type UPDATE and protocol TCP` parses as `log type NEW or (type UPDATE and protocol TCP)` + +## Evaluation Semantics + +### First-Match Wins + +Rules are evaluated in the order they are specified. The first rule whose predicate matches determines the action (allow or deny). + +```bash +# First rule matches TCP traffic - allows it +# Second rule never evaluated for TCP to 8.8.8.8 +conntrackd run \ + --filter "log protocol TCP" \ + --filter "drop destination address 8.8.8.8" +``` + +### Allow-by-Default + +If no rule matches an event, it is **logged** (allowed by default): + +```bash +# Don't log events to 8.8.8.8 +# All other events ARE logged +conntrackd run --filter "drop destination address 8.8.8.8" +``` + +To change this behavior and log **only** specific events, use `drop ANY` as the final rule: + +```bash +# Log ONLY events to 8.8.8.8 +# All other events are NOT logged +conntrackd run \ + --filter "log destination address 8.8.8.8" \ + --filter "drop ANY" +``` + +## Examples + +### Example 1: Don't Log Specific Destination + +Don't log events to a specific IP address, but log all TCP traffic to public networks: + +```bash +conntrackd run \ + --filter "drop destination address 8.8.8.8" \ + --filter "log protocol TCP and destination network PUBLIC" +``` + +**Evaluation:** +- Traffic to 8.8.8.8: Matches first rule → **NOT LOGGED** +- TCP to public IP (not 8.8.8.8): Matches second rule → **LOGGED** +- UDP to private network: No match → **LOGGED** (default) + +### Example 2: Don't Log DNS to Specific Server + +Don't log DNS traffic to a specific IP, log all other TCP/UDP: + +```bash +conntrackd run \ + --filter "drop destination address 10.19.80.100 on port 53" \ + --filter "log protocol TCP,UDP" +``` + +**Evaluation:** +- DNS to 10.19.80.100: Matches first rule → **NOT LOGGED** +- TCP/UDP to other destinations: Matches second rule → **LOGGED** +- Other protocols: No match → **LOGGED** (default) + +### Example 3: Log Only Specific Traffic + +Log only NEW TCP connections (don't log anything else): + +```bash +conntrackd run \ + --filter "log type NEW and protocol TCP" \ + --filter "drop ANY" +``` + +**Evaluation:** +- NEW TCP: Matches first rule → **LOGGED** +- NEW UDP: Matches second rule → **NOT LOGGED** +- UPDATE/DESTROY: Matches second rule → **NOT LOGGED** + +**Note:** Without `drop ANY`, all non-matching events would still be logged. + +### Example 4: Complex Filtering + +Don't log outbound traffic to private networks on specific ports: + +```bash +conntrackd run \ + --filter "drop destination network PRIVATE and destination port 22,23,3389" \ + --filter "log source network PRIVATE" +``` + +**Evaluation:** +- Private network destination on port 22: Matches first rule → **NOT LOGGED** +- Private network source: Matches second rule → **LOGGED** +- Other traffic: No match → **LOGGED** (default) + +## Best Practices + +1. **Order Matters**: Place more specific rules before general rules +2. **Use `drop ANY` for Exclusive Logging**: When you want to log ONLY specific events, end with `drop ANY` +3. **Use AND for Precision**: Combine multiple conditions to create precise filters +4. **Test Incrementally**: Start with simple rules and add complexity +5. **Document Complex Rules**: Add comments in your deployment scripts +6. **Use Parentheses**: Make precedence explicit in complex expressions + +## Case Insensitivity + +Keywords and identifiers are case-insensitive: + +```bash +# These are all equivalent +log type NEW +ALLOW TYPE NEW +Allow Type New +``` + +## Abbreviations + +Supported abbreviations: +- `src` = `source` +- `dst` = `dest` = `destination` + +```bash +# These are equivalent +deny src network PRIVATE +drop source network PRIVATE +``` diff --git a/go.mod b/go.mod index 2aca6d6..fea67d0 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/grafana/loki-client-go v0.0.0-20251015150631-c42bbddc310a github.com/mdlayher/netlink v1.8.0 github.com/oschwald/geoip2-golang/v2 v2.0.0 + github.com/prometheus/common v0.67.2 github.com/samber/slog-loki/v3 v3.6.0 github.com/samber/slog-multi v1.6.0 github.com/samber/slog-syslog/v2 v2.5.2 @@ -47,7 +48,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.20.4 // indirect github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.67.2 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/prometheus v0.35.0 // indirect github.com/samber/lo v1.52.0 // indirect diff --git a/internal/filter/ast.go b/internal/filter/ast.go new file mode 100644 index 0000000..9d1daa2 --- /dev/null +++ b/internal/filter/ast.go @@ -0,0 +1,121 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +// Action represents the action to take when a rule matches +type Action int + +const ( + ActionLog Action = iota + ActionDrop +) + +func (a Action) String() string { + switch a { + case ActionLog: + return "log" + case ActionDrop: + return "drop" + default: + return "unknown" + } +} + +// ExprNode represents a node in the expression AST +type ExprNode interface { + isExprNode() +} + +// BinaryExpr represents a binary expression (AND, OR) +type BinaryExpr struct { + Op BinaryOp + Left ExprNode + Right ExprNode +} + +func (BinaryExpr) isExprNode() {} + +type BinaryOp int + +const ( + OpAnd BinaryOp = iota + OpOr +) + +// UnaryExpr represents a unary expression (NOT) +type UnaryExpr struct { + Op UnaryOp + Expr ExprNode +} + +func (UnaryExpr) isExprNode() {} + +type UnaryOp int + +const ( + OpNot UnaryOp = iota +) + +// Predicate represents a base predicate +type Predicate interface { + ExprNode + isPredicate() +} + +// TypePredicate matches event types +type TypePredicate struct { + Types []string // NEW, UPDATE, DESTROY +} + +func (TypePredicate) isExprNode() {} +func (TypePredicate) isPredicate() {} + +// ProtocolPredicate matches protocols +type ProtocolPredicate struct { + Protocols []string // TCP, UDP +} + +func (ProtocolPredicate) isExprNode() {} +func (ProtocolPredicate) isPredicate() {} + +// NetworkPredicate matches network types +type NetworkPredicate struct { + Direction string // source, destination + Networks []string // LOCAL, PRIVATE, PUBLIC, MULTICAST +} + +func (NetworkPredicate) isExprNode() {} +func (NetworkPredicate) isPredicate() {} + +// AddressPredicate matches IP addresses or CIDR ranges +type AddressPredicate struct { + Direction string // source, destination + Addresses []string // IP addresses or CIDR + Ports []uint16 // optional ports +} + +func (AddressPredicate) isExprNode() {} +func (AddressPredicate) isPredicate() {} + +// PortPredicate matches ports +type PortPredicate struct { + Direction string // source, destination + Ports []uint16 // port numbers or ranges +} + +func (PortPredicate) isExprNode() {} +func (PortPredicate) isPredicate() {} + +// AnyPredicate matches any event (catch-all) +type AnyPredicate struct{} + +func (AnyPredicate) isExprNode() {} +func (AnyPredicate) isPredicate() {} + +// Rule represents a complete filter rule +type Rule struct { + Action Action + Expr ExprNode +} diff --git a/internal/filter/eval.go b/internal/filter/eval.go new file mode 100644 index 0000000..3a99dd8 --- /dev/null +++ b/internal/filter/eval.go @@ -0,0 +1,249 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +import ( + "fmt" + "net/netip" + "slices" + "strings" + "syscall" + + "github.com/ti-mo/conntrack" +) + +// PredicateFunc is a function that evaluates a predicate against an event +type PredicateFunc func(event conntrack.Event) bool + +// Compile compiles an expression AST into a predicate function +func Compile(expr ExprNode) (PredicateFunc, error) { + switch e := expr.(type) { + case BinaryExpr: + return compileBinaryExpr(e) + case UnaryExpr: + return compileUnaryExpr(e) + case TypePredicate: + return compileTypePredicate(e), nil + case ProtocolPredicate: + return compileProtocolPredicate(e), nil + case NetworkPredicate: + return compileNetworkPredicate(e), nil + case AddressPredicate: + return compileAddressPredicate(e) + case PortPredicate: + return compilePortPredicate(e), nil + case AnyPredicate: + return compileAnyPredicate(e), nil + default: + return nil, fmt.Errorf("unknown expression type: %T", expr) + } +} + +func compileBinaryExpr(expr BinaryExpr) (PredicateFunc, error) { + left, err := Compile(expr.Left) + if err != nil { + return nil, err + } + right, err := Compile(expr.Right) + if err != nil { + return nil, err + } + + switch expr.Op { + case OpAnd: + return func(event conntrack.Event) bool { + return left(event) && right(event) + }, nil + case OpOr: + return func(event conntrack.Event) bool { + return left(event) || right(event) + }, nil + default: + return nil, fmt.Errorf("unknown binary operator: %v", expr.Op) + } +} + +func compileUnaryExpr(expr UnaryExpr) (PredicateFunc, error) { + inner, err := Compile(expr.Expr) + if err != nil { + return nil, err + } + + switch expr.Op { + case OpNot: + return func(event conntrack.Event) bool { + return !inner(event) + }, nil + default: + return nil, fmt.Errorf("unknown unary operator: %v", expr.Op) + } +} + +func compileTypePredicate(pred TypePredicate) PredicateFunc { + return func(event conntrack.Event) bool { + var eventType string + switch event.Type { + case conntrack.EventNew: + eventType = "NEW" + case conntrack.EventUpdate: + eventType = "UPDATE" + case conntrack.EventDestroy: + eventType = "DESTROY" + default: + return false + } + return slices.Contains(pred.Types, eventType) + } +} + +func compileProtocolPredicate(pred ProtocolPredicate) PredicateFunc { + return func(event conntrack.Event) bool { + var protocol string + switch event.Flow.TupleOrig.Proto.Protocol { + case syscall.IPPROTO_TCP: + protocol = "TCP" + case syscall.IPPROTO_UDP: + protocol = "UDP" + default: + return false + } + return slices.Contains(pred.Protocols, protocol) + } +} + +func compileNetworkPredicate(pred NetworkPredicate) PredicateFunc { + return func(event conntrack.Event) bool { + var ip netip.Addr + switch strings.ToLower(pred.Direction) { + case "source", "src": + ip = event.Flow.TupleOrig.IP.SourceAddress + case "destination", "dst", "dest": + ip = event.Flow.TupleOrig.IP.DestinationAddress + default: + return false + } + + isLocal := ip.IsLoopback() || ip.IsLinkLocalUnicast() + isPrivate := ip.IsPrivate() + isMulticast := ip.IsMulticast() + isPublic := !isLocal && !isPrivate && !isMulticast + + for _, network := range pred.Networks { + switch network { + case "LOCAL": + if isLocal { + return true + } + case "PRIVATE": + if isPrivate { + return true + } + case "MULTICAST": + if isMulticast { + return true + } + case "PUBLIC": + if isPublic { + return true + } + } + } + return false + } +} + +func compileAddressPredicate(pred AddressPredicate) (PredicateFunc, error) { + // Pre-compile address matchers + var matchers []func(netip.Addr) bool + for _, addrStr := range pred.Addresses { + // Try to parse as CIDR first + if strings.Contains(addrStr, "/") { + prefix, err := netip.ParsePrefix(addrStr) + if err == nil { + // Capture prefix in a local variable to avoid loop variable capture + pfx := prefix + matchers = append(matchers, func(ip netip.Addr) bool { + return pfx.Contains(ip) + }) + continue + } + } + // Try to parse as IP + addr, err := netip.ParseAddr(addrStr) + if err == nil { + // Capture addr in a local variable to avoid loop variable capture + a := addr + matchers = append(matchers, func(ip netip.Addr) bool { + return ip == a + }) + } + } + + // If no addresses parsed successfully, return error + if len(matchers) == 0 { + return nil, fmt.Errorf("no valid addresses in predicate") + } + + return func(event conntrack.Event) bool { + var ip netip.Addr + var port uint16 + + switch strings.ToLower(pred.Direction) { + case "source", "src": + ip = event.Flow.TupleOrig.IP.SourceAddress + port = event.Flow.TupleOrig.Proto.SourcePort + case "destination", "dst", "dest": + ip = event.Flow.TupleOrig.IP.DestinationAddress + port = event.Flow.TupleOrig.Proto.DestinationPort + default: + return false + } + + // Check if IP matches + matched := false + for _, matcher := range matchers { + if matcher(ip) { + matched = true + break + } + } + + if !matched { + return false + } + + // If ports are specified, check them too + if len(pred.Ports) > 0 { + return slices.Contains(pred.Ports, port) + } + + return true + }, nil +} + +func compilePortPredicate(pred PortPredicate) PredicateFunc { + return func(event conntrack.Event) bool { + srcPort := event.Flow.TupleOrig.Proto.SourcePort + dstPort := event.Flow.TupleOrig.Proto.DestinationPort + + switch strings.ToLower(pred.Direction) { + case "source", "src": + return slices.Contains(pred.Ports, srcPort) + case "destination", "dst", "dest": + return slices.Contains(pred.Ports, dstPort) + case "both": + return slices.Contains(pred.Ports, srcPort) || slices.Contains(pred.Ports, dstPort) + default: + return false + } + } +} + +func compileAnyPredicate(pred AnyPredicate) PredicateFunc { + // AnyPredicate always matches + return func(event conntrack.Event) bool { + return true + } +} diff --git a/internal/filter/eval_test.go b/internal/filter/eval_test.go new file mode 100644 index 0000000..1be5761 --- /dev/null +++ b/internal/filter/eval_test.go @@ -0,0 +1,433 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +import ( + "net/netip" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ti-mo/conntrack" +) + +func createEvent(eventTypeVal, proto uint8) conntrack.Event { + flow := conntrack.NewFlow( + proto, + conntrack.StatusAssured, + netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("8.8.8.8"), + 1234, 80, + 60, 0, + ) + // Create event and then set Type using constants to satisfy type system + event := conntrack.Event{Flow: &flow} + switch eventTypeVal { + case 1: // NEW + event.Type = conntrack.EventNew + case 2: // UPDATE + event.Type = conntrack.EventUpdate + case 3: // DESTROY + event.Type = conntrack.EventDestroy + } + return event +} + +func createEventWithAddrs(eventTypeVal, proto uint8, srcIP, dstIP string, srcPort, dstPort uint16) conntrack.Event { + flow := conntrack.NewFlow( + proto, + conntrack.StatusAssured, + netip.MustParseAddr(srcIP), netip.MustParseAddr(dstIP), + srcPort, dstPort, + 60, 0, + ) + // Create event and then set Type using constants to satisfy type system + event := conntrack.Event{Flow: &flow} + switch eventTypeVal { + case 1: // NEW + event.Type = conntrack.EventNew + case 2: // UPDATE + event.Type = conntrack.EventUpdate + case 3: // DESTROY + event.Type = conntrack.EventDestroy + } + return event +} + +func TestEval_TypePredicate(t *testing.T) { + tests := []struct { + name string + rule string + event conntrack.Event + expected bool + }{ + {"match NEW", "log type NEW", createEvent(1, syscall.IPPROTO_TCP), true}, + {"match UPDATE", "log type UPDATE", createEvent(2, syscall.IPPROTO_TCP), true}, + {"match DESTROY", "log type DESTROY", createEvent(3, syscall.IPPROTO_TCP), true}, + {"no match NEW", "log type UPDATE", createEvent(1, syscall.IPPROTO_TCP), false}, + {"match multiple", "log type NEW,UPDATE", createEvent(1, syscall.IPPROTO_TCP), true}, + {"match multiple 2", "log type NEW,UPDATE", createEvent(2, syscall.IPPROTO_TCP), true}, + {"no match multiple", "log type NEW,UPDATE", createEvent(3, syscall.IPPROTO_TCP), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + assert.Equal(t, tt.expected, pred(tt.event)) + }) + } +} + +func TestEval_ProtocolPredicate(t *testing.T) { + tests := []struct { + name string + rule string + proto uint8 + expected bool + }{ + {"match TCP", "log protocol TCP", syscall.IPPROTO_TCP, true}, + {"match UDP", "log protocol UDP", syscall.IPPROTO_UDP, true}, + {"no match TCP", "log protocol UDP", syscall.IPPROTO_TCP, false}, + {"match multiple", "log protocol TCP,UDP", syscall.IPPROTO_TCP, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + event := createEvent(1, tt.proto) + assert.Equal(t, tt.expected, pred(event)) + }) + } +} + +func TestEval_NetworkPredicate(t *testing.T) { + tests := []struct { + name string + rule string + srcIP string + dstIP string + expected bool + }{ + {"src private", "log source network PRIVATE", "10.0.0.1", "8.8.8.8", true}, + {"src public", "log source network PUBLIC", "8.8.8.8", "10.0.0.1", true}, + {"dst public", "log destination network PUBLIC", "10.0.0.1", "8.8.8.8", true}, + {"dst private", "log destination network PRIVATE", "8.8.8.8", "10.0.0.1", true}, + {"src local loopback", "log source network LOCAL", "127.0.0.1", "8.8.8.8", true}, + {"dst multicast", "log destination network MULTICAST", "10.0.0.1", "224.0.0.1", true}, + {"no match", "log source network PUBLIC", "10.0.0.1", "8.8.8.8", false}, + {"ipv6 private", "log source network PRIVATE", "fc00::1", "2001:db8::1", true}, + {"ipv6 public", "log destination network PUBLIC", "10.0.0.1", "2001:4860:4860::8888", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + event := createEventWithAddrs(1, syscall.IPPROTO_TCP, tt.srcIP, tt.dstIP, 1234, 80) + assert.Equal(t, tt.expected, pred(event)) + }) + } +} + +func TestEval_AddressPredicate(t *testing.T) { + tests := []struct { + name string + rule string + srcIP string + dstIP string + srcPort uint16 + dstPort uint16 + expected bool + }{ + {"dst exact match", "log destination address 8.8.8.8", "10.0.0.1", "8.8.8.8", 1234, 80, true}, + {"dst no match", "log destination address 8.8.8.8", "10.0.0.1", "8.8.4.4", 1234, 80, false}, + {"src exact match", "log source address 10.0.0.1", "10.0.0.1", "8.8.8.8", 1234, 80, true}, + {"cidr match", "log destination address 8.8.8.0/24", "10.0.0.1", "8.8.8.100", 1234, 80, true}, + {"cidr no match", "log destination address 8.8.8.0/24", "10.0.0.1", "8.8.9.1", 1234, 80, false}, + {"with port match", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.100", 1234, 53, true}, + {"with port no match addr", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.101", 1234, 53, false}, + {"with port no match port", "log destination address 10.19.80.100 on port 53", "192.168.1.1", "10.19.80.100", 1234, 80, false}, + {"ipv6 match", "log destination address 2001:4860:4860::8888", "10.0.0.1", "2001:4860:4860::8888", 1234, 80, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + event := createEventWithAddrs(1, syscall.IPPROTO_TCP, tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.expected, pred(event)) + }) + } +} + +func TestEval_PortPredicate(t *testing.T) { + tests := []struct { + name string + rule string + srcPort uint16 + dstPort uint16 + expected bool + }{ + {"dst port match", "log destination port 80", 1234, 80, true}, + {"dst port no match", "log destination port 80", 1234, 443, false}, + {"src port match", "log source port 1234", 1234, 80, true}, + {"on port dst match", "log on port 80", 1234, 80, true}, + {"on port src match", "log on port 1234", 1234, 80, true}, + {"on port no match", "log on port 53", 1234, 80, false}, + {"port range", "log destination port 80,443,8080", 1234, 443, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", tt.srcPort, tt.dstPort) + assert.Equal(t, tt.expected, pred(event)) + }) + } +} + +func TestEval_AnyPredicate(t *testing.T) { + tests := []struct { + name string + rule string + event conntrack.Event + }{ + { + "any matches NEW TCP", + "log any", + createEvent(1, syscall.IPPROTO_TCP), + }, + { + "any matches UPDATE UDP", + "drop any", + createEvent(2, syscall.IPPROTO_UDP), + }, + { + "any matches DESTROY", + "log any", + createEvent(3, syscall.IPPROTO_TCP), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + result := pred(tt.event) + assert.True(t, result, "any predicate should always match") + }) + } +} + +func TestEval_BinaryExpressions(t *testing.T) { + tests := []struct { + name string + rule string + event conntrack.Event + expected bool + }{ + { + "and both match", + "log type NEW and protocol TCP", + createEvent(1, syscall.IPPROTO_TCP), + true, + }, + { + "and first no match", + "log type UPDATE and protocol TCP", + createEvent(1, syscall.IPPROTO_TCP), + false, + }, + { + "and second no match", + "log type NEW and protocol UDP", + createEvent(1, syscall.IPPROTO_TCP), + false, + }, + { + "or first match", + "log type NEW or type UPDATE", + createEvent(1, syscall.IPPROTO_TCP), + true, + }, + { + "or second match", + "log type UPDATE or type NEW", + createEvent(1, syscall.IPPROTO_TCP), + true, + }, + { + "or both no match", + "log type UPDATE or type DESTROY", + createEvent(1, syscall.IPPROTO_TCP), + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + assert.Equal(t, tt.expected, pred(tt.event)) + }) + } +} + +func TestEval_UnaryExpression(t *testing.T) { + tests := []struct { + name string + rule string + event conntrack.Event + expected bool + }{ + { + "not match becomes false", + "log not type NEW", + createEvent(1, syscall.IPPROTO_TCP), + false, + }, + { + "not no-match becomes true", + "log not type UPDATE", + createEvent(1, syscall.IPPROTO_TCP), + true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.rule) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + pred, err := Compile(rule.Expr) + require.NoError(t, err) + + assert.Equal(t, tt.expected, pred(tt.event)) + }) + } +} + +func TestFilter_Evaluate(t *testing.T) { + rules := []string{ + "drop destination address 8.8.8.8", + "log protocol TCP and destination network PUBLIC", + } + + filter, err := NewFilter(rules) + require.NoError(t, err) + + tests := []struct { + name string + event conntrack.Event + matched bool + allow bool + matchedIndex int + }{ + { + "first rule denies", + createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80), + true, + false, + 0, + }, + { + "second rule allows", + createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "1.1.1.1", 1234, 80), + true, + true, + 1, + }, + { + "no match allows by default", + createEventWithAddrs(1, syscall.IPPROTO_UDP, "10.0.0.1", "192.168.1.1", 1234, 80), + false, + true, + -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matched, allow, matchedIndex := filter.Evaluate(tt.event) + assert.Equal(t, tt.matched, matched) + assert.Equal(t, tt.allow, allow) + assert.Equal(t, tt.matchedIndex, matchedIndex) + }) + } +} + +func TestFilter_EmptyAllowsByDefault(t *testing.T) { + filter, err := NewFilter([]string{}) + require.NoError(t, err) + + event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80) + matched, allow, matchedIndex := filter.Evaluate(event) + assert.False(t, matched) + assert.True(t, allow) + assert.Equal(t, -1, matchedIndex) +} + +func TestFilter_FirstMatchWins(t *testing.T) { + rules := []string{ + "log protocol TCP", + "drop destination address 8.8.8.8", + } + + filter, err := NewFilter(rules) + require.NoError(t, err) + + // TCP to 8.8.8.8 should be allowed by first rule (first match wins) + event := createEventWithAddrs(1, syscall.IPPROTO_TCP, "10.0.0.1", "8.8.8.8", 1234, 80) + matched, allow, matchedIndex := filter.Evaluate(event) + assert.True(t, matched) + assert.True(t, allow) + assert.Equal(t, 0, matchedIndex) +} diff --git a/internal/filter/filter.go b/internal/filter/filter.go index 5e1befa..e736d5c 100644 --- a/internal/filter/filter.go +++ b/internal/filter/filter.go @@ -5,212 +5,72 @@ Licensed under the MIT License, see LICENSE file in the project root for details package filter import ( - "log/slog" - "slices" - "syscall" + "fmt" "github.com/ti-mo/conntrack" ) -type FilterAddresses struct { - Destinations []string - Sources []string -} - -type FilterNetworks struct { - Destinations []string - Sources []string -} - -type FilterPorts struct { - Destinations []uint - Sources []uint -} - +// Filter represents a compiled set of filter rules type Filter struct { - EventTypes []string - Protocols []string - Networks FilterNetworks - Addresses FilterAddresses - Ports FilterPorts + Rules []CompiledRule } -func (f *Filter) eventType(event conntrack.Event) bool { - if len(f.EventTypes) == 0 { - return false - } - - types := map[any]string{ - conntrack.EventNew: "NEW", - conntrack.EventUpdate: "UPDATE", - conntrack.EventDestroy: "DESTROY", - } - eventTypeStr, ok := types[event.Type] - if !ok { - return false - } - - return slices.Contains(f.EventTypes, eventTypeStr) +// CompiledRule represents a parsed and compiled filter rule +type CompiledRule struct { + Rule *Rule + Predicate PredicateFunc + RuleText string } -func (f *Filter) eventProtocol(event conntrack.Event) bool { - if len(f.Protocols) == 0 { - switch event.Flow.TupleOrig.Proto.Protocol { - case syscall.IPPROTO_TCP, syscall.IPPROTO_UDP: - return false - default: - return true - } +// NewFilter creates a new DSL-based filter from rule strings +func NewFilter(ruleStrings []string) (*Filter, error) { + filter := &Filter{ + Rules: make([]CompiledRule, 0, len(ruleStrings)), } - protocols := map[int]string{ - syscall.IPPROTO_TCP: "TCP", - syscall.IPPROTO_UDP: "UDP", - } - protocolStr, ok := protocols[int(event.Flow.TupleOrig.Proto.Protocol)] - if !ok { - return false - } - - return slices.Contains(f.Protocols, protocolStr) -} - -func (f *Filter) eventSource(event conntrack.Event) bool { - if len(f.Networks.Sources) == 0 { - return false - } - - src := event.Flow.TupleOrig.IP.SourceAddress - slog.Info("Source Address", "src", src.String()) - isLocal := src.IsLoopback() - slog.Info("Is Local", "isLocal", isLocal) - isPrivate := src.IsPrivate() - isMulticast := src.IsMulticast() - isPublic := !isLocal && !isPrivate && !isMulticast - - for _, filterSource := range f.Networks.Sources { - switch filterSource { - case "LOCAL": - if isLocal { - return true - } - case "PRIVATE": - if isPrivate { - return true - } - case "MULTICAST": - if isMulticast { - return true - } - case "PUBLIC": - if isPublic { - return true - } + for i, ruleStr := range ruleStrings { + parser, err := NewParser(ruleStr) + if err != nil { + return nil, fmt.Errorf("failed to initialize parser for rule %d: %w", i, err) } - } - return false -} - -func (f *Filter) eventDestination(event conntrack.Event) bool { - if len(f.Networks.Destinations) == 0 { - return false - } - - dest := event.Flow.TupleOrig.IP.DestinationAddress - isLocal := dest.IsLoopback() - isPrivate := dest.IsPrivate() - isMulticast := dest.IsMulticast() - isPublic := !isLocal && !isPrivate && !isMulticast - - for _, filterDest := range f.Networks.Destinations { - switch filterDest { - case "LOCAL": - if isLocal { - return true - } - case "PRIVATE": - if isPrivate { - return true - } - case "MULTICAST": - if isMulticast { - return true - } - case "PUBLIC": - if isPublic { - return true - } + rule, err := parser.ParseRule() + if err != nil { + return nil, fmt.Errorf("failed to parse rule %d (%s): %w", i, ruleStr, err) } - } - return false -} - -func (f *Filter) eventAddressDestination(event conntrack.Event) bool { - if len(f.Addresses.Destinations) == 0 { - return false - } - - return slices.Contains(f.Addresses.Destinations, event.Flow.TupleOrig.IP.DestinationAddress.String()) -} - -func (f *Filter) eventAddressSource(event conntrack.Event) bool { - if len(f.Addresses.Sources) == 0 { - return false - } - - return slices.Contains(f.Addresses.Sources, event.Flow.TupleOrig.IP.SourceAddress.String()) -} - -func (f *Filter) eventPortDestination(event conntrack.Event) bool { - if len(f.Ports.Destinations) == 0 { - return false - } - - return slices.Contains(f.Ports.Destinations, uint(event.Flow.TupleOrig.Proto.DestinationPort)) -} + predicate, err := Compile(rule.Expr) + if err != nil { + return nil, fmt.Errorf("failed to compile rule %d (%s): %w", i, ruleStr, err) + } -func (f *Filter) eventPortSource(event conntrack.Event) bool { - if len(f.Ports.Sources) == 0 { - return false + filter.Rules = append(filter.Rules, CompiledRule{ + Rule: rule, + Predicate: predicate, + RuleText: ruleStr, + }) } - return slices.Contains(f.Ports.Sources, uint(event.Flow.TupleOrig.Proto.SourcePort)) + return filter, nil } -func (f *Filter) Apply(event conntrack.Event) bool { - if f.eventType(event) { - return true +// Evaluate evaluates the filter against an event +// Returns: (matched bool, shouldLog bool, matchedRuleIndex int) +// If no rule matches, returns (false, true, -1) for log-by-default policy +func (f *Filter) Evaluate(event conntrack.Event) (bool, bool, int) { + if f == nil || len(f.Rules) == 0 { + // Log by default when no rules + return false, true, -1 } - if f.eventProtocol(event) { - return true - } - - if f.eventSource(event) { - return true - } - - if f.eventDestination(event) { - return true - } - - if f.eventAddressDestination(event) { - return true - } - - if f.eventAddressSource(event) { - return true - } - - if f.eventPortDestination(event) { - return true - } - - if f.eventPortSource(event) { - return true + // First-match wins + for i, compiledRule := range f.Rules { + if compiledRule.Predicate(event) { + shouldLog := compiledRule.Rule.Action == ActionLog + return true, shouldLog, i + } } - return false + // Log by default when no rule matches + return false, true, -1 } diff --git a/internal/filter/filter_test.go b/internal/filter/filter_test.go deleted file mode 100644 index d6c9664..0000000 --- a/internal/filter/filter_test.go +++ /dev/null @@ -1,246 +0,0 @@ -/* -Copyright (c) 2025 Tobias Schäfer. All rights reserved. -Licensed under the MIT License, see LICENSE file in the project root for details. -*/ -package filter - -import ( - "io" - "net/http" - "net/netip" - "os" - "syscall" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/ti-mo/conntrack" -) - -const geoDatabasePath = "/tmp/GeoLite2-City.mmdb" -const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb" - -func setup() { - setupGeoDatabase() -} - -func setupGeoDatabase() { - if _, err := os.Stat(geoDatabasePath); os.IsNotExist(err) { - resp, err := http.Get(geoDatabaseUrl) - if err != nil { - panic(err) - } - defer func() { - _ = resp.Body.Close() - }() - - out, err := os.Create(geoDatabasePath) - if err != nil { - panic(err) - } - defer func() { - _ = out.Close() - }() - - _, err = io.Copy(out, resp.Body) - if err != nil { - panic(err) - } - } -} - -func Test_Filter(t *testing.T) { - setup() - - flow := conntrack.NewFlow( - syscall.IPPROTO_TCP, - conntrack.StatusAssured, - netip.MustParseAddr("10.19.80.100"), netip.MustParseAddr("78.47.60.169"), - 4711, 443, - 60, 0, - ) - - event := conntrack.Event{ - Type: conntrack.EventNew, - Flow: &flow, - } - - f := &Filter{} - matched := f.Apply(event) - assert.False(t, matched, "no filters") - - f = &Filter{ - Protocols: []string{"TCP"}, - } - matched = f.Apply(event) - assert.True(t, matched, "protocol filter TCP") - - f = &Filter{ - EventTypes: []string{"NEW"}, - } - matched = f.Apply(event) - assert.True(t, matched, "event type filter NEW") - - f = &Filter{ - EventTypes: []string{"UPDATE"}, - } - matched = f.Apply(event) - assert.False(t, matched, "event type filter UPDATE") - - f = &Filter{ - EventTypes: []string{"DESTROY"}, - } - matched = f.Apply(event) - assert.False(t, matched, "event type filter DESTROY") - - f = &Filter{ - EventTypes: []string{"NEW", "DESTROY"}, - } - matched = f.Apply(event) - assert.True(t, matched, "event type filter NEW, DESTROY") - - f = &Filter{ - Protocols: []string{"UDP"}, - } - matched = f.Apply(event) - assert.False(t, matched, "protocol filter UDP") - - f = &Filter{ - Protocols: []string{"UDP", "TCP"}, - } - matched = f.Apply(event) - assert.True(t, matched, "protocol filter UDP, TCP") - - f = &Filter{ - Protocols: []string{"ICMP"}, - } - matched = f.Apply(event) - assert.False(t, matched, "bad protocol filter") - - f = &Filter{ - Networks: FilterNetworks{ - Destinations: []string{"PUBLIC"}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "destination network filter PUBLIC") - - f = &Filter{ - Networks: FilterNetworks{ - Destinations: []string{"PRIVATE"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "destination network filter PRIVATE") - - f = &Filter{ - Networks: FilterNetworks{ - Destinations: []string{"LOCAL"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "destination network filter LOCAL") - - f = &Filter{ - Networks: FilterNetworks{ - Destinations: []string{"MULTICAST"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source network filter MULTICAST") - - f = &Filter{ - Networks: FilterNetworks{ - Sources: []string{"PUBLIC"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source network filter PUBLIC") - - f = &Filter{ - Networks: FilterNetworks{ - Sources: []string{"PRIVATE"}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "source network filter PRIVATE") - - f = &Filter{ - Networks: FilterNetworks{ - Sources: []string{"LOCAL"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source network filter LOCAL") - - f = &Filter{ - Networks: FilterNetworks{ - Sources: []string{"MULTICAST"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source network filter MULTICAST") - - f = &Filter{ - Addresses: FilterAddresses{ - Destinations: []string{"78.47.60.169"}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "destination address filter match") - - f = &Filter{ - Addresses: FilterAddresses{ - Destinations: []string{"78.47.60.170"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "destination address filter no match") - - f = &Filter{ - Addresses: FilterAddresses{ - Sources: []string{"10.19.80.100"}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "source address filter match") - - f = &Filter{ - Addresses: FilterAddresses{ - Sources: []string{"10.19.80.200"}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source address filter no match") - - f = &Filter{ - Ports: FilterPorts{ - Destinations: []uint{443}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "destination port filter match") - - f = &Filter{ - Ports: FilterPorts{ - Destinations: []uint{80}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "destination port filter no match") - - f = &Filter{ - Ports: FilterPorts{ - Sources: []uint{4711}, - }, - } - matched = f.Apply(event) - assert.True(t, matched, "source port filter match") - - f = &Filter{ - Ports: FilterPorts{ - Sources: []uint{1234}, - }, - } - matched = f.Apply(event) - assert.False(t, matched, "source port filter no match") -} diff --git a/internal/filter/lexer.go b/internal/filter/lexer.go new file mode 100644 index 0000000..4f780d0 --- /dev/null +++ b/internal/filter/lexer.go @@ -0,0 +1,172 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +import ( + "fmt" + "strings" + "unicode" +) + +type TokenType int + +const ( + TokenEOF TokenType = iota + TokenIdent + TokenNumber + TokenComma + TokenDash + TokenSlash + TokenColon + TokenDot + TokenLParen + TokenRParen + TokenAnd + TokenOr + TokenNot + TokenLog + TokenDrop + TokenEventType + TokenProtocol + TokenSource + TokenDestination + TokenAddress + TokenNetwork + TokenPort + TokenOn + TokenAny +) + +type Token struct { + Type TokenType + Value string + Pos int +} + +type Lexer struct { + input string + pos int + ch rune +} + +func NewLexer(input string) *Lexer { + l := &Lexer{input: input} + l.readChar() + return l +} + +func (l *Lexer) readChar() { + if l.pos >= len(l.input) { + l.ch = 0 + } else { + l.ch = rune(l.input[l.pos]) + } + l.pos++ +} + +func (l *Lexer) skipWhitespace() { + for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' { + l.readChar() + } +} + +func (l *Lexer) readIdentifier() string { + start := l.pos - 1 + for unicode.IsLetter(l.ch) || unicode.IsDigit(l.ch) || l.ch == '_' { + l.readChar() + } + return l.input[start : l.pos-1] +} + +func (l *Lexer) readNumber() string { + start := l.pos - 1 + for unicode.IsDigit(l.ch) { + l.readChar() + } + return l.input[start : l.pos-1] +} + +func (l *Lexer) NextToken() (Token, error) { + l.skipWhitespace() + + tok := Token{Pos: l.pos - 1} + + switch l.ch { + case 0: + tok.Type = TokenEOF + case ',': + tok.Type = TokenComma + tok.Value = "," + l.readChar() + case '-': + tok.Type = TokenDash + tok.Value = "-" + l.readChar() + case '/': + tok.Type = TokenSlash + tok.Value = "/" + l.readChar() + case ':': + tok.Type = TokenColon + tok.Value = ":" + l.readChar() + case '.': + tok.Type = TokenDot + tok.Value = "." + l.readChar() + case '(': + tok.Type = TokenLParen + tok.Value = "(" + l.readChar() + case ')': + tok.Type = TokenRParen + tok.Value = ")" + l.readChar() + case '!': + tok.Type = TokenNot + tok.Value = "!" + l.readChar() + default: + if unicode.IsLetter(l.ch) { + tok.Value = l.readIdentifier() + tok.Type = l.lookupKeyword(tok.Value) + } else if unicode.IsDigit(l.ch) { + tok.Value = l.readNumber() + tok.Type = TokenNumber + } else { + return tok, fmt.Errorf("unexpected character: %c at position %d", l.ch, l.pos-1) + } + } + + return tok, nil +} + +func (l *Lexer) lookupKeyword(ident string) TokenType { + keywords := map[string]TokenType{ + "and": TokenAnd, + "or": TokenOr, + "not": TokenNot, + "log": TokenLog, + "drop": TokenDrop, + "type": TokenEventType, + "protocol": TokenProtocol, + "source": TokenSource, + "src": TokenSource, + "destination": TokenDestination, + "dst": TokenDestination, + "dest": TokenDestination, + "address": TokenAddress, + "network": TokenNetwork, + "port": TokenPort, + "on": TokenOn, + "any": TokenAny, + } + + lower := strings.ToLower(ident) + if tokType, ok := keywords[lower]; ok { + return tokType + } + return TokenIdent +} diff --git a/internal/filter/parser.go b/internal/filter/parser.go new file mode 100644 index 0000000..b6847ed --- /dev/null +++ b/internal/filter/parser.go @@ -0,0 +1,504 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +import ( + "fmt" + "strconv" + "strings" +) + +type Parser struct { + lexer *Lexer + current Token + peek Token +} + +func NewParser(input string) (*Parser, error) { + p := &Parser{lexer: NewLexer(input)} + // Read two tokens to initialize current and peek + if err := p.nextToken(); err != nil { + return nil, err + } + if err := p.nextToken(); err != nil { + return nil, err + } + return p, nil +} + +func (p *Parser) nextToken() error { + p.current = p.peek + tok, err := p.lexer.NextToken() + if err != nil { + return err + } + p.peek = tok + return nil +} + +func (p *Parser) expect(tokType TokenType) error { + if p.current.Type != tokType { + return fmt.Errorf("expected token %v, got %v (%s) at position %d", + tokType, p.current.Type, p.current.Value, p.current.Pos) + } + return p.nextToken() +} + +// ParseRule parses a complete rule: action expression +func (p *Parser) ParseRule() (*Rule, error) { + rule := &Rule{} + + // Parse action + switch p.current.Type { + case TokenLog: + rule.Action = ActionLog + if err := p.nextToken(); err != nil { + return nil, err + } + case TokenDrop: + rule.Action = ActionDrop + if err := p.nextToken(); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("expected 'log' or 'drop', got '%s' at position %d", + p.current.Value, p.current.Pos) + } + + // Parse expression + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + rule.Expr = expr + + // Expect EOF + if p.current.Type != TokenEOF { + return nil, fmt.Errorf("unexpected token after rule: %s at position %d", + p.current.Value, p.current.Pos) + } + + return rule, nil +} + +// parseExpression parses: orExpr +func (p *Parser) parseExpression() (ExprNode, error) { + return p.parseOrExpr() +} + +// parseOrExpr parses: andExpr { ("," | "or") andExpr } +func (p *Parser) parseOrExpr() (ExprNode, error) { + left, err := p.parseAndExpr() + if err != nil { + return nil, err + } + + for p.current.Type == TokenOr || p.current.Type == TokenComma { + if err := p.nextToken(); err != nil { + return nil, err + } + right, err := p.parseAndExpr() + if err != nil { + return nil, err + } + left = BinaryExpr{Op: OpOr, Left: left, Right: right} + } + + return left, nil +} + +// parseAndExpr parses: notExpr { "and" notExpr } +func (p *Parser) parseAndExpr() (ExprNode, error) { + left, err := p.parseNotExpr() + if err != nil { + return nil, err + } + + for p.current.Type == TokenAnd { + if err := p.nextToken(); err != nil { + return nil, err + } + right, err := p.parseNotExpr() + if err != nil { + return nil, err + } + left = BinaryExpr{Op: OpAnd, Left: left, Right: right} + } + + return left, nil +} + +// parseNotExpr parses: [ "not" | "!" ] primary +func (p *Parser) parseNotExpr() (ExprNode, error) { + if p.current.Type == TokenNot { + if err := p.nextToken(); err != nil { + return nil, err + } + expr, err := p.parsePrimary() + if err != nil { + return nil, err + } + return UnaryExpr{Op: OpNot, Expr: expr}, nil + } + return p.parsePrimary() +} + +// parsePrimary parses: predicate | "(" expression ")" +func (p *Parser) parsePrimary() (ExprNode, error) { + if p.current.Type == TokenLParen { + if err := p.nextToken(); err != nil { + return nil, err + } + expr, err := p.parseExpression() + if err != nil { + return nil, err + } + if err := p.expect(TokenRParen); err != nil { + return nil, err + } + return expr, nil + } + return p.parsePredicate() +} + +// parsePredicate parses various predicate types +func (p *Parser) parsePredicate() (ExprNode, error) { + switch p.current.Type { + case TokenEventType: + return p.parseTypePredicate() + case TokenProtocol: + return p.parseProtocolPredicate() + case TokenSource, TokenDestination: + return p.parseDirectionalPredicate() + case TokenOn: + return p.parsePortPredicate() + case TokenAny: + // Parse "any" - matches everything + if err := p.nextToken(); err != nil { + return nil, err + } + return AnyPredicate{}, nil + default: + return nil, fmt.Errorf("expected predicate keyword, got '%s' at position %d", + p.current.Value, p.current.Pos) + } +} + +// parseTypePredicate parses: "type" IDENT_LIST +func (p *Parser) parseTypePredicate() (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + + types, err := p.parseIdentList() + if err != nil { + return nil, err + } + + // Normalize and validate type names + validTypes := map[string]bool{ + "NEW": true, + "UPDATE": true, + "DESTROY": true, + } + normalized := make([]string, len(types)) + for i, t := range types { + normalized[i] = strings.ToUpper(t) + if !validTypes[normalized[i]] { + return nil, fmt.Errorf("invalid event type '%s' at position %d, valid types are: NEW, UPDATE, DESTROY", t, p.current.Pos) + } + } + + return TypePredicate{Types: normalized}, nil +} + +// parseProtocolPredicate parses: "protocol" IDENT_LIST +func (p *Parser) parseProtocolPredicate() (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + + protocols, err := p.parseIdentList() + if err != nil { + return nil, err + } + + // Normalize and validate protocol names + validProtocols := map[string]bool{ + "TCP": true, + "UDP": true, + } + normalized := make([]string, len(protocols)) + for i, proto := range protocols { + normalized[i] = strings.ToUpper(proto) + if !validProtocols[normalized[i]] { + return nil, fmt.Errorf("invalid protocol '%s' at position %d, valid protocols are: TCP, UDP", proto, p.current.Pos) + } + } + + return ProtocolPredicate{Protocols: normalized}, nil +} + +// parseDirectionalPredicate handles source/destination address, network, or port +func (p *Parser) parseDirectionalPredicate() (ExprNode, error) { + direction := p.current.Value + if err := p.nextToken(); err != nil { + return nil, err + } + + switch p.current.Type { + case TokenNetwork: + return p.parseNetworkPredicate(direction) + case TokenAddress: + return p.parseAddressPredicate(direction) + case TokenPort: + return p.parsePortPredicateWithDirection(direction) + default: + return nil, fmt.Errorf("expected 'network', 'address', or 'port' after '%s', got '%s' at position %d", + direction, p.current.Value, p.current.Pos) + } +} + +// parseNetworkPredicate parses: direction "network" IDENT +func (p *Parser) parseNetworkPredicate(direction string) (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + + networks, err := p.parseIdentList() + if err != nil { + return nil, err + } + + // Normalize and validate network names + validNetworks := map[string]bool{ + "LOCAL": true, + "PRIVATE": true, + "PUBLIC": true, + "MULTICAST": true, + } + normalized := make([]string, len(networks)) + for i, net := range networks { + normalized[i] = strings.ToUpper(net) + if !validNetworks[normalized[i]] { + return nil, fmt.Errorf("invalid network type '%s' at position %d, valid networks are: LOCAL, PRIVATE, PUBLIC, MULTICAST", net, p.current.Pos) + } + } + + return NetworkPredicate{Direction: direction, Networks: normalized}, nil +} + +// parseAddressPredicate parses: direction "address" (IP | CIDR) ["on" "port" PORT_SPEC] +func (p *Parser) parseAddressPredicate(direction string) (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + + addresses, err := p.parseAddressList() + if err != nil { + return nil, err + } + + var ports []uint16 + // Check for optional "on port" + if p.current.Type == TokenOn { + if err := p.nextToken(); err != nil { + return nil, err + } + if err := p.expect(TokenPort); err != nil { + return nil, err + } + ports, err = p.parsePortSpec() + if err != nil { + return nil, err + } + } + + return AddressPredicate{Direction: direction, Addresses: addresses, Ports: ports}, nil +} + +// parsePortPredicate parses: "on" "port" PORT_SPEC (no direction) +func (p *Parser) parsePortPredicate() (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + if err := p.expect(TokenPort); err != nil { + return nil, err + } + + ports, err := p.parsePortSpec() + if err != nil { + return nil, err + } + + // "on port" without direction means both source and destination + return PortPredicate{Direction: "both", Ports: ports}, nil +} + +// parsePortPredicateWithDirection parses: direction "port" PORT_SPEC +func (p *Parser) parsePortPredicateWithDirection(direction string) (ExprNode, error) { + if err := p.nextToken(); err != nil { + return nil, err + } + + ports, err := p.parsePortSpec() + if err != nil { + return nil, err + } + + return PortPredicate{Direction: direction, Ports: ports}, nil +} + +// parseIdentList parses: IDENT { "," IDENT } +func (p *Parser) parseIdentList() ([]string, error) { + var idents []string + + if p.current.Type != TokenIdent { + return nil, fmt.Errorf("expected identifier, got '%s' at position %d", + p.current.Value, p.current.Pos) + } + + idents = append(idents, p.current.Value) + if err := p.nextToken(); err != nil { + return nil, err + } + + for p.current.Type == TokenComma { + if err := p.nextToken(); err != nil { + return nil, err + } + if p.current.Type != TokenIdent { + return nil, fmt.Errorf("expected identifier after comma, got '%s' at position %d", + p.current.Value, p.current.Pos) + } + idents = append(idents, p.current.Value) + if err := p.nextToken(); err != nil { + return nil, err + } + } + + return idents, nil +} + +// parseAddressList parses IP addresses or CIDR ranges +func (p *Parser) parseAddressList() ([]string, error) { + var addresses []string + + addr, err := p.parseAddress() + if err != nil { + return nil, err + } + addresses = append(addresses, addr) + + for p.current.Type == TokenComma { + if err := p.nextToken(); err != nil { + return nil, err + } + addr, err := p.parseAddress() + if err != nil { + return nil, err + } + addresses = append(addresses, addr) + } + + return addresses, nil +} + +// parseAddress parses an IP address or CIDR range +func (p *Parser) parseAddress() (string, error) { + var parts []string + + // Parse IPv4, IPv6, or CIDR + for { + switch p.current.Type { + case TokenNumber, TokenIdent: + parts = append(parts, p.current.Value) + if err := p.nextToken(); err != nil { + return "", err + } + case TokenDot, TokenColon, TokenSlash: + parts = append(parts, p.current.Value) + if err := p.nextToken(); err != nil { + return "", err + } + default: + goto done + } + } +done: + + if len(parts) == 0 { + return "", fmt.Errorf("expected IP address at position %d", p.current.Pos) + } + + return strings.Join(parts, ""), nil +} + +// parsePortSpec parses: NUMBER | NUMBER "-" NUMBER | NUMBER { "," NUMBER } +func (p *Parser) parsePortSpec() ([]uint16, error) { + var ports []uint16 + + if p.current.Type != TokenNumber { + return nil, fmt.Errorf("expected port number, got '%s' at position %d", + p.current.Value, p.current.Pos) + } + + port, err := strconv.ParseUint(p.current.Value, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid port number '%s' at position %d", + p.current.Value, p.current.Pos) + } + ports = append(ports, uint16(port)) + + if err := p.nextToken(); err != nil { + return nil, err + } + + // Check for range (e.g., 80-90) + if p.current.Type == TokenDash { + if err := p.nextToken(); err != nil { + return nil, err + } + if p.current.Type != TokenNumber { + return nil, fmt.Errorf("expected port number after '-', got '%s' at position %d", + p.current.Value, p.current.Pos) + } + endPort, err := strconv.ParseUint(p.current.Value, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid port number '%s' at position %d", + p.current.Value, p.current.Pos) + } + // Expand range + for p := ports[0] + 1; p <= uint16(endPort); p++ { + ports = append(ports, p) + } + if err := p.nextToken(); err != nil { + return nil, err + } + return ports, nil + } + + // Check for comma-separated list + for p.current.Type == TokenComma { + if err := p.nextToken(); err != nil { + return nil, err + } + if p.current.Type != TokenNumber { + return nil, fmt.Errorf("expected port number after comma, got '%s' at position %d", + p.current.Value, p.current.Pos) + } + port, err := strconv.ParseUint(p.current.Value, 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid port number '%s' at position %d", + p.current.Value, p.current.Pos) + } + ports = append(ports, uint16(port)) + if err := p.nextToken(); err != nil { + return nil, err + } + } + + return ports, nil +} diff --git a/internal/filter/parser_test.go b/internal/filter/parser_test.go new file mode 100644 index 0000000..c1c1fbc --- /dev/null +++ b/internal/filter/parser_test.go @@ -0,0 +1,428 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package filter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParser_BasicActions(t *testing.T) { + tests := []struct { + name string + input string + action Action + valid bool + }{ + {"log type", "log type NEW", ActionLog, true}, + {"drop type", "drop type NEW", ActionDrop, true}, + {"LOG uppercase", "LOG type NEW", ActionLog, true}, + {"DROP uppercase", "DROP type NEW", ActionDrop, true}, + {"missing action", "type NEW", ActionLog, false}, + {"invalid action", "permit type NEW", ActionLog, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + if !tt.valid { + if err == nil { + _, err = parser.ParseRule() + } + assert.Error(t, err) + return + } + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + assert.Equal(t, tt.action, rule.Action) + }) + } +} + +func TestParser_TypePredicate(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"single type", "log type NEW", []string{"NEW"}}, + {"multiple types", "log type NEW,UPDATE", []string{"NEW", "UPDATE"}}, + {"all types", "log type NEW,UPDATE,DESTROY", []string{"NEW", "UPDATE", "DESTROY"}}, + {"lowercase", "log type new,update", []string{"NEW", "UPDATE"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + typePred, ok := rule.Expr.(TypePredicate) + require.True(t, ok, "expected TypePredicate") + assert.ElementsMatch(t, tt.expected, typePred.Types) + }) + } +} + +func TestParser_ProtocolPredicate(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"single protocol", "log protocol TCP", []string{"TCP"}}, + {"multiple protocols", "log protocol TCP,UDP", []string{"TCP", "UDP"}}, + {"lowercase", "log protocol tcp", []string{"TCP"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + protoPred, ok := rule.Expr.(ProtocolPredicate) + require.True(t, ok, "expected ProtocolPredicate") + assert.ElementsMatch(t, tt.expected, protoPred.Protocols) + }) + } +} + +func TestParser_NetworkPredicate(t *testing.T) { + tests := []struct { + name string + input string + direction string + networks []string + }{ + {"source private", "log source network PRIVATE", "source", []string{"PRIVATE"}}, + {"destination public", "log destination network PUBLIC", "destination", []string{"PUBLIC"}}, + {"dst abbreviation", "log dst network LOCAL", "dst", []string{"LOCAL"}}, + {"src abbreviation", "log src network MULTICAST", "src", []string{"MULTICAST"}}, + {"multiple networks", "log source network PRIVATE,LOCAL", "source", []string{"PRIVATE", "LOCAL"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + netPred, ok := rule.Expr.(NetworkPredicate) + require.True(t, ok, "expected NetworkPredicate") + assert.Equal(t, tt.direction, netPred.Direction) + assert.ElementsMatch(t, tt.networks, netPred.Networks) + }) + } +} + +func TestParser_AddressPredicate(t *testing.T) { + tests := []struct { + name string + input string + direction string + addresses []string + ports []uint16 + }{ + {"ipv4 address", "log destination address 8.8.8.8", "destination", []string{"8.8.8.8"}, nil}, + {"ipv6 address", "log source address 2001:db8::1", "source", []string{"2001:db8::1"}, nil}, + {"cidr", "log destination address 192.168.1.0/24", "destination", []string{"192.168.1.0/24"}, nil}, + {"address with port", "log destination address 10.19.80.100 on port 53", "destination", []string{"10.19.80.100"}, []uint16{53}}, + {"multiple addresses", "log source address 1.1.1.1,8.8.8.8", "source", []string{"1.1.1.1", "8.8.8.8"}, nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + addrPred, ok := rule.Expr.(AddressPredicate) + require.True(t, ok, "expected AddressPredicate") + assert.Equal(t, tt.direction, addrPred.Direction) + assert.ElementsMatch(t, tt.addresses, addrPred.Addresses) + if tt.ports != nil { + assert.ElementsMatch(t, tt.ports, addrPred.Ports) + } + }) + } +} + +func TestParser_PortPredicate(t *testing.T) { + tests := []struct { + name string + input string + direction string + ports []uint16 + }{ + {"single port", "log destination port 80", "destination", []uint16{80}}, + {"multiple ports", "log source port 80,443", "source", []uint16{80, 443}}, + {"port range", "log destination port 8000-8005", "destination", []uint16{8000, 8001, 8002, 8003, 8004, 8005}}, + {"on port both", "log on port 53", "both", []uint16{53}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + portPred, ok := rule.Expr.(PortPredicate) + require.True(t, ok, "expected PortPredicate") + assert.Equal(t, tt.direction, portPred.Direction) + assert.ElementsMatch(t, tt.ports, portPred.Ports) + }) + } +} + +func TestParser_AnyPredicate(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"log any", "log any"}, + {"drop any", "drop any"}, + {"ANY uppercase", "log ANY"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + _, ok := rule.Expr.(AnyPredicate) + assert.True(t, ok, "expected AnyPredicate") + }) + } +} + +func TestParser_BinaryExpressions(t *testing.T) { + tests := []struct { + name string + input string + op BinaryOp + }{ + {"and operator", "log type NEW and protocol TCP", OpAnd}, + {"or operator", "log type NEW or type UPDATE", OpOr}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + binExpr, ok := rule.Expr.(BinaryExpr) + require.True(t, ok, "expected BinaryExpr") + assert.Equal(t, tt.op, binExpr.Op) + }) + } +} + +func TestParser_UnaryExpression(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"not keyword", "log not type NEW"}, + {"exclamation", "log ! protocol TCP"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + unaryExpr, ok := rule.Expr.(UnaryExpr) + require.True(t, ok, "expected UnaryExpr") + assert.Equal(t, OpNot, unaryExpr.Op) + }) + } +} + +func TestParser_Parentheses(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"simple grouping", "log (type NEW)"}, + {"complex grouping", "log (type NEW and protocol TCP) or type UPDATE"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + _, err = parser.ParseRule() + require.NoError(t, err) + }) + } +} + +func TestParser_Precedence(t *testing.T) { + // Test that AND binds tighter than OR + input := "log type NEW or type UPDATE and protocol TCP" + parser, err := NewParser(input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + + // Should parse as: (type NEW) OR (type UPDATE AND protocol TCP) + orExpr, ok := rule.Expr.(BinaryExpr) + require.True(t, ok) + assert.Equal(t, OpOr, orExpr.Op) + + // Right side should be AND + andExpr, ok := orExpr.Right.(BinaryExpr) + require.True(t, ok) + assert.Equal(t, OpAnd, andExpr.Op) +} + +func TestParser_ComplexExamples(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + "example 1", + "drop destination address 8.8.8.8", + }, + { + "example 2", + "log protocol TCP and destination network PUBLIC", + }, + { + "example 3", + "drop destination address 10.19.80.100 on port 53", + }, + { + "example 4", + "log protocol TCP,UDP", + }, + { + "complex with negation", + "log not (type DESTROY and destination network PRIVATE)", + }, + { + "multiple conditions", + "drop source network LOCAL and destination port 22,23,3389", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + require.NoError(t, err) + rule, err := parser.ParseRule() + require.NoError(t, err) + assert.NotNil(t, rule) + }) + } +} + +func TestParser_InvalidInputs(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"empty", ""}, + {"only action", "log"}, + {"missing action", "type NEW"}, + {"invalid keyword", "log typo NEW"}, + {"unclosed paren", "log (type NEW"}, + {"extra tokens", "log type NEW extra stuff"}, + {"invalid port", "log destination port abc"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + if err == nil { + _, err = parser.ParseRule() + } + assert.Error(t, err) + }) + } +} + +func TestParser_InvalidTypeValidation(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"invalid type single", "log type FOO"}, + {"invalid type in list", "log type NEW,BAR"}, + {"invalid type complex", "log (destination address 78.47.60.169/32) and type NE"}, + {"typo in type", "log type NWE"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + if err == nil { + _, err = parser.ParseRule() + } + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid event type") + }) + } +} + +func TestParser_InvalidProtocolValidation(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"invalid protocol single", "log protocol ICMP"}, + {"invalid protocol in list", "log protocol TCP,ICMP"}, + {"invalid protocol complex", "log destination address 1.2.3.4 and protocol FOO"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + if err == nil { + _, err = parser.ParseRule() + } + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid protocol") + }) + } +} + +func TestParser_InvalidNetworkValidation(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"invalid network single", "log destination network INVALID"}, + {"invalid network in list", "log source network LOCAL,BOGUS"}, + {"invalid network complex", "log type NEW and destination network FOO"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.input) + if err == nil { + _, err = parser.ParseRule() + } + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid network type") + }) + } +} diff --git a/internal/geoip/geoip.go b/internal/geoip/geoip.go index 746c7b8..99ad1b2 100644 --- a/internal/geoip/geoip.go +++ b/internal/geoip/geoip.go @@ -5,17 +5,15 @@ Licensed under the MIT License, see LICENSE file in the project root for details package geoip import ( - "log/slog" + "fmt" "net/netip" + "strings" "github.com/oschwald/geoip2-golang/v2" ) -type Reader struct { - reader *geoip2.Reader -} - type GeoIP struct { + Reader *geoip2.Reader Database string } @@ -26,23 +24,30 @@ type Location struct { Lon float64 } -func Open(path string) (*Reader, error) { - slog.Debug("Opening GeoIP2 database", "path", path) - - reader, err := geoip2.Open(path) +func NewGeoIP(database string) (*GeoIP, error) { + reader, err := geoip2.Open(database) if err != nil { return nil, err } - return &Reader{reader: reader}, nil + metadata := reader.Metadata() + if !strings.HasSuffix(metadata.DatabaseType, "City") { + _ = reader.Close() + return nil, fmt.Errorf("invalid GeoIP2 database type: %s, expected City", metadata.DatabaseType) + } + + return &GeoIP{ + Reader: reader, + Database: database, + }, nil } -func (r *Reader) Close() error { - return r.reader.Close() +func (g *GeoIP) Close() error { + return g.Reader.Close() } -func (r *Reader) Location(ip netip.Addr) *Location { - record, err := r.reader.City(ip) +func (g *GeoIP) Location(ip netip.Addr) *Location { + record, err := g.Reader.City(ip) if err != nil { return nil } diff --git a/internal/geoip/geoip_test.go b/internal/geoip/geoip_test.go index 0e40a66..4022af7 100644 --- a/internal/geoip/geoip_test.go +++ b/internal/geoip/geoip_test.go @@ -11,15 +11,12 @@ import ( "os" "testing" + "github.com/oschwald/geoip2-golang/v2" "github.com/stretchr/testify/assert" ) const geoDatabasePath = "/tmp/GeoLite2-City.mmdb" -const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb" - -func setup() { - setupGeoDatabase() -} +const geoDatabaseUrl = "https://git.io/GeoLite2-City.mmdb" func setupGeoDatabase() { if _, err := os.Stat(geoDatabasePath); os.IsNotExist(err) { @@ -46,35 +43,59 @@ func setupGeoDatabase() { } } -func Test_GeoIP(t *testing.T) { - setup() +func newReturnsError_InvalidDatabase(t *testing.T) { + _, err := NewGeoIP("../../README.md") + assert.EqualError(t, err, "error opening database: invalid MaxMind DB file") +} - geo, err := Open("invalid-path.mmdb") - assert.Error(t, err, "open invalid path") - assert.Nil(t, geo, "no geoip instance") +func newReturnsInstance_ValidDatabase(t *testing.T) { + geoIP, err := NewGeoIP(geoDatabasePath) + assert.NoError(t, err) + assert.NotNil(t, geoIP) + assert.IsType(t, &GeoIP{}, geoIP) + assert.Equal(t, geoIP.Database, geoDatabasePath) + assert.IsType(t, geoIP.Reader, &geoip2.Reader{}) +} - geo, err = Open(geoDatabasePath) - assert.NoError(t, err, "open valid path") - assert.NotNil(t, geo, "geoip instance") - defer func() { - _ = geo.Close() - }() +func locationReturnsNil_UnresolvedIP(t *testing.T) { + geo, err := NewGeoIP(geoDatabasePath) + assert.NoError(t, err) + assert.NotNil(t, geo) - testees := map[string]string{ + for ipStr, desc := range map[string]string{ "::1": "local address", "10.19.80.12": "private address", "224.0.1.1": "multicast address", "172.66.43.195": "unresolved address", - } - - for ipStr, desc := range testees { + } { ip, _ := netip.ParseAddr(ipStr) location := geo.Location(ip) - assert.Nil(t, location, desc, "no location") + assert.Nil(t, location, desc) } +} + +func locationReturnsLocation_ResolvedIP(t *testing.T) { + geo, err := NewGeoIP(geoDatabasePath) + assert.NoError(t, err) + assert.NotNil(t, geo) ip, _ := netip.ParseAddr("63.176.75.230") location := geo.Location(ip) assert.NotNil(t, location, "resolved address") assert.IsType(t, &Location{}, location, "location type") } + +func TestGeoIP(t *testing.T) { + setupGeoDatabase() + + t.Run("New returns error for invalid database", newReturnsError_InvalidDatabase) + t.Run("New returns instance for valid database", newReturnsInstance_ValidDatabase) + t.Run("Location returns nil for unresolved IPs", locationReturnsNil_UnresolvedIP) + t.Run("Location returns location for resolved IPs", locationReturnsLocation_ResolvedIP) + + skip, ok := os.LookupEnv("KEEP_GEOIP_DB") + if ok && skip == "1" || skip == "true" { + return + } + _ = os.Remove(geoDatabasePath) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go index a856d5d..95fe99e 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -5,79 +5,36 @@ Licensed under the MIT License, see LICENSE file in the project root for details package logger import ( - "context" "fmt" "log/slog" "os" ) type Logger struct { - Format string - Level string + Logger *slog.Logger + Level slog.Level } -const ( - LevelTrace = slog.Level(-8) +var ( + Levels = []string{"debug", "info", "warn", "error"} + level slog.Level ) -var format string -var level slog.Level - -func (l *Logger) Initialize() error { - var logLevel slog.Level - switch l.Level { - case "trace": - logLevel = LevelTrace - case "debug": - logLevel = slog.LevelDebug - case "info": - logLevel = slog.LevelInfo - case "error": - logLevel = slog.LevelError - case "": - logLevel = slog.LevelInfo - default: - return fmt.Errorf("unknown log level: %q", l.Level) +func NewLogger(levelStr string) (*Logger, error) { + err := level.UnmarshalText([]byte(levelStr)) + if err != nil { + return nil, fmt.Errorf("unknown log level: %q", levelStr) } - level = logLevel - loggerOptions := &slog.HandlerOptions{ - Level: logLevel, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == slog.LevelKey { - if a.Value.String() == "DEBUG-4" { - a.Value = slog.StringValue("TRACE") - } - } - return a - }, - } - if logLevel == LevelTrace { - loggerOptions.AddSource = true + o := &slog.HandlerOptions{Level: level} + if level == slog.LevelDebug { + o.AddSource = true } - var logger *slog.Logger - switch l.Format { - case "json": - logger = slog.New(slog.NewJSONHandler(os.Stderr, loggerOptions)) - case "text", "": - logger = slog.New(slog.NewTextHandler(os.Stderr, loggerOptions)) - default: - return fmt.Errorf("unknown log format: %q", l.Format) - } - format = l.Format - - slog.SetDefault(logger) - - return nil -} - -func Trace(msg string, args ...any) { - slog.Log(context.Background(), LevelTrace, msg, args...) -} - -func Format() string { - return format + return &Logger{ + Logger: slog.New(slog.NewJSONHandler(os.Stderr, o)), + Level: level, + }, nil } func Level() slog.Level { diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index e5972a8..4658091 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -5,43 +5,45 @@ Licensed under the MIT License, see LICENSE file in the project root for details package logger import ( + "log/slog" + "strings" "testing" "github.com/stretchr/testify/assert" ) -func Test_Logger(t *testing.T) { - testees := []Logger{ - {Format: "json", Level: "trace"}, - {Format: "json", Level: "debug"}, - {Format: "json", Level: "info"}, - {Format: "json", Level: "error"}, - {Format: "text", Level: "trace"}, - {Format: "text", Level: "debug"}, - {Format: "text", Level: "info"}, - {Format: "text", Level: "error"}, - } +func newReturnsError_UnknownLogLevel(t *testing.T) { + _, err := NewLogger("unknown") + assert.Error(t, err) + assert.EqualError(t, err, `unknown log level: "unknown"`) +} - for _, logger := range testees { - err := logger.Initialize() - assert.NoError(t, err, "valid logger config") +func newReturnsLogger_KnownLogLevels(t *testing.T) { + for _, level := range []string{"debug", "info", "warn", "error"} { + logger, err := NewLogger(level) + assert.NoError(t, err) + assert.NotNil(t, logger) + assert.IsType(t, logger.Logger, &slog.Logger{}) + assert.Equal(t, strings.ToUpper(level), logger.Level.String()) } +} - logger := Logger{} - err := logger.Initialize() - assert.NoError(t, err, "default logger config") - - logger = Logger{Format: "xml", Level: "info"} - err = logger.Initialize() - assert.Errorf(t, err, "unknown log format: %q", "xml") - - logger = Logger{Format: "json", Level: "panic"} - err = logger.Initialize() - assert.Errorf(t, err, "unknown log level: %q", "error") +func levelReturnsCorrectLevel(t *testing.T) { + for str, level := range map[string]slog.Level{ + "debug": slog.LevelDebug, + "info": slog.LevelInfo, + "warn": slog.LevelWarn, + "error": slog.LevelError, + } { + _, err := NewLogger(str) + assert.NoError(t, err) + + assert.Equal(t, level, Level()) + } +} - logger = Logger{Format: "json", Level: "debug"} - err = logger.Initialize() - assert.NoError(t, err, "valid logger config") - assert.Equal(t, "json", logger.Format, "logger format set by Initialize args") - assert.Equal(t, "debug", logger.Level, "logger level set by Initialize args") +func TestLogger(t *testing.T) { + t.Run("New returns error for unknown log level", newReturnsError_UnknownLogLevel) + t.Run("New returns logger for known log levels", newReturnsLogger_KnownLogLevels) + t.Run("Level returns correct level", levelReturnsCorrectLevel) } diff --git a/internal/record/record.go b/internal/record/record.go index 3c3c8fd..b64f3da 100644 --- a/internal/record/record.go +++ b/internal/record/record.go @@ -11,11 +11,10 @@ import ( "github.com/ti-mo/conntrack" "github.com/tschaefer/conntrackd/internal/geoip" - "github.com/tschaefer/conntrackd/internal/logger" ) -func Record(event conntrack.Event, geo *geoip.Reader, sink *slog.Logger) { - logger.Trace("Conntrack Event", "data", event) +func Record(event conntrack.Event, geo *geoip.GeoIP, logger *slog.Logger) { + slog.Debug("Conntrack Event", "data", event) protocols := map[int]string{ syscall.IPPROTO_TCP: "TCP", @@ -80,5 +79,5 @@ func Record(event conntrack.Event, geo *geoip.Reader, sink *slog.Logger) { event.Flow.TupleOrig.Proto.DestinationPort, ) - sink.Info(msg, append(established, location...)...) + logger.Info(msg, append(established, location...)...) } diff --git a/internal/record/record_test.go b/internal/record/record_test.go index 3308229..ef7aebc 100644 --- a/internal/record/record_test.go +++ b/internal/record/record_test.go @@ -23,7 +23,7 @@ import ( ) const geoDatabasePath = "/tmp/GeoLite2-City.mmdb" -const geoDatabaseUrl = "https://github.com/P3TERX/GeoLite.mmdb/releases/latest/download/GeoLite2-City.mmdb" +const geoDatabaseUrl = "https://git.io/GeoLite2-City.mmdb" var log bytes.Buffer @@ -65,9 +65,8 @@ func setupGeoDatabase() { } } -func Test_Record(t *testing.T) { +func recordLogsBasicData(t *testing.T) { logger := setupLogger() - setupGeoDatabase() flow := conntrack.NewFlow( syscall.IPPROTO_TCP, @@ -82,7 +81,7 @@ func Test_Record(t *testing.T) { Flow: &flow, } - var geo *geoip.Reader + var geo *geoip.GeoIP Record(event, geo, logger) var result map[string]any err := json.Unmarshal(log.Bytes(), &result) @@ -91,21 +90,56 @@ func Test_Record(t *testing.T) { wanted := []string{"level", "time", "type", "flow", "prot", "src", "dst", "sport", "dport"} got := slices.Sorted(maps.Keys(result)) - assert.ElementsMatch(t, wanted, got, "record keys without geoip") + assert.ElementsMatch(t, wanted, got, "record basic keys") + + log.Reset() +} + +func recordLogsWithGeoIPData(t *testing.T) { + logger := setupLogger() + + flow := conntrack.NewFlow( + syscall.IPPROTO_TCP, + conntrack.StatusAssured, + netip.MustParseAddr("10.19.80.100"), netip.MustParseAddr("78.47.60.169"), + 4711, 443, + 60, 0, + ) + + event := conntrack.Event{ + Type: conntrack.EventNew, + Flow: &flow, + } - geo, err = geoip.Open(geoDatabasePath) + geo, err := geoip.NewGeoIP(geoDatabasePath) assert.NoError(t, err) defer func() { _ = geo.Close() }() - log.Reset() Record(event, geo, logger) + var result map[string]any err = json.Unmarshal(log.Bytes(), &result) assert.NoError(t, err) - wanted = append(wanted, []string{"city", "country", "lat", "lon"}...) - got = slices.Sorted(maps.Keys(result)) + wanted := []string{"level", "time", + "type", "flow", "prot", "src", "dst", "sport", "dport", + "city", "country", "lat", "lon"} + got := slices.Sorted(maps.Keys(result)) assert.ElementsMatch(t, wanted, got, "record keys with geoip") + log.Reset() +} + +func TestRecord(t *testing.T) { + setupGeoDatabase() + + t.Run("Record logs basic data", recordLogsBasicData) + t.Run("Record logs with GeoIP data", recordLogsWithGeoIPData) + + skip, ok := os.LookupEnv("KEEP_GEOIP_DB") + if ok && skip == "1" || skip == "true" { + return + } + _ = os.Remove(geoDatabasePath) } diff --git a/internal/service/service.go b/internal/service/service.go index 38faaa6..3ee43c4 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -7,7 +7,6 @@ package service import ( "context" "log/slog" - "os/signal" "syscall" "github.com/mdlayher/netlink" @@ -23,39 +22,75 @@ import ( ) type Service struct { - Filter filter.Filter - Logger logger.Logger - GeoIP geoip.GeoIP - Sink sink.Sink + Filter *filter.Filter + GeoIP *geoip.GeoIP + Sink *sink.Sink + Logger *slog.Logger } -func (s *Service) handler(geo *geoip.Reader, sink *slog.Logger) error { - con, err := conntrack.Dial(nil) +func NewService(logger *logger.Logger, geoip *geoip.GeoIP, filter *filter.Filter, sink *sink.Sink) (*Service, error) { + slog.SetDefault(logger.Logger) + + return &Service{ + Filter: filter, + GeoIP: geoip, + Sink: sink, + Logger: logger.Logger, + }, nil +} + +func (s *Service) Run(ctx context.Context) bool { + slog.Info("Starting conntrack listener.", + "release", version.Release(), "commit", version.Commit(), + ) + + con, err := s.setupConntrack() if err != nil { - slog.Error("Failed to dial conntrack.", "error", err) - return err + return false } + defer func() { + _ = con.Close() + }() - evCh := make(chan conntrack.Event, 1024) - errCh, err := con.Listen(evCh, 4, netfilter.GroupsCT) + evCh, errCh, err := s.startEventListener(con) if err != nil { _ = con.Close() - slog.Error("Failed to listen to conntrack events.", "error", err) - return err + return false + } + + g := s.startEventProcessor(ctx, evCh) + + return s.handleShutdown(ctx, con, g, errCh) +} + +func (s *Service) setupConntrack() (*conntrack.Conn, error) { + con, err := conntrack.Dial(nil) + if err != nil { + slog.Error("Failed to dial conntrack.", "error", err) + return nil, err } if err := con.SetOption(netlink.ListenAllNSID|netlink.NoENOBUFS, true); err != nil { _ = con.Close() slog.Error("Failed to set conntrack listen options.", "error", err) - return err + return nil, err } - defer func() { - _ = con.Close() - }() - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer stop() + return con, nil +} +func (s *Service) startEventListener(con *conntrack.Conn) (chan conntrack.Event, chan error, error) { + evCh := make(chan conntrack.Event, 1024) + errCh, err := con.Listen(evCh, 4, netfilter.GroupsCT) + if err != nil { + slog.Error("Failed to listen to conntrack events.", "error", err) + return nil, nil, err + } + + return evCh, errCh, nil +} + +func (s *Service) startEventProcessor(ctx context.Context, evCh chan conntrack.Event) *errgroup.Group { var g errgroup.Group g.Go(func() error { for { @@ -66,73 +101,53 @@ func (s *Service) handler(geo *geoip.Reader, sink *slog.Logger) error { if !ok { return nil } - go func() { - if !s.Filter.Apply(event) { - record.Record(event, geo, sink) - } - }() + go s.processEvent(event) } } }) + return &g +} + +func (s *Service) processEvent(event conntrack.Event) { + // Only process TCP and UDP events, ignore all other protocols (ICMP, etc.) + protocol := event.Flow.TupleOrig.Proto.Protocol + if protocol != syscall.IPPROTO_TCP && protocol != syscall.IPPROTO_UDP { + return + } + shouldRecord := true + if s.Filter != nil { + _, shouldLog, _ := s.Filter.Evaluate(event) + shouldRecord = shouldLog + } + + if shouldRecord { + record.Record(event, s.GeoIP, s.Sink.Logger) + } +} + +func (s *Service) handleShutdown(ctx context.Context, con *conntrack.Conn, g *errgroup.Group, errCh chan error) bool { select { case err := <-errCh: if err != nil { slog.Error("Conntrack listener error.", "error", err) - stop() _ = con.Close() if gErr := g.Wait(); gErr != nil { slog.Error("Event loop returned error during shutdown.", "error", gErr) } - return err + return false } - stop() _ = con.Close() if gErr := g.Wait(); gErr != nil { slog.Error("Event loop returned error during shutdown.", "error", gErr) } - return nil + return true case <-ctx.Done(): slog.Info("Shutting down conntrack listener.") _ = con.Close() if gErr := g.Wait(); gErr != nil { slog.Error("Event loop returned error during shutdown.", "error", gErr) } - return nil + return true } } - -func (s *Service) Run() error { - if err := s.Logger.Initialize(); err != nil { - slog.Error("Failed to initialize logger.", "error", err) - return err - } - - slog.Debug("Running Service.", "data", s) - - slog.Info("Starting conntrack listener.", - "release", version.Release(), "commit", version.Commit(), - "filter", s.Filter, - "geoip", s.GeoIP.Database, - ) - - sink, err := s.Sink.Initialize() - if err != nil { - slog.Error("Failed to initialize sink.", "error", err) - return err - } - - var geo *geoip.Reader - if s.GeoIP.Database != "" { - geo, err = geoip.Open(s.GeoIP.Database) - if err != nil { - slog.Error("Failed to open geoip database.", "error", err) - return err - } - defer func() { - _ = geo.Close() - }() - } - - return s.handler(geo, sink) -} diff --git a/internal/sink/journal.go b/internal/sink/journal.go index 474db64..cb10a3d 100644 --- a/internal/sink/journal.go +++ b/internal/sink/journal.go @@ -15,8 +15,6 @@ type Journal struct { } func (j *Journal) TargetJournal(options *slog.HandlerOptions) (slog.Handler, error) { - slog.Debug("Initializing systemd journal sink.") - slogjournal.FieldPrefix = "EVENT" o := &slogjournal.Option{ Level: options.Level, diff --git a/internal/sink/loki.go b/internal/sink/loki.go index 2939763..81d0860 100644 --- a/internal/sink/loki.go +++ b/internal/sink/loki.go @@ -12,7 +12,8 @@ import ( "os" "strings" - klog "github.com/go-kit/log" + kitlog "github.com/go-kit/log" + kitlevel "github.com/go-kit/log/level" "github.com/grafana/loki-client-go/loki" "github.com/grafana/loki-client-go/pkg/labelutil" "github.com/prometheus/common/model" @@ -20,100 +21,104 @@ import ( "github.com/tschaefer/conntrackd/internal/logger" ) +const ( + readyPath = "/ready" + pushPath = "/loki/api/v1/push" +) + type Loki struct { Enable bool Address string Labels []string } -const ( - readyPath = "/ready" - pushPath = "/loki/api/v1/push" -) +var LokiProtocols = []string{"http", "https"} -func (l *Loki) isReady() error { - uri, err := url.Parse(l.Address) - if err != nil { - return err - } - uri.Path = uri.Path + readyPath - - response, err := http.Get(uri.String()) +func (l *Loki) TargetLoki(options *slog.HandlerOptions) (slog.Handler, error) { + url, err := url.Parse(l.Address) if err != nil { - return err + return nil, err } - defer func() { - _ = response.Body.Close() - }() - if response.StatusCode != http.StatusOK { - return errors.New(response.Status) + if err := l.isReady(*url); err != nil { + return nil, err } - return nil -} - -func (l *Loki) TargetLoki(options *slog.HandlerOptions) (slog.Handler, error) { - slog.Debug("Initializing Grafana Loki sink.", "data", l) - - if err := l.isReady(); err != nil { + url.Path = url.Path + pushPath + config, err := loki.NewDefaultConfig(url.String()) + if err != nil { return nil, err } - uri, err := url.Parse(l.Address) + hostname, err := os.Hostname() if err != nil { - return nil, err + hostname = "unknown" } - uri.Path = uri.Path + pushPath + config.ExternalLabels = l.setLabels(hostname) - config, err := loki.NewDefaultConfig(uri.String()) + klogger := l.createLogger() + client, err := loki.NewWithLogger(config, klogger) if err != nil { return nil, err } - hostname, err := os.Hostname() + + o := &slogloki.Option{ + Client: client, + Level: options.Level, + } + return o.NewLokiHandler(), nil +} + +func (l *Loki) isReady(url url.URL) error { + url.Path = url.Path + readyPath + + response, err := http.Get(url.String()) if err != nil { - hostname = "unknown" + return err + } + defer func() { + _ = response.Body.Close() + }() + + if response.StatusCode != http.StatusOK { + return errors.New(response.Status) } - config.ExternalLabels = labelutil.LabelSet{ + return nil +} + +func (l *Loki) setLabels(hostname string) labelutil.LabelSet { + labels := labelutil.LabelSet{ LabelSet: model.LabelSet{ model.LabelName("service_name"): model.LabelValue("conntrackd"), model.LabelName("host"): model.LabelValue(hostname), }, } - if len(l.Labels) > 0 { - for _, label := range l.Labels { - if !strings.Contains(label, "=") { - continue - } - parts := strings.SplitN(label, "=", 2) - key := parts[0] - value := parts[1] - config.ExternalLabels.LabelSet[model.LabelName(key)] = model.LabelValue(value) - } + if len(l.Labels) == 0 { + return labels } - sw := klog.NewSyncWriter(os.Stderr) - var klogger klog.Logger - switch logger.Format() { - case "json": - klogger = klog.NewJSONLogger(sw) - case "text": - fallthrough - default: - klogger = klog.NewLogfmtLogger(sw) + for _, label := range l.Labels { + if !strings.Contains(label, "=") { + continue + } + parts := strings.SplitN(label, "=", 2) + key := parts[0] + value := parts[1] + labels.LabelSet[model.LabelName(key)] = model.LabelValue(value) } - klogger = klog.With(klogger, "time", klog.DefaultTimestamp, "sink", "loki") - client, err := loki.NewWithLogger(config, klogger) - if err != nil { - return nil, err - } + return labels +} - o := &slogloki.Option{ - Client: client, - Level: options.Level, - } - return o.NewLokiHandler(), nil +func (l *Loki) createLogger() kitlog.Logger { + level := logger.Level().String() + klevel := kitlevel.ParseDefault(level, kitlevel.InfoValue()) + + klogger := kitlog.NewJSONLogger(kitlog.NewSyncWriter(os.Stderr)) + klogger = kitlevel.NewFilter(klogger, kitlevel.Allow(klevel)) + klogger = kitlog.With(klogger, "time", kitlog.DefaultTimestamp, "sink", "loki") + + return klogger } diff --git a/internal/sink/sink.go b/internal/sink/sink.go index 66c2efb..85190c5 100644 --- a/internal/sink/sink.go +++ b/internal/sink/sink.go @@ -6,12 +6,22 @@ package sink import ( "errors" + "fmt" "log/slog" + "os" slogmulti "github.com/samber/slog-multi" ) +const ( + ExitOnWarningEnv string = "CONNTRACKD_SINK_EXIT_ON_WARNING" +) + type Sink struct { + Logger *slog.Logger +} + +type Config struct { Journal Journal Syslog Syslog Loki Loki @@ -20,53 +30,48 @@ type Sink struct { type SinkTarget func(*slog.HandlerOptions) (slog.Handler, error) -func (s *Sink) Initialize() (*slog.Logger, error) { - slog.Debug("Initializing sink targets.", "data", s) - +func NewSink(config *Config) (*Sink, error) { options := &slog.HandlerOptions{ Level: slog.LevelInfo, } - var handlers []slog.Handler - if s.Journal.Enable { - handler, err := s.Journal.TargetJournal(options) - if err != nil { - slog.Warn("Failed to initialize journal sink", "error", err) - } else { - handlers = append(handlers, handler) - } + exitOnWarning := false + envExitOnWarning, ok := os.LookupEnv(ExitOnWarningEnv) + if ok && envExitOnWarning == "1" || envExitOnWarning == "true" { + exitOnWarning = true } - if s.Syslog.Enable { - handler, err := s.Syslog.TargetSyslog(options) - if err != nil { - slog.Warn("Failed to initialize syslog sink", "error", err) - } else { - handlers = append(handlers, handler) - } - } + var handlers []slog.Handler - if s.Loki.Enable { - handler, err := s.Loki.TargetLoki(options) - if err != nil { - slog.Warn("Failed to initialize loki sink", "error", err) - } else { - handlers = append(handlers, handler) - } + targets := []struct { + name string + enabled bool + init SinkTarget + }{ + {"journal", config.Journal.Enable, config.Journal.TargetJournal}, + {"syslog", config.Syslog.Enable, config.Syslog.TargetSyslog}, + {"loki", config.Loki.Enable, config.Loki.TargetLoki}, + {"stream", config.Stream.Enable, config.Stream.TargetStream}, } - if s.Stream.Enable { - handler, err := s.Stream.TargetStream(options) + for _, t := range targets { + if !t.enabled { + continue + } + handler, err := t.init(options) if err != nil { - slog.Warn("Failed to initialize stream sink", "error", err) - } else { - handlers = append(handlers, handler) + fmt.Fprintf(os.Stderr, "Warning: Failed to initialize sink %q: %v\n", t.name, err) + if exitOnWarning { + os.Exit(1) + } + continue } + handlers = append(handlers, handler) } if len(handlers) == 0 { return nil, errors.New("no target sink available") } - return slog.New(slogmulti.Fanout(handlers...)), nil + return &Sink{Logger: slog.New(slogmulti.Fanout(handlers...))}, nil } diff --git a/internal/sink/sink_test.go b/internal/sink/sink_test.go new file mode 100644 index 0000000..047247a --- /dev/null +++ b/internal/sink/sink_test.go @@ -0,0 +1,76 @@ +/* +Copyright (c) 2025 Tobias Schäfer. All rights reserved. +Licensed under the MIT License, see LICENSE file in the project root for details. +*/ +package sink + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func capture(f func()) string { + originalStderr := os.Stderr + + r, w, _ := os.Pipe() + os.Stderr = w + + f() + + _ = w.Close() + os.Stderr = originalStderr + + var buf = make([]byte, 5096) + n, _ := r.Read(buf) + return string(buf[:n]) +} + +func newReturnsError_NoTargetsEnabled(t *testing.T) { + config := &Config{ + Journal: Journal{Enable: false}, + Syslog: Syslog{Enable: false}, + Loki: Loki{Enable: false}, + Stream: Stream{Enable: false}, + } + + sink, err := NewSink(config) + + assert.Nil(t, sink) + assert.NotNil(t, err) + assert.EqualError(t, err, "no target sink available") +} + +func newReturnsSink_TargetsEnabled(t *testing.T) { + config := &Config{ + Journal: Journal{Enable: false}, + Syslog: Syslog{Enable: false}, + Loki: Loki{Enable: false}, + Stream: Stream{Enable: true, Writer: "discard"}, + } + + sink, err := NewSink(config) + assert.NotNil(t, sink) + assert.Nil(t, err) + assert.IsType(t, &Sink{}, sink) +} + +func newPrintsWarning_TargetInitFails(t *testing.T) { + config := &Config{ + Journal: Journal{Enable: false}, + Syslog: Syslog{Enable: false}, + Loki: Loki{Enable: true, Address: "http://invalid-address"}, + Stream: Stream{Enable: true, Writer: "discard"}, + } + warning := capture(func() { + _, _ = NewSink(config) + }) + assert.Contains(t, warning, "Warning: Failed to initialize sink \"loki\"") +} + +func TestSink(t *testing.T) { + t.Run("NewSink returns error if no targets are enabled", newReturnsError_NoTargetsEnabled) + t.Run("NewSink returns sink if targets enabled", newReturnsSink_TargetsEnabled) + t.Run("NewSink prints warning if target init fails", newPrintsWarning_TargetInitFails) +} diff --git a/internal/sink/stream.go b/internal/sink/stream.go index b953937..155a1ab 100644 --- a/internal/sink/stream.go +++ b/internal/sink/stream.go @@ -16,16 +16,17 @@ type Stream struct { Writer string } +var StreamWriters = []string{"stdout", "stderr", "discard"} + func (s *Stream) TargetStream(options *slog.HandlerOptions) (slog.Handler, error) { - slog.Debug("Initializing stream sink.") + writer := map[string]io.Writer{ + "stdout": os.Stdout, + "stderr": os.Stderr, + "discard": io.Discard, + } - switch s.Writer { - case "stdout": - return slog.NewJSONHandler(os.Stdout, options), nil - case "stderr": - return slog.NewJSONHandler(os.Stderr, options), nil - case "discard": - return slog.NewJSONHandler(io.Discard, options), nil + if w, ok := writer[s.Writer]; ok { + return slog.NewJSONHandler(w, options), nil } return nil, fmt.Errorf("invalid stream writer specified: %q", s.Writer) diff --git a/internal/sink/syslog.go b/internal/sink/syslog.go index d98bc10..267c103 100644 --- a/internal/sink/syslog.go +++ b/internal/sink/syslog.go @@ -18,18 +18,18 @@ type Syslog struct { Address string } -func (s *Syslog) TargetSyslog(options *slog.HandlerOptions) (slog.Handler, error) { - slog.Debug("Initializing syslog sink", "data", s) +var SyslogProtocols = []string{"udp", "tcp", "unix", "unixgram", "unixpacket"} - uri, err := url.Parse(s.Address) +func (s *Syslog) TargetSyslog(options *slog.HandlerOptions) (slog.Handler, error) { + url, err := url.Parse(s.Address) if err != nil { return nil, err } - network := uri.Scheme - address := uri.Host + network := url.Scheme + address := url.Host if strings.HasPrefix(network, "unix") { - address = uri.Path + address = url.Path } writer, err := net.Dial(network, address)