Skip to content

Commit 9afe3ad

Browse files
committed
Implement active TCP candidate type
1 parent 148a905 commit 9afe3ad

File tree

4 files changed

+173
-32
lines changed

4 files changed

+173
-32
lines changed

agent.go

+41-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"fmt"
1111
"net"
12+
"strconv"
1213
"strings"
1314
"sync"
1415
"sync/atomic"
@@ -138,6 +139,7 @@ type Agent struct {
138139

139140
interfaceFilter func(string) bool
140141
ipFilter func(net.IP) bool
142+
ActiveTCP bool
141143
includeLoopback bool
142144

143145
insecureSkipVerify bool
@@ -312,6 +314,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
312314

313315
ipFilter: config.IPFilter,
314316

317+
ActiveTCP: config.ActiveTCP,
318+
315319
insecureSkipVerify: config.InsecureSkipVerify,
316320

317321
includeLoopback: config.IncludeLoopback,
@@ -578,6 +582,40 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
578582
}
579583

580584
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
585+
if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypePassive {
586+
addressToConnect := net.JoinHostPort(remote.Address(), strconv.Itoa(remote.Port()))
587+
588+
conn, err := net.Dial("tcp", addressToConnect)
589+
if err != nil {
590+
a.log.Errorf("Failed to dial TCP address %s: %v", addressToConnect, err)
591+
return nil
592+
}
593+
594+
packetConn := newTCPPacketConn(tcpPacketParams{
595+
ReadBuffer: 8,
596+
LocalAddr: conn.LocalAddr(),
597+
Logger: a.log,
598+
})
599+
600+
if err = packetConn.AddConn(conn, nil); err != nil {
601+
a.log.Errorf("Failed to add TCP connection: %v", err)
602+
return nil
603+
}
604+
605+
localAddress, ok := conn.LocalAddr().(*net.TCPAddr)
606+
if !ok {
607+
a.log.Errorf("Failed to cast local address to TCP address")
608+
return nil
609+
}
610+
611+
localCandidateHost, ok := local.(*CandidateHost)
612+
if !ok {
613+
a.log.Errorf("Failed to cast local candidate to CandidateHost")
614+
return nil
615+
}
616+
localCandidateHost.port = localAddress.Port // this causes a data race with candidateBase.Port()
617+
local.start(a, packetConn, a.startedCh)
618+
}
581619
p := newCandidatePair(local, remote, a.isControlling)
582620
a.checklist = append(a.checklist, p)
583621
return p
@@ -754,7 +792,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
754792
}
755793
}
756794

757-
c.start(a, candidateConn, a.startedCh)
795+
if c.TCPType() != TCPTypeActive {
796+
c.start(a, candidateConn, a.startedCh)
797+
}
758798

759799
set = append(set, c)
760800
a.localCandidates[c.NetworkType()] = set

agent_active_tcp_test.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
//go:build !js
5+
// +build !js
6+
7+
package ice
8+
9+
import (
10+
"net"
11+
"testing"
12+
13+
"github.com/pion/logging"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestAgentActiveTCP(t *testing.T) {
18+
r := require.New(t)
19+
20+
const port = 7686
21+
22+
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
23+
IP: net.IPv4(127, 0, 0, 1),
24+
Port: port,
25+
})
26+
r.NoError(err)
27+
defer func() {
28+
_ = listener.Close()
29+
}()
30+
31+
loggerFactory := logging.NewDefaultLoggerFactory()
32+
loggerFactory.DefaultLogLevel.Set(logging.LogLevelTrace)
33+
34+
tcpMux := NewTCPMuxDefault(TCPMuxParams{
35+
Listener: listener,
36+
Logger: loggerFactory.NewLogger("passive-ice"),
37+
ReadBufferSize: 20,
38+
})
39+
40+
defer func() {
41+
_ = tcpMux.Close()
42+
}()
43+
44+
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
45+
46+
passiveAgent, err := NewAgent(&AgentConfig{
47+
TCPMux: tcpMux,
48+
CandidateTypes: []CandidateType{CandidateTypeHost},
49+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
50+
LoggerFactory: loggerFactory,
51+
ActiveTCP: false,
52+
IncludeLoopback: true,
53+
})
54+
r.NoError(err)
55+
r.NotNil(passiveAgent)
56+
57+
activeAgent, err := NewAgent(&AgentConfig{
58+
CandidateTypes: []CandidateType{CandidateTypeHost},
59+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
60+
LoggerFactory: loggerFactory,
61+
ActiveTCP: true,
62+
})
63+
r.NoError(err)
64+
r.NotNil(activeAgent)
65+
66+
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
67+
r.NotNil(passiveAgentConn)
68+
r.NotNil(activeAgenConn)
69+
70+
pair := passiveAgent.getSelectedPair()
71+
r.NotNil(pair)
72+
r.Equal(port, pair.Local.Port())
73+
74+
data := []byte("hello world")
75+
_, err = passiveAgentConn.Write(data)
76+
r.NoError(err)
77+
78+
buffer := make([]byte, 1024)
79+
n, err := activeAgenConn.Read(buffer)
80+
r.NoError(err)
81+
r.Equal(data, buffer[:n])
82+
83+
data2 := []byte("hello world 2")
84+
_, err = activeAgenConn.Write(data2)
85+
r.NoError(err)
86+
87+
n, err = passiveAgentConn.Read(buffer)
88+
r.NoError(err)
89+
r.Equal(data2, buffer[:n])
90+
91+
r.NoError(activeAgenConn.Close())
92+
r.NoError(passiveAgentConn.Close())
93+
r.NoError(tcpMux.Close())
94+
}

agent_config.go

+2
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ type AgentConfig struct {
144144
// the ips which are used to gather ICE candidates.
145145
IPFilter func(net.IP) bool
146146

147+
ActiveTCP bool
148+
147149
// InsecureSkipVerify controls if self-signed certificates are accepted when connecting
148150
// to TURN servers via TLS or DTLS
149151
InsecureSkipVerify bool

gather.go

+36-31
Original file line numberDiff line numberDiff line change
@@ -165,44 +165,49 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
165165

166166
switch network {
167167
case tcp:
168-
if a.tcpMux == nil {
169-
continue
170-
}
171-
172-
// Handle ICE TCP passive mode
173-
var muxConns []net.PacketConn
174-
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
175-
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
176-
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
177-
if err != nil {
178-
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
179-
continue
180-
}
168+
if a.ActiveTCP {
169+
conns = append(conns, connAndPort{nil, 0})
170+
tcpType = TCPTypeActive
181171
} else {
182-
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
183-
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
184-
if err != nil {
185-
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
172+
// Handle ICE TCP passive mode
173+
if a.tcpMux == nil {
186174
continue
187175
}
188-
muxConns = []net.PacketConn{conn}
189-
}
190176

191-
// Extract the port for each PacketConn we got.
192-
for _, conn := range muxConns {
193-
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
194-
conns = append(conns, connAndPort{conn, tcpConn.Port})
177+
var muxConns []net.PacketConn
178+
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
179+
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
180+
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
181+
if err != nil {
182+
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
183+
continue
184+
}
195185
} else {
196-
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
186+
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
187+
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
188+
if err != nil {
189+
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
190+
continue
191+
}
192+
muxConns = []net.PacketConn{conn}
197193
}
194+
195+
// Extract the port for each PacketConn we got.
196+
for _, conn := range muxConns {
197+
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
198+
conns = append(conns, connAndPort{conn, tcpConn.Port})
199+
} else {
200+
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
201+
}
202+
}
203+
if len(conns) == 0 {
204+
// Didn't succeed with any, try the next network.
205+
continue
206+
}
207+
tcpType = TCPTypePassive
208+
// Is there a way to verify that the listen address is even
209+
// accessible from the current interface.
198210
}
199-
if len(conns) == 0 {
200-
// Didn't succeed with any, try the next network.
201-
continue
202-
}
203-
tcpType = TCPTypePassive
204-
// Is there a way to verify that the listen address is even
205-
// accessible from the current interface.
206211
case udp:
207212
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: ip, Port: 0})
208213
if err != nil {

0 commit comments

Comments
 (0)