@@ -10,6 +10,7 @@ use bytes::{BufMut, BytesMut};
10
10
use futures:: ready;
11
11
use log:: trace;
12
12
use once_cell:: sync:: Lazy ;
13
+ use pin_project:: pin_project;
13
14
use tokio:: {
14
15
io:: { AsyncRead , AsyncWrite , ReadBuf } ,
15
16
net:: TcpStream ,
@@ -26,10 +27,18 @@ use crate::{
26
27
} ,
27
28
} ;
28
29
30
+ enum ProxyClientStreamWriteState {
31
+ Connect ( Address ) ,
32
+ Connecting ( BytesMut ) ,
33
+ Connected ,
34
+ }
35
+
29
36
/// A stream for sending / receiving data stream from remote server via shadowsocks' proxy server
37
+ #[ pin_project]
30
38
pub struct ProxyClientStream < S > {
39
+ #[ pin]
31
40
stream : CryptoStream < S > ,
32
- addr : Option < Address > ,
41
+ state : ProxyClientStreamWriteState ,
33
42
context : SharedContext ,
34
43
}
35
44
@@ -140,7 +149,7 @@ where
140
149
141
150
ProxyClientStream {
142
151
stream,
143
- addr : Some ( addr) ,
152
+ state : ProxyClientStreamWriteState :: Connect ( addr) ,
144
153
context,
145
154
}
146
155
}
@@ -166,63 +175,85 @@ where
166
175
S : AsyncRead + AsyncWrite + Unpin ,
167
176
{
168
177
#[ inline]
169
- fn poll_read ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
170
- let context = unsafe { & * ( self . context . as_ref ( ) as * const _ ) } ;
171
- self . stream . poll_read_decrypted ( cx, context, buf)
178
+ fn poll_read ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
179
+ let mut this = self . project ( ) ;
180
+ this . stream . poll_read_decrypted ( cx, & this . context , buf)
172
181
}
173
182
}
174
183
175
184
impl < S > AsyncWrite for ProxyClientStream < S >
176
185
where
177
186
S : AsyncRead + AsyncWrite + Unpin ,
178
187
{
179
- fn poll_write ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io:: Error > > {
180
- match self . addr {
181
- None => {
182
- // For all subsequence calls, just proxy it to self.stream
183
- return self . stream . poll_write_encrypted ( cx, buf) ;
184
- }
185
- Some ( ref addr) => {
186
- let addr_length = addr. serialized_len ( ) ;
187
-
188
- let mut buffer = BytesMut :: with_capacity ( addr_length + buf. len ( ) ) ;
189
- addr. write_to_buf ( & mut buffer) ;
190
- buffer. put_slice ( buf) ;
191
-
192
- ready ! ( self . stream. poll_write_encrypted( cx, & buffer) ) ?;
193
-
194
- // fallthrough. take the self.addr out
188
+ fn poll_write ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io:: Error > > {
189
+ let mut this = self . project ( ) ;
190
+
191
+ loop {
192
+ match this. state {
193
+ ProxyClientStreamWriteState :: Connect ( ref addr) => {
194
+ // Target Address should be sent with the first packet together,
195
+ // which would prevent from being detected by connection features.
196
+
197
+ let addr_length = addr. serialized_len ( ) ;
198
+
199
+ let mut buffer = BytesMut :: with_capacity ( addr_length + buf. len ( ) ) ;
200
+ addr. write_to_buf ( & mut buffer) ;
201
+ buffer. put_slice ( buf) ;
202
+
203
+ // Save the concatenated buffer before it is written successfully.
204
+ // APIs require buffer to be kept alive before Poll::Ready
205
+ //
206
+ // Proactor APIs like IOCP on Windows, pointers of buffers have to be kept alive
207
+ // before IO completion.
208
+ * ( this. state ) = ProxyClientStreamWriteState :: Connecting ( buffer) ;
209
+ }
210
+ ProxyClientStreamWriteState :: Connecting ( ref buffer) => {
211
+ let n = ready ! ( this. stream. poll_write_encrypted( cx, & buffer) ) ?;
212
+
213
+ // In general, poll_write_encrypted should perform like write_all.
214
+ debug_assert ! ( n == buffer. len( ) ) ;
215
+
216
+ * ( this. state ) = ProxyClientStreamWriteState :: Connected ;
217
+
218
+ // NOTE:
219
+ // poll_write will return Ok(0) if buf.len() == 0
220
+ // But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
221
+ //
222
+ // https://github.com/shadowsocks/shadowsocks-rust/issues/232
223
+ //
224
+ // For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
225
+ // This could be achieved by calling poll_write with an empty input buffer.
226
+ return Ok ( buf. len ( ) ) . into ( ) ;
227
+ }
228
+ ProxyClientStreamWriteState :: Connected => {
229
+ return this. stream . poll_write_encrypted ( cx, buf) ;
230
+ }
195
231
}
196
232
}
197
-
198
- let _ = self . addr . take ( ) ;
199
-
200
- // NOTE:
201
- // poll_write will return Ok(0) if buf.len() == 0
202
- // But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
203
- //
204
- // https://github.com/shadowsocks/shadowsocks-rust/issues/232
205
- //
206
- // For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
207
- // This could be achieved by calling poll_write with an empty input buffer.
208
-
209
- Ok ( buf. len ( ) ) . into ( )
210
233
}
211
234
212
- fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
213
- self . stream . poll_flush ( cx)
235
+ #[ inline]
236
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
237
+ self . project ( ) . stream . poll_flush ( cx)
214
238
}
215
239
216
- fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
217
- self . stream . poll_shutdown ( cx)
240
+ #[ inline]
241
+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
242
+ self . project ( ) . stream . poll_shutdown ( cx)
218
243
}
219
244
}
220
245
221
246
impl < S > ProxyClientStream < S >
222
247
where
223
248
S : AsyncRead + AsyncWrite + Unpin ,
224
249
{
250
+ /// Splits into reader and writer halves
225
251
pub fn into_split ( self ) -> ( ProxyClientStreamReadHalf < S > , ProxyClientStreamWriteHalf < S > ) {
252
+ // Cannot split if stream is still pending
253
+ assert ! (
254
+ !matches!( self . state, ProxyClientStreamWriteState :: Connecting ( ..) ) ,
255
+ "stream is pending on writing the first packet"
256
+ ) ;
226
257
let ( reader, writer) = self . stream . into_split ( ) ;
227
258
(
228
259
ProxyClientStreamReadHalf {
@@ -231,13 +262,16 @@ where
231
262
} ,
232
263
ProxyClientStreamWriteHalf {
233
264
writer,
234
- addr : self . addr ,
265
+ state : self . state ,
235
266
} ,
236
267
)
237
268
}
238
269
}
239
270
271
+ /// Owned read half produced by `ProxyClientStream::into_split`
272
+ #[ pin_project]
240
273
pub struct ProxyClientStreamReadHalf < S > {
274
+ #[ pin]
241
275
reader : CryptoStreamReadHalf < S > ,
242
276
context : SharedContext ,
243
277
}
@@ -247,53 +281,78 @@ where
247
281
S : AsyncRead + Unpin ,
248
282
{
249
283
#[ inline]
250
- fn poll_read ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
251
- let context = unsafe { & * ( self . context . as_ref ( ) as * const _ ) } ;
252
- self . reader . poll_read_decrypted ( cx, context, buf)
284
+ fn poll_read ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < io:: Result < ( ) > > {
285
+ let mut this = self . project ( ) ;
286
+ this . reader . poll_read_decrypted ( cx, & this . context , buf)
253
287
}
254
288
}
255
289
290
+ /// Owned write half produced by `ProxyClientStream::into_split`
291
+ #[ pin_project]
256
292
pub struct ProxyClientStreamWriteHalf < S > {
293
+ #[ pin]
257
294
writer : CryptoStreamWriteHalf < S > ,
258
- addr : Option < Address > ,
295
+ state : ProxyClientStreamWriteState ,
259
296
}
260
297
261
298
impl < S > AsyncWrite for ProxyClientStreamWriteHalf < S >
262
299
where
263
300
S : AsyncWrite + Unpin ,
264
301
{
265
- fn poll_write ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io:: Error > > {
266
- if self . addr . is_none ( ) {
267
- // For all subsequence calls, just proxy it to self.writer
268
- return self . writer . poll_write_encrypted ( cx, buf) ;
302
+ fn poll_write ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize , io:: Error > > {
303
+ let mut this = self . project ( ) ;
304
+
305
+ loop {
306
+ match this. state {
307
+ ProxyClientStreamWriteState :: Connect ( ref addr) => {
308
+ // Target Address should be sent with the first packet together,
309
+ // which would prevent from being detected by connection features.
310
+
311
+ let addr_length = addr. serialized_len ( ) ;
312
+
313
+ let mut buffer = BytesMut :: with_capacity ( addr_length + buf. len ( ) ) ;
314
+ addr. write_to_buf ( & mut buffer) ;
315
+ buffer. put_slice ( buf) ;
316
+
317
+ // Save the concatenated buffer before it is written successfully.
318
+ // APIs require buffer to be kept alive before Poll::Ready
319
+ //
320
+ // Proactor APIs like IOCP on Windows, pointers of buffers have to be kept alive
321
+ // before IO completion.
322
+ * ( this. state ) = ProxyClientStreamWriteState :: Connecting ( buffer) ;
323
+ }
324
+ ProxyClientStreamWriteState :: Connecting ( ref buffer) => {
325
+ let n = ready ! ( this. writer. poll_write_encrypted( cx, & buffer) ) ?;
326
+
327
+ // In general, poll_write_encrypted should perform like write_all.
328
+ debug_assert ! ( n == buffer. len( ) ) ;
329
+
330
+ * ( this. state ) = ProxyClientStreamWriteState :: Connected ;
331
+
332
+ // NOTE:
333
+ // poll_write will return Ok(0) if buf.len() == 0
334
+ // But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
335
+ //
336
+ // https://github.com/shadowsocks/shadowsocks-rust/issues/232
337
+ //
338
+ // For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
339
+ // This could be achieved by calling poll_write with an empty input buffer.
340
+ return Ok ( buf. len ( ) ) . into ( ) ;
341
+ }
342
+ ProxyClientStreamWriteState :: Connected => {
343
+ return this. writer . poll_write_encrypted ( cx, buf) ;
344
+ }
345
+ }
269
346
}
270
-
271
- let addr = self . addr . take ( ) . unwrap ( ) ;
272
- let addr_length = addr. serialized_len ( ) ;
273
-
274
- let mut buffer = BytesMut :: with_capacity ( addr_length + buf. len ( ) ) ;
275
- addr. write_to_buf ( & mut buffer) ;
276
- buffer. put_slice ( buf) ;
277
-
278
- ready ! ( self . writer. poll_write_encrypted( cx, & buffer) ) ?;
279
-
280
- // NOTE:
281
- // poll_write will return Ok(0) if buf.len() == 0
282
- // But for the first call, this function will eventually send the handshake packet (IV/Salt + ADDR) to the remote address.
283
- //
284
- // https://github.com/shadowsocks/shadowsocks-rust/issues/232
285
- //
286
- // For protocols that requires *Server Hello* message, like FTP, clients won't send anything to the server until server sends handshake messages.
287
- // This could be achieved by calling poll_write with an empty input buffer.
288
-
289
- Ok ( buf. len ( ) ) . into ( )
290
347
}
291
348
292
- fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
293
- self . writer . poll_flush ( cx)
349
+ #[ inline]
350
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
351
+ self . project ( ) . writer . poll_flush ( cx)
294
352
}
295
353
296
- fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
297
- self . writer . poll_shutdown ( cx)
354
+ #[ inline]
355
+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut task:: Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
356
+ self . project ( ) . writer . poll_shutdown ( cx)
298
357
}
299
358
}
0 commit comments