diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 2d43e2bf..e33fc095 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -25,7 +25,7 @@ use windows_sys::Win32::Networking::WinSock::{ self, tcp_keepalive, FIONBIO, IN6_ADDR, IN6_ADDR_0, INVALID_SOCKET, IN_ADDR, IN_ADDR_0, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND, SIO_KEEPALIVE_VALS, SOCKET_ERROR, WSABUF, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW, - WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED, + WSA_FLAG_NO_HANDLE_INHERIT, WSA_FLAG_OVERLAPPED, WSA_FLAG_REGISTERED_IO, }; #[cfg(feature = "all")] use windows_sys::Win32::Networking::WinSock::{ @@ -125,6 +125,8 @@ impl Type { /// Our custom flag to set `WSA_FLAG_NO_HANDLE_INHERIT` on socket creation. /// Trying to mimic `Type::cloexec` on windows. const NO_INHERIT: c_int = 1 << ((size_of::() * 8) - 1); // Last bit. + /// Our custom flag to set `WSA_FLAG_REGISTERED_IO` on socket creation. + const REGISTERED_IO: c_int = 1 << ((size_of::() * 8) - 2); // Second last bit. /// Set `WSA_FLAG_NO_HANDLE_INHERIT` on the socket. #[cfg(feature = "all")] @@ -132,6 +134,12 @@ impl Type { self._no_inherit() } + /// Set `WSA_FLAG_REGISTERED_IO` on the socket. + #[cfg(feature = "all")] + pub const fn registered_io(self) -> Type { + Type(self.0 | Type::REGISTERED_IO) + } + pub(crate) const fn _no_inherit(self) -> Type { Type(self.0 | Type::NO_INHERIT) } @@ -252,13 +260,19 @@ pub(crate) fn socket_into_raw(socket: Socket) -> RawSocket { pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Result { init(); - // Check if we set our custom flag. + // Check if we set our custom flags. let flags = if ty & Type::NO_INHERIT != 0 { ty = ty & !Type::NO_INHERIT; WSA_FLAG_NO_HANDLE_INHERIT } else { 0 }; + let flags = if ty & Type::REGISTERED_IO != 0 { + ty = ty & !Type::REGISTERED_IO; + flags | WSA_FLAG_REGISTERED_IO + } else { + flags + }; syscall!( WSASocketW( diff --git a/tests/socket.rs b/tests/socket.rs index cc98b8a9..fe69a406 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -418,6 +418,48 @@ where ); } +#[cfg(all(feature = "all", windows))] +#[test] +fn type_registered_io() { + let ty = Type::DGRAM.registered_io(); + let socket = Socket::new(Domain::IPV4, ty, None).unwrap(); + assert_registered_io(&socket); +} + +/// Assert that registered I/O is enabled on `socket`. +#[cfg(windows)] +#[track_caller] +pub fn assert_registered_io(socket: &S) +where + S: AsRawSocket, +{ + use std::ptr; + use windows_sys::core::GUID; + use windows_sys::Win32::Networking::WinSock; + + let mut table = MaybeUninit::::uninit(); + let guid = WinSock::WSAID_MULTIPLE_RIO; + let mut bytes = 0; + + let r = unsafe { + WinSock::WSAIoctl( + socket.as_raw_socket() as _, + WinSock::SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, + (&guid as *const GUID) as *const _, + size_of_val(&guid) as u32, + table.as_mut_ptr() as *mut _, + size_of_val(&table) as u32, + (&mut bytes as *mut i32) as *mut _, + ptr::null_mut(), + None, + ) + }; + if r != 0 { + let err = io::Error::last_os_error(); + panic!("unexpected error: {err}"); + } +} + #[cfg(all( feature = "all", any(