Skip to content

Commit f805b5d

Browse files
committed
Fixed bug, concatenated buffer kept available before first write successfully
1 parent 1746c62 commit f805b5d

File tree

8 files changed

+197
-99
lines changed

8 files changed

+197
-99
lines changed

Cargo.lock

+3-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "shadowsocks-rust"
3-
version = "1.10.4"
3+
version = "1.10.5"
44
authors = ["Shadowsocks Contributors"]
55
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
66
repository = "https://github.com/shadowsocks/shadowsocks-rust"
@@ -122,7 +122,7 @@ mimalloc = { version = "0.1", optional = true }
122122
tcmalloc = { version = "0.3", optional = true }
123123
jemallocator = { version = "0.3", optional = true }
124124

125-
shadowsocks-service = { version = "1.10.3", path = "./crates/shadowsocks-service" }
125+
shadowsocks-service = { version = "1.10.4", path = "./crates/shadowsocks-service" }
126126

127127
[target.'cfg(unix)'.dependencies]
128128
daemonize = "0.4"

crates/shadowsocks-service/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "shadowsocks-service"
3-
version = "1.10.3"
3+
version = "1.10.4"
44
authors = ["Shadowsocks Contributors"]
55
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
66
repository = "https://github.com/shadowsocks/shadowsocks-rust"
@@ -105,7 +105,7 @@ regex = "1.4"
105105
serde = { version = "1.0", features = ["derive"] }
106106
json5 = "0.3"
107107

108-
shadowsocks = { version = "1.10.1", path = "../shadowsocks" }
108+
shadowsocks = { version = "1.10.2", path = "../shadowsocks" }
109109

110110
strum = { version = "0.20", optional = true }
111111
strum_macros = { version = "0.20", optional = true }

crates/shadowsocks/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "shadowsocks"
3-
version = "1.10.1"
3+
version = "1.10.2"
44
authors = ["Shadowsocks Contributors"]
55
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
66
repository = "https://github.com/shadowsocks/shadowsocks-rust"

crates/shadowsocks/src/relay/tcprelay/proxy_stream/client.rs

+130-71
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use bytes::{BufMut, BytesMut};
1010
use futures::ready;
1111
use log::trace;
1212
use once_cell::sync::Lazy;
13+
use pin_project::pin_project;
1314
use tokio::{
1415
io::{AsyncRead, AsyncWrite, ReadBuf},
1516
net::TcpStream,
@@ -26,10 +27,18 @@ use crate::{
2627
},
2728
};
2829

30+
enum ProxyClientStreamWriteState {
31+
Connect(Address),
32+
Connecting(BytesMut),
33+
Connected,
34+
}
35+
2936
/// A stream for sending / receiving data stream from remote server via shadowsocks' proxy server
37+
#[pin_project]
3038
pub struct ProxyClientStream<S> {
39+
#[pin]
3140
stream: CryptoStream<S>,
32-
addr: Option<Address>,
41+
state: ProxyClientStreamWriteState,
3342
context: SharedContext,
3443
}
3544

@@ -140,7 +149,7 @@ where
140149

141150
ProxyClientStream {
142151
stream,
143-
addr: Some(addr),
152+
state: ProxyClientStreamWriteState::Connect(addr),
144153
context,
145154
}
146155
}
@@ -166,63 +175,85 @@ where
166175
S: AsyncRead + AsyncWrite + Unpin,
167176
{
168177
#[inline]
169-
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
170-
let context = unsafe { &*(self.context.as_ref() as *const _) };
171-
self.stream.poll_read_decrypted(cx, context, buf)
178+
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
179+
let mut this = self.project();
180+
this.stream.poll_read_decrypted(cx, &this.context, buf)
172181
}
173182
}
174183

175184
impl<S> AsyncWrite for ProxyClientStream<S>
176185
where
177186
S: AsyncRead + AsyncWrite + Unpin,
178187
{
179-
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
180-
match self.addr {
181-
None => {
182-
// For all subsequence calls, just proxy it to self.stream
183-
return self.stream.poll_write_encrypted(cx, buf);
184-
}
185-
Some(ref addr) => {
186-
let addr_length = addr.serialized_len();
187-
188-
let mut buffer = BytesMut::with_capacity(addr_length + buf.len());
189-
addr.write_to_buf(&mut buffer);
190-
buffer.put_slice(buf);
191-
192-
ready!(self.stream.poll_write_encrypted(cx, &buffer))?;
193-
194-
// fallthrough. take the self.addr out
188+
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
189+
let mut this = self.project();
190+
191+
loop {
192+
match this.state {
193+
ProxyClientStreamWriteState::Connect(ref addr) => {
194+
// Target Address should be sent with the first packet together,
195+
// which would prevent from being detected by connection features.
196+
197+
let addr_length = addr.serialized_len();
198+
199+
let mut buffer = BytesMut::with_capacity(addr_length + buf.len());
200+
addr.write_to_buf(&mut buffer);
201+
buffer.put_slice(buf);
202+
203+
// Save the concatenated buffer before it is written successfully.
204+
// APIs require buffer to be kept alive before Poll::Ready
205+
//
206+
// Proactor APIs like IOCP on Windows, pointers of buffers have to be kept alive
207+
// before IO completion.
208+
*(this.state) = ProxyClientStreamWriteState::Connecting(buffer);
209+
}
210+
ProxyClientStreamWriteState::Connecting(ref buffer) => {
211+
let n = ready!(this.stream.poll_write_encrypted(cx, &buffer))?;
212+
213+
// In general, poll_write_encrypted should perform like write_all.
214+
debug_assert!(n == buffer.len());
215+
216+
*(this.state) = ProxyClientStreamWriteState::Connected;
217+
218+
// NOTE:
219+
// poll_write will return Ok(0) if buf.len() == 0
220+
// But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
221+
//
222+
// https://github.com/shadowsocks/shadowsocks-rust/issues/232
223+
//
224+
// For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
225+
// This could be achieved by calling poll_write with an empty input buffer.
226+
return Ok(buf.len()).into();
227+
}
228+
ProxyClientStreamWriteState::Connected => {
229+
return this.stream.poll_write_encrypted(cx, buf);
230+
}
195231
}
196232
}
197-
198-
let _ = self.addr.take();
199-
200-
// NOTE:
201-
// poll_write will return Ok(0) if buf.len() == 0
202-
// But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
203-
//
204-
// https://github.com/shadowsocks/shadowsocks-rust/issues/232
205-
//
206-
// For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
207-
// This could be achieved by calling poll_write with an empty input buffer.
208-
209-
Ok(buf.len()).into()
210233
}
211234

212-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
213-
self.stream.poll_flush(cx)
235+
#[inline]
236+
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
237+
self.project().stream.poll_flush(cx)
214238
}
215239

216-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
217-
self.stream.poll_shutdown(cx)
240+
#[inline]
241+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
242+
self.project().stream.poll_shutdown(cx)
218243
}
219244
}
220245

221246
impl<S> ProxyClientStream<S>
222247
where
223248
S: AsyncRead + AsyncWrite + Unpin,
224249
{
250+
/// Splits into reader and writer halves
225251
pub fn into_split(self) -> (ProxyClientStreamReadHalf<S>, ProxyClientStreamWriteHalf<S>) {
252+
// Cannot split if stream is still pending
253+
assert!(
254+
!matches!(self.state, ProxyClientStreamWriteState::Connecting(..)),
255+
"stream is pending on writing the first packet"
256+
);
226257
let (reader, writer) = self.stream.into_split();
227258
(
228259
ProxyClientStreamReadHalf {
@@ -231,13 +262,16 @@ where
231262
},
232263
ProxyClientStreamWriteHalf {
233264
writer,
234-
addr: self.addr,
265+
state: self.state,
235266
},
236267
)
237268
}
238269
}
239270

271+
/// Owned read half produced by `ProxyClientStream::into_split`
272+
#[pin_project]
240273
pub struct ProxyClientStreamReadHalf<S> {
274+
#[pin]
241275
reader: CryptoStreamReadHalf<S>,
242276
context: SharedContext,
243277
}
@@ -247,53 +281,78 @@ where
247281
S: AsyncRead + Unpin,
248282
{
249283
#[inline]
250-
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
251-
let context = unsafe { &*(self.context.as_ref() as *const _) };
252-
self.reader.poll_read_decrypted(cx, context, buf)
284+
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
285+
let mut this = self.project();
286+
this.reader.poll_read_decrypted(cx, &this.context, buf)
253287
}
254288
}
255289

290+
/// Owned write half produced by `ProxyClientStream::into_split`
291+
#[pin_project]
256292
pub struct ProxyClientStreamWriteHalf<S> {
293+
#[pin]
257294
writer: CryptoStreamWriteHalf<S>,
258-
addr: Option<Address>,
295+
state: ProxyClientStreamWriteState,
259296
}
260297

261298
impl<S> AsyncWrite for ProxyClientStreamWriteHalf<S>
262299
where
263300
S: AsyncWrite + Unpin,
264301
{
265-
fn poll_write(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
266-
if self.addr.is_none() {
267-
// For all subsequence calls, just proxy it to self.writer
268-
return self.writer.poll_write_encrypted(cx, buf);
302+
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
303+
let mut this = self.project();
304+
305+
loop {
306+
match this.state {
307+
ProxyClientStreamWriteState::Connect(ref addr) => {
308+
// Target Address should be sent with the first packet together,
309+
// which would prevent from being detected by connection features.
310+
311+
let addr_length = addr.serialized_len();
312+
313+
let mut buffer = BytesMut::with_capacity(addr_length + buf.len());
314+
addr.write_to_buf(&mut buffer);
315+
buffer.put_slice(buf);
316+
317+
// Save the concatenated buffer before it is written successfully.
318+
// APIs require buffer to be kept alive before Poll::Ready
319+
//
320+
// Proactor APIs like IOCP on Windows, pointers of buffers have to be kept alive
321+
// before IO completion.
322+
*(this.state) = ProxyClientStreamWriteState::Connecting(buffer);
323+
}
324+
ProxyClientStreamWriteState::Connecting(ref buffer) => {
325+
let n = ready!(this.writer.poll_write_encrypted(cx, &buffer))?;
326+
327+
// In general, poll_write_encrypted should perform like write_all.
328+
debug_assert!(n == buffer.len());
329+
330+
*(this.state) = ProxyClientStreamWriteState::Connected;
331+
332+
// NOTE:
333+
// poll_write will return Ok(0) if buf.len() == 0
334+
// But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
335+
//
336+
// https://github.com/shadowsocks/shadowsocks-rust/issues/232
337+
//
338+
// For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
339+
// This could be achieved by calling poll_write with an empty input buffer.
340+
return Ok(buf.len()).into();
341+
}
342+
ProxyClientStreamWriteState::Connected => {
343+
return this.writer.poll_write_encrypted(cx, buf);
344+
}
345+
}
269346
}
270-
271-
let addr = self.addr.take().unwrap();
272-
let addr_length = addr.serialized_len();
273-
274-
let mut buffer = BytesMut::with_capacity(addr_length + buf.len());
275-
addr.write_to_buf(&mut buffer);
276-
buffer.put_slice(buf);
277-
278-
ready!(self.writer.poll_write_encrypted(cx, &buffer))?;
279-
280-
// NOTE:
281-
// poll_write will return Ok(0) if buf.len() == 0
282-
// But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
283-
//
284-
// https://github.com/shadowsocks/shadowsocks-rust/issues/232
285-
//
286-
// For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
287-
// This could be achieved by calling poll_write with an empty input buffer.
288-
289-
Ok(buf.len()).into()
290347
}
291348

292-
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
293-
self.writer.poll_flush(cx)
349+
#[inline]
350+
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
351+
self.project().writer.poll_flush(cx)
294352
}
295353

296-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
297-
self.writer.poll_shutdown(cx)
354+
#[inline]
355+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), io::Error>> {
356+
self.project().writer.poll_shutdown(cx)
298357
}
299358
}

0 commit comments

Comments
 (0)