@@ -3,7 +3,7 @@ use std::{
3
3
io,
4
4
net:: SocketAddr ,
5
5
pin:: Pin ,
6
- sync:: { atomic:: AtomicBool , RwLock , RwLockReadGuard , TryLockError } ,
6
+ sync:: { atomic:: AtomicBool , Arc , RwLock , RwLockReadGuard , TryLockError } ,
7
7
task:: { Context , Poll } ,
8
8
} ;
9
9
@@ -321,7 +321,7 @@ impl UdpSocket {
321
321
panic ! ( "lock poisoned: {:?}" , e) ;
322
322
}
323
323
Err ( TryLockError :: WouldBlock ) => {
324
- return Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "" ) ) ;
324
+ return Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "locked " ) ) ;
325
325
}
326
326
} ;
327
327
let ( socket, state) = guard. try_get_connected ( ) ?;
@@ -340,6 +340,50 @@ impl UdpSocket {
340
340
}
341
341
}
342
342
343
+ /// poll send a quinn based `Transmit`.
344
+ pub fn poll_send_quinn (
345
+ & self ,
346
+ cx : & mut Context ,
347
+ transmit : & Transmit < ' _ > ,
348
+ ) -> Poll < io:: Result < ( ) > > {
349
+ loop {
350
+ if let Err ( err) = self . maybe_rebind ( ) {
351
+ return Poll :: Ready ( Err ( err) ) ;
352
+ }
353
+
354
+ let guard = n0_future:: ready!( self . poll_read_socket( & self . send_waker, cx) ) ;
355
+ let ( socket, state) = guard. try_get_connected ( ) ?;
356
+
357
+ match socket. poll_send_ready ( cx) {
358
+ Poll :: Pending => {
359
+ self . send_waker . register ( cx. waker ( ) ) ;
360
+ return Poll :: Pending ;
361
+ }
362
+ Poll :: Ready ( Ok ( ( ) ) ) => {
363
+ let res =
364
+ socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) ) ;
365
+ if let Err ( err) = res {
366
+ if err. kind ( ) == io:: ErrorKind :: WouldBlock {
367
+ continue ;
368
+ }
369
+
370
+ if let Some ( err) = self . handle_write_error ( err) {
371
+ return Poll :: Ready ( Err ( err) ) ;
372
+ }
373
+ continue ;
374
+ }
375
+ return Poll :: Ready ( res) ;
376
+ }
377
+ Poll :: Ready ( Err ( err) ) => {
378
+ if let Some ( err) = self . handle_write_error ( err) {
379
+ return Poll :: Ready ( Err ( err) ) ;
380
+ }
381
+ continue ;
382
+ }
383
+ }
384
+ }
385
+ }
386
+
343
387
/// quinn based `poll_recv`
344
388
pub fn poll_recv_quinn (
345
389
& self ,
@@ -401,6 +445,11 @@ impl UdpSocket {
401
445
}
402
446
}
403
447
448
+ /// Creates a [`UdpSender`] sender.
449
+ pub fn create_sender ( self : Arc < Self > ) -> UdpSender {
450
+ UdpSender :: new ( self . clone ( ) )
451
+ }
452
+
404
453
/// Whether transmitted datagrams might get fragmented by the IP layer
405
454
///
406
455
/// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option.
@@ -806,6 +855,151 @@ impl Drop for UdpSocket {
806
855
}
807
856
}
808
857
858
+ pin_project_lite:: pin_project! {
859
+ pub struct UdpSender {
860
+ socket: Arc <UdpSocket >,
861
+ #[ pin]
862
+ fut: Option <Pin <Box <dyn Future <Output = io:: Result <( ) >> + Send + Sync + ' static >>>,
863
+ }
864
+ }
865
+
866
+ impl std:: fmt:: Debug for UdpSender {
867
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
868
+ f. write_str ( "UdpSender" )
869
+ }
870
+ }
871
+
872
+ impl UdpSender {
873
+ fn new ( socket : Arc < UdpSocket > ) -> Self {
874
+ Self { socket, fut : None }
875
+ }
876
+
877
+ /// Async sending
878
+ pub fn send < ' a , ' b > ( & self , transmit : & ' a quinn_udp:: Transmit < ' b > ) -> SendFutQuinn < ' a , ' b > {
879
+ SendFutQuinn {
880
+ socket : self . socket . clone ( ) ,
881
+ transmit,
882
+ }
883
+ }
884
+
885
+ /// Poll send
886
+ pub fn poll_send (
887
+ self : Pin < & mut Self > ,
888
+ transmit : & quinn_udp:: Transmit ,
889
+ cx : & mut Context ,
890
+ ) -> Poll < io:: Result < ( ) > > {
891
+ let mut this = self . project ( ) ;
892
+ loop {
893
+ if let Err ( err) = this. socket . maybe_rebind ( ) {
894
+ return Poll :: Ready ( Err ( err) ) ;
895
+ }
896
+
897
+ let guard =
898
+ n0_future:: ready!( this. socket. poll_read_socket( & this. socket. send_waker, cx) ) ;
899
+
900
+ if this. fut . is_none ( ) {
901
+ let socket = this. socket . clone ( ) ;
902
+ this. fut . set ( Some ( Box :: pin ( async move {
903
+ n0_future:: future:: poll_fn ( |cx| socket. poll_writable ( cx) ) . await
904
+ } ) ) ) ;
905
+ }
906
+ // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely
907
+ // obtain an `&mut Fut` after storing it in `this.fut` when `this` is already behind `Pin`,
908
+ // and if we didn't store it then we wouldn't be able to keep it alive between
909
+ // `poll_writable` calls.
910
+ let result = n0_future:: ready!( this. fut. as_mut( ) . as_pin_mut( ) . unwrap( ) . poll( cx) ) ;
911
+
912
+ // Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for
913
+ // a new `Future` to be created on the next call.
914
+ this. fut . set ( None ) ;
915
+
916
+ // If .writable() fails, propagate the error
917
+ result?;
918
+
919
+ let ( socket, state) = guard. try_get_connected ( ) ?;
920
+ let result = socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) ) ;
921
+
922
+ match result {
923
+ // We thought the socket was writable, but it wasn't, then retry so that either another
924
+ // `writable().await` call determines that the socket is indeed not writable and
925
+ // registers us for a wakeup, or the send succeeds if this really was just a
926
+ // transient failure.
927
+ Err ( ref e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => continue ,
928
+ // In all other cases, either propagate the error or we're Ok
929
+ _ => return Poll :: Ready ( result) ,
930
+ }
931
+ }
932
+ }
933
+
934
+ /// Best effort sending
935
+ pub fn try_send ( & self , transmit : & quinn_udp:: Transmit ) -> io:: Result < ( ) > {
936
+ self . socket . maybe_rebind ( ) ?;
937
+
938
+ match self . socket . socket . try_read ( ) {
939
+ Ok ( guard) => {
940
+ let ( socket, state) = guard. try_get_connected ( ) ?;
941
+ socket. try_io ( Interest :: WRITABLE , || state. send ( socket. into ( ) , transmit) )
942
+ }
943
+ Err ( TryLockError :: Poisoned ( e) ) => panic ! ( "socket lock poisoned: {e}" ) ,
944
+ Err ( TryLockError :: WouldBlock ) => {
945
+ Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "locked" ) )
946
+ }
947
+ }
948
+ }
949
+ }
950
+
951
+ /// Send future quinn
952
+ #[ derive( Debug ) ]
953
+ pub struct SendFutQuinn < ' a , ' b > {
954
+ socket : Arc < UdpSocket > ,
955
+ transmit : & ' a quinn_udp:: Transmit < ' b > ,
956
+ }
957
+
958
+ impl Future for SendFutQuinn < ' _ , ' _ > {
959
+ type Output = io:: Result < ( ) > ;
960
+
961
+ fn poll ( self : Pin < & mut Self > , cx : & mut std:: task:: Context < ' _ > ) -> Poll < Self :: Output > {
962
+ loop {
963
+ if let Err ( err) = self . socket . maybe_rebind ( ) {
964
+ return Poll :: Ready ( Err ( err) ) ;
965
+ }
966
+
967
+ let guard =
968
+ n0_future:: ready!( self . socket. poll_read_socket( & self . socket. send_waker, cx) ) ;
969
+ let ( socket, state) = guard. try_get_connected ( ) ?;
970
+
971
+ match socket. poll_send_ready ( cx) {
972
+ Poll :: Pending => {
973
+ self . socket . send_waker . register ( cx. waker ( ) ) ;
974
+ return Poll :: Pending ;
975
+ }
976
+ Poll :: Ready ( Ok ( ( ) ) ) => {
977
+ let res = socket. try_io ( Interest :: WRITABLE , || {
978
+ state. send ( socket. into ( ) , self . transmit )
979
+ } ) ;
980
+
981
+ if let Err ( err) = res {
982
+ if err. kind ( ) == io:: ErrorKind :: WouldBlock {
983
+ continue ;
984
+ }
985
+ if let Some ( err) = self . socket . handle_write_error ( err) {
986
+ return Poll :: Ready ( Err ( err) ) ;
987
+ }
988
+ continue ;
989
+ }
990
+ return Poll :: Ready ( res) ;
991
+ }
992
+ Poll :: Ready ( Err ( err) ) => {
993
+ if let Some ( err) = self . socket . handle_write_error ( err) {
994
+ return Poll :: Ready ( Err ( err) ) ;
995
+ }
996
+ continue ;
997
+ }
998
+ }
999
+ }
1000
+ }
1001
+ }
1002
+
809
1003
#[ cfg( test) ]
810
1004
mod tests {
811
1005
use testresult:: TestResult ;
0 commit comments