Skip to content

[client] Use native facilities to add routes on BSD and Windows #3862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
Expand Down
39 changes: 15 additions & 24 deletions client/firewall/iptables/manager_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package iptables

import (
"fmt"
"net"
"net/netip"
"testing"
"time"

Expand All @@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
Expand Down Expand Up @@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {

var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")

for _, r := range rule2 {
Expand All @@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {

t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")

err = manager.Close(nil)
Expand All @@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
Expand All @@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {

var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
Expand Down Expand Up @@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("10.20.0.1"),
Network: netip.MustParsePrefix("10.20.0.0/24"),
}
},
}
Expand All @@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {

require.NoError(t, err)

ip := net.ParseIP("10.20.0.100")
ip := netip.MustParseAddr("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")

require.NoError(t, err, "failed to add rule")
}
Expand Down
31 changes: 11 additions & 20 deletions client/firewall/nftables/manager_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package nftables
import (
"bytes"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
Expand All @@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
}
},
}
Expand Down Expand Up @@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second)
}()

ip := net.ParseIP("100.96.0.1")
ip := netip.MustParseAddr("100.96.0.1").Unmap()

testClient := &nftables.Conn{}

rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule")

err = manager.Flush()
Expand Down Expand Up @@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)

ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
Expand All @@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: add.AsSlice(),
Data: ip.AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Expand Down Expand Up @@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
IP: netip.MustParseAddr("100.96.0.1"),
Network: netip.MustParsePrefix("100.96.0.0/16"),
}
},
}
Expand All @@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()

ip := net.ParseIP("10.20.0.100")
ip := netip.MustParseAddr("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")

if i%100 == 0 {
Expand Down Expand Up @@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
})

ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")

_, err = manager.AddRouteFiltering(
Expand Down
12 changes: 5 additions & 7 deletions client/firewall/uspfilter/forwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
ip tcpip.Address
netstack bool
}

Expand Down Expand Up @@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err)
}

ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: iface.Address().Network.Bits(),
},
}

Expand Down Expand Up @@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
}

receiveWindow := defaultReceiveWindow
Expand Down Expand Up @@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
}

func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
Expand All @@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
}

func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {

if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
Expand Down
48 changes: 28 additions & 20 deletions client/firewall/uspfilter/localip.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}

func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()

if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])

if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}

index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit

ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
}
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
}

Expand All @@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}

func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}

func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
Expand All @@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}

if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}

if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
Expand All @@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}()

var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []netip.Addr

// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
Expand Down
Loading
Loading