@@ -12,7 +12,7 @@ use tokio::{
12
12
sync:: mpsc,
13
13
task:: { AbortHandle , JoinSet } ,
14
14
} ;
15
- use tracing:: { debug , error } ;
15
+ use tracing:: { error , warn } ;
16
16
17
17
use crate :: {
18
18
endpoint:: { get_remote_node_id, Connection } ,
@@ -22,17 +22,75 @@ use crate::{
22
22
const DUPLICATE_REASON : & [ u8 ] = b"abort_duplicate" ;
23
23
const DUPLICATE_CODE : u32 = 123 ;
24
24
25
- /// A connection manager.
25
+ /// Whether we accepted the connection or initiated it.
26
+ #[ derive( Debug , Clone , Copy , Eq , PartialEq ) ]
27
+ pub enum ConnDirection {
28
+ /// We accepted this connection from the other peer.
29
+ Accept ,
30
+ /// We initiated this connection by connecting to the other peer.
31
+ Dial ,
32
+ }
33
+
34
+ /// A new connection as emitted from [`ConnManager`].
35
+ #[ derive( Debug , Clone , derive_more:: Deref ) ]
36
+ pub struct ConnInfo {
37
+ /// The QUIC connection.
38
+ #[ deref]
39
+ pub conn : Connection ,
40
+ /// The node id of the other peer.
41
+ pub node_id : NodeId ,
42
+ /// Whether we accepted or initiated this connection.
43
+ pub direction : ConnDirection ,
44
+ }
45
+
46
+ /// A sender to push new connections into a [`ConnManager`].
47
+ ///
48
+ /// See [`ConnManager::accept_sender`] for details.
49
+ #[ derive( Debug , Clone ) ]
50
+ pub struct HandleConnectionSender {
51
+ tx : mpsc:: Sender < Connection > ,
52
+ }
53
+
54
+ impl HandleConnectionSender {
55
+ /// Send a new connection to the [`ConnManager`].
56
+ pub async fn send ( & self , conn : Connection ) -> anyhow:: Result < ( ) > {
57
+ self . tx . send ( conn) . await ?;
58
+ Ok ( ( ) )
59
+ }
60
+ }
61
+
62
+ /// The error returned from [`ConnManager::poll_next`].
63
+ #[ derive( thiserror:: Error , Debug ) ]
64
+ #[ error( "Connection to node {} direction {:?} failed: {:?}" , self . node_id, self . direction, self . reason) ]
65
+ pub struct ConnectError {
66
+ /// The node id of the peer to which the connection failed.
67
+ pub node_id : NodeId ,
68
+ /// The direction of the connection.
69
+ pub direction : ConnDirection ,
70
+ /// The actual error that ocurred.
71
+ #[ source]
72
+ pub reason : anyhow:: Error ,
73
+ }
74
+
75
+ /// A connection manager that ensures that only a single connection between two peers prevails.
76
+ ///
77
+ /// You can start to dial peers by calling [`ConnManager::dial`]. Note that the method only takes a
78
+ /// node id; if you have more addressing info, add it to the endpoint directly with
79
+ /// [`Endpoint::add_node_addr`] before calling `dial`;
26
80
///
27
81
/// The [`ConnManager`] does not accept connections from the endpoint by itself. Instead, you
28
82
/// should run an accept loop yourself, and push connections with a matching ALPN into the manager
29
- /// with [`ConnManager::accept`]. The connection will be dropped if we already have a connection to
30
- /// that node. If we are currently dialing the node, the connection will only be accepted if the
31
- /// peer's node id sorts lower than our node id. Through this, it is ensured that we will not get
32
- /// double connections with a node if both we and them dial each other at the same time.
83
+ /// with [`ConnManager::handle_connection`] or [`ConnManager::handle_connection_sender`].
33
84
///
34
- /// The [`ConnManager`] implements [`Stream`]. It will yield new connections, both from dialing and
35
- /// accepting.
85
+ /// The [`ConnManager`] is a [`Stream`] that yields all connections from both accepting and dialing.
86
+ ///
87
+ /// Before accepting incoming connections, the [`ConnManager`] makes sure that, if we are dialing
88
+ /// the same node, only one of the connections will prevail. In this case, the accepting side
89
+ /// rejects the connection if the peer's node id sorts higher than their own node id.
90
+ ///
91
+ /// To make this reliable even if the dials happen exactly at the same time, a single unidirectional
92
+ /// stream is opened, on which a single byte is sent. This additional rountrip ensures that no
93
+ /// double connections can prevail.
36
94
#[ derive( Debug ) ]
37
95
pub struct ConnManager {
38
96
endpoint : Endpoint ,
@@ -80,44 +138,14 @@ impl ConnManager {
80
138
}
81
139
}
82
140
83
- fn spawn (
84
- & mut self ,
85
- node_id : NodeId ,
86
- direction : ConnDirection ,
87
- fut : impl Future < Output = Result < Connection , InitError > > + Send + ' static ,
88
- ) {
89
- let abort_handle = self . tasks . spawn ( fut. map ( move |res| ( node_id, res) ) ) ;
90
- let pending_state = PendingState {
91
- direction,
92
- abort_handle,
93
- } ;
94
- self . pending . insert ( node_id, pending_state) ;
95
- if let Some ( waker) = self . waker . take ( ) {
96
- waker. wake ( ) ;
97
- }
98
- }
99
-
100
- /// Get a sender to push new connections towards the [`ConnManager`]
101
- ///
102
- /// This does not check the connection's ALPN, so you should make sure that the ALPN matches
103
- /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender.
104
- ///
105
- /// If we are currently dialing the node, the connection will be dropped if the peer's node id
106
- /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager
107
- /// stream.
108
- pub fn accept_sender ( & self ) -> AcceptSender {
109
- let tx = self . accept_tx . clone ( ) ;
110
- AcceptSender { tx }
111
- }
112
-
113
141
/// Accept a connection.
114
142
///
115
143
/// This does not check the connection's ALPN, so you should make sure that the ALPN matches
116
144
/// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender.
117
145
///
118
146
/// If we are currently dialing the node, the connection will be dropped if the peer's node id
119
147
/// sorty higher than our node id. Otherwise, the connection will be returned.
120
- pub fn accept ( & mut self , conn : quinn:: Connection ) -> anyhow:: Result < ( ) > {
148
+ pub fn handle_connection ( & mut self , conn : quinn:: Connection ) -> anyhow:: Result < ( ) > {
121
149
let node_id = get_remote_node_id ( & conn) ?;
122
150
// We are already connected: drop the connection, keep using the existing conn.
123
151
if self . is_connected ( & node_id) {
@@ -128,7 +156,7 @@ impl ConnManager {
128
156
// We are currently dialing the node, but the incoming conn "wins": accept and abort
129
157
// our dial.
130
158
Some ( state)
131
- if state. direction == ConnDirection :: Dial && node_id > self . our_node_id ( ) =>
159
+ if state. direction == ConnDirection :: Dial && node_id > self . endpoint . node_id ( ) =>
132
160
{
133
161
state. abort_handle . abort ( ) ;
134
162
true
@@ -147,6 +175,19 @@ impl ConnManager {
147
175
Ok ( ( ) )
148
176
}
149
177
178
+ /// Get a sender to push new connections towards the [`ConnManager`]
179
+ ///
180
+ /// This does not check the connection's ALPN, so you should make sure that the ALPN matches
181
+ /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender.
182
+ ///
183
+ /// If we are currently dialing the node, the connection will be dropped if the peer's node id
184
+ /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager
185
+ /// stream.
186
+ pub fn handle_connection_sender ( & self ) -> HandleConnectionSender {
187
+ let tx = self . accept_tx . clone ( ) ;
188
+ HandleConnectionSender { tx }
189
+ }
190
+
150
191
/// Remove the connection to a node.
151
192
///
152
193
/// Also aborts pending dials to the node, if existing.
@@ -174,45 +215,57 @@ impl ConnManager {
174
215
self . active . contains_key ( node_id)
175
216
}
176
217
177
- fn our_node_id ( & self ) -> NodeId {
178
- self . endpoint . node_id ( )
218
+ fn spawn (
219
+ & mut self ,
220
+ node_id : NodeId ,
221
+ direction : ConnDirection ,
222
+ fut : impl Future < Output = Result < Connection , InitError > > + Send + ' static ,
223
+ ) {
224
+ let abort_handle = self . tasks . spawn ( fut. map ( move |res| ( node_id, res) ) ) ;
225
+ let pending_state = PendingState {
226
+ direction,
227
+ abort_handle,
228
+ } ;
229
+ self . pending . insert ( node_id, pending_state) ;
230
+ if let Some ( waker) = self . waker . take ( ) {
231
+ waker. wake ( ) ;
232
+ }
179
233
}
180
234
}
181
235
182
236
impl Stream for ConnManager {
183
237
type Item = Result < ConnInfo , ConnectError > ;
184
238
185
239
fn poll_next ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
186
- tracing:: debug!( "poll_next in" ) ;
187
240
// Create new tasks for incoming connections.
188
241
while let Poll :: Ready ( Some ( conn) ) = Pin :: new ( & mut self . accept_rx ) . poll_recv ( cx) {
189
- // self.accept(conn)
190
- debug ! ( "accept - polled" ) ;
191
- if let Err ( error) = self . accept ( conn) {
192
- tracing:: warn!( ?error, "skipping invalid connection attempt" ) ;
242
+ if let Err ( error) = self . handle_connection ( conn) {
243
+ warn ! ( ?error, "skipping invalid connection attempt" ) ;
193
244
}
194
245
}
195
246
196
247
// Poll for finished tasks,
197
248
loop {
198
249
let join_res = ready ! ( self . tasks. poll_join_next( cx) ) ;
199
- debug ! ( ?join_res, "join res" ) ;
200
250
let ( node_id, res) = match join_res {
201
251
None => {
202
252
self . waker = Some ( cx. waker ( ) . to_owned ( ) ) ;
203
253
return Poll :: Pending ;
204
254
}
205
255
Some ( Err ( err) ) if err. is_cancelled ( ) => continue ,
206
- // we are merely forwarding a panic here, which should never occur.
207
- Some ( Err ( err) ) => panic ! ( "connection manager task paniced with {err:?}" ) ,
256
+ Some ( Err ( err) ) => {
257
+ // TODO: unreachable?
258
+ warn ! ( "connection manager task paniced with {err:?}" ) ;
259
+ continue ;
260
+ }
208
261
Some ( Ok ( res) ) => res,
209
262
} ;
210
263
match res {
211
264
Err ( InitError :: IsDuplicate ) => continue ,
212
265
Err ( InitError :: Other ( reason) ) => {
213
266
let Some ( PendingState { direction, .. } ) = self . pending . remove ( & node_id) else {
214
267
// TODO: unreachable?
215
- tracing :: warn!( node_id=%node_id. fmt_short( ) , "missing pending state, dropping connection" ) ;
268
+ warn ! ( node_id=%node_id. fmt_short( ) , "missing pending state, dropping connection" ) ;
216
269
continue ;
217
270
} ;
218
271
let err = ConnectError {
@@ -225,7 +278,7 @@ impl Stream for ConnManager {
225
278
Ok ( conn) => {
226
279
let Some ( PendingState { direction, .. } ) = self . pending . remove ( & node_id) else {
227
280
// TODO: unreachable?
228
- tracing :: warn!( node_id=%node_id. fmt_short( ) , "missing pending state, dropping connection" ) ;
281
+ warn ! ( node_id=%node_id. fmt_short( ) , "missing pending state, dropping connection" ) ;
229
282
continue ;
230
283
} ;
231
284
let info = ConnInfo {
@@ -265,35 +318,6 @@ struct PendingState {
265
318
abort_handle : AbortHandle ,
266
319
}
267
320
268
- /// A sender to push new connections into a [`ConnManager`].
269
- ///
270
- /// See [`ConnManager::accept_sender`] for details.
271
- #[ derive( Debug , Clone ) ]
272
- pub struct AcceptSender {
273
- tx : mpsc:: Sender < Connection > ,
274
- }
275
-
276
- impl AcceptSender {
277
- /// Send a new connection to the [`ConnManager`].
278
- pub async fn send ( & self , conn : Connection ) -> anyhow:: Result < ( ) > {
279
- self . tx . send ( conn) . await ?;
280
- Ok ( ( ) )
281
- }
282
- }
283
-
284
- /// The error returned from [`ConnManager::poll_next`].
285
- #[ derive( thiserror:: Error , Debug ) ]
286
- #[ error( "Connection to node {} direction {:?} failed: {:?}" , self . node_id, self . direction, self . reason) ]
287
- pub struct ConnectError {
288
- /// The node id of the peer to which the connection failed.
289
- pub node_id : NodeId ,
290
- /// The direction of the connection.
291
- pub direction : ConnDirection ,
292
- /// The actual error that ocurred.
293
- #[ source]
294
- pub reason : anyhow:: Error ,
295
- }
296
-
297
321
#[ derive( Debug ) ]
298
322
enum InitError {
299
323
IsDuplicate ,
@@ -338,27 +362,6 @@ impl From<quinn::WriteError> for InitError {
338
362
}
339
363
}
340
364
341
- /// Whether we accepted the connection or initiated it.
342
- #[ derive( Debug , Clone , Copy , Eq , PartialEq ) ]
343
- pub enum ConnDirection {
344
- /// We accepted this connection from the other peer.
345
- Accept ,
346
- /// We initiated this connection by connecting to the other peer.
347
- Dial ,
348
- }
349
-
350
- /// A new connection as emitted from [`ConnManager`].
351
- #[ derive( Debug , Clone , derive_more:: Deref ) ]
352
- pub struct ConnInfo {
353
- /// The QUIC connection.
354
- #[ deref]
355
- pub conn : Connection ,
356
- /// The node id of the other peer.
357
- pub node_id : NodeId ,
358
- /// Whether we accepted or initiated this connection.
359
- pub direction : ConnDirection ,
360
- }
361
-
362
365
#[ cfg( test) ]
363
366
mod tests {
364
367
use std:: time:: Duration ;
@@ -368,11 +371,14 @@ mod tests {
368
371
369
372
use crate :: test_utils:: TestEndpointFactory ;
370
373
371
- use super :: { AcceptSender , ConnManager } ;
374
+ use super :: { ConnManager , HandleConnectionSender } ;
372
375
373
376
const TEST_ALPN : & [ u8 ] = b"test" ;
374
377
375
- async fn accept_loop ( ep : crate :: Endpoint , accept_sender : AcceptSender ) -> anyhow:: Result < ( ) > {
378
+ async fn accept_loop (
379
+ ep : crate :: Endpoint ,
380
+ accept_sender : HandleConnectionSender ,
381
+ ) -> anyhow:: Result < ( ) > {
376
382
while let Some ( conn) = ep. accept ( ) . await {
377
383
let conn = conn. await ?;
378
384
tracing:: debug!( me=%ep. node_id( ) . fmt_short( ) , "conn incoming" ) ;
@@ -398,8 +404,8 @@ mod tests {
398
404
let mut conn_manager1 = ConnManager :: new ( ep1. clone ( ) , TEST_ALPN ) ;
399
405
let mut conn_manager2 = ConnManager :: new ( ep2. clone ( ) , TEST_ALPN ) ;
400
406
401
- let accept1 = conn_manager1. accept_sender ( ) ;
402
- let accept2 = conn_manager2. accept_sender ( ) ;
407
+ let accept1 = conn_manager1. handle_connection_sender ( ) ;
408
+ let accept2 = conn_manager2. handle_connection_sender ( ) ;
403
409
let mut tasks = JoinSet :: new ( ) ;
404
410
tasks. spawn ( accept_loop ( ep1, accept1) ) ;
405
411
tasks. spawn ( accept_loop ( ep2, accept2) ) ;
0 commit comments