Skip to content

Commit ac23d7d

Browse files
BobAnkhduskmoon314
andauthored
Add support for TCP_CONGESTION socketopt
Co-authored-by: Campbell He <[email protected]>
1 parent ed23383 commit ac23d7d

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/sys/unix.rs

+53
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ const MAX_BUF_LEN: usize = ssize_t::MAX as usize;
204204
#[cfg(target_vendor = "apple")]
205205
const MAX_BUF_LEN: usize = c_int::MAX as usize - 1;
206206

207+
// TCP_CA_NAME_MAX isn't defined in user space include files(not in libc)
208+
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
209+
const TCP_CA_NAME_MAX: usize = 16;
210+
207211
#[cfg(any(
208212
all(
209213
target_os = "linux",
@@ -2154,6 +2158,55 @@ impl crate::Socket {
21542158
)
21552159
}
21562160
}
2161+
2162+
/// Get the value of the `TCP_CONGESTION` option for this socket.
2163+
///
2164+
/// For more information about this option, see [`set_tcp_congestion`].
2165+
///
2166+
/// [`set_tcp_congestion`]: Socket::set_tcp_congestion
2167+
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
2168+
#[cfg_attr(
2169+
docsrs,
2170+
doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux"))))
2171+
)]
2172+
pub fn tcp_congestion(&self) -> io::Result<Vec<u8>> {
2173+
let mut payload: [u8; TCP_CA_NAME_MAX] = [0; TCP_CA_NAME_MAX];
2174+
let mut len = payload.len() as libc::socklen_t;
2175+
syscall!(getsockopt(
2176+
self.as_raw(),
2177+
IPPROTO_TCP,
2178+
libc::TCP_CONGESTION,
2179+
payload.as_mut_ptr().cast(),
2180+
&mut len,
2181+
))
2182+
.map(|_| {
2183+
let buf = &payload[..len as usize];
2184+
// TODO: use `MaybeUninit::slice_assume_init_ref` once stable.
2185+
unsafe { &*(buf as *const [_] as *const [u8]) }.into()
2186+
})
2187+
}
2188+
2189+
/// Set the value of the `TCP_CONGESTION` option for this socket.
2190+
///
2191+
/// Specifies the TCP congestion control algorithm to use for this socket.
2192+
///
2193+
/// The value must be a valid TCP congestion control algorithm name of the
2194+
/// platform. For example, Linux may supports "reno", "cubic".
2195+
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
2196+
#[cfg_attr(
2197+
docsrs,
2198+
doc(cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux"))))
2199+
)]
2200+
pub fn set_tcp_congestion(&self, tcp_ca_name: &[u8]) -> io::Result<()> {
2201+
syscall!(setsockopt(
2202+
self.as_raw(),
2203+
IPPROTO_TCP,
2204+
libc::TCP_CONGESTION,
2205+
tcp_ca_name.as_ptr() as *const _,
2206+
tcp_ca_name.len() as libc::socklen_t,
2207+
))
2208+
.map(|_| ())
2209+
}
21572210
}
21582211

21592212
#[cfg_attr(docsrs, doc(cfg(unix)))]

tests/socket.rs

+46
Original file line numberDiff line numberDiff line change
@@ -1341,3 +1341,49 @@ fn original_dst_ipv6() {
13411341
Err(err) => assert_eq!(err.raw_os_error(), Some(libc::EOPNOTSUPP)),
13421342
}
13431343
}
1344+
1345+
#[test]
1346+
#[cfg(all(feature = "all", any(target_os = "freebsd", target_os = "linux")))]
1347+
fn tcp_congestion() {
1348+
let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap();
1349+
// Get and set current tcp_ca
1350+
let origin_tcp_ca = socket
1351+
.tcp_congestion()
1352+
.expect("failed to get tcp congestion algorithm");
1353+
socket
1354+
.set_tcp_congestion(&origin_tcp_ca)
1355+
.expect("failed to set tcp congestion algorithm");
1356+
// Return a Err when set a non-exist tcp_ca
1357+
socket
1358+
.set_tcp_congestion(b"tcp_congestion_does_not_exist")
1359+
.unwrap_err();
1360+
let cur_tcp_ca = socket.tcp_congestion().unwrap();
1361+
assert_eq!(
1362+
cur_tcp_ca, origin_tcp_ca,
1363+
"expected {origin_tcp_ca:?} but get {cur_tcp_ca:?}"
1364+
);
1365+
let cur_tcp_ca = cur_tcp_ca.splitn(2, |num| *num == 0).next().unwrap();
1366+
const OPTIONS: [&[u8]; 2] = [
1367+
b"cubic",
1368+
#[cfg(target_os = "linux")] // or Android.
1369+
b"reno",
1370+
#[cfg(target_os = "freebsd")]
1371+
b"newreno",
1372+
];
1373+
// Set a new tcp ca
1374+
#[cfg(target_os = "linux")]
1375+
let new_tcp_ca = if cur_tcp_ca == OPTIONS[0] {
1376+
OPTIONS[1]
1377+
} else {
1378+
OPTIONS[0]
1379+
};
1380+
#[cfg(target_os = "freebsd")]
1381+
let new_tcp_ca = OPTIONS[1];
1382+
socket.set_tcp_congestion(new_tcp_ca).unwrap();
1383+
// Check if new tcp ca is successfully set
1384+
let cur_tcp_ca = socket.tcp_congestion().unwrap();
1385+
assert_eq!(
1386+
cur_tcp_ca.splitn(2, |num| *num == 0).next().unwrap(),
1387+
new_tcp_ca,
1388+
);
1389+
}

0 commit comments

Comments
 (0)