Skip to content

Commit a9fdfc5

Browse files
committed
[#184] Send IV/Salt with the first payload packet
TFO on macOS (seems) the second send call must wait until the first recv has called. Don't know why.
1 parent 353e7fc commit a9fdfc5

File tree

4 files changed

+68
-65
lines changed

4 files changed

+68
-65
lines changed

src/relay/tcprelay/aead.rs

+16-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use std::{
4343
};
4444

4545
use byteorder::{BigEndian, ByteOrder};
46-
use bytes::{BufMut, BytesMut};
46+
use bytes::{BufMut, Bytes, BytesMut};
4747
use futures::ready;
4848
use tokio::prelude::*;
4949

@@ -217,15 +217,17 @@ pub struct EncryptedWriter {
217217
cipher: BoxAeadEncryptor,
218218
tag_size: usize,
219219
steps: EncryptWriteStep,
220+
nonce_opt: Option<Bytes>,
220221
}
221222

222223
impl EncryptedWriter {
223224
/// Creates a new EncryptedWriter
224-
pub fn new(t: CipherType, key: &[u8], nonce: &[u8]) -> EncryptedWriter {
225+
pub fn new(t: CipherType, key: &[u8], nonce: Bytes) -> EncryptedWriter {
225226
EncryptedWriter {
226-
cipher: crypto::new_aead_encryptor(t, key, nonce),
227+
cipher: crypto::new_aead_encryptor(t, key, &nonce),
227228
tag_size: t.tag_size(),
228229
steps: EncryptWriteStep::Nothing,
230+
nonce_opt: Some(nonce),
229231
}
230232
}
231233

@@ -253,7 +255,17 @@ impl EncryptedWriter {
253255
let output_length = self.buffer_size(data);
254256
let data_length = data.len() as u16;
255257

256-
let mut buf = BytesMut::with_capacity(output_length);
258+
// First packet is IV
259+
let iv_len = match self.nonce_opt {
260+
Some(ref v) => v.len(),
261+
None => 0,
262+
};
263+
264+
let mut buf = BytesMut::with_capacity(iv_len + output_length);
265+
266+
if let Some(iv) = self.nonce_opt.take() {
267+
buf.extend(iv);
268+
}
257269

258270
let mut data_len_buf = [0u8; 2];
259271
BigEndian::write_u16(&mut data_len_buf, data_length);

src/relay/tcprelay/crypto_io.rs

+15-43
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use std::{
99
};
1010

1111
use byte_string::ByteStr;
12-
use bytes::Bytes;
1312
use futures::ready;
1413
use log::trace;
1514
use tokio::prelude::*;
@@ -36,18 +35,12 @@ enum ReadStatus {
3635
Established,
3736
}
3837

39-
enum WriteStatus {
40-
SendIv(Bytes, usize),
41-
Established,
42-
}
43-
4438
pub struct CryptoStream<S> {
4539
stream: S,
4640
dec: Option<DecryptedReader>,
47-
enc: Option<EncryptedWriter>,
41+
enc: EncryptedWriter,
4842
svr_cfg: Arc<ServerConfig>,
4943
read_status: ReadStatus,
50-
write_status: WriteStatus,
5144
}
5245

5346
impl<S: Unpin> Unpin for CryptoStream<S> {}
@@ -73,13 +66,24 @@ impl<S> CryptoStream<S> {
7366
}
7467
};
7568

69+
let method = svr_cfg.method();
70+
let enc = match method.category() {
71+
CipherCategory::Stream => {
72+
trace!("Sent Stream cipher IV {:?}", ByteStr::new(&local_iv));
73+
EncryptedWriter::Stream(StreamEncryptedWriter::new(method, svr_cfg.key(), local_iv))
74+
}
75+
CipherCategory::Aead => {
76+
trace!("Sent AEAD cipher salt {:?}", ByteStr::new(&local_iv));
77+
EncryptedWriter::Aead(AeadEncryptedWriter::new(method, svr_cfg.key(), local_iv))
78+
}
79+
};
80+
7681
CryptoStream {
7782
stream,
7883
dec: None,
79-
enc: None,
84+
enc,
8085
svr_cfg,
8186
read_status: ReadStatus::WaitIv(vec![0u8; prev_len], 0usize),
82-
write_status: WriteStatus::SendIv(local_iv, 0usize),
8387
}
8488
}
8589
}
@@ -133,41 +137,9 @@ impl<S> CryptoStream<S>
133137
where
134138
S: AsyncWrite + Unpin,
135139
{
136-
fn poll_write_handshake(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137-
if let WriteStatus::SendIv(ref iv, ref mut pos) = self.write_status {
138-
while *pos < iv.len() {
139-
let n = ready!(Pin::new(&mut self.stream).poll_write(cx, &iv[*pos..]))?;
140-
if n == 0 {
141-
use std::io::ErrorKind;
142-
return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()));
143-
}
144-
*pos += n;
145-
}
146-
147-
let method = self.svr_cfg.method();
148-
let enc = match method.category() {
149-
CipherCategory::Stream => {
150-
trace!("Sent Stream cipher IV {:?}", ByteStr::new(&iv));
151-
EncryptedWriter::Stream(StreamEncryptedWriter::new(method, self.svr_cfg.key(), &iv))
152-
}
153-
CipherCategory::Aead => {
154-
trace!("Sent AEAD cipher salt {:?}", ByteStr::new(&iv));
155-
EncryptedWriter::Aead(AeadEncryptedWriter::new(method, self.svr_cfg.key(), &iv))
156-
}
157-
};
158-
159-
self.enc = Some(enc);
160-
self.write_status = WriteStatus::Established;
161-
}
162-
163-
Poll::Ready(Ok(()))
164-
}
165-
166140
fn priv_poll_write(mut self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
167-
ready!(self.poll_write_handshake(ctx))?;
168-
169141
let stream = unsafe { &mut *(&mut self.stream as *mut _) };
170-
match *self.enc.as_mut().unwrap() {
142+
match self.enc {
171143
EncryptedWriter::Aead(ref mut w) => w.poll_write_encrypted(ctx, stream, buf),
172144
EncryptedWriter::Stream(ref mut w) => w.poll_write_encrypted(ctx, stream, buf),
173145
}

src/relay/tcprelay/stream.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99
};
1010

1111
use crate::crypto::{new_stream, BoxStreamCipher, CipherType, CryptoMode};
12-
use bytes::{BufMut, BytesMut};
12+
use bytes::{BufMut, Bytes, BytesMut};
1313
use futures::ready;
1414
use tokio::prelude::*;
1515

@@ -94,14 +94,16 @@ enum EncryptWriteStep {
9494
pub struct EncryptedWriter {
9595
cipher: BoxStreamCipher,
9696
steps: EncryptWriteStep,
97+
iv_opt: Option<Bytes>,
9798
}
9899

99100
impl EncryptedWriter {
100101
/// Creates a new EncryptedWriter
101-
pub fn new(t: CipherType, key: &[u8], iv: &[u8]) -> EncryptedWriter {
102+
pub fn new(t: CipherType, key: &[u8], iv: Bytes) -> EncryptedWriter {
102103
EncryptedWriter {
103-
cipher: new_stream(t, key, iv, CryptoMode::Encrypt),
104+
cipher: new_stream(t, key, &iv, CryptoMode::Encrypt),
104105
steps: EncryptWriteStep::Nothing,
106+
iv_opt: Some(iv),
105107
}
106108
}
107109

@@ -122,7 +124,16 @@ impl EncryptedWriter {
122124
loop {
123125
match self.steps {
124126
EncryptWriteStep::Nothing => {
125-
let mut buf = BytesMut::with_capacity(self.buffer_size(data));
127+
let iv_len = match self.iv_opt {
128+
Some(ref iv) => iv.len(),
129+
None => 0,
130+
};
131+
132+
let mut buf = BytesMut::with_capacity(iv_len + self.buffer_size(data));
133+
if let Some(iv) = self.iv_opt.take() {
134+
buf.extend(iv);
135+
}
136+
126137
self.cipher_update(data, &mut buf)?;
127138

128139
self.steps = EncryptWriteStep::Writing(buf, 0);

src/relay/tcprelay/utils/split.rs

+22-14
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ pub struct ReadHalf<'a> {
1919
}
2020

2121
impl<'a> ReadHalf<'a> {
22-
fn stream(&self) -> &'a mut TcpStream {
22+
fn stream(&self) -> &'a TcpStream {
23+
unsafe { &mut *self.stream }
24+
}
25+
26+
fn stream_mut(&mut self) -> &'a mut TcpStream {
2327
unsafe { &mut *self.stream }
2428
}
2529
}
@@ -35,16 +39,20 @@ impl AsyncRead for ReadHalf<'_> {
3539
self.stream().prepare_uninitialized_buffer(buf)
3640
}
3741

38-
fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8]) -> task::Poll<io::Result<usize>> {
39-
Pin::new(self.stream()).poll_read(cx, buf)
42+
fn poll_read(
43+
mut self: Pin<&mut Self>,
44+
cx: &mut task::Context<'_>,
45+
buf: &mut [u8],
46+
) -> task::Poll<io::Result<usize>> {
47+
Pin::new(self.stream_mut()).poll_read(cx, buf)
4048
}
4149

4250
fn poll_read_buf<B: BufMut>(
43-
self: Pin<&mut Self>,
51+
mut self: Pin<&mut Self>,
4452
cx: &mut task::Context<'_>,
4553
buf: &mut B,
4654
) -> task::Poll<io::Result<usize>> {
47-
Pin::new(self.stream()).poll_read_buf(cx, buf)
55+
Pin::new(self.stream_mut()).poll_read_buf(cx, buf)
4856
}
4957
}
5058

@@ -61,7 +69,7 @@ pub struct WriteHalf<'a> {
6169
}
6270

6371
impl<'a> WriteHalf<'a> {
64-
fn stream(&self) -> &'a mut TcpStream {
72+
fn stream_mut(&mut self) -> &'a mut TcpStream {
6573
unsafe { &mut *self.stream }
6674
}
6775
}
@@ -74,27 +82,27 @@ impl AsRef<TcpStream> for WriteHalf<'_> {
7482

7583
impl AsyncWrite for WriteHalf<'_> {
7684
fn poll_write(
77-
self: Pin<&mut Self>,
85+
mut self: Pin<&mut Self>,
7886
cx: &mut task::Context<'_>,
7987
buf: &[u8],
8088
) -> task::Poll<Result<usize, io::Error>> {
81-
Pin::new(self.stream()).poll_write(cx, buf)
89+
Pin::new(self.stream_mut()).poll_write(cx, buf)
8290
}
8391

84-
fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
85-
Pin::new(self.stream()).poll_flush(cx)
92+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
93+
Pin::new(self.stream_mut()).poll_flush(cx)
8694
}
8795

88-
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
89-
Pin::new(self.stream()).poll_shutdown(cx)
96+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Result<(), io::Error>> {
97+
Pin::new(self.stream_mut()).poll_shutdown(cx)
9098
}
9199

92100
fn poll_write_buf<B: Buf>(
93-
self: Pin<&mut Self>,
101+
mut self: Pin<&mut Self>,
94102
cx: &mut task::Context<'_>,
95103
buf: &mut B,
96104
) -> task::Poll<Result<usize, io::Error>> {
97-
Pin::new(self.stream()).poll_write_buf(cx, buf)
105+
Pin::new(self.stream_mut()).poll_write_buf(cx, buf)
98106
}
99107
}
100108

0 commit comments

Comments
 (0)