Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions wstunnel/src/protocols/dns/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,32 @@ use tokio_rustls::rustls::client::EchConfig;

// Interleave v4 and v6 addresses as per RFC8305.
// The first address is v6 if we have any v6 addresses.
// Optimized to use a single-pass partition instead of multiple iterator chains.
#[inline]
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> impl Iterator<Item = &'_ SocketAddr> {
fn sort_socket_addrs(socket_addrs: &[SocketAddr], prefer_ipv6: bool) -> Vec<SocketAddr> {
let (v6_addrs, v4_addrs): (Vec<SocketAddr>, Vec<SocketAddr>) = socket_addrs.iter()
.partition(|s| matches!(s, SocketAddr::V6(_)));

let mut result = Vec::with_capacity(socket_addrs.len());
let mut v6_iter = v6_addrs.into_iter();
let mut v4_iter = v4_addrs.into_iter();
let mut pick_v6 = !prefer_ipv6;
let mut v6 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V6(_)));
let mut v4 = socket_addrs.iter().filter(|s| matches!(s, SocketAddr::V4(_)));
std::iter::from_fn(move || {

loop {
pick_v6 = !pick_v6;
if pick_v6 {
v6.next().or_else(|| v4.next())
let addr = if pick_v6 {
v6_iter.next().or_else(|| v4_iter.next())
} else {
v4.next().or_else(|| v6.next())
v4_iter.next().or_else(|| v6_iter.next())
};

match addr {
Some(addr) => result.push(addr),
None => break,
}
})
}

result
}

#[allow(clippy::large_enum_variant)] // System variant never used mostly
Expand All @@ -63,7 +76,7 @@ impl DnsResolver {
IpAddr::V6(ip) => SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)),
})
.collect();
sort_socket_addrs(&addrs, *prefer_ipv6).copied().collect()
sort_socket_addrs(&addrs, *prefer_ipv6)
}
};

Expand Down Expand Up @@ -332,7 +345,7 @@ mod tests {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 2), 1)),
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 3), 1)),
];
let actual: Vec<_> = sort_socket_addrs(&addrs, true).copied().collect();
assert_eq!(expected, *actual);
let actual = sort_socket_addrs(&addrs, true);
assert_eq!(expected.to_vec(), actual);
}
}
7 changes: 4 additions & 3 deletions wstunnel/src/protocols/udp/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,18 @@ impl UdpServer {

#[inline]
pub fn clean_dead_keys(&mut self) {
// Fast path: check if there are keys to delete without acquiring write lock
let nb_key_to_delete = self.keys_to_delete.read().len();
if nb_key_to_delete == 0 {
return;
}

debug!("Cleaning {} dead udp peers", nb_key_to_delete);
// Use drain to avoid separate iter + clear operations
let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() {
self.peers.remove(key);
for key in keys_to_delete.drain(..) {
self.peers.remove(&key);
}
keys_to_delete.clear();
}
pub fn clone_socket(&self) -> Arc<UdpSocket> {
self.listener.clone()
Expand Down
4 changes: 3 additions & 1 deletion wstunnel/src/tunnel/transport/http2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ impl TunnelWrite for Http2TunnelWrite {
}

fn pending_operations_notify(&mut self) -> Arc<Notify> {
Arc::new(Notify::new())
// HTTP2 doesn't use pending operations, so return a static Arc to avoid allocation
static DUMMY_NOTIFY: std::sync::LazyLock<Arc<Notify>> = std::sync::LazyLock::new(|| Arc::new(Notify::new()));
Arc::clone(&DUMMY_NOTIFY)
}

fn handle_pending_operations(&mut self) -> impl Future<Output = Result<(), io::Error>> + Send {
Expand Down
11 changes: 5 additions & 6 deletions wstunnel/src/tunnel/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
const _32_MB: usize = 32 * 1024 * 1024;
buf.clear();
if buf.capacity() == read_len && buf.capacity() < _32_MB {
let new_size = buf.capacity() + (buf.capacity() / 4); // grow buffer by 1.25 %
buf.reserve(new_size);
// Grow buffer by 25% (capacity / 4 additional bytes)
let additional_capacity = buf.capacity() / 4;
buf.reserve(additional_capacity);
trace!(
"Buffer {} Mb {} {} {}",
"Buffer grown to {} Mb (added {} bytes)",
buf.capacity() as f64 / 1024.0 / 1024.0,
new_size,
buf.len(),
buf.capacity()
additional_capacity
)
}

Expand Down