Skip to content

Commit db5a2ca

Browse files
committed
[#184] Uses blocking ConnectEx on Windows (FIXME)
1 parent b35758c commit db5a2ca

File tree

4 files changed

+151
-27
lines changed

4 files changed

+151
-27
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cfg-if = "0.1"
7979

8080
[target.'cfg(windows)'.dependencies]
8181
winapi = { version = "0.3", features = ["mswsock", "winsock2"] }
82+
lazy_static = "1.4"
8283

8384
# [patch.crates-io]
8485
# libc = { git = "https://github.com/zonyitoo/libc.git", branch = "feature-linux-fastopen-connect", optional = true }

build/build-release

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,5 @@ function build() {
5757
echo "* Done build package ${PKG_NAME}"
5858
}
5959

60-
build "x86_64-unknown-linux-musl"
61-
#build "x86_64-pc-windows-gnu"
60+
#build "x86_64-unknown-linux-musl"
61+
build "x86_64-pc-windows-gnu"

src/relay/tcprelay/utils/tfo/bsd.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
use std::{
44
io::{self, Error},
55
mem,
6-
net::{self, SocketAddr},
6+
net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
77
os::unix::io::AsRawFd,
88
};
99

1010
use libc;
1111
use log::error;
12-
use tokio::net::{TcpListener, TcpStream};
12+
use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream};
1313

1414
fn create_socket(domain: libc::c_int) -> io::Result<libc::c_int> {
1515
unsafe {
@@ -87,7 +87,7 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
8787
return Err(Error::last_os_error());
8888
}
8989

90-
TcpListener::from_std(net::TcpListener::from_raw_fd(sockfd))
90+
TcpListener::from_std(StdTcpListener::from_raw_fd(sockfd))
9191
}
9292
}
9393

@@ -124,7 +124,7 @@ impl ConnectContext {
124124
}
125125
}
126126

127-
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
127+
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> {
128128
let domain = match addr {
129129
SocketAddr::V4(..) => libc::AF_INET,
130130
SocketAddr::V6(..) => libc::AF_INET6,
@@ -176,10 +176,16 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
176176
return Err(Error::last_os_error());
177177
}
178178

179-
TcpStream::from_std(net::TcpStream::from_raw_fd(sockfd))
179+
TcpStream::from_std(StdTcpStream::from_raw_fd(sockfd)).map(|s| {
180+
(
181+
s,
182+
ConnectContext {
183+
socket: sockfd,
184+
remote_addr: *addr,
185+
},
186+
)
187+
})
180188
}
181-
182-
TcpStream::from_std(stream)
183189
}
184190

185191
// Borrowed from net2

src/relay/tcprelay/utils/tfo/windows.rs

+135-18
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,45 @@ use std::{
55
mem,
66
net::{self, IpAddr, SocketAddr},
77
os::windows::io::AsRawSocket,
8+
ptr,
89
};
910

10-
use log::error;
11+
use lazy_static::lazy_static;
12+
use log::{error, warn};
1113
use net2::TcpBuilder;
1214
use tokio::net::{TcpListener, TcpStream};
1315
use winapi::{
1416
ctypes::{c_char, c_int},
1517
shared::{
16-
minwindef::DWORD,
17-
ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6, IPPROTO_TCP, SOCKADDR, SOCKADDR_IN},
18+
minwindef::{BOOL, DWORD, FALSE, LPDWORD, LPVOID, TRUE},
19+
ws2def::{
20+
ADDRESS_FAMILY,
21+
AF_INET,
22+
AF_INET6,
23+
IPPROTO_TCP,
24+
SIO_GET_EXTENSION_FUNCTION_POINTER,
25+
SOCKADDR,
26+
SOCKADDR_IN,
27+
},
28+
},
29+
um::{
30+
minwinbase::OVERLAPPED,
31+
mswsock::{LPFN_CONNECTEX, WSAID_CONNECTEX},
32+
winnt::PVOID,
33+
winsock2::{
34+
bind,
35+
closesocket,
36+
setsockopt,
37+
socket,
38+
WSAGetLastError,
39+
WSAGetOverlappedResult,
40+
WSAIoctl,
41+
INVALID_SOCKET,
42+
SOCKET,
43+
SOCKET_ERROR,
44+
SOCK_STREAM,
45+
},
1846
},
19-
um::winsock2::{bind, connect, setsockopt, WSAGetLastError, SOCKET, SOCKET_ERROR},
2047
};
2148

2249
// ws2ipdef.h
@@ -61,7 +88,101 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
6188
TcpListener::from_std(listener)
6289
}
6390

64-
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
91+
lazy_static! {
92+
static ref PFN_CONNECTEX_OPT: LPFN_CONNECTEX = unsafe {
93+
let socket = socket(AF_INET, SOCK_STREAM, 0);
94+
if socket == INVALID_SOCKET {
95+
return None;
96+
}
97+
98+
let mut guid = WSAID_CONNECTEX;
99+
let mut num_bytes: DWORD = 0;
100+
101+
let mut connectex: LPFN_CONNECTEX = None;
102+
103+
let ret = WSAIoctl(
104+
socket,
105+
SIO_GET_EXTENSION_FUNCTION_POINTER,
106+
&mut guid as *mut _ as LPVOID,
107+
mem::size_of_val(&guid) as DWORD,
108+
&mut connectex as *mut _ as LPVOID,
109+
mem::size_of_val(&connectex) as DWORD,
110+
&mut num_bytes as *mut _,
111+
ptr::null_mut(),
112+
None,
113+
);
114+
115+
if ret != 0 {
116+
let err = WSAGetLastError();
117+
let e = Error::from_raw_os_error(err);
118+
119+
warn!("Failed to get ConnectEx function from WSA extension, error: {}", e);
120+
}
121+
122+
let _ = closesocket(socket);
123+
124+
connectex
125+
};
126+
}
127+
128+
pub struct ConnectContext {
129+
// Reference to the partial connected socket fd
130+
// This struct doesn't own the HANDLE, so do not close it while dropping
131+
socket: SOCKET,
132+
133+
// Target address for calling `ConnectEx`
134+
remote_addr: SocketAddr,
135+
}
136+
137+
impl ConnectContext {
138+
/// Performing actual connect operation
139+
pub fn connect_with_data(self, buf: &[u8]) -> io::Result<usize> {
140+
unsafe {
141+
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_connectex
142+
let connect_ex = PFN_CONNECTEX_OPT.expect("LPFN_CONNECTEX function doesn't exists");
143+
let (saddr, saddr_len) = addr2raw(&self.remote_addr);
144+
145+
let mut overlapped: OVERLAPPED = mem::zeroed();
146+
147+
let mut bytes_sent: DWORD = 0;
148+
let ret: BOOL = connect_ex(
149+
self.socket,
150+
saddr,
151+
saddr_len,
152+
buf.as_ptr() as PVOID,
153+
buf.len() as DWORD,
154+
&mut bytes_sent as *mut _ as LPDWORD,
155+
&mut overlapped as *mut _,
156+
);
157+
158+
if ret == FALSE {
159+
let mut bytes_sent: DWORD = 0;
160+
let mut flags: DWORD = 0;
161+
162+
// FIXME: Blocking call.
163+
let ret: BOOL = WSAGetOverlappedResult(
164+
self.socket,
165+
&mut overlapped as *mut _,
166+
&mut bytes_sent as LPDWORD,
167+
TRUE,
168+
&mut flags as LPDWORD,
169+
);
170+
171+
if ret == TRUE {
172+
Ok(bytes_sent as usize)
173+
} else {
174+
let err = WSAGetLastError();
175+
Err(Error::from_raw_os_error(err))
176+
}
177+
} else {
178+
// Connect succeeded
179+
Ok(bytes_sent as usize)
180+
}
181+
}
182+
}
183+
}
184+
185+
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> {
65186
let builder = match addr.ip() {
66187
IpAddr::V4(..) => TcpBuilder::new_v4()?,
67188
IpAddr::V6(..) => TcpBuilder::new_v6()?,
@@ -113,21 +234,17 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
113234
let err = WSAGetLastError();
114235
return Err(Error::from_raw_os_error(err));
115236
}
116-
117-
// FIXME: MSDN suggests to use ConnectEx instead of connect
118-
// But it requires dynamic load from WSAIoctl and cache it in a global variable
119-
// That sucks.
120-
121-
let (saddr, saddr_len) = addr2raw(addr);
122-
let ret = connect(socket, saddr, saddr_len);
123-
124-
if ret == SOCKET_ERROR {
125-
let err = WSAGetLastError();
126-
return Err(Error::from_raw_os_error(err));
127-
}
128237
}
129238

130-
TcpStream::from_std(stream)
239+
TcpStream::from_std(stream).map(|s| {
240+
(
241+
s,
242+
ConnectContext {
243+
socket,
244+
remote_addr: *addr,
245+
},
246+
)
247+
})
131248
}
132249

133250
// Borrowed from net2

0 commit comments

Comments
 (0)