Skip to content

Commit d6a3d1c

Browse files
committed
Initial basic cmsg support for unix
Fixes #313
1 parent 7234537 commit d6a3d1c

File tree

5 files changed

+522
-9
lines changed

5 files changed

+522
-9
lines changed

src/cmsg.rs

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
use std::convert::TryInto as _;
2+
use std::io::IoSlice;
3+
4+
#[derive(Debug, Clone)]
5+
struct MsgHdrWalker<B> {
6+
buffer: B,
7+
position: Option<usize>,
8+
}
9+
10+
impl<B: AsRef<[u8]>> MsgHdrWalker<B> {
11+
fn next_ptr(&mut self) -> Option<*const libc::cmsghdr> {
12+
// Build a msghdr so we can use the functionality in libc.
13+
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
14+
let buffer = self.buffer.as_ref();
15+
// SAFETY: We're giving msghdr a mutable pointer to comply with the C
16+
// API. We'll only allow mutation of `cmsghdr`, however if `B` is
17+
// AsMut<[u8]>.
18+
msghdr.msg_control = buffer.as_ptr() as *mut _;
19+
msghdr.msg_controllen = buffer.len().try_into().expect("buffer is too long");
20+
21+
let nxt_hdr = if let Some(position) = self.position {
22+
if position >= buffer.len() {
23+
return None;
24+
}
25+
let cur_hdr = &buffer[position] as *const u8 as *const _;
26+
// Safety: msghdr is a valid pointer and cur_hdr is not null.
27+
unsafe { libc::CMSG_NXTHDR(&msghdr, cur_hdr) }
28+
} else {
29+
// Safety: msghdr is a valid pointer.
30+
unsafe { libc::CMSG_FIRSTHDR(&msghdr) }
31+
};
32+
33+
if nxt_hdr.is_null() {
34+
self.position = Some(buffer.len());
35+
return None;
36+
}
37+
38+
// SAFETY: nxt_hdr always points to data within the buffer, they must be
39+
// part of the same allocation.
40+
let distance = unsafe { (nxt_hdr as *const u8).offset_from(buffer.as_ptr()) };
41+
// nxt_hdr is always ahead of the buffer and not null if we're here,
42+
// meaning the distance is always positive.
43+
self.position = Some(distance.try_into().unwrap());
44+
Some(nxt_hdr)
45+
}
46+
47+
fn next(&mut self) -> Option<(&libc::cmsghdr, &[u8])> {
48+
self.next_ptr().map(|cmsghdr| {
49+
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
50+
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
51+
let cmsghdr = unsafe { &*cmsghdr };
52+
// SAFETY: data points to buffer and is controlled by control
53+
// message length.
54+
let data = unsafe {
55+
std::slice::from_raw_parts(
56+
data,
57+
(cmsghdr.cmsg_len as usize)
58+
.saturating_sub(std::mem::size_of::<libc::cmsghdr>()),
59+
)
60+
};
61+
(cmsghdr, data)
62+
})
63+
}
64+
}
65+
66+
impl<B: AsRef<[u8]> + AsMut<[u8]>> MsgHdrWalker<B> {
67+
fn next_mut(&mut self) -> Option<(&mut libc::cmsghdr, &mut [u8])> {
68+
match self.next_ptr() {
69+
Some(cmsghdr) => {
70+
// SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`.
71+
let data = unsafe { libc::CMSG_DATA(cmsghdr) };
72+
// SAFETY: The mutable pointer is safe because we're not going to
73+
// vend any concurrent access to the same memory region and B is
74+
// AsMut<[u8]> guaranteeing we have exclusive access to the buffer.
75+
let cmsghdr = cmsghdr as *mut libc::cmsghdr;
76+
let cmsghdr = unsafe { &mut *cmsghdr };
77+
78+
// We'll always yield the entirety of the rest of the buffer.
79+
let distance = unsafe { data.offset_from(self.buffer.as_ref().as_ptr()) };
80+
// The data pointer is always part of the buffer, can't be before
81+
// it.
82+
let distance: usize = distance.try_into().unwrap();
83+
Some((cmsghdr, &mut self.buffer.as_mut()[distance..]))
84+
}
85+
None => None,
86+
}
87+
}
88+
}
89+
90+
/// A wrapper around a buffer that can be used to write ancillary control
91+
/// messages.
92+
#[derive(Debug)]
93+
pub struct CmsgWriter<'a> {
94+
walker: MsgHdrWalker<&'a mut [u8]>,
95+
last_push: usize,
96+
}
97+
98+
impl<'a> CmsgWriter<'a> {
99+
/// Creates a new [`CmsgBuffer`] backed by the bytes in `buffer`.
100+
pub fn new(buffer: &'a mut [u8]) -> Self {
101+
Self {
102+
walker: MsgHdrWalker {
103+
buffer,
104+
position: None,
105+
},
106+
last_push: 0,
107+
}
108+
}
109+
110+
/// Pushes a new control message `m` to the buffer.
111+
///
112+
/// # Panics
113+
///
114+
/// Panics if the contained buffer does not have enough space to fit `m`.
115+
pub fn push(&mut self, m: &Cmsg) {
116+
let (cmsg_level, cmsg_type, size) = m.level_type_size();
117+
let (nxt_hdr, data) = self
118+
.walker
119+
.next_mut()
120+
.unwrap_or_else(|| panic!("can't fit message {:?}", m));
121+
// Safety: All values are passed by copy.
122+
let cmsg_len = unsafe { libc::CMSG_LEN(size) }.try_into().unwrap();
123+
nxt_hdr.cmsg_len = cmsg_len;
124+
nxt_hdr.cmsg_level = cmsg_level;
125+
nxt_hdr.cmsg_type = cmsg_type;
126+
m.write(&mut data[..size as usize]);
127+
// Always store the space required for the last push because the walker
128+
// maintains its position cursor at the currently written option, we
129+
// must always add the space for the last control message when returning
130+
// the consolidated buffer.
131+
self.last_push = unsafe { libc::CMSG_SPACE(size) } as usize;
132+
}
133+
134+
pub(crate) fn io_slice(&self) -> IoSlice<'_> {
135+
IoSlice::new(self.buffer())
136+
}
137+
138+
pub(crate) fn buffer(&self) -> &[u8] {
139+
if let Some(position) = self.walker.position {
140+
&self.walker.buffer.as_ref()[..position + self.last_push]
141+
} else {
142+
&[]
143+
}
144+
}
145+
}
146+
147+
impl<'a, C: std::borrow::Borrow<Cmsg>> Extend<C> for CmsgWriter<'a> {
148+
fn extend<T: IntoIterator<Item = C>>(&mut self, iter: T) {
149+
for cmsg in iter {
150+
self.push(cmsg.borrow())
151+
}
152+
}
153+
}
154+
155+
/// An iterator over received control messages.
156+
#[derive(Debug, Clone)]
157+
pub struct CmsgIter<'a> {
158+
walker: MsgHdrWalker<&'a [u8]>,
159+
}
160+
161+
impl<'a> CmsgIter<'a> {
162+
pub(crate) fn new(buffer: &'a [u8]) -> Self {
163+
Self {
164+
walker: MsgHdrWalker {
165+
buffer,
166+
position: None,
167+
},
168+
}
169+
}
170+
}
171+
172+
impl<'a> Iterator for CmsgIter<'a> {
173+
type Item = Cmsg;
174+
175+
fn next(&mut self) -> Option<Self::Item> {
176+
self.walker.next().map(
177+
|(
178+
libc::cmsghdr {
179+
cmsg_len: _,
180+
cmsg_level,
181+
cmsg_type,
182+
..
183+
},
184+
data,
185+
)| Cmsg::from_raw(*cmsg_level, *cmsg_type, data),
186+
)
187+
}
188+
}
189+
190+
/// An unknown control message.
191+
#[derive(Debug, Eq, PartialEq)]
192+
pub struct UnknownCmsg {
193+
cmsg_level: libc::c_int,
194+
cmsg_type: libc::c_int,
195+
}
196+
197+
/// Control messages.
198+
#[derive(Debug, Eq, PartialEq)]
199+
pub enum Cmsg {
200+
/// The `IP_TOS` control message.
201+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
202+
IpTos(u8),
203+
/// The `IPV6_PKTINFO` control message.
204+
#[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))]
205+
Ipv6PktInfo {
206+
/// The address the packet is destined to/received from. Equivalent to
207+
/// `in6_pktinfo.ipi6_addr`.
208+
addr: std::net::Ipv6Addr,
209+
/// The interface index the packet is destined to/received from.
210+
/// Equivalent to `in6_pktinfo.ipi6_ifindex`.
211+
ifindex: u32,
212+
},
213+
/// An unrecognized control message.
214+
Unknown(UnknownCmsg),
215+
}
216+
217+
impl Cmsg {
218+
/// Returns the amount of buffer space required to hold this option.
219+
pub fn space(&self) -> usize {
220+
let (_, _, size) = self.level_type_size();
221+
// Safety: All values are passed by copy.
222+
let size = unsafe { libc::CMSG_SPACE(size) };
223+
size as usize
224+
}
225+
226+
fn level_type_size(&self) -> (libc::c_int, libc::c_int, libc::c_uint) {
227+
match self {
228+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
229+
Cmsg::IpTos(_) => (
230+
libc::IPPROTO_IP,
231+
libc::IP_TOS,
232+
std::mem::size_of::<u8>() as libc::c_uint,
233+
),
234+
#[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))]
235+
Cmsg::Ipv6PktInfo { .. } => (
236+
libc::IPPROTO_IPV6,
237+
libc::IPV6_PKTINFO,
238+
std::mem::size_of::<libc::in6_pktinfo>() as libc::c_uint,
239+
),
240+
Cmsg::Unknown(UnknownCmsg {
241+
cmsg_level,
242+
cmsg_type,
243+
}) => (*cmsg_level, *cmsg_type, 0),
244+
}
245+
}
246+
247+
fn write(&self, buffer: &mut [u8]) {
248+
match self {
249+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
250+
Cmsg::IpTos(tos) => {
251+
buffer[0] = *tos;
252+
}
253+
#[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))]
254+
Cmsg::Ipv6PktInfo { addr, ifindex } => {
255+
let pktinfo = libc::in6_pktinfo {
256+
ipi6_addr: crate::sys::to_in6_addr(addr),
257+
ipi6_ifindex: *ifindex as _,
258+
};
259+
let size = std::mem::size_of::<libc::in6_pktinfo>();
260+
assert_eq!(buffer.len(), size);
261+
// Safety: `pktinfo` is valid for reads for its size in bytes.
262+
// `buffer` is valid for write for the same length, as
263+
// guaranteed by the assertion above. Copy unit is byte, so
264+
// alignment is okay. The two regions do not overlap.
265+
unsafe {
266+
std::ptr::copy_nonoverlapping(
267+
&pktinfo as *const libc::in6_pktinfo as *const _,
268+
buffer.as_mut_ptr(),
269+
size,
270+
)
271+
}
272+
}
273+
Cmsg::Unknown(_) => {
274+
// NOTE: We don't actually allow users of the public API
275+
// serialize unknown control messages, but we use this code path
276+
// for testing.
277+
debug_assert_eq!(buffer.len(), 0);
278+
}
279+
}
280+
}
281+
282+
fn from_raw(cmsg_level: libc::c_int, cmsg_type: libc::c_int, bytes: &[u8]) -> Self {
283+
match (cmsg_level, cmsg_type) {
284+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
285+
(libc::IPPROTO_IP, libc::IP_TOS) => {
286+
assert_eq!(bytes.len(), std::mem::size_of::<u8>(), "{:?}", bytes);
287+
Cmsg::IpTos(bytes[0])
288+
}
289+
#[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))]
290+
(libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => {
291+
let mut pktinfo = unsafe { std::mem::zeroed::<libc::in6_pktinfo>() };
292+
let size = std::mem::size_of::<libc::in6_pktinfo>();
293+
assert!(bytes.len() >= size, "{:?}", bytes);
294+
// Safety: `pktinfo` is valid for writes for its size in bytes.
295+
// `buffer` is valid for read for the same length, as
296+
// guaranteed by the assertion above. Copy unit is byte, so
297+
// alignment is okay. The two regions do not overlap.
298+
unsafe {
299+
std::ptr::copy_nonoverlapping(
300+
bytes.as_ptr(),
301+
&mut pktinfo as *mut libc::in6_pktinfo as *mut _,
302+
size,
303+
)
304+
}
305+
Cmsg::Ipv6PktInfo {
306+
addr: crate::sys::from_in6_addr(pktinfo.ipi6_addr),
307+
ifindex: pktinfo.ipi6_ifindex as _,
308+
}
309+
}
310+
(cmsg_level, cmsg_type) => {
311+
let _ = bytes;
312+
Cmsg::Unknown(UnknownCmsg {
313+
cmsg_level,
314+
cmsg_type,
315+
})
316+
}
317+
}
318+
}
319+
}
320+
321+
#[cfg(test)]
322+
mod tests {
323+
use super::*;
324+
325+
#[test]
326+
fn ser_deser() {
327+
let cmsgs = [
328+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
329+
Cmsg::IpTos(2),
330+
#[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))]
331+
Cmsg::Ipv6PktInfo {
332+
addr: std::net::Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8),
333+
ifindex: 13,
334+
},
335+
Cmsg::Unknown(UnknownCmsg {
336+
cmsg_level: 12345678,
337+
cmsg_type: 87654321,
338+
}),
339+
];
340+
let mut buffer = [0u8; 256];
341+
let mut writer = CmsgWriter::new(&mut buffer[..]);
342+
writer.extend(cmsgs.iter());
343+
let deser = CmsgIter::new(writer.buffer()).collect::<Vec<_>>();
344+
assert_eq!(&cmsgs[..], &deser[..]);
345+
}
346+
347+
#[test]
348+
#[should_panic]
349+
#[cfg(not(any(target_os = "solaris", target_os = "illumos")))]
350+
fn ser_insufficient_space_panics() {
351+
let mut buffer = CmsgWriter::new(&mut []);
352+
buffer.push(&Cmsg::IpTos(2));
353+
}
354+
355+
#[test]
356+
fn empty_deser() {
357+
assert_eq!(CmsgIter::new(&[]).next(), None);
358+
}
359+
}

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ macro_rules! from {
115115
};
116116
}
117117

118+
#[cfg(all(unix, not(target_os = "redox")))]
119+
mod cmsg;
118120
mod sockaddr;
119121
mod socket;
120122
mod sockref;
@@ -141,6 +143,9 @@ pub use sockref::SockRef;
141143
)))]
142144
pub use socket::InterfaceIndexOrAddress;
143145

146+
#[cfg(all(unix, not(target_os = "redox")))]
147+
pub use cmsg::{Cmsg, CmsgIter, CmsgWriter};
148+
144149
/// Specification of the communication domain for a socket.
145150
///
146151
/// This is a newtype wrapper around an integer which provides a nicer API in

0 commit comments

Comments
 (0)