diff --git a/src/migtd/src/migration/mod.rs b/src/migtd/src/migration/mod.rs index 3987fa99..59c5670f 100644 --- a/src/migtd/src/migration/mod.rs +++ b/src/migtd/src/migration/mod.rs @@ -5,8 +5,12 @@ pub mod data; pub mod event; pub mod logging; +#[cfg(feature = "policy_v2")] +pub mod pre_session_data; #[cfg(feature = "main")] pub mod session; +#[cfg(feature = "main")] +pub mod transport; use crate::driver::ticks::TimeoutError; use crate::ratls::RatlsError; diff --git a/src/migtd/src/migration/pre_session_data.rs b/src/migtd/src/migration/pre_session_data.rs new file mode 100644 index 00000000..94748abc --- /dev/null +++ b/src/migtd/src/migration/pre_session_data.rs @@ -0,0 +1,370 @@ +// Copyright (c) 2025 Intel Corporation +// +// SPDX-License-Identifier: BSD-2-Clause-Patent + +use super::MigrationResult; +use alloc::{vec, vec::Vec}; +use async_io::{AsyncRead, AsyncWrite}; + +type Result = core::result::Result; + +#[repr(C)] +pub(super) struct PreSessionMessage { + pub r#type: u8, + pub reserved: [u8; 3], + pub length: u32, // Length in bytes of the message payload +} + +impl PreSessionMessage { + const PRE_SESSION_DATA_TYPE: u8 = 1; + const START_SESSION_TYPE: u8 = 2; + const HELLO_PACKET_TYPE: u8 = 0xff; + + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self as *const Self as *const u8, size_of::()) } + } + + pub fn read_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < size_of::() { + log::error!( + "PreSessionMessage: Insufficient bytes to read header bytes.len() = {}\n", + bytes.len() + ); + return None; + } + let header = PreSessionMessage { + r#type: bytes[0], + reserved: bytes[1..4].try_into().unwrap(), + length: u32::from_le_bytes(bytes[4..8].try_into().unwrap()), + }; + Some(header) + } +} + +pub(super) struct HelloPacketPayload { + magic_word: [u8; 4], + lowest_supported_version: u16, + highest_supported_version: u16, +} + +impl HelloPacketPayload { + const HELLO_PACKET_PAYLOAD_SIZE: usize = 8; + const HELLO_PACKET_MAGIC_WORD: [u8; 4] = [b'M', b'G', b'T', b'D']; + const LOWEST_VERSION: u16 = 0x0100; + const HIGHEST_VERSION: u16 = 0x0100; + + pub const fn new() -> Self { + Self { + magic_word: Self::HELLO_PACKET_MAGIC_WORD, + lowest_supported_version: Self::LOWEST_VERSION, + highest_supported_version: Self::HIGHEST_VERSION, + } + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self as *const Self as *const u8, size_of::()) } + } + + pub fn read_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < size_of::() { + log::error!( + "HelloPacketPayload: Insufficient bytes to read header bytes.len() = {}\n", + bytes.len() + ); + return None; + } + let payload = HelloPacketPayload { + magic_word: bytes[..4].try_into().unwrap(), + lowest_supported_version: u16::from_le_bytes(bytes[4..6].try_into().unwrap()), + highest_supported_version: u16::from_le_bytes(bytes[6..8].try_into().unwrap()), + }; + + if payload.magic_word != HelloPacketPayload::HELLO_PACKET_MAGIC_WORD { + log::error!("HelloPacketPayload: Invalid magic word in hello packet\n"); + return None; + } + Some(payload) + } + + fn negotiate_supported_version(&self) -> Option { + let low = core::cmp::max(Self::LOWEST_VERSION, self.lowest_supported_version); + let high = core::cmp::min(Self::HIGHEST_VERSION, self.highest_supported_version); + if low > high { + None + } else { + Some(high) + } + } +} + +pub(super) async fn send_pre_session_data( + transport: &mut T, + data: &[u8], +) -> Result<()> { + let mut sent = 0; + while sent < data.len() { + let n = transport.write(&data[sent..]).await.map_err(|e| { + log::error!("send_pre_session_data: Network error: {:?}\n", e); + MigrationResult::NetworkError + })?; + sent += n; + } + Ok(()) +} + +pub(super) async fn receive_pre_session_data( + transport: &mut T, + data: &mut [u8], +) -> Result<()> { + let mut recvd = 0; + while recvd < data.len() { + let n = transport.read(&mut data[recvd..]).await.map_err(|e| { + log::error!("receive_pre_session_data: Network error: {:?}\n", e); + MigrationResult::NetworkError + })?; + recvd += n; + } + Ok(()) +} + +pub(super) async fn send_pre_session_data_packet( + pre_session_data: &[u8], + transport: &mut T, +) -> Result<()> { + let header = PreSessionMessage { + r#type: PreSessionMessage::PRE_SESSION_DATA_TYPE, + reserved: [0u8; 3], + length: pre_session_data.len() as u32, + }; + + send_pre_session_data(transport, header.as_bytes()) + .await + .map_err(|e| { + log::error!("send_pre_session_data header: Network error: {:?}\n", e); + e + })?; + send_pre_session_data(transport, pre_session_data) + .await + .map_err(|e| { + log::error!( + "send_pre_session_data pre_session_data: Network error: {:?}\n", + e + ); + e + }) +} + +pub(super) async fn receive_pre_session_data_packet( + transport: &mut T, +) -> Result> { + let mut header_buffer = [0u8; size_of::()]; + receive_pre_session_data(transport, &mut header_buffer) + .await + .map_err(|e| { + log::error!("receive_pre_session_data header: Network error: {:?}\n", e); + e + })?; + + let header = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { + log::error!("receive_pre_session_data_packet: Failed to read PreSessionMessage header\n"); + MigrationResult::InvalidParameter + })?; + if header.r#type != PreSessionMessage::PRE_SESSION_DATA_TYPE { + log::error!("PreSessionMessage: Invalid type in pre-session data packet\n"); + return Err(MigrationResult::InvalidParameter); + } + + let pre_session_data_payload_size = header.length as usize; + let mut pre_session_data_payload = vec![0u8; pre_session_data_payload_size]; + receive_pre_session_data(transport, &mut pre_session_data_payload) + .await + .map_err(|e| { + log::error!("receive_pre_session_data payload: Network error: {:?}\n", e); + e + })?; + + Ok(pre_session_data_payload) +} + +pub(super) async fn send_start_session_packet( + transport: &mut T, +) -> Result<()> { + let header = PreSessionMessage { + r#type: PreSessionMessage::START_SESSION_TYPE, + reserved: [0u8; 3], + length: 0, + }; + + send_pre_session_data(transport, header.as_bytes()) + .await + .map_err(|e| { + log::error!("send_start_session_packet: Network error: {:?}\n", e); + e + }) +} + +pub(super) async fn receive_start_session_packet( + transport: &mut T, +) -> Result<()> { + let mut header_buffer = [0u8; size_of::()]; + receive_pre_session_data(transport, &mut header_buffer) + .await + .map_err(|e| { + log::error!("receive_start_session_packet: Network error: {:?}\n", e); + e + })?; + + let packet = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { + log::error!("receive_start_session_packet: Failed to read PreSessionMessage header\n"); + MigrationResult::InvalidParameter + })?; + + // Sanity checks + if packet.r#type != PreSessionMessage::START_SESSION_TYPE { + log::error!("PreSessionMessage: Invalid type in start session packet\n"); + return Err(MigrationResult::InvalidParameter); + } + if packet.length != 0 { + log::error!("PreSessionMessage: Invalid length in start session packet\n"); + return Err(MigrationResult::InvalidParameter); + } + + Ok(()) +} + +async fn send_hello_packet(transport: &mut T) -> Result<()> { + let header = PreSessionMessage { + r#type: PreSessionMessage::HELLO_PACKET_TYPE, + reserved: [0u8; 3], + length: 8, + }; + send_pre_session_data(transport, header.as_bytes()) + .await + .map_err(|e| { + log::error!("send_hello_packet: Network error: {:?}\n", e); + e + })?; + + let payload = HelloPacketPayload::new(); + send_pre_session_data(transport, payload.as_bytes()) + .await + .map_err(|e| { + log::error!("send_hello_packet: Network error: {:?}\n", e); + e + }) +} + +async fn receive_hello_packet( + transport: &mut T, +) -> Result { + let mut header_buffer = [0u8; size_of::()]; + receive_pre_session_data(transport, &mut header_buffer) + .await + .map_err(|e| { + log::error!("receive_hello_packet: Network error: {:?}\n", e); + e + })?; + + let header = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { + log::error!("receive_hello_packet: Failed to read PreSessionMessage header\n"); + MigrationResult::InvalidParameter + })?; + + // Sanity checks + if header.r#type != PreSessionMessage::HELLO_PACKET_TYPE { + log::error!("PreSessionMessage: Invalid type in hello packet\n"); + return Err(MigrationResult::InvalidParameter); + } + if header.length as usize != HelloPacketPayload::HELLO_PACKET_PAYLOAD_SIZE { + log::error!("PreSessionMessage: Invalid length in hello packet\n"); + return Err(MigrationResult::InvalidParameter); + } + + // Receive hello packet payload + let mut hello_payload = vec![0u8; HelloPacketPayload::HELLO_PACKET_PAYLOAD_SIZE]; + receive_pre_session_data(transport, &mut hello_payload) + .await + .map_err(|e| { + log::error!("receive_hello_packet payload: Network error: {:?}\n", e); + e + })?; + + HelloPacketPayload::read_from_bytes(&hello_payload) + .ok_or(MigrationResult::InvalidParameter) + .map_err(|_| { + log::error!("receive_hello_packet: Failed to read HelloPacketPayload\n"); + MigrationResult::InvalidParameter + }) +} + +// Exchange hello packet and negotiate a pre-session message version +pub(super) async fn exchange_hello_packet( + transport: &mut T, +) -> Result { + send_hello_packet(transport).await.map_err(|e| { + log::error!("exchange_hello_packet: send_hello_packet error: {:?}\n", e); + e + })?; + let remote = receive_hello_packet(transport).await.map_err(|e| { + log::error!( + "exchange_hello_packet: receive_hello_packet error: {:?}\n", + e + ); + e + })?; + + remote + .negotiate_supported_version() + .ok_or(MigrationResult::InvalidParameter) +} + +#[cfg(feature = "policy_v2")] +pub(super) async fn pre_session_data_exchange( + transport: &mut T, + pre_session_data: &[u8], +) -> Result> { + let version = exchange_hello_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: exchange_hello_packet error: {:?}\n", + e + ); + e + })?; + log::info!("Pre-Session-Message Version: 0x{:04x}\n", version); + + send_pre_session_data_packet(pre_session_data, transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + let remote_policy = receive_pre_session_data_packet(transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + + send_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: send_start_session_packet error: {:?}\n", + e + ); + e + })?; + receive_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_start_session_packet error: {:?}\n", + e + ); + e + })?; + + Ok(remote_policy) +} diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index 8570b594..44d75a2f 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -5,10 +5,13 @@ #[cfg(feature = "vmcall-raw")] use crate::migration::event::VMCALL_MIG_REPORTSTATUS_FLAGS; #[cfg(feature = "policy_v2")] +use crate::migration::pre_session_data::pre_session_data_exchange; +use crate::migration::transport::setup_transport; +use crate::migration::transport::shutdown_transport; +use crate::migration::transport::TransportType; +#[cfg(feature = "policy_v2")] use alloc::boxed::Box; use alloc::collections::BTreeSet; -#[cfg(feature = "policy_v2")] -use async_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "vmcall-raw")] use core::sync::atomic::AtomicBool; #[cfg(any(feature = "vmcall-interrupt", feature = "vmcall-raw"))] @@ -720,458 +723,252 @@ pub fn report_status(status: u8, request_id: u64) -> Result<()> { Ok(()) } -#[cfg(feature = "policy_v2")] -#[repr(C)] -struct PreSessionMessage { - pub r#type: u8, - pub reserved: [u8; 3], - pub length: u32, // Length in bytes of the message payload -} - -#[cfg(feature = "policy_v2")] -impl PreSessionMessage { - const PRE_SESSION_DATA_TYPE: u8 = 1; - const START_SESSION_TYPE: u8 = 2; - const HELLO_PACKET_TYPE: u8 = 0xff; - - pub fn as_bytes(&self) -> &[u8] { - unsafe { core::slice::from_raw_parts(self as *const Self as *const u8, size_of::()) } - } - - pub fn read_from_bytes(bytes: &[u8]) -> Option { - if bytes.len() < size_of::() { - log::error!( - "PreSessionMessage: Insufficient bytes to read header bytes.len() = {}\n", - bytes.len() - ); - return None; - } - let header = PreSessionMessage { - r#type: bytes[0], - reserved: bytes[1..4].try_into().unwrap(), - length: u32::from_le_bytes(bytes[4..8].try_into().unwrap()), - }; - Some(header) - } -} - -#[cfg(feature = "policy_v2")] -struct HelloPacketPayload { - magic_word: [u8; 4], - lowest_supported_version: u16, - highest_supported_version: u16, -} - -#[cfg(feature = "policy_v2")] -impl HelloPacketPayload { - const HELLO_PACKET_PAYLOAD_SIZE: usize = 8; - const HELLO_PACKET_MAGIC_WORD: [u8; 4] = [b'M', b'G', b'T', b'D']; - const LOWEST_VERSION: u16 = 0x0100; - const HIGHEST_VERSION: u16 = 0x0100; - - pub const fn new() -> Self { - Self { - magic_word: Self::HELLO_PACKET_MAGIC_WORD, - lowest_supported_version: Self::LOWEST_VERSION, - highest_supported_version: Self::HIGHEST_VERSION, - } - } - - pub fn as_bytes(&self) -> &[u8] { - unsafe { core::slice::from_raw_parts(self as *const Self as *const u8, size_of::()) } - } - - pub fn read_from_bytes(bytes: &[u8]) -> Option { - if bytes.len() < size_of::() { - log::error!( - "HelloPacketPayload: Insufficient bytes to read header bytes.len() = {}\n", - bytes.len() - ); - return None; - } - let payload = HelloPacketPayload { - magic_word: bytes[..4].try_into().unwrap(), - lowest_supported_version: u16::from_le_bytes(bytes[4..6].try_into().unwrap()), - highest_supported_version: u16::from_le_bytes(bytes[6..8].try_into().unwrap()), - }; - - if payload.magic_word != HelloPacketPayload::HELLO_PACKET_MAGIC_WORD { - log::error!("HelloPacketPayload: Invalid magic word in hello packet\n"); - return None; - } - Some(payload) - } - - fn negotiate_supported_version(&self) -> Option { - let low = core::cmp::max(Self::LOWEST_VERSION, self.lowest_supported_version); - let high = core::cmp::min(Self::HIGHEST_VERSION, self.highest_supported_version); - if low > high { - None - } else { - Some(high) - } - } -} - -#[cfg(feature = "policy_v2")] -async fn send_pre_session_data( - transport: &mut T, - data: &[u8], -) -> Result<()> { - let mut sent = 0; - while sent < data.len() { - let n = transport.write(&data[sent..]).await.map_err(|e| { - log::error!("send_pre_session_data: Network error: {:?}\n", e); - MigrationResult::NetworkError - })?; - sent += n; - } - Ok(()) -} - -#[cfg(feature = "policy_v2")] -async fn receive_pre_session_data( - transport: &mut T, - data: &mut [u8], -) -> Result<()> { - let mut recvd = 0; - while recvd < data.len() { - let n = transport.read(&mut data[recvd..]).await.map_err(|e| { - log::error!("receive_pre_session_data: Network error: {:?}\n", e); - MigrationResult::NetworkError - })?; - recvd += n; - } - Ok(()) -} - -#[cfg(feature = "policy_v2")] -async fn send_pre_session_data_packet( - pre_session_data: &[u8], - transport: &mut T, +#[cfg(not(feature = "spdm_attestation"))] +async fn migration_src_exchange_msk( + transport: TransportType, + info: &MigrationInformation, + data: &mut Vec, + exchange_information: &ExchangeInformation, + remote_information: &mut ExchangeInformation, + #[cfg(feature = "policy_v2")] remote_policy: Vec, ) -> Result<()> { - let header = PreSessionMessage { - r#type: PreSessionMessage::PRE_SESSION_DATA_TYPE, - reserved: [0u8; 3], - length: pre_session_data.len() as u32, - }; + const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds - send_pre_session_data(transport, header.as_bytes()) - .await - .map_err(|e| { - log::error!("send_pre_session_data header: Network error: {:?}\n", e); - e - })?; - send_pre_session_data(transport, pre_session_data) - .await - .map_err(|e| { - log::error!( - "send_pre_session_data pre_session_data: Network error: {:?}\n", - e - ); - e - }) -} - -#[cfg(feature = "policy_v2")] -async fn receive_pre_session_data_packet( - transport: &mut T, -) -> Result> { - let mut header_buffer = [0u8; size_of::()]; - receive_pre_session_data(transport, &mut header_buffer) - .await - .map_err(|e| { - log::error!("receive_pre_session_data header: Network error: {:?}\n", e); - e - })?; - - let header = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { - log::error!("receive_pre_session_data_packet: Failed to read PreSessionMessage header\n"); - MigrationResult::InvalidParameter + // TLS client + let mut ratls_client = ratls::client( + transport, + #[cfg(feature = "policy_v2")] + remote_policy, + #[cfg(feature = "vmcall-raw")] + data, + ) + .map_err(|_| { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice( + &format!( + "Error: exchange_msk(): Failed in ratls transport. Migration ID: {:x}\n", + info.mig_info.mig_request_id + ) + .into_bytes(), + ); + log::error!( + "exchange_msk(): Failed in ratls transport. Migration ID: {}\n", + info.mig_info.mig_request_id + ); + MigrationResult::SecureSessionError })?; - if header.r#type != PreSessionMessage::PRE_SESSION_DATA_TYPE { - log::error!("PreSessionMessage: Invalid type in pre-session data packet\n"); - return Err(MigrationResult::InvalidParameter); - } - - let pre_session_data_payload_size = header.length as usize; - let mut pre_session_data_payload = vec![0u8; pre_session_data_payload_size]; - receive_pre_session_data(transport, &mut pre_session_data_payload) - .await - .map_err(|e| { - log::error!("receive_pre_session_data payload: Network error: {:?}\n", e); - e - })?; - Ok(pre_session_data_payload) -} - -#[cfg(feature = "policy_v2")] -async fn send_start_session_packet( - transport: &mut T, -) -> Result<()> { - let header = PreSessionMessage { - r#type: PreSessionMessage::START_SESSION_TYPE, - reserved: [0u8; 3], - length: 0, - }; - - send_pre_session_data(transport, header.as_bytes()) - .await - .map_err(|e| { - log::error!("send_start_session_packet: Network error: {:?}\n", e); - e - }) -} - -#[cfg(feature = "policy_v2")] -async fn receive_start_session_packet( - transport: &mut T, -) -> Result<()> { - let mut header_buffer = [0u8; size_of::()]; - receive_pre_session_data(transport, &mut header_buffer) - .await - .map_err(|e| { - log::error!("receive_start_session_packet: Network error: {:?}\n", e); - e - })?; - - let packet = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { - log::error!("receive_start_session_packet: Failed to read PreSessionMessage header\n"); - MigrationResult::InvalidParameter + // MigTD-S send Migration Session Forward key to peer + with_timeout( + TLS_TIMEOUT, + ratls_client.write(exchange_information.as_bytes()), + ) + .await + .map_err(|e| { + log::error!("exchange_msk: ratls_client.write timeout error: {:?}\n", e); + e + })? + .map_err(|e| { + log::error!("exchange_msk: ratls_client.write error: {:?}\n", e); + e })?; - - // Sanity checks - if packet.r#type != PreSessionMessage::START_SESSION_TYPE { - log::error!("PreSessionMessage: Invalid type in start session packet\n"); - return Err(MigrationResult::InvalidParameter); - } - if packet.length != 0 { - log::error!("PreSessionMessage: Invalid length in start session packet\n"); - return Err(MigrationResult::InvalidParameter); + let size = with_timeout( + TLS_TIMEOUT, + ratls_client.read(remote_information.as_bytes_mut()), + ) + .await + .map_err(|e| { + log::error!("exchange_msk: ratls_client.read timeout error: {:?}\n", e); + e + })? + .map_err(|e| { + log::error!("exchange_msk: ratls_client.read error: {:?}\n", e); + e + })?; + if size < size_of::() { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice( + &format!( + "Error: exchange_msk(): Incorrect ExchangeInformation size Migration ID: {:x}. Size - Expected: {:x} Actual: {:x}\n", + info.mig_info.mig_request_id, + size_of::(), + size + ) + .into_bytes(), + ); + log::error!("exchange_msk(): Incorrect ExchangeInformation size Migration ID: {}. Size - Expected: {} Actual: {}\n", info.mig_info.mig_request_id, size_of::(), size); + return Err(MigrationResult::NetworkError); } - + shutdown_transport(ratls_client.transport_mut(), info, data).await?; Ok(()) } -#[cfg(feature = "policy_v2")] -async fn send_hello_packet(transport: &mut T) -> Result<()> { - let header = PreSessionMessage { - r#type: PreSessionMessage::HELLO_PACKET_TYPE, - reserved: [0u8; 3], - length: 8, - }; - send_pre_session_data(transport, header.as_bytes()) - .await - .map_err(|e| { - log::error!("send_hello_packet: Network error: {:?}\n", e); - e - })?; - - let payload = HelloPacketPayload::new(); - send_pre_session_data(transport, payload.as_bytes()) - .await - .map_err(|e| { - log::error!("send_hello_packet: Network error: {:?}\n", e); - e - }) -} - -#[cfg(feature = "policy_v2")] -async fn receive_hello_packet( - transport: &mut T, -) -> Result { - let mut header_buffer = [0u8; size_of::()]; - receive_pre_session_data(transport, &mut header_buffer) - .await - .map_err(|e| { - log::error!("receive_hello_packet: Network error: {:?}\n", e); - e - })?; +#[cfg(not(feature = "spdm_attestation"))] +async fn migration_dst_exchange_msk( + transport: TransportType, + info: &MigrationInformation, + data: &mut Vec, + exchange_information: &ExchangeInformation, + remote_information: &mut ExchangeInformation, + #[cfg(feature = "policy_v2")] remote_policy: Vec, +) -> Result<()> { + const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds - let header = PreSessionMessage::read_from_bytes(&header_buffer).ok_or_else(|| { - log::error!("receive_hello_packet: Failed to read PreSessionMessage header\n"); - MigrationResult::InvalidParameter + // TLS server + let mut ratls_server = ratls::server( + transport, + #[cfg(feature = "policy_v2")] + remote_policy, + ) + .map_err(|_| { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice( + &format!( + "Error: exchange_msk(): Failed in ratls transport. Migration ID: {:x}\n", + info.mig_info.mig_request_id + ) + .into_bytes(), + ); + log::error!( + "exchange_msk(): Failed in ratls transport. Migration ID: {}\n", + info.mig_info.mig_request_id + ); + MigrationResult::SecureSessionError })?; - // Sanity checks - if header.r#type != PreSessionMessage::HELLO_PACKET_TYPE { - log::error!("PreSessionMessage: Invalid type in hello packet\n"); - return Err(MigrationResult::InvalidParameter); - } - if header.length as usize != HelloPacketPayload::HELLO_PACKET_PAYLOAD_SIZE { - log::error!("PreSessionMessage: Invalid length in hello packet\n"); - return Err(MigrationResult::InvalidParameter); + with_timeout( + TLS_TIMEOUT, + ratls_server.write(exchange_information.as_bytes()), + ) + .await + .map_err(|e| { + log::error!("exchange_msk: ratls_server.write timeout error: {:?}\n", e); + e + })? + .map_err(|e| { + log::error!("exchange_msk: ratls_server.write error: {:?}\n", e); + e + })?; + let size = with_timeout( + TLS_TIMEOUT, + ratls_server.read(remote_information.as_bytes_mut()), + ) + .await + .map_err(|e| { + log::error!("exchange_msk: ratls_server.read timeout error: {:?}\n", e); + e + })? + .map_err(|e| { + log::error!("exchange_msk: ratls_server.read error: {:?}\n", e); + e + })?; + if size < size_of::() { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice(&format!("Error: exchange_msk(): Incorrect ExchangeInformation size Migration ID: {:x}. Size - Expected: {:x} Actual: {:x}\n", info.mig_info.mig_request_id, size_of::(), size).into_bytes()); + log::error!("exchange_msk(): Incorrect ExchangeInformation size Migration ID: {}. Size - Expected: {} Actual: {}\n", info.mig_info.mig_request_id, size_of::(), size); + return Err(MigrationResult::NetworkError); } - - // Receive hello packet payload - let mut hello_payload = vec![0u8; HelloPacketPayload::HELLO_PACKET_PAYLOAD_SIZE]; - receive_pre_session_data(transport, &mut hello_payload) - .await - .map_err(|e| { - log::error!("receive_hello_packet payload: Network error: {:?}\n", e); - e - })?; - - HelloPacketPayload::read_from_bytes(&hello_payload) - .ok_or(MigrationResult::InvalidParameter) - .map_err(|_| { - log::error!("receive_hello_packet: Failed to read HelloPacketPayload\n"); - MigrationResult::InvalidParameter - }) + shutdown_transport(ratls_server.transport_mut(), info, data).await?; + Ok(()) } -// Exchange hello packet and negotiate a pre-session message version -#[cfg(feature = "policy_v2")] -async fn exchange_hello_packet( - transport: &mut T, -) -> Result { - send_hello_packet(transport).await.map_err(|e| { - log::error!("exchange_hello_packet: send_hello_packet error: {:?}\n", e); - e +#[cfg(feature = "spdm_attestation")] +async fn migration_src_exchange_msk( + transport: TransportType, + info: &MigrationInformation, + #[cfg(feature = "policy_v2")] remote_policy: Vec, +) -> Result<()> { + const SPDM_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds + let mut spdm_requester = spdm::spdm_requester(transport).map_err(|_e| { + log::error!( + "exchange_msk(): Failed in spdm_requester transport. Migration ID: {}\n", + info.mig_info.mig_request_id + ); + MigrationResult::SecureSessionError })?; - let remote = receive_hello_packet(transport).await.map_err(|e| { + with_timeout( + SPDM_TIMEOUT, + spdm::spdm_requester_transfer_msk( + &mut spdm_requester, + &info.mig_info, + #[cfg(feature = "policy_v2")] + remote_policy, + ), + ) + .await + .map_err(|e| { log::error!( - "exchange_hello_packet: receive_hello_packet error: {:?}\n", + "exchange_msk: spdm_requester_transfer_msk timeout error: {:?}\n", e ); e + })? + .map_err(|e| { + log::error!("exchange_msk: spdm_requester_transfer_msk error: {:?}\n", e); + e })?; - - remote - .negotiate_supported_version() - .ok_or(MigrationResult::InvalidParameter) + log::info!("MSK exchange completed\n"); + Ok(()) } -#[cfg(feature = "policy_v2")] -async fn pre_session_data_exchange( - transport: &mut T, -) -> Result> { - use crate::config; - - let version = exchange_hello_packet(transport).await.map_err(|e| { +#[cfg(feature = "spdm_attestation")] +async fn migration_dst_exchange_msk( + transport: TransportType, + info: &MigrationInformation, + #[cfg(feature = "policy_v2")] remote_policy: Vec, +) -> Result<()> { + const SPDM_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds + let mut spdm_responder = spdm::spdm_responder(transport).map_err(|_e| { log::error!( - "pre_session_data_exchange: exchange_hello_packet error: {:?}\n", - e + "exchange_msk(): Failed in spdm_responder transport. Migration ID: {}\n", + info.mig_info.mig_request_id ); - e + MigrationResult::SecureSessionError })?; - log::info!("Pre-Session-Message Version: 0x{:04x}\n", version); - - let policy = config::get_policy() - .ok_or(MigrationResult::InvalidParameter) - .map_err(|e| { - log::error!("pre_session_data_exchange: get_policy error: {:?}\n", e); - e - })?; - send_pre_session_data_packet(policy, transport) - .await - .map_err(|e| { - log::error!( - "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", - e - ); - e - })?; - let remote_policy = receive_pre_session_data_packet(transport) - .await - .map_err(|e| { - log::error!( - "pre_session_data_exchange: receive_pre_session_data_packet error: {:?}\n", - e - ); - e - })?; - send_start_session_packet(transport).await.map_err(|e| { + with_timeout( + SPDM_TIMEOUT, + spdm::spdm_responder_transfer_msk( + &mut spdm_responder, + &info.mig_info, + #[cfg(feature = "policy_v2")] + remote_policy, + ), + ) + .await + .map_err(|e| { log::error!( - "pre_session_data_exchange: send_start_session_packet error: {:?}\n", + "exchange_msk: spdm_responder_transfer_msk timeout error: {:?}\n", e ); e - })?; - receive_start_session_packet(transport).await.map_err(|e| { - log::error!( - "pre_session_data_exchange: receive_start_session_packet error: {:?}\n", - e - ); + })? + .map_err(|e| { + log::error!("exchange_msk: spdm_responder_transfer_msk error: {:?}\n", e); e })?; - - Ok(remote_policy) + log::info!("MSK exchange completed\n"); + Ok(()) } #[cfg(feature = "main")] pub async fn exchange_msk(info: &MigrationInformation, data: &mut Vec) -> Result<()> { - #[cfg(not(feature = "vmcall-raw"))] - let _ = data; - #[cfg(feature = "policy_v2")] - let mut transport; - #[cfg(not(feature = "policy_v2"))] - let transport; - - #[cfg(feature = "vmcall-raw")] - { - use vmcall_raw::stream::VmcallRaw; - let mut vmcall_raw_instance = VmcallRaw::new_with_mid(info.mig_info.mig_request_id) - .map_err(|e| { - data.extend_from_slice(&format!("Error: exchange_msk(): Failed to create vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", info.mig_info.mig_request_id, e).into_bytes()); - log::error!("exchange_msk: Failed to create vmcall_raw_instance with Migration ID: {} errorcode: {:?}\n", info.mig_info.mig_request_id, e); - MigrationResult::InvalidParameter - })?; - - vmcall_raw_instance - .connect() - .await - .map_err(|e| { - data.extend_from_slice(&format!("Error: exchange_msk(): Failed to connect vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", info.mig_info.mig_request_id, e).into_bytes()); - log::error!("exchange_msk: Failed to connect vmcall_raw_instance with Migration ID: {} errorcode: {:?}\n", info.mig_info.mig_request_id, e); - MigrationResult::InvalidParameter - })?; - transport = vmcall_raw_instance; - } - - #[cfg(feature = "virtio-serial")] - { - use virtio_serial::VirtioSerialPort; - const VIRTIO_SERIAL_PORT_ID: u32 = 1; - - let port = VirtioSerialPort::new(VIRTIO_SERIAL_PORT_ID); - port.open()?; - transport = port; - }; - - #[cfg(not(feature = "virtio-serial"))] - #[cfg(not(feature = "vmcall-raw"))] - { - use vsock::{stream::VsockStream, VsockAddr}; - - #[cfg(feature = "virtio-vsock")] - let mut vsock = VsockStream::new()?; - - #[cfg(feature = "vmcall-vsock")] - let mut vsock = VsockStream::new_with_cid( - info.mig_socket_info.mig_td_cid, - info.mig_info.mig_request_id, - )?; - - // Establish the vsock connection with host - vsock - .connect(&VsockAddr::new( - info.mig_socket_info.mig_td_cid as u32, - info.mig_socket_info.mig_channel_port, - )) - .await?; - transport = vsock; - }; + let mut transport = setup_transport(info, data).await?; // Exchange policy firstly because of the message size limitation of TLS protocol #[cfg(feature = "policy_v2")] const PRE_SESSION_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds #[cfg(feature = "policy_v2")] + let policy = crate::config::get_policy() + .ok_or(MigrationResult::InvalidParameter) + .map_err(|e| { + log::error!("pre_session_data_exchange: get_policy error: {:?}\n", e); + e + })?; + #[cfg(feature = "policy_v2")] let remote_policy = Box::pin(with_timeout( PRE_SESSION_TIMEOUT, - pre_session_data_exchange(&mut transport), + pre_session_data_exchange(&mut transport, policy), )) .await .map_err(|e| { @@ -1188,8 +985,6 @@ pub async fn exchange_msk(info: &MigrationInformation, data: &mut Vec) -> Re #[cfg(not(feature = "spdm_attestation"))] { - const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds - let mut remote_information = ExchangeInformation::default(); let mut exchange_information = exchange_info(&info.mig_info, info.is_src()).map_err(|e| { @@ -1199,169 +994,27 @@ pub async fn exchange_msk(info: &MigrationInformation, data: &mut Vec) -> Re // Establish TLS layer connection and negotiate the MSK if info.is_src() { - // TLS client - let mut ratls_client = ratls::client( + migration_src_exchange_msk( transport, + info, + data, + &exchange_information, + &mut remote_information, #[cfg(feature = "policy_v2")] remote_policy, - #[cfg(feature = "vmcall-raw")] - data, - ) - .map_err(|_| { - #[cfg(feature = "vmcall-raw")] - data.extend_from_slice( - &format!( - "Error: exchange_msk(): Failed in ratls transport. Migration ID: {:x}\n", - info.mig_info.mig_request_id - ) - .into_bytes(), - ); - log::error!( - "exchange_msk(): Failed in ratls transport. Migration ID: {}\n", - info.mig_info.mig_request_id - ); - MigrationResult::SecureSessionError - })?; - - // MigTD-S send Migration Session Forward key to peer - with_timeout( - TLS_TIMEOUT, - ratls_client.write(exchange_information.as_bytes()), - ) - .await - .map_err(|e| { - log::error!("exchange_msk: ratls_client.write timeout error: {:?}\n", e); - e - })? - .map_err(|e| { - log::error!("exchange_msk: ratls_client.write error: {:?}\n", e); - e - })?; - let size = with_timeout( - TLS_TIMEOUT, - ratls_client.read(remote_information.as_bytes_mut()), ) - .await - .map_err(|e| { - log::error!("exchange_msk: ratls_client.read timeout error: {:?}\n", e); - e - })? - .map_err(|e| { - log::error!("exchange_msk: ratls_client.read error: {:?}\n", e); - e - })?; - if size < size_of::() { - #[cfg(feature = "vmcall-raw")] - data.extend_from_slice( - &format!( - "Error: exchange_msk(): Incorrect ExchangeInformation size Migration ID: {:x}. Size - Expected: {:x} Actual: {:x}\n", - info.mig_info.mig_request_id, - size_of::(), - size - ) - .into_bytes(), - ); - log::error!("exchange_msk(): Incorrect ExchangeInformation size Migration ID: {}. Size - Expected: {} Actual: {}\n", info.mig_info.mig_request_id, size_of::(), size); - return Err(MigrationResult::NetworkError); - } - #[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] - ratls_client.transport_mut().shutdown().await.map_err(|e| { - log::error!( - "exchange_msk: ratls_client.transport_mut().shutdown() error: {:?}\n", - e - ); - e - })?; - - #[cfg(feature = "vmcall-raw")] - ratls_client - .transport_mut() - .shutdown() - .await - .map_err(|e| { - data.extend_from_slice( - &format!( - "Error: exchange_msk(): Failed to transport in vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", - info.mig_info.mig_request_id, - e - ) - .into_bytes(), - ); - log::error!( - "exchange_msk: Failed to transport in vmcall_raw_instance with Migration ID: {} errorcode: {}", - info.mig_info.mig_request_id, - e - ); - MigrationResult::InvalidParameter - })?; + .await?; } else { - // TLS server - let mut ratls_server = ratls::server( + migration_dst_exchange_msk( transport, + info, + data, + &exchange_information, + &mut remote_information, #[cfg(feature = "policy_v2")] remote_policy, ) - .map_err(|_| { - #[cfg(feature = "vmcall-raw")] - data.extend_from_slice( - &format!( - "Error: exchange_msk(): Failed in ratls transport. Migration ID: {:x}\n", - info.mig_info.mig_request_id - ) - .into_bytes(), - ); - log::error!( - "exchange_msk(): Failed in ratls transport. Migration ID: {}\n", - info.mig_info.mig_request_id - ); - MigrationResult::SecureSessionError - })?; - - with_timeout( - TLS_TIMEOUT, - ratls_server.write(exchange_information.as_bytes()), - ) - .await - .map_err(|e| { - log::error!("exchange_msk: ratls_server.write timeout error: {:?}\n", e); - e - })? - .map_err(|e| { - log::error!("exchange_msk: ratls_server.write error: {:?}\n", e); - e - })?; - let size = with_timeout( - TLS_TIMEOUT, - ratls_server.read(remote_information.as_bytes_mut()), - ) - .await - .map_err(|e| { - log::error!("exchange_msk: ratls_server.read timeout error: {:?}\n", e); - e - })? - .map_err(|e| { - log::error!("exchange_msk: ratls_server.read error: {:?}\n", e); - e - })?; - if size < size_of::() { - #[cfg(feature = "vmcall-raw")] - data.extend_from_slice(&format!("Error: exchange_msk(): Incorrect ExchangeInformation size Migration ID: {:x}. Size - Expected: {:x} Actual: {:x}\n", info.mig_info.mig_request_id, size_of::(), size).into_bytes()); - log::error!("exchange_msk(): Incorrect ExchangeInformation size Migration ID: {}. Size - Expected: {} Actual: {}\n", info.mig_info.mig_request_id, size_of::(), size); - return Err(MigrationResult::NetworkError); - } - #[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] - ratls_server.transport_mut().shutdown().await?; - - #[cfg(feature = "vmcall-raw")] - ratls_server - .transport_mut() - .shutdown() - .await - .map_err(|e| { - data.extend_from_slice(&format!("Error: exchange_msk(): Failed to transport in vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", info.mig_info.mig_request_id, e).into_bytes()); - log::error!("exchange_msk: Failed to transport in vmcall_raw_instance with Migration ID: {} errorcode: {}\n", info.mig_info.mig_request_id, e); - MigrationResult::InvalidParameter - })?; + .await?; } let mig_ver = cal_mig_version(info.is_src(), &exchange_information, &remote_information) @@ -1394,68 +1047,22 @@ pub async fn exchange_msk(info: &MigrationInformation, data: &mut Vec) -> Re #[cfg(feature = "spdm_attestation")] { - const SPDM_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds if info.is_src() { - let mut spdm_requester = spdm::spdm_requester(transport).map_err(|_e| { - log::error!( - "exchange_msk(): Failed in spdm_requester transport. Migration ID: {}\n", - info.mig_info.mig_request_id - ); - MigrationResult::SecureSessionError - })?; - with_timeout( - SPDM_TIMEOUT, - spdm::spdm_requester_transfer_msk( - &mut spdm_requester, - &info.mig_info, - #[cfg(feature = "policy_v2")] - remote_policy, - ), + migration_src_exchange_msk( + transport, + info, + #[cfg(feature = "policy_v2")] + remote_policy, ) - .await - .map_err(|e| { - log::error!( - "exchange_msk: spdm_requester_transfer_msk timeout error: {:?}\n", - e - ); - e - })? - .map_err(|e| { - log::error!("exchange_msk: spdm_requester_transfer_msk error: {:?}\n", e); - e - })?; - log::info!("MSK exchange completed\n"); + .await?; } else { - let mut spdm_responder = spdm::spdm_responder(transport).map_err(|_e| { - log::error!( - "exchange_msk(): Failed in spdm_responder transport. Migration ID: {}\n", - info.mig_info.mig_request_id - ); - MigrationResult::SecureSessionError - })?; - - with_timeout( - SPDM_TIMEOUT, - spdm::spdm_responder_transfer_msk( - &mut spdm_responder, - &info.mig_info, - #[cfg(feature = "policy_v2")] - remote_policy, - ), + migration_dst_exchange_msk( + transport, + info, + #[cfg(feature = "policy_v2")] + remote_policy, ) - .await - .map_err(|e| { - log::error!( - "exchange_msk: spdm_responder_transfer_msk timeout error: {:?}\n", - e - ); - e - })? - .map_err(|e| { - log::error!("exchange_msk: spdm_responder_transfer_msk error: {:?}\n", e); - e - })?; - log::info!("MSK exchange completed\n"); + .await?; } } diff --git a/src/migtd/src/migration/transport.rs b/src/migtd/src/migration/transport.rs new file mode 100644 index 00000000..1e6ca3ee --- /dev/null +++ b/src/migtd/src/migration/transport.rs @@ -0,0 +1,121 @@ +// Copyright (c) 2025 Intel Corporation +// +// SPDX-License-Identifier: BSD-2-Clause-Patent + +use super::MigrationResult; +use crate::migration::data::MigrationInformation; +use alloc::vec::Vec; + +type Result = core::result::Result; + +#[cfg(feature = "vmcall-raw")] +pub(super) type TransportType = vmcall_raw::stream::VmcallRaw; + +#[cfg(all(feature = "virtio-serial", not(feature = "vmcall-raw")))] +pub(super) type TransportType = virtio_serial::VirtioSerialPort; + +#[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] +pub(super) type TransportType = vsock::stream::VsockStream; + +pub(super) async fn setup_transport( + info: &MigrationInformation, + data: &mut Vec, +) -> Result { + #[cfg(not(feature = "vmcall-raw"))] + let _ = data; + + #[cfg(feature = "vmcall-raw")] + { + use vmcall_raw::stream::VmcallRaw; + let mut vmcall_raw_instance = VmcallRaw::new_with_mid(info.mig_info.mig_request_id) + .map_err(|e| { + data.extend_from_slice(&format!("Error: exchange_msk(): Failed to create vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", info.mig_info.mig_request_id, e).into_bytes()); + log::error!("exchange_msk: Failed to create vmcall_raw_instance with Migration ID: {} errorcode: {:?}\n", info.mig_info.mig_request_id, e); + MigrationResult::InvalidParameter + })?; + + vmcall_raw_instance + .connect() + .await + .map_err(|e| { + data.extend_from_slice(&format!("Error: exchange_msk(): Failed to connect vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", info.mig_info.mig_request_id, e).into_bytes()); + log::error!("exchange_msk: Failed to connect vmcall_raw_instance with Migration ID: {} errorcode: {:?}\n", info.mig_info.mig_request_id, e); + MigrationResult::InvalidParameter + })?; + return Ok(vmcall_raw_instance); + } + + #[cfg(all(feature = "virtio-serial", not(feature = "vmcall-raw")))] + { + use virtio_serial::VirtioSerialPort; + const VIRTIO_SERIAL_PORT_ID: u32 = 1; + + let port = VirtioSerialPort::new(VIRTIO_SERIAL_PORT_ID); + port.open()?; + return Ok(port); + } + + #[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] + { + use vsock::{stream::VsockStream, VsockAddr}; + + #[cfg(feature = "virtio-vsock")] + let mut vsock = VsockStream::new()?; + + #[cfg(feature = "vmcall-vsock")] + let mut vsock = VsockStream::new_with_cid( + info.mig_socket_info.mig_td_cid, + info.mig_info.mig_request_id, + )?; + + // Establish the vsock connection with host + vsock + .connect(&VsockAddr::new( + info.mig_socket_info.mig_td_cid as u32, + info.mig_socket_info.mig_channel_port, + )) + .await?; + return Ok(vsock); + } +} + +pub(super) async fn shutdown_transport( + transport: &mut TransportType, + info: &MigrationInformation, + data: &mut Vec, +) -> Result<()> { + #[cfg(not(feature = "vmcall-raw"))] + let _ = data; + + #[cfg(feature = "vmcall-raw")] + transport.shutdown().await.map_err(|e| { + data.extend_from_slice( + &format!( + "Error: shutdown_transport(): Failed to transport in vmcall_raw_instance with Migration ID: {:x} errorcode: {}\n", + info.mig_info.mig_request_id, + e + ) + .into_bytes(), + ); + log::error!( + "shutdown_transport: Failed to transport in vmcall_raw_instance with Migration ID: {} errorcode: {}", + info.mig_info.mig_request_id, + e + ); + MigrationResult::InvalidParameter + })?; + + #[cfg(all(feature = "virtio-serial", not(feature = "vmcall-raw")))] + transport.close().map_err(|e| { + log::error!("shutdown_transport: virtio_serial close error: {:?}\n", e); + MigrationResult::InvalidParameter + })?; + + #[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] + transport.shutdown().await.map_err(|e| { + log::error!("shutdown_transport: vsock shutdown error: {:?}\n", e); + MigrationResult::InvalidParameter + })?; + + Ok(()) +}