diff --git a/Cargo.toml b/Cargo.toml index d0c1b16aae..68c073d86d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ name = "measure_startup_time" harness = false [features] -default = ["pci", "pci-ids", "acpi", "fsgsbase", "smp", "tcp", "dhcpv4", "fuse"] +default = ["pci", "pci-ids", "acpi", "fsgsbase", "smp", "tcp", "dhcpv4", "fuse", "vsock"] acpi = [] dhcpv4 = [ "smoltcp", @@ -54,6 +54,7 @@ dhcpv4 = [ ] fs = ["fuse"] fuse = ["pci", "dep:fuse-abi", "fuse-abi/num_enum"] +vsock = ["pci"] fsgsbase = [] gem-net = ["tcp", "dep:tock-registers"] newlib = [] diff --git a/src/arch/aarch64/mm/paging.rs b/src/arch/aarch64/mm/paging.rs index 5e03396a35..e9598eca46 100644 --- a/src/arch/aarch64/mm/paging.rs +++ b/src/arch/aarch64/mm/paging.rs @@ -578,7 +578,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { get_physical_address::(virtual_address) } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/arch/riscv64/mm/paging.rs b/src/arch/riscv64/mm/paging.rs index f44c3256cb..3aaf7c0008 100644 --- a/src/arch/riscv64/mm/paging.rs +++ b/src/arch/riscv64/mm/paging.rs @@ -584,7 +584,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { panic!("virtual_to_physical should never reach this point"); } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/arch/x86_64/kernel/interrupts.rs b/src/arch/x86_64/kernel/interrupts.rs index c5709dbd3b..6a3c854755 100644 --- a/src/arch/x86_64/kernel/interrupts.rs +++ b/src/arch/x86_64/kernel/interrupts.rs @@ -9,7 +9,7 @@ use hermit_sync::{InterruptSpinMutex, InterruptTicketMutex}; use x86_64::instructions::interrupts::enable_and_hlt; pub use x86_64::instructions::interrupts::{disable, enable}; use x86_64::set_general_handler; -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp", feature = "vsock"))] use x86_64::structures::idt; use x86_64::structures::idt::InterruptDescriptorTable; pub use x86_64::structures::idt::InterruptStackFrame as ExceptionStackFrame; @@ -155,7 +155,7 @@ pub(crate) fn install() { IRQ_NAMES.lock().insert(7, "FPU"); } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp", feature = "vsock"))] pub fn irq_install_handler(irq_number: u8, handler: idt::HandlerFunc) { debug!("Install handler for interrupt {}", irq_number); diff --git a/src/arch/x86_64/mm/paging.rs b/src/arch/x86_64/mm/paging.rs index bc7cd99f57..b55ba41464 100644 --- a/src/arch/x86_64/mm/paging.rs +++ b/src/arch/x86_64/mm/paging.rs @@ -118,7 +118,7 @@ pub fn virtual_to_physical(virtual_address: VirtAddr) -> Option { } } -#[cfg(any(feature = "fuse", feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "fuse", feature = "vsock", feature = "tcp", feature = "udp"))] pub fn virt_to_phys(virtual_address: VirtAddr) -> PhysAddr { virtual_to_physical(virtual_address).unwrap() } diff --git a/src/config.rs b/src/config.rs index 8b32b25ca8..805acb44b6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,5 +12,8 @@ pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 2048; pub(crate) const VIRTIO_MAX_QUEUE_SIZE: u16 = 1024; /// Default keep alive interval in milliseconds -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(feature = "tcp")] pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000; + +#[cfg(feature = "vsock")] +pub(crate) const VSOCK_PACKET_SIZE: u32 = 8192; diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 9d31c91c5c..580171c434 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -10,9 +10,12 @@ pub mod net; pub mod pci; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] pub mod virtio; +#[cfg(feature = "vsock")] +pub mod vsock; /// A common error module for drivers. /// [DriverError](error::DriverError) values will be @@ -26,7 +29,8 @@ pub mod error { use crate::drivers::net::rtl8139::RTL8139Error; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::error::VirtioError; @@ -34,7 +38,8 @@ pub mod error { pub enum DriverError { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] InitVirtioDevFail(VirtioError), #[cfg(feature = "rtl8139")] @@ -45,7 +50,8 @@ pub mod error { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] impl From for DriverError { fn from(err: VirtioError) -> Self { @@ -73,7 +79,8 @@ pub mod error { match *self { #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] DriverError::InitVirtioDevFail(ref err) => { write!(f, "Virtio driver failed: {err:?}") diff --git a/src/drivers/net/gem.rs b/src/drivers/net/gem.rs index 89bb76e362..f2060230e4 100644 --- a/src/drivers/net/gem.rs +++ b/src/drivers/net/gem.rs @@ -16,10 +16,14 @@ use tock_registers::{register_bitfields, register_structs}; use crate::arch::kernel::core_local::core_scheduler; use crate::arch::kernel::interrupts::*; +#[cfg(all(any(feature = "tcp", feature = "udp"), not(feature = "pci")))] +use crate::arch::kernel::mmio as hardware; use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; use crate::drivers::error::DriverError; -use crate::drivers::net::{network_irqhandler, NetworkDriver}; +use crate::drivers::net::NetworkDriver; +#[cfg(all(any(feature = "tcp", feature = "udp"), feature = "pci"))] +use crate::drivers::pci as hardware; use crate::executor::device::{RxToken, TxToken}; //Base address of the control registers @@ -197,6 +201,22 @@ pub enum GEMError { Unknown, } +fn gem_irqhandler() { + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive network interrupt"); + + // PLIC end of interrupt + crate::arch::kernel::interrupts::external_eoi(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + crate::executor::run(); + + core_scheduler().reschedule(); +} + /// GEM network driver struct. /// /// Struct allows to control device queus as also @@ -349,7 +369,7 @@ impl NetworkDriver for GEMDriver { } } - fn handle_interrupt(&mut self) -> bool { + fn handle_interrupt(&mut self) { let int_status = unsafe { (*self.gem).int_status.extract() }; let receive_status = unsafe { (*self.gem).receive_status.extract() }; @@ -393,8 +413,8 @@ impl NetworkDriver for GEMDriver { // handle incoming packets todo!(); } - // increment_irq_counter((32 + self.irq).into()); - ret + + //increment_irq_counter((32 + self.irq).into()); } } @@ -674,9 +694,9 @@ pub fn init_device( // Configure Interrupts debug!( "Install interrupt handler for GEM at {:x}", - network_irqhandler as usize + gem_irqhandler as usize ); - irq_install_handler(irq, network_irqhandler); + irq_install_handler(irq, gem_irqhandler); (*gem).int_enable.write(Interrupts::FRAMERX::SET); // + Interrupts::TXCOMPL::SET // Enable the Controller (again?) diff --git a/src/drivers/net/mod.rs b/src/drivers/net/mod.rs index 810db6a447..d4f72652a0 100644 --- a/src/drivers/net/mod.rs +++ b/src/drivers/net/mod.rs @@ -7,18 +7,8 @@ pub mod virtio; use smoltcp::phy::ChecksumCapabilities; -#[cfg(target_arch = "x86_64")] -use crate::arch::kernel::apic; #[allow(unused_imports)] use crate::arch::kernel::core_local::*; -#[cfg(target_arch = "x86_64")] -use crate::arch::kernel::interrupts::ExceptionStackFrame; -#[cfg(not(feature = "pci"))] -use crate::arch::kernel::mmio as hardware; -#[cfg(target_arch = "aarch64")] -use crate::arch::scheduler::State; -#[cfg(feature = "pci")] -use crate::drivers::pci as hardware; use crate::executor::device::{RxToken, TxToken}; /// A trait for accessing the network interface @@ -43,52 +33,5 @@ pub(crate) trait NetworkDriver { /// Enable / disable the polling mode of the network interface fn set_polling_mode(&mut self, value: bool); /// Handle interrupt and check if a packet is available - fn handle_interrupt(&mut self) -> bool; -} - -#[inline] -fn _irqhandler() -> bool { - let result = if let Some(driver) = hardware::get_network_driver() { - driver.lock().handle_interrupt() - } else { - debug!("Unable to handle interrupt!"); - false - }; - - // TODO: do we need it? - crate::executor::run(); - - result -} - -#[cfg(target_arch = "aarch64")] -pub(crate) fn network_irqhandler(_state: &State) -> bool { - debug!("Receive network interrupt"); - _irqhandler() -} - -#[cfg(target_arch = "x86_64")] -pub(crate) extern "x86-interrupt" fn network_irqhandler(stack_frame: ExceptionStackFrame) { - crate::arch::x86_64::swapgs(&stack_frame); - use crate::scheduler::PerCoreSchedulerExt; - - debug!("Receive network interrupt"); - apic::eoi(); - let _ = _irqhandler(); - - core_scheduler().reschedule(); - crate::arch::x86_64::swapgs(&stack_frame); -} - -#[cfg(target_arch = "riscv64")] -pub fn network_irqhandler() { - use crate::scheduler::PerCoreSchedulerExt; - - debug!("Receive network interrupt"); - - // PLIC end of interrupt - crate::arch::kernel::interrupts::external_eoi(); - let _ = _irqhandler(); - - core_scheduler().reschedule(); + fn handle_interrupt(&mut self); } diff --git a/src/drivers/net/rtl8139.rs b/src/drivers/net/rtl8139.rs index cc0dffede0..dddedce8e8 100644 --- a/src/drivers/net/rtl8139.rs +++ b/src/drivers/net/rtl8139.rs @@ -9,12 +9,15 @@ use pci_types::{Bar, CommandRegister, InterruptLine, MAX_BARS}; use x86::io::*; use crate::arch::kernel::core_local::increment_irq_counter; +#[cfg(target_arch = "x86_64")] +use crate::arch::kernel::interrupts::ExceptionStackFrame; use crate::arch::kernel::interrupts::*; use crate::arch::mm::paging::virt_to_phys; use crate::arch::mm::VirtAddr; use crate::arch::pci::PciConfigRegion; use crate::drivers::error::DriverError; -use crate::drivers::net::{network_irqhandler, NetworkDriver}; +use crate::drivers::net::NetworkDriver; +use crate::drivers::pci as hardware; use crate::drivers::pci::PciDevice; use crate::executor::device::{RxToken, TxToken}; @@ -317,7 +320,7 @@ impl NetworkDriver for RTL8139Driver { } } - fn handle_interrupt(&mut self) -> bool { + fn handle_interrupt(&mut self) { increment_irq_counter(32 + self.irq); let isr_contents = unsafe { inw(self.iobase + ISR) }; @@ -338,16 +341,12 @@ impl NetworkDriver for RTL8139Driver { trace!("RTL88139: RX overflow detected!\n"); } - let ret = (isr_contents & ISR_ROK) == ISR_ROK; - unsafe { outw( self.iobase + ISR, isr_contents & (ISR_RXOVW | ISR_TER | ISR_RER | ISR_TOK | ISR_ROK), ); } - - ret } } @@ -419,6 +418,25 @@ impl Drop for RTL8139Driver { } } +extern "x86-interrupt" fn rtl8139_irqhandler(stack_frame: ExceptionStackFrame) { + crate::arch::x86_64::swapgs(&stack_frame); + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive network interrupt"); + crate::arch::x86_64::kernel::apic::eoi(); + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } else { + debug!("Unable to handle interrupt!"); + } + + crate::executor::run(); + + core_scheduler().reschedule(); + crate::arch::x86_64::swapgs(&stack_frame); +} + pub(crate) fn init_device( device: &PciDevice, ) -> Result { @@ -573,7 +591,7 @@ pub(crate) fn init_device( // Install interrupt handler for RTL8139 debug!("Install interrupt handler for RTL8139 at {}", irq); - irq_install_handler(irq, network_irqhandler); + irq_install_handler(irq, rtl8139_irqhandler); add_irq_name(irq, "rtl8139_net"); Ok(RTL8139Driver { diff --git a/src/drivers/net/virtio/mmio.rs b/src/drivers/net/virtio/mmio.rs index 8e86c2b796..ff73ef3b8e 100644 --- a/src/drivers/net/virtio/mmio.rs +++ b/src/drivers/net/virtio/mmio.rs @@ -20,7 +20,6 @@ impl VirtioNetDriver { pub fn new( dev_id: u16, mut registers: VolatileRef<'static, DeviceRegisters>, - irq: u8, ) -> Result { let dev_cfg_raw: &'static virtio::net::Config = unsafe { &*registers @@ -58,7 +57,6 @@ impl VirtioNetDriver { recv_vqs, send_vqs, num_vqs: 0, - irq, mtu, checksums: ChecksumCapabilities::default(), }) @@ -79,9 +77,8 @@ impl VirtioNetDriver { pub fn init( dev_id: u16, registers: VolatileRef<'static, DeviceRegisters>, - irq_no: u8, ) -> Result { - if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers, irq_no) { + if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers) { match drv.init_dev() { Err(error_code) => Err(VirtioError::NetDriver(error_code)), _ => { diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 83cdf07fbb..ca88d50681 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -14,7 +14,6 @@ use alloc::boxed::Box; use alloc::vec::Vec; use core::mem::MaybeUninit; -use pci_types::InterruptLine; use smoltcp::phy::{Checksum, ChecksumCapabilities}; use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN}; use virtio::net::{ConfigVolatileFieldAccess, Hdr, HdrF}; @@ -24,8 +23,6 @@ use volatile::VolatileRef; use self::constants::MAX_NUM_VQ; use self::error::VirtioNetError; -#[cfg(not(target_arch = "riscv64"))] -use crate::arch::kernel::core_local::increment_irq_counter; use crate::config::VIRTIO_MAX_QUEUE_SIZE; use crate::drivers::net::NetworkDriver; #[cfg(not(feature = "pci"))] @@ -249,8 +246,6 @@ pub(crate) struct VirtioNetDriver { pub(super) send_vqs: TxQueues, pub(super) num_vqs: u16, - #[cfg_attr(target_arch = "riscv64", allow(dead_code))] - pub(super) irq: InterruptLine, pub(super) mtu: u16, pub(super) checksums: ChecksumCapabilities, } @@ -394,22 +389,15 @@ impl NetworkDriver for VirtioNetDriver { } } - fn handle_interrupt(&mut self) -> bool { - #[cfg(not(target_arch = "riscv64"))] - increment_irq_counter(32 + self.irq); + fn handle_interrupt(&mut self) { + let _ = self.isr_stat.is_interrupt(); - let result = if self.isr_stat.is_interrupt() { - true - } else if self.isr_stat.is_cfg_change() { + if self.isr_stat.is_cfg_change() { info!("Configuration changes are not possible! Aborting"); todo!("Implement possibility to change config on the fly...") - } else { - false - }; + } self.isr_stat.acknowledge(); - - result } } diff --git a/src/drivers/net/virtio/pci.rs b/src/drivers/net/virtio/pci.rs index bcd68cfab7..a477b6509e 100644 --- a/src/drivers/net/virtio/pci.rs +++ b/src/drivers/net/virtio/pci.rs @@ -97,7 +97,6 @@ impl VirtioNetDriver { recv_vqs, send_vqs, num_vqs: 0, - irq: device.get_irq().unwrap(), mtu, checksums: ChecksumCapabilities::default(), }) diff --git a/src/drivers/pci.rs b/src/drivers/pci.rs index ee6239db4a..a00f4a8637 100644 --- a/src/drivers/pci.rs +++ b/src/drivers/pci.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use core::fmt; use hermit_sync::without_interrupts; -#[cfg(any(feature = "tcp", feature = "udp", feature = "fuse"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "fuse", feature = "vsock"))] use hermit_sync::InterruptTicketMutex; use pci_types::capability::CapabilityIterator; use pci_types::{ @@ -22,14 +22,18 @@ use crate::drivers::net::rtl8139::{self, RTL8139Driver}; use crate::drivers::net::virtio::VirtioNetDriver; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::transport::pci as pci_virtio; #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] use crate::drivers::virtio::transport::pci::VirtioDriver; +#[cfg(feature = "vsock")] +use crate::drivers::vsock::VirtioVsockDriver; pub(crate) static mut PCI_DEVICES: Vec> = Vec::new(); static mut PCI_DRIVERS: Vec = Vec::new(); @@ -131,6 +135,10 @@ impl PciDevice { } }; + if address == 0 { + return None; + } + debug!( "Mapping bar {} at {:#x} with length {:#x}", index, address, size @@ -294,9 +302,12 @@ pub(crate) fn print_information() { } #[allow(clippy::large_enum_variant)] +#[allow(clippy::enum_variant_names)] pub(crate) enum PciDriver { #[cfg(feature = "fuse")] VirtioFs(InterruptTicketMutex), + #[cfg(feature = "vsock")] + VirtioVsock(InterruptTicketMutex), #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] VirtioNet(InterruptTicketMutex), #[cfg(all(feature = "rtl8139", any(feature = "tcp", feature = "udp")))] @@ -322,6 +333,15 @@ impl PciDriver { } } + #[cfg(feature = "vsock")] + fn get_vsock_driver(&self) -> Option<&InterruptTicketMutex> { + #[allow(unreachable_patterns)] + match self { + Self::VirtioVsock(drv) => Some(drv), + _ => None, + } + } + #[cfg(feature = "fuse")] fn get_filesystem_driver(&self) -> Option<&InterruptTicketMutex> { match self { @@ -348,6 +368,11 @@ pub(crate) fn get_network_driver() -> Option<&'static InterruptTicketMutex Option<&'static InterruptTicketMutex> { + unsafe { PCI_DRIVERS.iter().find_map(|drv| drv.get_vsock_driver()) } +} + #[cfg(feature = "fuse")] pub(crate) fn get_filesystem_driver() -> Option<&'static InterruptTicketMutex> { unsafe { @@ -367,19 +392,24 @@ pub(crate) fn init_drivers() { }) } { info!( - "Found virtio network device with device id {:#x}", + "Found virtio device with device id {:#x}", adapter.device_id() ); #[cfg(any( all(any(feature = "tcp", feature = "udp"), not(feature = "rtl8139")), - feature = "fuse" + feature = "fuse", + feature = "vsock" ))] match pci_virtio::init_device(adapter) { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] Ok(VirtioDriver::Network(drv)) => { register_driver(PciDriver::VirtioNet(InterruptTicketMutex::new(drv))) } + #[cfg(feature = "vsock")] + Ok(VirtioDriver::Vsock(drv)) => { + register_driver(PciDriver::VirtioVsock(InterruptTicketMutex::new(drv))) + } #[cfg(feature = "fuse")] Ok(VirtioDriver::FileSystem(drv)) => { register_driver(PciDriver::VirtioFs(InterruptTicketMutex::new(drv))) diff --git a/src/drivers/virtio/mod.rs b/src/drivers/virtio/mod.rs index cbd2719cc4..e32cd5f375 100644 --- a/src/drivers/virtio/mod.rs +++ b/src/drivers/virtio/mod.rs @@ -14,6 +14,8 @@ pub mod error { pub use crate::drivers::net::virtio::error::VirtioNetError; #[cfg(feature = "pci")] use crate::drivers::pci::error::PciError; + #[cfg(feature = "vsock")] + pub use crate::drivers::vsock::error::VirtioVsockError; #[allow(dead_code)] #[derive(Debug)] @@ -25,6 +27,8 @@ pub mod error { NetDriver(VirtioNetError), #[cfg(feature = "fuse")] FsDriver(VirtioFsError), + #[cfg(feature = "vsock")] + VsockDriver(VirtioVsockError), #[cfg(not(feature = "pci"))] Unknown, } @@ -71,6 +75,20 @@ pub mod error { VirtioFsError::IncompatibleFeatureSets(driver_features, device_features) => write!(f, "Feature set: {driver_features:?} , is incompatible with the device features: {device_features:?}", ), VirtioFsError::Unknown => write!(f, "Virtio filesystem failed, driver failed due unknown reason!"), }, + #[cfg(feature = "vsock")] + VirtioError::VsockDriver(vsock_error) => match vsock_error { + #[cfg(feature = "pci")] + VirtioVsockError::NoDevCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed device config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoComCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed common config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoIsrCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed ISR status config!"), + #[cfg(feature = "pci")] + VirtioVsockError::NoNotifCfg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, due to a missing or malformed notification config!"), + VirtioVsockError::FailFeatureNeg(id) => write!(f, "Virtio socket device driver failed, for device {id:x}, device did not acknowledge negotiated feature set!"), + VirtioVsockError::FeatureRequirementsNotMet(features) => write!(f, "Virtio socket driver tried to set feature bit without setting dependency feature. Feat set: {features:?}"), + VirtioVsockError::IncompatibleFeatureSets(driver_features, device_features) => write!(f, "Feature set: {driver_features:?} , is incompatible with the device features: {device_features:?}"), + }, } } } diff --git a/src/drivers/virtio/transport/mmio.rs b/src/drivers/virtio/transport/mmio.rs index da826ec3b0..cf8b7f3114 100644 --- a/src/drivers/virtio/transport/mmio.rs +++ b/src/drivers/virtio/transport/mmio.rs @@ -18,10 +18,9 @@ use crate::arch::kernel::interrupts::*; use crate::arch::mm::PhysAddr; use crate::drivers::error::DriverError; #[cfg(any(feature = "tcp", feature = "udp"))] -use crate::drivers::net::network_irqhandler; -#[cfg(any(feature = "tcp", feature = "udp"))] use crate::drivers::net::virtio::VirtioNetDriver; use crate::drivers::virtio::error::VirtioError; +use crate::drivers::virtio::transport::virtio_irqhandler; pub struct VqCfgHandler<'a> { vq_index: u16, @@ -382,13 +381,16 @@ pub(crate) fn init_device( match registers.as_ptr().device_id().read() { #[cfg(any(feature = "tcp", feature = "udp"))] virtio::Id::Net => { - match VirtioNetDriver::init(dev_id, registers, irq_no) { + match VirtioNetDriver::init(dev_id, registers) { Ok(virt_net_drv) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + info!("Virtio network driver initialized."); // Install interrupt handler - irq_install_handler(irq_no, network_irqhandler); + irq_install_handler(irq_no, virtio_irqhandler); #[cfg(not(target_arch = "riscv64"))] - add_irq_name(irq_no, "virtio_net"); + add_irq_name(irq_no, "virtio"); + let _ = VIRTIO_IRQ.try_insert(irq_no); Ok(VirtioDriver::Network(virt_net_drv)) } @@ -398,6 +400,27 @@ pub(crate) fn init_device( } } } + #[cfg(feature = "vsock")] + virtio::Id::Vsock => { + match VirtioVsockDriver::init(dev_id, registers) { + Ok(virt_net_drv) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + + info!("Virtio sock driver initialized."); + // Install interrupt handler + irq_install_handler(irq_no, virtio_irqhandler); + #[cfg(not(target_arch = "riscv64"))] + add_irq_name(irq_no, "virtio"); + let _ = VIRTIO_IRQ.try_insert(irq_no); + + Ok(VirtioDriver::Vsock(virt_vsock_drv)) + } + Err(virtio_error) => { + error!("Virtio sock driver could not be initialized with device"); + Err(DriverError::InitVirtioDevFail(virtio_error)) + } + } + } device_id => { error!("Device with id {device_id:?} is currently not supported!"); // Return Driver error inidacting device is not supported diff --git a/src/drivers/virtio/transport/mod.rs b/src/drivers/virtio/transport/mod.rs index 50df4b30c9..26df517264 100644 --- a/src/drivers/virtio/transport/mod.rs +++ b/src/drivers/virtio/transport/mod.rs @@ -8,3 +8,100 @@ pub mod mmio; #[cfg(feature = "pci")] pub mod pci; + +use hermit_sync::OnceCell; + +#[cfg(not(target_arch = "riscv64"))] +use crate::arch::kernel::core_local::increment_irq_counter; +#[cfg(target_arch = "x86_64")] +use crate::arch::kernel::interrupts::ExceptionStackFrame; +#[cfg(all( + any(feature = "vsock", feature = "tcp", feature = "udp"), + not(feature = "pci") +))] +use crate::arch::kernel::mmio as hardware; +#[cfg(target_arch = "aarch64")] +use crate::arch::scheduler::State; +#[cfg(any(feature = "tcp", feature = "udp"))] +use crate::drivers::net::NetworkDriver; +#[cfg(all( + any(feature = "vsock", feature = "tcp", feature = "udp"), + feature = "pci" +))] +use crate::drivers::pci as hardware; + +/// All virtio devices share the interrupt number `VIRTIO_IRQ` +static VIRTIO_IRQ: OnceCell = OnceCell::new(); + +#[cfg(target_arch = "aarch64")] +pub(crate) fn virtio_irqhandler(_state: &State) -> bool { + debug!("Receive virtio interrupt"); + + increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); + + #[cfg(any(feature = "tcp", feature = "udp"))] + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); + } + + crate::executor::run(); + + true +} + +#[cfg(target_arch = "x86_64")] +pub(crate) extern "x86-interrupt" fn virtio_irqhandler(stack_frame: ExceptionStackFrame) { + crate::arch::x86_64::swapgs(&stack_frame); + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive virtio interrupt"); + + increment_irq_counter(32 + VIRTIO_IRQ.get().unwrap()); + + crate::kernel::apic::eoi(); + + #[cfg(any(feature = "tcp", feature = "udp"))] + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); + } + + crate::executor::run(); + + core_scheduler().reschedule(); + crate::arch::x86_64::swapgs(&stack_frame); +} + +#[cfg(target_arch = "riscv64")] +pub(crate) fn virtio_irqhandler() { + use crate::arch::kernel::core_local::core_scheduler; + use crate::scheduler::PerCoreSchedulerExt; + + debug!("Receive virtio interrupt"); + + // PLIC end of interrupt + crate::arch::kernel::interrupts::external_eoi(); + #[cfg(any(feature = "tcp", feature = "udp"))] + if let Some(driver) = hardware::get_network_driver() { + driver.lock().handle_interrupt() + } + + #[cfg(feature = "vsock")] + if let Some(driver) = hardware::get_vsock_driver() { + driver.lock().handle_interrupt(); + } + + crate::executor::run(); + + core_scheduler().reschedule(); +} diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index 7dc5e6770e..f725b06499 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -16,7 +16,10 @@ use virtio::{le16, le32, DeviceStatus}; use volatile::access::ReadOnly; use volatile::{VolatilePtr, VolatileRef}; -#[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] +#[cfg(all( + not(feature = "rtl8139"), + any(feature = "tcp", feature = "udp", feature = "vsock") +))] use crate::arch::kernel::interrupts::*; use crate::arch::memory_barrier; use crate::arch::mm::PhysAddr; @@ -25,12 +28,13 @@ use crate::drivers::error::DriverError; #[cfg(feature = "fuse")] use crate::drivers::fs::virtio_fs::VirtioFsDriver; #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] -use crate::drivers::net::network_irqhandler; -#[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] use crate::drivers::net::virtio::VirtioNetDriver; use crate::drivers::pci::error::PciError; use crate::drivers::pci::PciDevice; use crate::drivers::virtio::error::VirtioError; +use crate::drivers::virtio::transport::virtio_irqhandler; +#[cfg(feature = "vsock")] +use crate::drivers::vsock::VirtioVsockDriver; /// Maps a given device specific pci configuration structure and /// returns a static reference to it. @@ -915,6 +919,20 @@ pub(crate) fn init_device( Err(DriverError::InitVirtioDevFail(virtio_error)) } }, + #[cfg(feature = "vsock")] + virtio::Id::Vsock => match VirtioVsockDriver::init(device) { + Ok(virt_sock_drv) => { + info!("Virtio sock driver initialized."); + Ok(VirtioDriver::Vsock(virt_sock_drv)) + } + Err(virtio_error) => { + error!( + "Virtio sock driver could not be initialized with device: {:x}", + device_id + ); + Err(DriverError::InitVirtioDevFail(virtio_error)) + } + }, #[cfg(feature = "fuse")] virtio::Id::Fs => { // TODO: check subclass @@ -948,11 +966,29 @@ pub(crate) fn init_device( match &drv { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] VirtioDriver::Network(_) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + + let irq = device.get_irq().unwrap(); + let _ = VIRTIO_IRQ.try_insert(irq); + + info!("Install virtio interrupt handler at line {}", irq); + // Install interrupt handler + irq_install_handler(irq, virtio_irqhandler); + add_irq_name(irq, "virtio"); + + Ok(drv) + } + #[cfg(feature = "vsock")] + VirtioDriver::Vsock(_) => { + use crate::drivers::virtio::transport::VIRTIO_IRQ; + let irq = device.get_irq().unwrap(); + let _ = VIRTIO_IRQ.try_insert(irq); + info!("Install virtio interrupt handler at line {}", irq); // Install interrupt handler - irq_install_handler(irq, network_irqhandler); - add_irq_name(irq, "virtio_net"); + irq_install_handler(irq, virtio_irqhandler); + add_irq_name(irq, "virtio"); Ok(drv) } @@ -967,6 +1003,8 @@ pub(crate) fn init_device( pub(crate) enum VirtioDriver { #[cfg(all(not(feature = "rtl8139"), any(feature = "tcp", feature = "udp")))] Network(VirtioNetDriver), + #[cfg(feature = "vsock")] + Vsock(VirtioVsockDriver), #[cfg(feature = "fuse")] FileSystem(VirtioFsDriver), } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs new file mode 100644 index 0000000000..e906220671 --- /dev/null +++ b/src/drivers/vsock/mod.rs @@ -0,0 +1,477 @@ +#![allow(dead_code)] + +#[cfg(feature = "pci")] +pub mod pci; + +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::mem; +use core::mem::MaybeUninit; + +use pci_types::InterruptLine; +use virtio::vsock::Hdr; +use virtio::FeatureBits; + +use crate::config::VIRTIO_MAX_QUEUE_SIZE; +use crate::drivers::virtio::error::VirtioVsockError; +#[cfg(feature = "pci")] +use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg}; +use crate::drivers::virtio::virtqueue::split::SplitVq; +use crate::drivers::virtio::virtqueue::{ + AvailBufferToken, BufferElem, BufferType, UsedBufferToken, Virtq, VqIndex, VqSize, +}; +#[cfg(feature = "pci")] +use crate::drivers::vsock::pci::VsockDevCfgRaw; +use crate::mm::device_alloc::DeviceAlloc; + +fn fill_queue( + vq: &mut dyn Virtq, + num_packets: u16, + packet_size: u32, + poll_sender: async_channel::Sender, +) { + for _ in 0..num_packets { + let buff_tkn = match AvailBufferToken::new( + vec![], + vec![ + BufferElem::Sized(Box::::new_uninit_in(DeviceAlloc)), + BufferElem::Vector(Vec::with_capacity_in( + packet_size.try_into().unwrap(), + DeviceAlloc, + )), + ], + ) { + Ok(tkn) => tkn, + Err(_vq_err) => { + error!("Setup of network queue failed, which should not happen!"); + panic!("setup of network queue failed!"); + } + }; + + // BufferTokens are directly provided to the queue + // TransferTokens are directly dispatched + // Transfers will be awaited at the queue + match vq.dispatch( + buff_tkn, + Some(poll_sender.clone()), + false, + BufferType::Direct, + ) { + Ok(_) => (), + Err(err) => { + error!("{:#?}", err); + break; + } + } + } +} + +pub(crate) struct RxQueue { + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, + packet_size: u32, +} + +impl RxQueue { + pub fn new() -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); + + Self { + vq: None, + poll_sender, + poll_receiver, + packet_size: crate::VSOCK_PACKET_SIZE + mem::size_of::() as u32, + } + } + + pub fn add(&mut self, mut vq: Box) { + const BUFF_PER_PACKET: u16 = 2; + let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; + info!("num_packets {}", num_packets); + fill_queue( + vq.as_mut(), + num_packets, + self.packet_size, + self.poll_sender.clone(), + ); + + self.vq = Some(vq); + } + + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.disable_notifs(); + } + } + + fn get_next(&mut self) -> Option { + let transfer = self.poll_receiver.try_recv(); + + transfer + .or_else(|_| { + // Check if any not yet provided transfers are in the queue. + self.poll(); + + self.poll_receiver.try_recv() + }) + .ok() + } + + fn poll(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.poll(); + } + } + + pub fn process_packet(&mut self, mut f: F) + where + F: FnMut(&Hdr, &[u8]), + { + while let Some(mut buffer_tkn) = self.get_next() { + let header = buffer_tkn + .used_recv_buff + .pop_front_downcast::() + .unwrap(); + let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); + + if let Some(ref mut vq) = self.vq { + f(&header, &packet[..]); + + fill_queue(vq.as_mut(), 1, self.packet_size, self.poll_sender.clone()); + } else { + panic!("Invalid length of receive queue"); + } + } + } +} + +pub(crate) struct TxQueue { + vq: Option>, + /// Indicates, whether the Driver/Device are using multiple + /// queues for communication. + packet_length: u32, +} + +impl TxQueue { + pub fn new() -> Self { + Self { + vq: None, + packet_length: crate::VSOCK_PACKET_SIZE + mem::size_of::() as u32, + } + } + + pub fn add(&mut self, vq: Box) { + self.vq = Some(vq); + } + + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.disable_notifs(); + } + } + + fn poll(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.poll(); + } + } + + /// Provides a slice to copy the packet and transfer the packet + /// to the send queue. The caller has to create the header + /// for the vsock interface. + pub fn send_packet(&mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + // We need to poll to get the queue to remove elements from the table and make space for + // what we are about to add + if let Some(ref mut vq) = self.vq { + vq.poll(); + + assert!(len < usize::try_from(self.packet_length).unwrap()); + let mut packet = Vec::with_capacity_in(len, DeviceAlloc); + let result = unsafe { + let result = f(MaybeUninit::slice_assume_init_mut( + packet.spare_capacity_mut(), + )); + packet.set_len(len); + result + }; + + let buff_tkn = AvailBufferToken::new(vec![BufferElem::Vector(packet)], vec![]).unwrap(); + + vq.dispatch(buff_tkn, None, false, BufferType::Direct) + .unwrap(); + + result + } else { + panic!("Unable to get send queue"); + } + } +} + +pub(crate) struct EventQueue { + vq: Option>, + poll_sender: async_channel::Sender, + poll_receiver: async_channel::Receiver, + packet_size: u32, +} + +impl EventQueue { + pub fn new() -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); + + Self { + vq: None, + poll_sender, + poll_receiver, + packet_size: 128u32, + } + } + + /// Adds a given queue to the underlying vector and populates the queue with RecvBuffers. + /// + /// Queues are all populated according to Virtio specification v1.1. - 5.1.6.3.1 + fn add(&mut self, mut vq: Box) { + const BUFF_PER_PACKET: u16 = 2; + let num_packets: u16 = u16::from(vq.size()) / BUFF_PER_PACKET; + fill_queue( + vq.as_mut(), + num_packets, + self.packet_size, + self.poll_sender.clone(), + ); + self.vq = Some(vq); + } + + pub fn enable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.enable_notifs(); + } + } + + pub fn disable_notifs(&mut self) { + if let Some(ref mut vq) = self.vq { + vq.disable_notifs(); + } + } +} + +/// A wrapper struct for the raw configuration structure. +/// Handling the right access to fields, as some are read-only +/// for the driver. +pub(crate) struct VsockDevCfg { + pub raw: &'static VsockDevCfgRaw, + pub dev_id: u16, + pub features: virtio::vsock::F, +} + +pub(crate) struct VirtioVsockDriver { + pub(super) dev_cfg: VsockDevCfg, + pub(super) com_cfg: ComCfg, + pub(super) isr_stat: IsrStatus, + pub(super) notif_cfg: NotifCfg, + pub(super) irq: InterruptLine, + + pub(super) event_vq: EventQueue, + pub(super) recv_vq: RxQueue, + pub(super) send_vq: TxQueue, +} + +impl VirtioVsockDriver { + #[cfg(feature = "pci")] + pub fn get_dev_id(&self) -> u16 { + self.dev_cfg.dev_id + } + + #[inline] + pub fn get_cid(&self) -> u64 { + self.dev_cfg.raw.guest_cid + } + + #[cfg(feature = "pci")] + pub fn set_failed(&mut self) { + self.com_cfg.set_failed(); + } + + pub fn disable_interrupts(&mut self) { + // For send and receive queues? + // Only for receive? Because send is off anyway? + self.recv_vq.disable_notifs(); + } + + pub fn enable_interrupts(&mut self) { + // For send and receive queues? + // Only for receive? Because send is off anyway? + self.recv_vq.enable_notifs(); + } + + pub fn handle_interrupt(&mut self) { + let _ = self.isr_stat.is_interrupt(); + + if self.isr_stat.is_cfg_change() { + info!("Configuration changes are not possible! Aborting"); + todo!("Implement possibility to change config on the fly...") + } + + self.isr_stat.acknowledge(); + } + + /// Negotiates a subset of features, understood and wanted by both the OS + /// and the device. + fn negotiate_features( + &mut self, + driver_features: virtio::vsock::F, + ) -> Result<(), VirtioVsockError> { + let device_features = virtio::vsock::F::from(self.com_cfg.dev_features()); + + if device_features.requirements_satisfied() { + info!("Feature set wanted by vsock driver are in conformance with specification."); + } else { + return Err(VirtioVsockError::FeatureRequirementsNotMet(device_features)); + } + + if device_features.contains(driver_features) { + // If device supports subset of features write feature set to common config + self.com_cfg.set_drv_features(driver_features.into()); + Ok(()) + } else { + Err(VirtioVsockError::IncompatibleFeatureSets( + driver_features, + device_features, + )) + } + } + + /// Initializes the device in adherence to specification. Returns Some(VirtioVsockError) + /// upon failure and None in case everything worked as expected. + /// + /// See Virtio specification v1.1. - 3.1.1. + /// and v1.1. - 5.10.6 + pub fn init_dev(&mut self) -> Result<(), VirtioVsockError> { + // Reset + self.com_cfg.reset_dev(); + + // Indiacte device, that OS noticed it + self.com_cfg.ack_dev(); + + // Indicate device, that driver is able to handle it + self.com_cfg.set_drv(); + + let features = virtio::vsock::F::VERSION_1; + self.negotiate_features(features)?; + + // Indicates the device, that the current feature set is final for the driver + // and will not be changed. + self.com_cfg.features_ok(); + + // Checks if the device has accepted final set. This finishes feature negotiation. + if self.com_cfg.check_features() { + info!( + "Features have been negotiated between virtio socket device {:x} and driver.", + self.dev_cfg.dev_id + ); + // Set feature set in device config fur future use. + self.dev_cfg.features = features; + } else { + return Err(VirtioVsockError::FailFeatureNeg(self.dev_cfg.dev_id)); + } + + // create the queues and tell device about them + self.recv_vq.add(Box::new( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(0u16), + self.dev_cfg.features.into(), + ) + .unwrap(), + )); + // Interrupt for receiving packets is wanted + self.recv_vq.enable_notifs(); + + self.send_vq.add(Box::new( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(1u16), + self.dev_cfg.features.into(), + ) + .unwrap(), + )); + // Interrupt for communicating that a sended packet left, is not needed + self.send_vq.disable_notifs(); + + // create the queues and tell device about them + self.event_vq.add(Box::new( + SplitVq::new( + &mut self.com_cfg, + &self.notif_cfg, + VqSize::from(VIRTIO_MAX_QUEUE_SIZE), + VqIndex::from(2u16), + self.dev_cfg.features.into(), + ) + .unwrap(), + )); + // Interrupt for event packets is wanted + self.event_vq.enable_notifs(); + + // At this point the device is "live" + self.com_cfg.drv_ok(); + + Ok(()) + } + + #[inline] + pub fn process_packet(&mut self, f: F) + where + F: FnMut(&Hdr, &[u8]), + { + self.recv_vq.process_packet(f) + } + + /// Provides a slice to copy the packet and transfer the packet + /// to the send queue. The caller has to creatde the header + /// for the vsock interface. + #[inline] + pub fn send_packet(&mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.send_vq.send_packet(len, f) + } +} + +/// Error module of virtio socket device driver. +pub mod error { + /// Virtio socket device error enum. + #[derive(Debug, Copy, Clone)] + pub enum VirtioVsockError { + NoDevCfg(u16), + NoComCfg(u16), + NoIsrCfg(u16), + NoNotifCfg(u16), + FailFeatureNeg(u16), + /// Set of features does not adhere to the requirements of features + /// indicated by the specification + FeatureRequirementsNotMet(virtio::vsock::F), + /// The first u64 contains the feature bits wanted by the driver. + /// but which are incompatible with the device feature set, second u64. + IncompatibleFeatureSets(virtio::vsock::F, virtio::vsock::F), + } +} diff --git a/src/drivers/vsock/pci.rs b/src/drivers/vsock/pci.rs new file mode 100644 index 0000000000..536e4700de --- /dev/null +++ b/src/drivers/vsock/pci.rs @@ -0,0 +1,126 @@ +use crate::arch::pci::PciConfigRegion; +use crate::drivers::pci::PciDevice; +use crate::drivers::virtio::error::{self, VirtioError}; +use crate::drivers::virtio::transport::pci; +use crate::drivers::virtio::transport::pci::{PciCap, UniCapsColl}; +use crate::drivers::vsock::{EventQueue, RxQueue, TxQueue, VirtioVsockDriver, VsockDevCfg}; + +/// Virtio's socket device configuration structure. +/// See specification v1.1. - 5.11.4 +/// +#[derive(Debug, Copy, Clone)] +#[repr(C)] +pub(crate) struct VsockDevCfgRaw { + /// The guest_cid field contains the guest’s context ID, which uniquely identifies the device + /// for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + pub guest_cid: u64, +} + +impl VirtioVsockDriver { + fn map_cfg(cap: &PciCap) -> Option { + let dev_cfg: &'static VsockDevCfgRaw = match pci::map_dev_cfg::(cap) { + Some(cfg) => cfg, + None => return None, + }; + + Some(VsockDevCfg { + raw: dev_cfg, + dev_id: cap.dev_id(), + features: virtio::vsock::F::empty(), + }) + } + + /// Instantiates a new VirtioVsockDriver struct, by checking the available + /// configuration structures and moving them into the struct. + pub fn new( + mut caps_coll: UniCapsColl, + device: &PciDevice, + ) -> Result { + let device_id = device.device_id(); + + let com_cfg = match caps_coll.get_com_cfg() { + Some(com_cfg) => com_cfg, + None => { + error!("No common config. Aborting!"); + return Err(error::VirtioVsockError::NoComCfg(device_id)); + } + }; + + let isr_stat = match caps_coll.get_isr_cfg() { + Some(isr_stat) => isr_stat, + None => { + error!("No ISR status config. Aborting!"); + return Err(error::VirtioVsockError::NoIsrCfg(device_id)); + } + }; + + let notif_cfg = match caps_coll.get_notif_cfg() { + Some(notif_cfg) => notif_cfg, + None => { + error!("No notif config. Aborting!"); + return Err(error::VirtioVsockError::NoNotifCfg(device_id)); + } + }; + + let dev_cfg = loop { + match caps_coll.get_dev_cfg() { + Some(cfg) => { + if let Some(dev_cfg) = VirtioVsockDriver::map_cfg(&cfg) { + break dev_cfg; + } + } + None => { + error!("No dev config. Aborting!"); + return Err(error::VirtioVsockError::NoDevCfg(device_id)); + } + } + }; + + Ok(VirtioVsockDriver { + dev_cfg, + com_cfg, + isr_stat, + notif_cfg, + irq: device.get_irq().unwrap(), + event_vq: EventQueue::new(), + recv_vq: RxQueue::new(), + send_vq: TxQueue::new(), + }) + } + + /// Initializes virtio socket device + /// + /// Returns a driver instance of VirtioVsockDriver. + pub(crate) fn init( + device: &PciDevice, + ) -> Result { + let mut drv = match pci::map_caps(device) { + Ok(caps) => match VirtioVsockDriver::new(caps, device) { + Ok(driver) => driver, + Err(vsock_err) => { + error!("Initializing new virtio socket device driver failed. Aborting!"); + return Err(VirtioError::VsockDriver(vsock_err)); + } + }, + Err(pci_error) => { + error!("Mapping capabilities failed. Aborting!"); + return Err(VirtioError::FromPci(pci_error)); + } + }; + + match drv.init_dev() { + Ok(_) => { + info!( + "Socket device with cid {:x}, has been initialized by driver!", + drv.dev_cfg.raw.guest_cid + ); + + Ok(drv) + } + Err(fs_err) => { + drv.set_failed(); + Err(VirtioError::VsockDriver(fs_err)) + } + } + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 232a72224f..5cb48939eb 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -5,6 +5,8 @@ pub(crate) mod device; #[cfg(any(feature = "tcp", feature = "udp"))] pub(crate) mod network; pub(crate) mod task; +#[cfg(feature = "vsock")] +pub(crate) mod vsock; use alloc::sync::Arc; use alloc::task::Wake; @@ -91,6 +93,8 @@ where pub fn init() { #[cfg(any(feature = "tcp", feature = "udp"))] crate::executor::network::init(); + #[cfg(feature = "vsock")] + crate::executor::vsock::init(); } #[inline] diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs new file mode 100644 index 0000000000..4cd43ea4d3 --- /dev/null +++ b/src/executor/vsock.rs @@ -0,0 +1,241 @@ +use alloc::collections::BTreeMap; +use alloc::vec::Vec; +use core::future; +use core::task::{Poll, Waker}; + +use hermit_sync::InterruptTicketMutex; +use virtio::vsock::{Hdr, Op, Type}; +use virtio::{le16, le32}; + +#[cfg(not(feature = "pci"))] +use crate::arch::kernel::mmio as hardware; +#[cfg(feature = "pci")] +use crate::drivers::pci as hardware; +use crate::executor::spawn; +use crate::io; +use crate::io::Error::EADDRINUSE; + +pub(crate) static VSOCK_MAP: InterruptTicketMutex = + InterruptTicketMutex::new(VsockMap::new()); + +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum VsockState { + Listen, + ReceiveRequest, + Connected, + Connecting, + Shutdown, +} + +/// WakerRegistration is derived from smoltcp's +/// implementation. +#[derive(Debug)] +pub(crate) struct WakerRegistration { + waker: Option, +} + +impl WakerRegistration { + pub const fn new() -> Self { + Self { waker: None } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&mut self, w: &Waker) { + match self.waker { + // Optimization: If both the old and new Wakers wake the same task, we can simply + // keep the old waker, skipping the clone. + Some(ref w2) if (w2.will_wake(w)) => {} + // In all other cases + // - we have no waker registered + // - we have a waker registered but it's for a different task. + // then clone the new waker and store it + _ => self.waker = Some(w.clone()), + } + } + + /// Wake the registered waker, if any. + pub fn wake(&mut self) { + if let Some(w) = self.waker.take() { + w.wake() + } + } +} + +pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024; + +#[derive(Debug)] +pub(crate) struct RawSocket { + pub remote_cid: u32, + pub remote_port: u32, + pub fwd_cnt: u32, + pub peer_fwd_cnt: u32, + pub peer_buf_alloc: u32, + pub tx_cnt: u32, + pub state: VsockState, + pub rx_waker: WakerRegistration, + pub tx_waker: WakerRegistration, + pub buffer: Vec, +} + +impl RawSocket { + pub fn new(state: VsockState) -> Self { + Self { + remote_cid: 0, + remote_port: 0, + fwd_cnt: 0, + peer_fwd_cnt: 0, + peer_buf_alloc: 0, + tx_cnt: 0, + state, + rx_waker: WakerRegistration::new(), + tx_waker: WakerRegistration::new(), + buffer: Vec::with_capacity(RAW_SOCKET_BUFFER_SIZE), + } + } +} + +async fn vsock_run() { + future::poll_fn(|_cx| { + if let Some(driver) = hardware::get_vsock_driver() { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = driver.lock(); + let mut hdr: Option = None; + let mut fwd_cnt: u32 = 0; + + driver_guard.process_packet(|header, data| { + let op = Op::try_from(header.op.to_ne()).unwrap(); + let port = header.dst_port.to_ne(); + let type_ = Type::try_from(header.type_.to_ne()).unwrap(); + let mut vsock_guard = VSOCK_MAP.lock(); + let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); + + if let Some(raw) = vsock_guard.get_mut_socket(port) { + if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream + { + raw.state = VsockState::ReceiveRequest; + raw.remote_cid = header_cid; + raw.remote_port = header.src_port.to_ne(); + raw.peer_buf_alloc = header.buf_alloc.to_ne(); + raw.rx_waker.wake(); + } else if (raw.state == VsockState::Connected + || raw.state == VsockState::Shutdown) + && type_ == Type::Stream + && op == Op::Rw + { + if raw.remote_cid == header_cid { + raw.buffer.extend_from_slice(data); + raw.fwd_cnt = + raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + raw.rx_waker.wake(); + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } else { + trace!("Receive message from invalid source {}", header_cid); + } + } else if op == Op::CreditUpdate { + if raw.remote_cid == header_cid { + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + } else { + trace!("Receive message from invalid source {}", header_cid); + } + } else if op == Op::Shutdown { + if raw.remote_cid == header_cid { + raw.state = VsockState::Shutdown; + } else { + trace!("Receive message from invalid source {}", header_cid); + } + } else if op == Op::Response && type_ == Type::Stream { + if raw.remote_cid == header_cid && raw.state == VsockState::Connecting { + raw.state = VsockState::Connected; + } + } else if raw.remote_cid == header_cid { + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } + } + }); + + if let Some(hdr) = hdr { + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = hdr.dst_cid; + response.dst_cid = hdr.src_cid; + response.src_port = hdr.dst_port; + response.dst_port = hdr.src_port; + response.len = le32::from_ne(0); + response.type_ = hdr.type_; + if hdr.op.to_ne() == Op::CreditRequest.into() || hdr.op.to_ne() == Op::Rw.into() + { + response.op = le16::from_ne(Op::CreditUpdate.into()); + } else { + // reset connection + response.op = le16::from_ne(Op::Rst.into()); + } + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(fwd_cnt); + }); + } + + Poll::Pending + } else { + Poll::Ready(()) + } + }) + .await +} + +pub(crate) struct VsockMap { + port_map: BTreeMap, +} + +impl VsockMap { + pub const fn new() -> Self { + Self { + port_map: BTreeMap::new(), + } + } + + pub fn bind(&mut self, port: u32) -> io::Result<()> { + self.port_map + .try_insert(port, RawSocket::new(VsockState::Listen)) + .map_err(|_| EADDRINUSE)?; + Ok(()) + } + + pub fn connect(&mut self, port: u32, cid: u32) -> io::Result { + for i in u32::MAX / 4..u32::MAX { + let mut raw = RawSocket::new(VsockState::Connecting); + raw.remote_cid = cid; + raw.remote_port = port; + + if self.port_map.try_insert(i, raw).is_ok() { + return Ok(i); + } + } + + Err(io::Error::EBADF) + } + + pub fn get_socket(&self, port: u32) -> Option<&RawSocket> { + self.port_map.get(&port) + } + + pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> { + self.port_map.get_mut(&port) + } + + pub fn remove_socket(&mut self, port: u32) { + let _ = self.port_map.remove(&port); + } +} + +pub(crate) fn init() { + info!("Try to initialize vsock interface!"); + + spawn(vsock_run()); +} diff --git a/src/fd/mod.rs b/src/fd/mod.rs index c2018b2324..1f0c2f652d 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -16,7 +16,7 @@ use crate::fs::{DirectoryEntry, FileAttr, SeekWhence}; use crate::io; mod eventfd; -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] pub(crate) mod socket; pub(crate) mod stdio; @@ -24,6 +24,24 @@ pub(crate) const STDIN_FILENO: FileDescriptor = 0; pub(crate) const STDOUT_FILENO: FileDescriptor = 1; pub(crate) const STDERR_FILENO: FileDescriptor = 2; +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] +#[derive(Debug)] +pub(crate) enum Endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Ip(IpEndpoint), + #[cfg(feature = "vsock")] + Vsock(socket::vsock::VsockEndpoint), +} + +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] +#[derive(Debug)] +pub(crate) enum ListenEndpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Ip(IpListenEndpoint), + #[cfg(feature = "vsock")] + Vsock(socket::vsock::VsockListenEndpoint), +} + #[allow(dead_code)] #[derive(Debug, PartialEq)] pub(crate) enum SocketOption { @@ -186,57 +204,57 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { } /// `accept` a connection on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn accept(&self) -> io::Result { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn accept(&self) -> io::Result { Err(io::Error::EINVAL) } /// initiate a connection on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn connect(&self, _endpoint: IpEndpoint) -> io::Result<()> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn connect(&self, _endpoint: Endpoint) -> io::Result<()> { Err(io::Error::EINVAL) } /// `bind` a name to a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn bind(&self, _name: IpListenEndpoint) -> io::Result<()> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn bind(&self, _name: ListenEndpoint) -> io::Result<()> { Err(io::Error::EINVAL) } /// `listen` for connections on a socket - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn listen(&self, _backlog: i32) -> io::Result<()> { Err(io::Error::EINVAL) } /// `setsockopt` sets options on sockets - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn setsockopt(&self, _opt: SocketOption, _optval: bool) -> io::Result<()> { Err(io::Error::EINVAL) } /// `getsockopt` gets options on sockets - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn getsockopt(&self, _opt: SocketOption) -> io::Result { Err(io::Error::EINVAL) } /// `getsockname` gets socket name - #[cfg(any(feature = "tcp", feature = "udp"))] - fn getsockname(&self) -> Option { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn getsockname(&self) -> Option { None } /// `getpeername` get address of connected peer - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] #[allow(dead_code)] - fn getpeername(&self) -> Option { + fn getpeername(&self) -> Option { None } /// receive a message from a socket - #[cfg(any(feature = "tcp", feature = "udp"))] - fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { Err(io::Error::ENOSYS) } @@ -247,13 +265,13 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug + DynClone { /// If a peer address has been prespecified, either the message shall /// be sent to the address specified by dest_addr (overriding the pre-specified peer /// address). - #[cfg(any(feature = "tcp", feature = "udp"))] - fn sendto(&self, _buffer: &[u8], _endpoint: IpEndpoint) -> io::Result { + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] + fn sendto(&self, _buffer: &[u8], _endpoint: Endpoint) -> io::Result { Err(io::Error::ENOSYS) } /// shut down part of a full-duplex connection - #[cfg(any(feature = "tcp", feature = "udp"))] + #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] fn shutdown(&self, _how: i32) -> io::Result<()> { Err(io::Error::ENOSYS) } diff --git a/src/fd/socket/mod.rs b/src/fd/socket/mod.rs index 7a41790273..1ccbb1f1bc 100644 --- a/src/fd/socket/mod.rs +++ b/src/fd/socket/mod.rs @@ -2,3 +2,5 @@ pub(crate) mod tcp; #[cfg(feature = "udp")] pub(crate) mod udp; +#[cfg(feature = "vsock")] +pub(crate) mod vsock; diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 79c03b82f4..9aedf492ea 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -8,11 +8,11 @@ use async_trait::async_trait; use smoltcp::iface; use smoltcp::socket::tcp; use smoltcp::time::Duration; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; +use smoltcp::wire::IpEndpoint; use crate::executor::block_on; use crate::executor::network::{now, Handle, NetworkState, NIC}; -use crate::fd::{IoCtl, ObjectInterface, PollEvent, SocketOption}; +use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent, SocketOption}; use crate::{io, DEFAULT_KEEP_ALIVE_INTERVAL}; /// further receives will be disallowed @@ -304,45 +304,59 @@ impl ObjectInterface for Socket { Ok(pos) } - fn bind(&self, endpoint: IpListenEndpoint) -> io::Result<()> { - self.port.store(endpoint.port, Ordering::Release); - Ok(()) + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let ListenEndpoint::Ip(endpoint) = endpoint { + self.port.store(endpoint.port, Ordering::Release); + Ok(()) + } else { + Err(io::Error::EIO) + } } - fn connect(&self, endpoint: IpEndpoint) -> io::Result<()> { - if self.nonblocking.load(Ordering::Acquire) { - block_on(self.async_connect(endpoint), Some(Duration::ZERO.into())).map_err(|x| { - if x == io::Error::ETIME { - io::Error::EAGAIN - } else { - x - } - }) + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + if self.nonblocking.load(Ordering::Acquire) { + block_on(self.async_connect(endpoint), Some(Duration::ZERO.into())).map_err(|x| { + if x == io::Error::ETIME { + io::Error::EAGAIN + } else { + x + } + }) + } else { + block_on(self.async_connect(endpoint), None) + } } else { - block_on(self.async_connect(endpoint), None) + Err(io::Error::EIO) } } - fn accept(&self) -> io::Result { - if self.is_nonblocking() { + fn accept(&self) -> io::Result { + let endpoint = if self.is_nonblocking() { block_on(self.async_accept(), Some(Duration::ZERO.into())).map_err(|x| { if x == io::Error::ETIME { io::Error::EAGAIN } else { x } - }) + })? } else { - block_on(self.async_accept(), None) - } + block_on(self.async_accept(), None)? + }; + + Ok(Endpoint::Ip(endpoint)) } - fn getpeername(&self) -> Option { + fn getpeername(&self) -> Option { self.with(|socket| socket.remote_endpoint()) + .map(Endpoint::Ip) } - fn getsockname(&self) -> Option { + fn getsockname(&self) -> Option { self.with(|socket| socket.local_endpoint()) + .map(Endpoint::Ip) } fn is_nonblocking(&self) -> bool { diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 59e62c04a6..8004dd68a2 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -9,11 +9,11 @@ use crossbeam_utils::atomic::AtomicCell; use smoltcp::socket::udp; use smoltcp::socket::udp::UdpMetadata; use smoltcp::time::Duration; -use smoltcp::wire::{IpEndpoint, IpListenEndpoint}; +use smoltcp::wire::{IpEndpoint, IpVersion}; use crate::executor::network::{now, Handle, NetworkState, NIC}; use crate::executor::{block_on, poll_on}; -use crate::fd::{IoCtl, ObjectInterface, PollEvent}; +use crate::fd::{Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io; #[derive(Debug)] @@ -51,7 +51,7 @@ impl Socket { .await } - async fn async_recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + async fn async_recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { future::poll_fn(|cx| { self.with(|socket| { if socket.is_open() { @@ -81,6 +81,7 @@ impl Socket { }) }) .await + .map(|(len, endpoint)| (len, Endpoint::Ip(endpoint))) } async fn async_write_with_meta(&self, buffer: &[u8], meta: &UdpMetadata) -> io::Result { @@ -154,29 +155,44 @@ impl ObjectInterface for Socket { .await } - fn bind(&self, endpoint: IpListenEndpoint) -> io::Result<()> { - self.with(|socket| socket.bind(endpoint).map_err(|_| io::Error::EADDRINUSE)) + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let ListenEndpoint::Ip(endpoint) = endpoint { + self.with(|socket| socket.bind(endpoint).map_err(|_| io::Error::EADDRINUSE)) + } else { + Err(io::Error::EIO) + } } - fn connect(&self, endpoint: IpEndpoint) -> io::Result<()> { - self.endpoint.store(Some(endpoint)); - Ok(()) + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + self.endpoint.store(Some(endpoint)); + Ok(()) + } else { + Err(io::Error::EIO) + } } - fn sendto(&self, buf: &[u8], endpoint: IpEndpoint) -> io::Result { - let meta = UdpMetadata::from(endpoint); + fn sendto(&self, buf: &[u8], endpoint: Endpoint) -> io::Result { + #[allow(irrefutable_let_patterns)] + if let Endpoint::Ip(endpoint) = endpoint { + let meta = UdpMetadata::from(endpoint); - if self.nonblocking.load(Ordering::Acquire) { - poll_on( - self.async_write_with_meta(buf, &meta), - Some(Duration::ZERO.into()), - ) + if self.nonblocking.load(Ordering::Acquire) { + poll_on( + self.async_write_with_meta(buf, &meta), + Some(Duration::ZERO.into()), + ) + } else { + poll_on(self.async_write_with_meta(buf, &meta), None) + } } else { - poll_on(self.async_write_with_meta(buf, &meta), None) + Err(io::Error::EIO) } } - fn recvfrom(&self, buf: &mut [u8]) -> io::Result<(usize, IpEndpoint)> { + fn recvfrom(&self, buf: &mut [u8]) -> io::Result<(usize, Endpoint)> { if self.nonblocking.load(Ordering::Acquire) { poll_on(self.async_recvfrom(buf), Some(Duration::ZERO.into())).map_err(|x| { if x == io::Error::ETIME { diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs new file mode 100644 index 0000000000..402c57d130 --- /dev/null +++ b/src/fd/socket/vsock.rs @@ -0,0 +1,423 @@ +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::future; +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use core::task::Poll; +use core::time::Duration; + +use async_trait::async_trait; +use virtio::vsock::{Hdr, Op, Type}; +use virtio::{le16, le32, le64}; + +#[cfg(not(feature = "pci"))] +use crate::arch::kernel::mmio as hardware; +#[cfg(feature = "pci")] +use crate::drivers::pci as hardware; +use crate::executor::vsock::{VsockState, VSOCK_MAP}; +use crate::fd::{block_on, poll_on, Endpoint, IoCtl, ListenEndpoint, ObjectInterface, PollEvent}; +use crate::io::{self, Error}; + +#[derive(Debug)] +pub(crate) struct VsockListenEndpoint { + pub port: u32, + pub cid: Option, +} + +impl VsockListenEndpoint { + pub const fn new(port: u32, cid: Option) -> Self { + Self { port, cid } + } +} + +#[derive(Debug)] +pub(crate) struct VsockEndpoint { + pub port: u32, + pub cid: u32, +} + +impl VsockEndpoint { + pub const fn new(port: u32, cid: u32) -> Self { + Self { port, cid } + } +} + +#[derive(Debug)] +pub struct Socket { + port: AtomicU32, + cid: AtomicU32, + nonblocking: AtomicBool, +} + +impl Socket { + pub fn new() -> Self { + Self { + port: AtomicU32::new(0), + cid: AtomicU32::new(u32::MAX), + nonblocking: AtomicBool::new(false), + } + } +} + +#[async_trait] +impl ObjectInterface for Socket { + async fn poll(&self, event: PollEvent) -> io::Result { + let port = self.port.load(Ordering::Acquire); + + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Shutdown | VsockState::ReceiveRequest => { + let available = PollEvent::POLLOUT + | PollEvent::POLLWRNORM + | PollEvent::POLLWRBAND + | PollEvent::POLLIN + | PollEvent::POLLRDNORM + | PollEvent::POLLRDBAND; + + let ret = event & available; + + if ret.is_empty() { + Poll::Ready(Ok(PollEvent::POLLHUP)) + } else { + Poll::Ready(Ok(ret)) + } + } + VsockState::Listen | VsockState::Connecting => { + raw.rx_waker.register(cx.waker()); + raw.tx_waker.register(cx.waker()); + Poll::Pending + } + VsockState::Connected => { + let mut available = PollEvent::empty(); + + if !raw.buffer.is_empty() { + // In case, we just establish a fresh connection in non-blocking mode, we try to read data. + available.insert( + PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND, + ); + } + + let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt); + if diff < raw.peer_buf_alloc { + available.insert( + PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND, + ); + } + + let ret = event & available; + + if ret.is_empty() { + if event.intersects( + PollEvent::POLLIN | PollEvent::POLLRDNORM | PollEvent::POLLRDBAND, + ) { + raw.rx_waker.register(cx.waker()); + } + + if event.intersects( + PollEvent::POLLOUT | PollEvent::POLLWRNORM | PollEvent::POLLWRBAND, + ) { + raw.tx_waker.register(cx.waker()); + } + + Poll::Pending + } else { + Poll::Ready(Ok(ret)) + } + } + } + }) + .await + } + + fn bind(&self, endpoint: ListenEndpoint) -> io::Result<()> { + match endpoint { + ListenEndpoint::Vsock(ep) => { + self.port.store(ep.port, Ordering::Release); + if let Some(cid) = ep.cid { + self.cid.store(cid, Ordering::Release); + } else { + self.cid.store(u32::MAX, Ordering::Release); + } + VSOCK_MAP.lock().bind(ep.port) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + _ => Err(io::Error::EINVAL), + } + } + + fn connect(&self, endpoint: Endpoint) -> io::Result<()> { + match endpoint { + Endpoint::Vsock(ep) => { + const HEADER_SIZE: usize = core::mem::size_of::(); + let port = VSOCK_MAP.lock().connect(ep.port, ep.cid)?; + self.port.store(port, Ordering::Release); + self.port.store(ep.cid, Ordering::Release); + + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(ep.cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(ep.port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Request.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = + le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(0); + }); + + drop(driver_guard); + + poll_on( + async { + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Connected => Poll::Ready(Ok(())), + VsockState::Connecting => { + raw.rx_waker.register(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Err(io::Error::EBADF)), + } + }) + .await + }, + Some(Duration::from_millis(1000)), + ) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + _ => Err(io::Error::EINVAL), + } + } + + fn getpeername(&self) -> Option { + let port = self.port.load(Ordering::Acquire); + let guard = VSOCK_MAP.lock(); + let raw = guard.get_socket(port)?; + + Some(Endpoint::Vsock(VsockEndpoint::new( + raw.remote_port, + raw.remote_cid, + ))) + } + + fn getsockname(&self) -> Option { + let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid(); + + Some(Endpoint::Vsock(VsockEndpoint::new( + self.port.load(Ordering::Acquire), + local_cid.try_into().unwrap(), + ))) + } + + fn is_nonblocking(&self) -> bool { + self.nonblocking.load(Ordering::Acquire) + } + + fn listen(&self, _backlog: i32) -> io::Result<()> { + Ok(()) + } + + fn accept(&self) -> io::Result { + let port = self.port.load(Ordering::Acquire); + let cid = self.cid.load(Ordering::Acquire); + + let endpoint = block_on( + async { + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Listen => { + raw.rx_waker.register(cx.waker()); + Poll::Pending + } + VsockState::ReceiveRequest => { + let result = { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = + unsafe { &mut *(buffer.as_mut_ptr() as *mut Hdr) }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + if local_cid != cid.into() && cid != u32::MAX { + response.op = le16::from_ne(Op::Rst.into()) + } else { + response.op = le16::from_ne(Op::Response.into()); + } + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne( + crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, + ); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); + }); + + raw.state = VsockState::Connected; + + Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid)) + }; + + Poll::Ready(result) + } + _ => Poll::Ready(Err(Error::EBADF)), + } + }) + .await + }, + None, + )?; + + Ok(Endpoint::Vsock(endpoint)) + } + + fn shutdown(&self, _how: i32) -> io::Result<()> { + Ok(()) + } + + fn ioctl(&self, cmd: IoCtl, value: bool) -> io::Result<()> { + if cmd == IoCtl::NonBlocking { + if value { + trace!("set vsock device to nonblocking mode"); + self.nonblocking.store(true, Ordering::Release); + } else { + trace!("set vsock device to blocking mode"); + self.nonblocking.store(false, Ordering::Release); + } + + Ok(()) + } else { + Err(io::Error::EINVAL) + } + } + + // TODO: Remove allow once fixed: + // https://github.com/rust-lang/rust-clippy/issues/11380 + #[allow(clippy::needless_pass_by_ref_mut)] + async fn async_read(&self, buffer: &mut [u8]) -> io::Result { + let port = self.port.load(Ordering::Acquire); + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + + match raw.state { + VsockState::Connected => { + let len = core::cmp::min(buffer.len(), raw.buffer.len()); + + if len == 0 { + raw.rx_waker.register(cx.waker()); + Poll::Pending + } else { + let tmp: Vec<_> = raw.buffer.drain(..len).collect(); + buffer[..len].copy_from_slice(tmp.as_slice()); + + Poll::Ready(Ok(len)) + } + } + VsockState::Shutdown => { + let len = core::cmp::min(buffer.len(), raw.buffer.len()); + + if len == 0 { + Poll::Ready(Ok(0)) + } else { + let tmp: Vec<_> = raw.buffer.drain(..len).collect(); + buffer[..len].copy_from_slice(tmp.as_slice()); + + Poll::Ready(Ok(len)) + } + } + _ => Poll::Ready(Err(Error::EIO)), + } + }) + .await + } + + async fn async_write(&self, buffer: &[u8]) -> io::Result { + let port = self.port.load(Ordering::Acquire); + future::poll_fn(|cx| { + let mut guard = VSOCK_MAP.lock(); + let raw = guard.get_mut_socket(port).ok_or(Error::EINVAL)?; + let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt); + + match raw.state { + VsockState::Connected => { + if diff >= raw.peer_buf_alloc { + raw.tx_waker.register(cx.waker()); + Poll::Pending + } else { + const HEADER_SIZE: usize = core::mem::size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + let len = core::cmp::min( + buffer.len(), + usize::try_from(raw.peer_buf_alloc - diff).unwrap(), + ); + + driver_guard.send_packet(HEADER_SIZE + len, |virtio_buffer| { + let response = + unsafe { &mut *(virtio_buffer.as_mut_ptr() as *mut Hdr) }; + + raw.tx_cnt = raw.tx_cnt.wrapping_add(len.try_into().unwrap()); + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid as u64); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(len.try_into().unwrap()); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Rw.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne( + crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, + ); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); + + virtio_buffer[HEADER_SIZE..HEADER_SIZE + len] + .copy_from_slice(&buffer[..len]); + }); + + Poll::Ready(Ok(len)) + } + } + _ => Poll::Ready(Err(Error::EIO)), + } + }) + .await + } +} + +impl Clone for Socket { + fn clone(&self) -> Self { + Self { + port: AtomicU32::new(self.port.load(Ordering::Acquire)), + cid: AtomicU32::new(self.cid.load(Ordering::Acquire)), + nonblocking: AtomicBool::new(self.nonblocking.load(Ordering::Acquire)), + } + } +} + +impl Drop for Socket { + fn drop(&mut self) { + let port = self.port.load(Ordering::Acquire); + let mut guard = VSOCK_MAP.lock(); + guard.remove_socket(port); + } +} diff --git a/src/lib.rs b/src/lib.rs index 68d84da879..89eb303463 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ #![feature(asm_const)] #![feature(exposed_provenance)] #![feature(linked_list_cursors)] +#![feature(map_try_insert)] #![feature(maybe_uninit_as_bytes)] #![feature(maybe_uninit_slice)] #![feature(naked_functions)] diff --git a/src/syscalls/mod.rs b/src/syscalls/mod.rs index 7c8b662eb0..ed50da6025 100644 --- a/src/syscalls/mod.rs +++ b/src/syscalls/mod.rs @@ -40,7 +40,7 @@ mod processor; #[cfg(feature = "newlib")] mod recmutex; mod semaphore; -#[cfg(any(feature = "tcp", feature = "udp"))] +#[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] pub mod socket; mod spinlock; mod system; diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index 3a2a118724..5137e56590 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -3,22 +3,31 @@ use alloc::sync::Arc; use core::ffi::{c_char, c_void}; use core::mem::size_of; +#[allow(unused_imports)] use core::ops::DerefMut; +use cfg_if::cfg_if; #[cfg(any(feature = "tcp", feature = "udp"))] use smoltcp::wire::{IpAddress, IpEndpoint, IpListenEndpoint}; use crate::errno::*; +#[cfg(any(feature = "tcp", feature = "udp"))] use crate::executor::network::{NetworkState, NIC}; #[cfg(feature = "tcp")] use crate::fd::socket::tcp; #[cfg(feature = "udp")] use crate::fd::socket::udp; -use crate::fd::{get_object, insert_object, replace_object, ObjectInterface, SocketOption}; +#[cfg(feature = "vsock")] +use crate::fd::socket::vsock::{self, VsockEndpoint, VsockListenEndpoint}; +use crate::fd::{ + get_object, insert_object, replace_object, Endpoint, ListenEndpoint, ObjectInterface, + SocketOption, +}; use crate::syscalls::IoCtl; pub const AF_INET: i32 = 0; pub const AF_INET6: i32 = 1; +pub const AF_VSOCK: i32 = 2; pub const IPPROTO_IP: i32 = 0; pub const IPPROTO_IPV6: i32 = 41; pub const IPPROTO_TCP: i32 = 6; @@ -92,6 +101,55 @@ pub struct sockaddr { pub sa_data: [c_char; 14], } +#[cfg(feature = "vsock")] +#[repr(C)] +#[derive(Debug, Copy, Clone, Default)] +pub struct sockaddr_vm { + pub svm_len: u8, + pub svm_family: sa_family_t, + pub svm_reserved1: u16, + pub svm_port: u32, + pub svm_cid: u32, + pub svm_zero: [u8; 4], +} + +#[cfg(feature = "vsock")] +impl From for VsockListenEndpoint { + fn from(addr: sockaddr_vm) -> VsockListenEndpoint { + let port = addr.svm_port; + let cid = if addr.svm_cid < u32::MAX { + Some(addr.svm_cid) + } else { + None + }; + + VsockListenEndpoint::new(port, cid) + } +} + +#[cfg(feature = "vsock")] +impl From for VsockEndpoint { + fn from(addr: sockaddr_vm) -> VsockEndpoint { + let port = addr.svm_port; + let cid = addr.svm_cid; + + VsockEndpoint::new(port, cid) + } +} + +#[cfg(feature = "vsock")] +impl From for sockaddr_vm { + fn from(endpoint: VsockEndpoint) -> Self { + Self { + svm_len: core::mem::size_of::().try_into().unwrap(), + svm_family: AF_VSOCK.try_into().unwrap(), + svm_port: endpoint.port, + svm_cid: endpoint.cid, + ..Default::default() + } + } +} + #[repr(C)] #[derive(Debug, Default, Copy, Clone)] pub struct sockaddr_in { @@ -357,20 +415,16 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 domain, type_, protocol ); - if (domain != AF_INET && domain != AF_INET6) + if (domain != AF_INET && domain != AF_INET6 && domain != AF_VSOCK) || !type_.intersects(SockType::SOCK_STREAM | SockType::SOCK_DGRAM) || protocol != 0 { -EINVAL } else { - let mut guard = NIC.lock(); - - if let NetworkState::Initialized(nic) = guard.deref_mut() { - #[cfg(feature = "udp")] - if type_.contains(SockType::SOCK_DGRAM) { - let handle = nic.create_udp_handle().unwrap(); - drop(guard); - let socket = udp::Socket::new(handle); + #[cfg(feature = "vsock")] + { + if type_.contains(SockType::SOCK_STREAM) { + let socket = vsock::Socket::new(); if type_.contains(SockType::SOCK_NONBLOCK) { socket.ioctl(IoCtl::NonBlocking, true).unwrap(); @@ -380,26 +434,45 @@ pub extern "C" fn sys_socket(domain: i32, type_: SockType, protocol: i32) -> i32 return fd; } + } + #[cfg(any(feature = "tcp", feature = "udp"))] + { + let mut guard = NIC.lock(); + + if let NetworkState::Initialized(nic) = guard.deref_mut() { + #[cfg(feature = "udp")] + if type_.contains(SockType::SOCK_DGRAM) { + let handle = nic.create_udp_handle().unwrap(); + drop(guard); + let socket = udp::Socket::new(handle); + + if type_.contains(SockType::SOCK_NONBLOCK) { + socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + } - #[cfg(feature = "tcp")] - if type_.contains(SockType::SOCK_STREAM) { - let handle = nic.create_tcp_handle().unwrap(); - drop(guard); - let socket = tcp::Socket::new(handle); + let fd = insert_object(Arc::new(socket)).expect("FD is already used"); - if type_.contains(SockType::SOCK_NONBLOCK) { - socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + return fd; } - let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + #[cfg(feature = "tcp")] + if type_.contains(SockType::SOCK_STREAM) { + let handle = nic.create_tcp_handle().unwrap(); + drop(guard); + let socket = tcp::Socket::new(handle); - return fd; - } + if type_.contains(SockType::SOCK_NONBLOCK) { + socket.ioctl(IoCtl::NonBlocking, true).unwrap(); + } - -EINVAL - } else { - -EINVAL + let fd = insert_object(Arc::new(socket)).expect("FD is already used"); + + return fd; + } + } } + + -EINVAL } } @@ -412,33 +485,54 @@ pub unsafe extern "C" fn sys_accept(fd: i32, addr: *mut sockaddr, addrlen: *mut |v| { (*v).accept().map_or_else( |e| -num::ToPrimitive::to_i32(&e).unwrap(), - |endpoint| { - let new_obj = dyn_clone::clone_box(&*v); - replace_object(fd, Arc::from(new_obj)).unwrap(); - let new_fd = insert_object(v).unwrap(); - - if !addr.is_null() && !addrlen.is_null() { - let addrlen = unsafe { &mut *addrlen }; - - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + |endpoint| match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => { + let new_obj = dyn_clone::clone_box(&*v); + replace_object(fd, Arc::from(new_obj)).unwrap(); + let new_fd = insert_object(v).unwrap(); + + if !addr.is_null() && !addrlen.is_null() { + let addrlen = unsafe { &mut *addrlen }; + + match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } } } } + + new_fd } + #[cfg(feature = "vsock")] + Endpoint::Vsock(endpoint) => { + //let new_obj = dyn_clone::clone_box(&*v); + //replace_object(fd, Arc::from(new_obj)).unwrap(); + let new_fd = insert_object(v.clone()).unwrap(); + + if !addr.is_null() && !addrlen.is_null() { + let addrlen = unsafe { &mut *addrlen }; + + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_vm) }; + *addr = sockaddr_vm::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } + } - new_fd + new_fd + } }, ) }, @@ -461,20 +555,44 @@ pub extern "C" fn sys_listen(fd: i32, backlog: i32) -> i32 { #[hermit_macro::system] #[no_mangle] pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: socklen_t) -> i32 { - let endpoint = if namelen == size_of::().try_into().unwrap() { - IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in) }) - } else if namelen == size_of::().try_into().unwrap() { - IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in6) }) - } else { + if name.is_null() { return -crate::errno::EINVAL; - }; + } + + let family: i32 = unsafe { (*name).sa_family.into() }; let obj = get_object(fd); obj.map_or_else( |e| -num::ToPrimitive::to_i32(&e).unwrap(), - |v| { - (*v).bind(endpoint) - .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + |v| match family { + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in) }); + (*v).bind(ListenEndpoint::Ip(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET6 => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = IpListenEndpoint::from(unsafe { *(name as *const sockaddr_in6) }); + (*v).bind(ListenEndpoint::Ip(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + #[cfg(feature = "vsock")] + AF_VSOCK => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + let endpoint = VsockListenEndpoint::from(unsafe { *(name as *const sockaddr_vm) }); + (*v).bind(ListenEndpoint::Vsock(endpoint)) + .map_or_else(|e| -num::ToPrimitive::to_i32(&e).unwrap(), |_| 0) + } + _ => -crate::errno::EINVAL, }, ) } @@ -482,12 +600,39 @@ pub unsafe extern "C" fn sys_bind(fd: i32, name: *const sockaddr, namelen: sockl #[hermit_macro::system] #[no_mangle] pub unsafe extern "C" fn sys_connect(fd: i32, name: *const sockaddr, namelen: socklen_t) -> i32 { - let endpoint = if namelen == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(name as *const sockaddr_in) }) - } else if namelen == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(name as *const sockaddr_in6) }) - } else { + if name.is_null() { return -crate::errno::EINVAL; + } + + let sa_family = unsafe { (*name).sa_family as i32 }; + + let endpoint = match sa_family { + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Ip(IpEndpoint::from(unsafe { *(name as *const sockaddr_in) })) + } + #[cfg(any(feature = "tcp", feature = "udp"))] + AF_INET6 => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Ip(IpEndpoint::from(unsafe { *(name as *const sockaddr_in6) })) + } + #[cfg(feature = "vsock")] + AF_VSOCK => { + if namelen < size_of::().try_into().unwrap() { + return -crate::errno::EINVAL; + } + Endpoint::Vsock(VsockEndpoint::from(unsafe { + *(name as *const sockaddr_vm) + })) + } + _ => { + return -crate::errno::EINVAL; + } }; let obj = get_object(fd); @@ -515,21 +660,33 @@ pub unsafe extern "C" fn sys_getsockname( if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -crate::errno::EINVAL; + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + #[cfg(any(feature = "tcp", feature = "udp"))] + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } + } + }, + #[cfg(feature = "vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + warn!("unsupported device"); } else { return -crate::errno::EINVAL; } @@ -643,21 +800,32 @@ pub unsafe extern "C" fn sys_getpeername( if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -crate::errno::EINVAL; + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -crate::errno::EINVAL; + } + } + }, + #[cfg(feature = "vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + warn!("unsupported device"); } else { return -crate::errno::EINVAL; } @@ -741,25 +909,52 @@ pub unsafe extern "C" fn sys_sendto( addr: *const sockaddr, addr_len: socklen_t, ) -> isize { - let endpoint = if addr_len == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(addr as *const sockaddr_in) }) - } else if addr_len == size_of::().try_into().unwrap() { - IpEndpoint::from(unsafe { *(addr as *const sockaddr_in6) }) - } else { + let endpoint; + + if addr.is_null() || addr_len == 0 { return (-crate::errno::EINVAL).try_into().unwrap(); - }; - let slice = unsafe { core::slice::from_raw_parts(buf, len) }; - let obj = get_object(fd); + } - obj.map_or_else( - |e| -num::ToPrimitive::to_isize(&e).unwrap(), - |v| { - (*v).sendto(slice, endpoint).map_or_else( - |e| -num::ToPrimitive::to_isize(&e).unwrap(), - |v| v.try_into().unwrap(), - ) - }, - ) + cfg_if! { + if #[cfg(any(feature = "tcp", feature = "udp"))] { + let sa_family = unsafe { (*addr).sa_family as i32 }; + + if sa_family == AF_INET { + if addr_len < size_of::().try_into().unwrap() { + return (-crate::errno::EINVAL).try_into().unwrap(); + } + + endpoint = Some(Endpoint::Ip(IpEndpoint::from(unsafe {*(addr as *const sockaddr_in)}))); + } else if sa_family == AF_INET6 { + if addr_len < size_of::().try_into().unwrap() { + return (-crate::errno::EINVAL).try_into().unwrap(); + } + + endpoint = Some(Endpoint::Ip(IpEndpoint::from(unsafe { *(addr as *const sockaddr_in6) }))); + } else { + endpoint = None; + } + } else { + endpoint = None; + } + } + + if let Some(endpoint) = endpoint { + let slice = unsafe { core::slice::from_raw_parts(buf, len) }; + let obj = get_object(fd); + + obj.map_or_else( + |e| -num::ToPrimitive::to_isize(&e).unwrap(), + |v| { + (*v).sendto(slice, endpoint).map_or_else( + |e| -num::ToPrimitive::to_isize(&e).unwrap(), + |v| v.try_into().unwrap(), + ) + }, + ) + } else { + (-crate::errno::EINVAL).try_into().unwrap() + } } #[hermit_macro::system] @@ -781,26 +976,34 @@ pub unsafe extern "C" fn sys_recvfrom( |e| -num::ToPrimitive::to_isize(&e).unwrap(), |(len, endpoint)| { if !addr.is_null() && !addrlen.is_null() { + #[allow(unused_variables)] let addrlen = unsafe { &mut *addrlen }; - match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return (-crate::errno::EINVAL).try_into().unwrap(); + match endpoint { + #[cfg(any(feature = "tcp", feature = "udp"))] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in) }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return (-crate::errno::EINVAL).try_into().unwrap(); + } } - } - IpAddress::Ipv6(_) => { - if *addrlen >= size_of::().try_into().unwrap() { - let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return (-crate::errno::EINVAL).try_into().unwrap(); + IpAddress::Ipv6(_) => { + if *addrlen >= size_of::().try_into().unwrap() { + let addr = unsafe { &mut *(addr as *mut sockaddr_in6) }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return (-crate::errno::EINVAL).try_into().unwrap(); + } } + }, + #[cfg(feature = "vsock")] + _ => { + return (-crate::errno::EINVAL).try_into().unwrap(); } } }