diff --git a/conn/bind_std.go b/conn/bind_std.go index f5c88160e..39145d6a4 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -46,9 +46,11 @@ type StdNetBind struct { blackhole4 bool blackhole6 bool + + extraFns []ControlFn } -func NewStdNetBind() Bind { +func NewStdNetBind(fns []ControlFn) Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { @@ -70,6 +72,8 @@ func NewStdNetBind() Bind { return &msgs }, }, + + extraFns: fns, } } @@ -119,8 +123,8 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) +func listenNet(network string, port int, fns []ControlFn) (*net.UDPConn, int, error) { + conn, err := listenConfig(fns).ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -156,13 +160,13 @@ again: var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn - v4conn, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port, s.extraFns) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - v6conn, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port, s.extraFns) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ @@ -338,7 +342,7 @@ func (e ErrUDPGSODisabled) Unwrap() error { return e.RetryErr } -func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (s *StdNetBind) Send(bufs [][]byte, services []Service, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 diff --git a/conn/bind_windows.go b/conn/bind_windows.go index a3b846067..620200bf1 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -81,7 +81,7 @@ func NewDefaultBind() Bind { return NewWinRingBind() } func NewWinRingBind() Bind { if !winrio.Initialize() { - return NewStdNetBind() + return NewStdNetBind([]ControlFn{}) } return new(WinRingBind) } @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, services []Service, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 46e20e68c..d81591fe5 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -107,7 +107,7 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { } } -func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { +func (c *ChannelBind) Send(bufs [][]byte, services []conn.Service, ep conn.Endpoint) error { for _, b := range bufs { select { case <-c.closeSignal: diff --git a/conn/conn.go b/conn/conn.go index 1304657e5..263d8250e 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -47,7 +47,7 @@ type Bind interface { // Send writes one or more packets in bufs to address ep. The length of // bufs must not exceed BatchSize(). - Send(bufs [][]byte, ep Endpoint) error + Send(bufs [][]byte, services []Service, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) diff --git a/conn/controlfns.go b/conn/controlfns.go index 27421bd26..61b7233f4 100644 --- a/conn/controlfns.go +++ b/conn/controlfns.go @@ -20,16 +20,16 @@ const socketBufferSize = 7 << 20 // controlFn is the callback function signature from net.ListenConfig.Control. // It is used to apply platform specific configuration to the socket prior to // bind. -type controlFn func(network, address string, c syscall.RawConn) error +type ControlFn func(network, address string, c syscall.RawConn) error // controlFns is a list of functions that are called from the listen config // that can apply socket options. -var controlFns = []controlFn{} +var controlFns = []ControlFn{} // listenConfig returns a net.ListenConfig that applies the controlFns to the // socket prior to bind. This is used to apply socket buffer sizing and packet // information OOB configuration for sticky sockets. -func listenConfig() *net.ListenConfig { +func listenConfig(extraFns []ControlFn) *net.ListenConfig { return &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { for _, fn := range controlFns { @@ -37,6 +37,12 @@ func listenConfig() *net.ListenConfig { return err } } + + for _, fn := range extraFns { + if err := fn(network, address, c); err != nil { + return err + } + } return nil }, } diff --git a/conn/default.go b/conn/default.go index 2ce157956..6fdcabd04 100644 --- a/conn/default.go +++ b/conn/default.go @@ -7,4 +7,4 @@ package conn -func NewDefaultBind() Bind { return NewStdNetBind() } +func NewDefaultBind() Bind { return NewStdNetBind(nil) } diff --git a/conn/service.go b/conn/service.go new file mode 100644 index 000000000..9f26d8152 --- /dev/null +++ b/conn/service.go @@ -0,0 +1,31 @@ +package conn + +// Service pass inner packet info to outer bind +type Service interface { + ID() uint64 +} + +// ServiceFn process inner packet and return service info and drop flag +type ServiceFn func(buff []byte) (service Service, shouldDrop bool) + +var serviceFns []ServiceFn + +// RegisterServiceFn register service function to identify packet +func RegisterServiceFn(fn ServiceFn) { + serviceFns = append(serviceFns, fn) +} + +// ExecuteServiceFns to process packet data +func ExecuteServiceFns(buff []byte) (service Service, shouldDrop bool) { + finalService := Service(nil) + for _, fn := range serviceFns { + service, shouldDrop = fn(buff) + if service != nil { + finalService = service + } + if shouldDrop { + return finalService, true + } + } + return finalService, false +} diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 1b1ee6833..9a5f6e372 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -213,7 +213,7 @@ func Test_getSrcFromControl(t *testing.T) { func Test_listenConfig(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp4", ":0") if err != nil { t.Fatal(err) } @@ -239,7 +239,7 @@ func Test_listenConfig(t *testing.T) { } }) t.Run("IPv6", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp6", ":0") if err != nil { t.Fatal(err) } diff --git a/device/device_test.go b/device/device_test.go index 0091e2052..72311b7b2 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -426,9 +426,11 @@ type fakeBindSized struct { func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, services []conn.Service, ep conn.Endpoint) error { + return nil +} func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } func (b *fakeBindSized) BatchSize() int { return b.size } diff --git a/device/peer.go b/device/peer.go index ebf25f941..501b2b628 100644 --- a/device/peer.go +++ b/device/peer.go @@ -113,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffers(buffers [][]byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte, services []conn.Service) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -133,7 +133,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { } peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, endpoint) + err := peer.device.net.bind.Send(buffers, services, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { diff --git a/device/send.go b/device/send.go index ff8f7da50..5f207858d 100644 --- a/device/send.go +++ b/device/send.go @@ -50,6 +50,8 @@ type QueueOutboundElement struct { nonce uint64 // nonce for encryption keypair *Keypair // keypair for encryption peer *Peer // related peer + service conn.Service // inner packet service + drop bool // service identifier result, should drop this packet } type QueueOutboundElementsContainer struct { @@ -130,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{packet}, []conn.Service{nil}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -167,7 +169,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{packet}, []conn.Service{nil}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -187,7 +189,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) packet := make([]byte, MessageCookieReplySize) _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, []conn.Service{nil}, initiatingElem.endpoint) return nil } @@ -445,6 +447,14 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { + // identify inner packet + service, shouldDrop := conn.ExecuteServiceFns(elem.packet) + if shouldDrop { + elem.drop = true + continue + } + elem.service = service + // populate header fields header := elem.buffer[:MessageTransportHeaderSize] @@ -483,9 +493,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.log.Verbosef("%v - Routine: sequential sender - started", peer) bufs := make([][]byte, 0, maxBatchSize) + services := make([]conn.Service, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] + services = services[:0] if elemsContainer == nil { return } @@ -507,16 +519,20 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { dataSent := false elemsContainer.Lock() for _, elem := range elemsContainer.elems { + if elem.drop { + continue + } if len(elem.packet) != MessageKeepaliveSize { dataSent = true } bufs = append(bufs, elem.packet) + services = append(services, elem.service) } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendBuffers(bufs, services) if dataSent { peer.timersDataSent() }