Skip to content

Commit 153572c

Browse files
committed
Implement active TCP candidate type
1 parent 9b4e7d9 commit 153572c

7 files changed

+189
-42
lines changed

agent.go

+47-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"fmt"
1111
"net"
12+
"strconv"
1213
"strings"
1314
"sync"
1415
"sync/atomic"
@@ -579,6 +580,44 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
579580
}
580581

581582
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
583+
if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypeActive {
584+
return nil
585+
}
586+
587+
if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypePassive {
588+
addressToConnect := net.JoinHostPort(remote.Address(), strconv.Itoa(remote.Port()))
589+
590+
conn, err := net.Dial("tcp", addressToConnect)
591+
if err != nil {
592+
a.log.Errorf("Failed to dial TCP address %s: %v", addressToConnect, err)
593+
return nil
594+
}
595+
596+
packetConn := newTCPPacketConn(tcpPacketParams{
597+
ReadBuffer: tcpReadBufferSize,
598+
LocalAddr: conn.LocalAddr(),
599+
Logger: a.log,
600+
})
601+
602+
if err = packetConn.AddConn(conn, nil); err != nil {
603+
a.log.Errorf("Failed to add TCP connection: %v", err)
604+
return nil
605+
}
606+
607+
localAddress, ok := conn.LocalAddr().(*net.TCPAddr)
608+
if !ok {
609+
a.log.Errorf("Failed to cast local address to TCP address")
610+
return nil
611+
}
612+
613+
localCandidateHost, ok := local.(*CandidateHost)
614+
if !ok {
615+
a.log.Errorf("Failed to cast local candidate to CandidateHost")
616+
return nil
617+
}
618+
localCandidateHost.port = localAddress.Port // this causes a data race with candidateBase.Port()
619+
local.start(a, packetConn, a.startedCh)
620+
}
582621
p := newCandidatePair(local, remote, a.isControlling)
583622
a.checklist = append(a.checklist, p)
584623
return p
@@ -755,7 +794,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
755794
}
756795
}
757796

758-
c.start(a, candidateConn, a.startedCh)
797+
if c.TCPType() != TCPTypeActive {
798+
c.start(a, candidateConn, a.startedCh)
799+
}
759800

760801
set = append(set, c)
761802
a.localCandidates[c.NetworkType()] = set
@@ -1023,13 +1064,18 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
10231064
return
10241065
}
10251066

1067+
tcpType := TCPTypeUnspecified
1068+
if networkType == NetworkTypeTCP4 && local.NetworkType() == NetworkTypeTCP4 && local.TCPType() == TCPTypePassive {
1069+
tcpType = TCPTypeActive
1070+
}
10261071
prflxCandidateConfig := CandidatePeerReflexiveConfig{
10271072
Network: networkType.String(),
10281073
Address: ip.String(),
10291074
Port: port,
10301075
Component: local.Component(),
10311076
RelAddr: "",
10321077
RelPort: 0,
1078+
TCPType: tcpType,
10331079
}
10341080

10351081
prflxCandidate, err := NewCandidatePeerReflexive(&prflxCandidateConfig)

agent_active_tcp_test.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
//go:build !js
5+
// +build !js
6+
7+
package ice
8+
9+
import (
10+
"net"
11+
"testing"
12+
13+
"github.com/pion/logging"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestAgentActiveTCP(t *testing.T) {
18+
r := require.New(t)
19+
20+
const port = 7686
21+
22+
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
23+
IP: net.IPv4(127, 0, 0, 1),
24+
Port: port,
25+
})
26+
r.NoError(err)
27+
defer func() {
28+
_ = listener.Close()
29+
}()
30+
31+
loggerFactory := logging.NewDefaultLoggerFactory()
32+
loggerFactory.DefaultLogLevel.Set(logging.LogLevelTrace)
33+
34+
tcpMux := NewTCPMuxDefault(TCPMuxParams{
35+
Listener: listener,
36+
Logger: loggerFactory.NewLogger("passive-ice-tcp-mux"),
37+
ReadBufferSize: 20,
38+
})
39+
40+
defer func() {
41+
_ = tcpMux.Close()
42+
}()
43+
44+
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
45+
46+
passiveAgent, err := NewAgent(&AgentConfig{
47+
TCPMux: tcpMux,
48+
CandidateTypes: []CandidateType{CandidateTypeHost},
49+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
50+
LoggerFactory: loggerFactory,
51+
IncludeLoopback: true,
52+
})
53+
r.NoError(err)
54+
r.NotNil(passiveAgent)
55+
56+
activeAgent, err := NewAgent(&AgentConfig{
57+
CandidateTypes: []CandidateType{CandidateTypeHost},
58+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
59+
LoggerFactory: loggerFactory,
60+
})
61+
r.NoError(err)
62+
r.NotNil(activeAgent)
63+
64+
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
65+
r.NotNil(passiveAgentConn)
66+
r.NotNil(activeAgenConn)
67+
68+
pair := passiveAgent.getSelectedPair()
69+
r.NotNil(pair)
70+
r.Equal(port, pair.Local.Port())
71+
72+
data := []byte("hello world")
73+
_, err = passiveAgentConn.Write(data)
74+
r.NoError(err)
75+
76+
buffer := make([]byte, 1024)
77+
n, err := activeAgenConn.Read(buffer)
78+
r.NoError(err)
79+
r.Equal(data, buffer[:n])
80+
81+
data2 := []byte("hello world 2")
82+
_, err = activeAgenConn.Write(data2)
83+
r.NoError(err)
84+
85+
n, err = passiveAgentConn.Read(buffer)
86+
r.NoError(err)
87+
r.Equal(data2, buffer[:n])
88+
89+
r.NoError(activeAgenConn.Close())
90+
r.NoError(passiveAgentConn.Close())
91+
}

agent_config.go

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ const (
4646

4747
// maxBindingRequestTimeout is the wait time before binding requests can be deleted
4848
maxBindingRequestTimeout = 4000 * time.Millisecond
49+
50+
// tcpReadBufferSize is the size of the read buffer of tcpPacketConn used by active tcp candidate
51+
tcpReadBufferSize = 8
4952
)
5053

5154
func defaultCandidateTypes() []CandidateType {

candidate_base.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ func UnmarshalCandidate(raw string) (Candidate, error) {
533533
case "srflx":
534534
return NewCandidateServerReflexive(&CandidateServerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort})
535535
case "prflx":
536-
return NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort})
536+
return NewCandidatePeerReflexive(&CandidatePeerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort, tcpType})
537537
case "relay":
538538
return NewCandidateRelay(&CandidateRelayConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort, "", nil})
539539
default:

candidate_peer_reflexive.go

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type CandidatePeerReflexiveConfig struct {
2424
Foundation string
2525
RelAddr string
2626
RelPort int
27+
TCPType TCPType
2728
}
2829

2930
// NewCandidatePeerReflexive creates a new peer reflective candidate
@@ -49,6 +50,7 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate
4950
id: candidateID,
5051
networkType: networkType,
5152
candidateType: CandidateTypePeerReflexive,
53+
tcpType: config.TCPType,
5254
address: config.Address,
5355
port: config.Port,
5456
resolvedAddr: createAddr(networkType, ip, config.Port),

gather.go

+44-39
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ const (
2525
stunGatherTimeout = time.Second * 5
2626
)
2727

28+
type connAndPort struct {
29+
conn net.PacketConn
30+
port int
31+
tcpType TCPType
32+
}
33+
2834
// Close a net.Conn and log if we have a failure
2935
func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ...interface{}) {
3036
if c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
@@ -155,53 +161,21 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
155161
}
156162

157163
for network := range networks {
158-
type connAndPort struct {
159-
conn net.PacketConn
160-
port int
161-
}
162-
var (
163-
conns []connAndPort
164-
tcpType TCPType
165-
)
164+
var conns []connAndPort
166165

167166
switch network {
168167
case tcp:
169-
if a.tcpMux == nil {
170-
continue
171-
}
168+
// Handle ICE TCP active mode
169+
conns = append(conns, connAndPort{nil, 0, TCPTypeActive})
172170

173171
// Handle ICE TCP passive mode
174-
var muxConns []net.PacketConn
175-
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
176-
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
177-
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
178-
if err != nil {
179-
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
180-
continue
181-
}
182-
} else {
183-
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
184-
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
185-
if err != nil {
186-
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
187-
continue
188-
}
189-
muxConns = []net.PacketConn{conn}
190-
}
191-
192-
// Extract the port for each PacketConn we got.
193-
for _, conn := range muxConns {
194-
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
195-
conns = append(conns, connAndPort{conn, tcpConn.Port})
196-
} else {
197-
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
198-
}
172+
if a.tcpMux != nil {
173+
conns = a.getTCPMuxConnections(mappedIP, ip, network, conns)
199174
}
200175
if len(conns) == 0 {
201176
// Didn't succeed with any, try the next network.
202177
continue
203178
}
204-
tcpType = TCPTypePassive
205179
// Is there a way to verify that the listen address is even
206180
// accessible from the current interface.
207181
case udp:
@@ -212,7 +186,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
212186
}
213187

214188
if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok {
215-
conns = append(conns, connAndPort{conn, udpConn.Port})
189+
conns = append(conns, connAndPort{conn, udpConn.Port, TCPTypeUnspecified})
216190
} else {
217191
a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag)
218192
continue
@@ -225,7 +199,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
225199
Address: address,
226200
Port: connAndPort.port,
227201
Component: ComponentRTP,
228-
TCPType: tcpType,
202+
TCPType: connAndPort.tcpType,
229203
}
230204

231205
c, err := NewCandidateHost(&hostConfig)
@@ -252,6 +226,37 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
252226
}
253227
}
254228

229+
func (a *Agent) getTCPMuxConnections(mappedIP net.IP, ip net.IP, network string, conns []connAndPort) []connAndPort {
230+
var muxConns []net.PacketConn
231+
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
232+
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
233+
var err error
234+
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
235+
if err != nil {
236+
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
237+
return conns
238+
}
239+
} else {
240+
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
241+
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
242+
if err != nil {
243+
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
244+
return conns
245+
}
246+
muxConns = []net.PacketConn{conn}
247+
}
248+
249+
// Extract the port for each PacketConn we got.
250+
for _, conn := range muxConns {
251+
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
252+
conns = append(conns, connAndPort{conn, tcpConn.Port, TCPTypePassive})
253+
} else {
254+
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
255+
}
256+
}
257+
return conns
258+
}
259+
255260
func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit
256261
if a.udpMux == nil {
257262
return errUDPMuxDisabled

gather_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ func TestMultiUDPMuxUsage(t *testing.T) {
675675
}
676676

677677
a, err := NewAgent(&AgentConfig{
678-
NetworkTypes: supportedNetworkTypes(),
678+
NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6},
679679
CandidateTypes: []CandidateType{CandidateTypeHost},
680680
UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...),
681681
})

0 commit comments

Comments
 (0)