@@ -5,18 +5,45 @@ use std::{
5
5
mem,
6
6
net:: { self , IpAddr , SocketAddr } ,
7
7
os:: windows:: io:: AsRawSocket ,
8
+ ptr,
8
9
} ;
9
10
10
- use log:: error;
11
+ use lazy_static:: lazy_static;
12
+ use log:: { error, warn} ;
11
13
use net2:: TcpBuilder ;
12
14
use tokio:: net:: { TcpListener , TcpStream } ;
13
15
use winapi:: {
14
16
ctypes:: { c_char, c_int} ,
15
17
shared:: {
16
- minwindef:: DWORD ,
17
- ws2def:: { ADDRESS_FAMILY , AF_INET , AF_INET6 , IPPROTO_TCP , SOCKADDR , SOCKADDR_IN } ,
18
+ minwindef:: { BOOL , DWORD , FALSE , LPDWORD , LPVOID , TRUE } ,
19
+ ws2def:: {
20
+ ADDRESS_FAMILY ,
21
+ AF_INET ,
22
+ AF_INET6 ,
23
+ IPPROTO_TCP ,
24
+ SIO_GET_EXTENSION_FUNCTION_POINTER ,
25
+ SOCKADDR ,
26
+ SOCKADDR_IN ,
27
+ } ,
28
+ } ,
29
+ um:: {
30
+ minwinbase:: OVERLAPPED ,
31
+ mswsock:: { LPFN_CONNECTEX , WSAID_CONNECTEX } ,
32
+ winnt:: PVOID ,
33
+ winsock2:: {
34
+ bind,
35
+ closesocket,
36
+ setsockopt,
37
+ socket,
38
+ WSAGetLastError ,
39
+ WSAGetOverlappedResult ,
40
+ WSAIoctl ,
41
+ INVALID_SOCKET ,
42
+ SOCKET ,
43
+ SOCKET_ERROR ,
44
+ SOCK_STREAM ,
45
+ } ,
18
46
} ,
19
- um:: winsock2:: { bind, connect, setsockopt, WSAGetLastError , SOCKET , SOCKET_ERROR } ,
20
47
} ;
21
48
22
49
// ws2ipdef.h
@@ -61,7 +88,101 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
61
88
TcpListener :: from_std ( listener)
62
89
}
63
90
64
- pub async fn connect_stream ( addr : & SocketAddr ) -> io:: Result < TcpStream > {
91
+ lazy_static ! {
92
+ static ref PFN_CONNECTEX_OPT : LPFN_CONNECTEX = unsafe {
93
+ let socket = socket( AF_INET , SOCK_STREAM , 0 ) ;
94
+ if socket == INVALID_SOCKET {
95
+ return None ;
96
+ }
97
+
98
+ let mut guid = WSAID_CONNECTEX ;
99
+ let mut num_bytes: DWORD = 0 ;
100
+
101
+ let mut connectex: LPFN_CONNECTEX = None ;
102
+
103
+ let ret = WSAIoctl (
104
+ socket,
105
+ SIO_GET_EXTENSION_FUNCTION_POINTER ,
106
+ & mut guid as * mut _ as LPVOID ,
107
+ mem:: size_of_val( & guid) as DWORD ,
108
+ & mut connectex as * mut _ as LPVOID ,
109
+ mem:: size_of_val( & connectex) as DWORD ,
110
+ & mut num_bytes as * mut _,
111
+ ptr:: null_mut( ) ,
112
+ None ,
113
+ ) ;
114
+
115
+ if ret != 0 {
116
+ let err = WSAGetLastError ( ) ;
117
+ let e = Error :: from_raw_os_error( err) ;
118
+
119
+ warn!( "Failed to get ConnectEx function from WSA extension, error: {}" , e) ;
120
+ }
121
+
122
+ let _ = closesocket( socket) ;
123
+
124
+ connectex
125
+ } ;
126
+ }
127
+
128
+ pub struct ConnectContext {
129
+ // Reference to the partial connected socket fd
130
+ // This struct doesn't own the HANDLE, so do not close it while dropping
131
+ socket : SOCKET ,
132
+
133
+ // Target address for calling `ConnectEx`
134
+ remote_addr : SocketAddr ,
135
+ }
136
+
137
+ impl ConnectContext {
138
+ /// Performing actual connect operation
139
+ pub fn connect_with_data ( self , buf : & [ u8 ] ) -> io:: Result < usize > {
140
+ unsafe {
141
+ // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_connectex
142
+ let connect_ex = PFN_CONNECTEX_OPT . expect ( "LPFN_CONNECTEX function doesn't exists" ) ;
143
+ let ( saddr, saddr_len) = addr2raw ( & self . remote_addr ) ;
144
+
145
+ let mut overlapped: OVERLAPPED = mem:: zeroed ( ) ;
146
+
147
+ let mut bytes_sent: DWORD = 0 ;
148
+ let ret: BOOL = connect_ex (
149
+ self . socket ,
150
+ saddr,
151
+ saddr_len,
152
+ buf. as_ptr ( ) as PVOID ,
153
+ buf. len ( ) as DWORD ,
154
+ & mut bytes_sent as * mut _ as LPDWORD ,
155
+ & mut overlapped as * mut _ ,
156
+ ) ;
157
+
158
+ if ret == FALSE {
159
+ let mut bytes_sent: DWORD = 0 ;
160
+ let mut flags: DWORD = 0 ;
161
+
162
+ // FIXME: Blocking call.
163
+ let ret: BOOL = WSAGetOverlappedResult (
164
+ self . socket ,
165
+ & mut overlapped as * mut _ ,
166
+ & mut bytes_sent as LPDWORD ,
167
+ TRUE ,
168
+ & mut flags as LPDWORD ,
169
+ ) ;
170
+
171
+ if ret == TRUE {
172
+ Ok ( bytes_sent as usize )
173
+ } else {
174
+ let err = WSAGetLastError ( ) ;
175
+ Err ( Error :: from_raw_os_error ( err) )
176
+ }
177
+ } else {
178
+ // Connect succeeded
179
+ Ok ( bytes_sent as usize )
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ pub async fn connect_stream ( addr : & SocketAddr ) -> io:: Result < ( TcpStream , ConnectContext ) > {
65
186
let builder = match addr. ip ( ) {
66
187
IpAddr :: V4 ( ..) => TcpBuilder :: new_v4 ( ) ?,
67
188
IpAddr :: V6 ( ..) => TcpBuilder :: new_v6 ( ) ?,
@@ -113,21 +234,17 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
113
234
let err = WSAGetLastError ( ) ;
114
235
return Err ( Error :: from_raw_os_error ( err) ) ;
115
236
}
116
-
117
- // FIXME: MSDN suggests to use ConnectEx instead of connect
118
- // But it requires dynamic load from WSAIoctl and cache it in a global variable
119
- // That sucks.
120
-
121
- let ( saddr, saddr_len) = addr2raw ( addr) ;
122
- let ret = connect ( socket, saddr, saddr_len) ;
123
-
124
- if ret == SOCKET_ERROR {
125
- let err = WSAGetLastError ( ) ;
126
- return Err ( Error :: from_raw_os_error ( err) ) ;
127
- }
128
237
}
129
238
130
- TcpStream :: from_std ( stream)
239
+ TcpStream :: from_std ( stream) . map ( |s| {
240
+ (
241
+ s,
242
+ ConnectContext {
243
+ socket,
244
+ remote_addr : * addr,
245
+ } ,
246
+ )
247
+ } )
131
248
}
132
249
133
250
// Borrowed from net2
0 commit comments