diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index bc78624..4f8cd75 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -115,7 +115,7 @@ impl Display for AttestationType { } /// Can generate a local attestation based on attestation type -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct AttestationGenerator { pub attestation_type: AttestationType, dummy_dcap_url: Option, diff --git a/src/attested_tls.rs b/src/attested_tls.rs new file mode 100644 index 0000000..6460371 --- /dev/null +++ b/src/attested_tls.rs @@ -0,0 +1,503 @@ +use crate::{ + attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, + AttestationGenerator, AttestationType, AttestationVerifier, + }, + host_to_host_with_port, +}; +use parity_scale_codec::{Decode, Encode}; +use sha2::{Digest, Sha256}; +use thiserror::Error; +use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; +use x509_parser::parse_x509_certificate; + +use std::num::TryFromIntError; +use std::{net::SocketAddr, sync::Arc}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use tokio_rustls::rustls::RootCertStore; +use tokio_rustls::{ + rustls::{ClientConfig, ServerConfig}, + TlsAcceptor, TlsConnector, +}; + +/// This makes it possible to add breaking protocol changes and provide backwards compatibility. +/// When adding more supported versions, note that ordering is important. ALPN will pick the first +/// protocol which both parties support - so newer supported versions should come first. +pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"]; + +/// The label used when exporting key material from a TLS session +pub(crate) const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; + +/// TLS Credentials +pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain + pub cert_chain: Vec>, + /// Der-encoded TLS private key + pub key: PrivateKeyDer<'static>, +} + +/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address +#[derive(Clone)] +pub struct AttestedTlsServer { + /// The underlying TCP listener + pub listener: Arc, + /// Quote generation type to use (including none) + attestation_generator: AttestationGenerator, + /// Verifier for remote attestation (including none) + attestation_verifier: AttestationVerifier, + /// The certificate chain + cert_chain: Vec>, + /// For accepting TLS connections + acceptor: TlsAcceptor, +} + +impl AttestedTlsServer { + pub async fn new( + cert_and_key: TlsCertAndKey, + local: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + ) -> Result { + let mut server_config = if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? + }; + + server_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + Self::new_with_tls_config( + cert_and_key.cert_chain, + server_config.into(), + local, + attestation_generator, + attestation_verifier, + ) + .await + } + + /// Start with preconfigured TLS + /// + /// This is not fully public as it allows dangerous configuration + pub(crate) async fn new_with_tls_config( + cert_chain: Vec>, + server_config: Arc, + local: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + ) -> Result { + let acceptor = tokio_rustls::TlsAcceptor::from(server_config); + let listener = TcpListener::bind(local).await?; + + Ok(Self { + listener: listener.into(), + attestation_generator, + attestation_verifier, + acceptor, + cert_chain, + }) + } + + /// Accept an incoming connection and do an attestation exchange + pub async fn accept( + &self, + ) -> Result< + ( + tokio_rustls::server::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + let (inbound, _client_addr) = self.listener.accept().await?; + + self.handle_connection(inbound).await + } + + /// Helper to get the socket address of the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() + } + + /// Handle an incoming connection from a proxy-client + pub async fn handle_connection( + &self, + inbound: TcpStream, + ) -> Result< + ( + tokio_rustls::server::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + tracing::debug!("attested-tls-server accepted connection"); + + // Do TLS handshake + let mut tls_stream = self.acceptor.accept(inbound).await?; + let (_io, connection) = tls_stream.get_ref(); + + // Ensure that we agreed a protocol + let _negotiated_protocol = connection + .alpn_protocol() + .ok_or(AttestedTlsError::AlpnFailed)?; + + // Compute an exporter unique to the session + let mut exporter = [0u8; 32]; + connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + let input_data = compute_report_input(Some(&self.cert_chain), exporter)?; + + // Get the TLS certficate chain of the client, if there is one + let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); + + // If we are in a CVM, generate an attestation + let attestation = self + .attestation_generator + .generate_attestation(input_data) + .await? + .encode(); + + // Write our attestation to the channel, with length prefix + let attestation_length_prefix = length_prefix(&attestation); + tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream.write_all(&attestation).await?; + + // Now read a length-prefixed attestation from the remote peer + // In the case of no client attestation this will be zero bytes + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; + let remote_attestation_type = remote_attestation_message.attestation_type; + + // If we expect an attestaion from the client, verify it and get measurements + let measurements = if self.attestation_verifier.has_remote_attestion() { + let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?; + + self.attestation_verifier + .verify_attestation(remote_attestation_message, remote_input_data) + .await? + } else { + None + }; + + Ok((tls_stream, measurements, remote_attestation_type)) + } +} + +/// A proxy client which forwards http traffic to a proxy-server +#[derive(Clone)] +pub struct AttestedTlsClient { + /// The connector for making TLS connections with out configuration + connector: TlsConnector, + /// Quote generation type to use (including none) + attestation_generator: AttestationGenerator, + /// Verifier for remote attestation (including none) + attestation_verifier: AttestationVerifier, + /// The certificate chain for client auth + cert_chain: Option>>, +} + +impl std::fmt::Debug for AttestedTlsClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AttestedTlsClient") + .field("attestation_verifier", &self.attestation_verifier) + .field("attestation_generator", &self.attestation_generator) + .field("cert_chain", &self.cert_chain) + .finish() + } +} + +impl AttestedTlsClient { + /// Start with optional TLS client auth + pub async fn new( + cert_and_key: Option, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + remote_certificate: Option>, + ) -> Result { + // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + // Setup TLS client configuration, with or without client auth + let mut client_config = if let Some(ref cert_and_key) = cert_and_key { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + Self::new_with_tls_config( + client_config.into(), + attestation_generator, + attestation_verifier, + cert_and_key.map(|c| c.cert_chain), + ) + .await + } + + /// Create a new proxy client with given TLS configuration + /// + /// This not fully public as it allows dangerous configuration but is used in tests + pub(crate) async fn new_with_tls_config( + client_config: Arc, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + cert_chain: Option>>, + ) -> Result { + let connector = TlsConnector::from(client_config.clone()); + + Ok(Self { + connector, + attestation_generator, + attestation_verifier, + cert_chain, + }) + } + + /// Connect to an attested-tls-server, do TLS handshake and attestation exchange + pub async fn connect( + &self, + target: &str, + ) -> Result< + ( + tokio_rustls::client::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + // Make a TCP client connection and TLS handshake + let out = TcpStream::connect(&target).await?; + let mut tls_stream = self + .connector + .connect(server_name_from_host(target)?, out) + .await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + // Ensure that we agreed a protocol + let _negotiated_protocol = server_connection + .alpn_protocol() + .ok_or(AttestedTlsError::AlpnFailed)?; + + // Compute an exporter unique to the channel + let mut exporter = [0u8; 32]; + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + // Get the TLS certificate chain of the server + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(AttestedTlsError::NoCertificate)? + .to_owned(); + + let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; + + // Read a length prefixed attestation from the proxy-server + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; + let remote_attestation_type = remote_attestation_message.attestation_type; + + // Verify the remote attestation against our accepted measurements + let measurements = self + .attestation_verifier + .verify_attestation(remote_attestation_message, remote_input_data) + .await?; + + // If we are in a CVM, provide an attestation + let attestation = if self.attestation_generator.attestation_type != AttestationType::None { + let local_input_data = compute_report_input(self.cert_chain.as_deref(), exporter)?; + self.attestation_generator + .generate_attestation(local_input_data) + .await? + .encode() + } else { + AttestationExchangeMessage::without_attestation().encode() + }; + + // Send our attestation (or zero bytes) prefixed with length + let attestation_length_prefix = length_prefix(&attestation); + tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream.write_all(&attestation).await?; + + Ok((tls_stream, measurements, remote_attestation_type)) + } + + /// Connect to an attested TLS server, retrieve the remote TLS certificate and return it + pub async fn get_tls_cert( + &self, + server_name: &str, + ) -> Result>, AttestedTlsError> { + let (mut tls_stream, _, _) = self.connect(server_name).await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(AttestedTlsError::NoCertificate)? + .to_owned(); + + tls_stream.shutdown().await?; + + Ok(remote_cert_chain) + } +} + +/// A client which just gets the attested remote certificate, with no client authentication +pub async fn get_tls_cert( + server_name: String, + attestation_verifier: AttestationVerifier, + remote_certificate: Option>, +) -> Result>, AttestedTlsError> { + tracing::debug!("Getting remote TLS cert"); + let attested_tls_client = AttestedTlsClient::new( + None, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + remote_certificate, + ) + .await?; + attested_tls_client + .get_tls_cert(&host_to_host_with_port(&server_name)) + .await +} + +/// Helper for testing getting remote certificate +#[cfg(test)] +pub(crate) async fn get_tls_cert_with_config( + server_name: &str, + attestation_verifier: AttestationVerifier, + client_config: Arc, +) -> Result>, AttestedTlsError> { + let attested_tls_client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + None, + ) + .await?; + attested_tls_client.get_tls_cert(server_name).await +} + +/// Given a certificate chain and an exporter (session key material), build the quote input value +/// SHA256(pki) || exporter +pub fn compute_report_input( + cert_chain: Option<&[CertificateDer<'_>]>, + exporter: [u8; 32], +) -> Result<[u8; 64], AttestationError> { + let mut quote_input = [0u8; 64]; + if let Some(cert_chain) = cert_chain { + let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; + quote_input[..32].copy_from_slice(&pki_hash); + } + quote_input[32..].copy_from_slice(&exporter); + Ok(quote_input) +} + +/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate +fn get_pki_hash_from_certificate_chain( + cert_chain: &[CertificateDer<'_>], +) -> Result<[u8; 32], AttestationError> { + let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; + let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; + let public_key = &cert.tbs_certificate.subject_pki; + let key_bytes = public_key.subject_public_key.as_ref(); + + let mut hasher = Sha256::new(); + hasher.update(key_bytes); + Ok(hasher.finalize().into()) +} + +/// An error when running an attested TLS client or server +#[derive(Error, Debug)] +pub enum AttestedTlsError { + #[error("Failed to get server ceritifcate")] + NoCertificate, + #[error("TLS: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Verifier builder: {0}")] + VerifierBuilder(#[from] VerifierBuilderError), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Attestation: {0}")] + Attestation(#[from] AttestationError), + #[error("Integer conversion: {0}")] + IntConversion(#[from] TryFromIntError), + #[error("Bad host name: {0}")] + BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("Serialization: {0}")] + Serialization(#[from] parity_scale_codec::Error), + #[error("Protocol negotiation failed - remote peer does not support this protocol")] + AlpnFailed, +} + +/// Given a byte array, encode its length as a 4 byte big endian u32 +fn length_prefix(input: &[u8]) -> [u8; 4] { + let len = input.len() as u32; + len.to_be_bytes() +} + +/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part +fn server_name_from_host( + host: &str, +) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { + // If host contains ':', try to split off the port. + let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); + + // If the host is an IPv6 literal in brackets like "[::1]:443", + // remove the brackets for SNI (SNI allows bare IPv6 too). + let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); + + ServerName::try_from(host_part.to_string()) +} diff --git a/src/lib.rs b/src/lib.rs index e319fe9..36a738c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,47 +1,42 @@ pub mod attestation; pub mod attested_get; +pub mod attested_tls; pub mod file_server; pub mod health_check; pub use attestation::AttestationGenerator; -use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType}; + use bytes::Bytes; use http::HeaderValue; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{service::service_fn, Response}; use hyper_util::rt::TokioIo; -use parity_scale_codec::{Decode, Encode}; -use sha2::{Digest, Sha256}; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; -use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; +use tokio_rustls::rustls::server::VerifierBuilderError; use tracing::{error, warn}; -use x509_parser::parse_x509_certificate; #[cfg(test)] mod test_helpers; +use std::net::SocketAddr; use std::num::TryFromIntError; use std::time::Duration; -use std::{net::SocketAddr, sync::Arc}; -use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +use tokio::io; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; -use tokio_rustls::rustls::RootCertStore; -use tokio_rustls::{ - rustls::{ClientConfig, ServerConfig}, - TlsAcceptor, TlsConnector, -}; +use tokio_rustls::rustls::pki_types::CertificateDer; -use crate::attestation::{AttestationExchangeMessage, AttestationVerifier}; - -/// This makes it possible to add breaking protocol changes and provide backwards compatibility. -/// When adding more supported versions, note that ordering is important. ALPN will pick the first -/// protocol which both parties support - so newer supported versions should come first. -pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"]; +#[cfg(test)] +use std::sync::Arc; +#[cfg(test)] +use tokio_rustls::rustls::{ClientConfig, ServerConfig}; -/// The label used when exporting key material from a TLS session -const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +use crate::{ + attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationType, AttestationVerifier, + }, + attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey}, +}; /// The header name for giving attestation type const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; @@ -58,26 +53,10 @@ type RequestWithResponseSender = ( ); type Http2Sender = hyper::client::conn::http2::SendRequest; -/// TLS Credentials -pub struct TlsCertAndKey { - /// Der-encoded TLS certificate chain - pub cert_chain: Vec>, - /// Der-encoded TLS private key - pub key: PrivateKeyDer<'static>, -} - /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - /// The underlying TCP listener - listener: TcpListener, - /// Quote generation type to use (including none) - attestation_generator: AttestationGenerator, - /// Verifier for remote attestation (including none) - attestation_verifier: AttestationVerifier, - /// The certificate chain - cert_chain: Vec>, - /// For accepting TLS connections - acceptor: TlsAcceptor, + /// The underlying attested TLS server + attested_tls_server: AttestedTlsServer, /// The address of the target service we are proxying to target: SocketAddr, } @@ -91,39 +70,25 @@ impl ProxyServer { attestation_verifier: AttestationVerifier, client_auth: bool, ) -> Result { - let mut server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder() - .with_client_cert_verifier(verifier) - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? - } else { - ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? - }; - - server_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - Self::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), + let attested_tls_server = AttestedTlsServer::new( + cert_and_key, local, - target, attestation_generator, attestation_verifier, + client_auth, ) - .await + .await?; + + Ok(Self { + attested_tls_server, + target, + }) } /// Start with preconfigured TLS /// /// This is not public as it allows dangerous configuration + #[cfg(test)] async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, @@ -132,40 +97,40 @@ impl ProxyServer { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { - let acceptor = tokio_rustls::TlsAcceptor::from(server_config); - let listener = TcpListener::bind(local).await?; - - Ok(Self { - listener, + let attested_tls_server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + local, attestation_generator, attestation_verifier, - acceptor, + ) + .await?; + + Ok(Self { + attested_tls_server, target, - cert_chain, }) } /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { - let (inbound, _client_addr) = self.listener.accept().await?; - - let acceptor = self.acceptor.clone(); let target = self.target; - let cert_chain = self.cert_chain.clone(); - let attestation_generator = self.attestation_generator.clone(); - let attestation_verifier = self.attestation_verifier.clone(); + let (inbound, _client_addr) = self.attested_tls_server.listener.accept().await?; + let attested_tls_server = self.attested_tls_server.clone(); + tokio::spawn(async move { - if let Err(err) = Self::handle_connection( - inbound, - acceptor, - target, - cert_chain, - attestation_generator, - attestation_verifier, - ) - .await - { - warn!("Failed to handle connection: {err}"); + match attested_tls_server.handle_connection(inbound).await { + Ok((tls_stream, measurements, attestation_type)) => { + if let Err(err) = + Self::handle_connection(tls_stream, measurements, attestation_type, target) + .await + { + warn!("Failed to handle connection: {err}"); + } + } + Err(err) => { + warn!("Attestation exchange failed: {err}"); + } } }); @@ -174,74 +139,18 @@ impl ProxyServer { /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() + self.attested_tls_server.local_addr() } /// Handle an incoming connection from a proxy-client async fn handle_connection( - inbound: TcpStream, - acceptor: TlsAcceptor, + tls_stream: tokio_rustls::server::TlsStream, + measurements: Option, + remote_attestation_type: AttestationType, target: SocketAddr, - cert_chain: Vec>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, ) -> Result<(), ProxyError> { tracing::debug!("proxy-server accepted connection"); - // Do TLS handshake - let mut tls_stream = acceptor.accept(inbound).await?; - let (_io, connection) = tls_stream.get_ref(); - - // Ensure that we agreed a protocol - let _negotiated_protocol = connection.alpn_protocol().ok_or(ProxyError::AlpnFailed)?; - - // Compute an exporter unique to the session - let mut exporter = [0u8; 32]; - connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - let input_data = compute_report_input(Some(&cert_chain), exporter)?; - - // Get the TLS certficate chain of the client, if there is one - let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - - // If we are in a CVM, generate an attestation - let attestation = attestation_generator - .generate_attestation(input_data) - .await? - .encode(); - - // Write our attestation to the channel, with length prefix - let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; - - // Now read a length-prefixed attestation from the remote peer - // In the case of no client attestation this will be zero bytes - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - let remote_attestation_type = remote_attestation_message.attestation_type; - - // If we expect an attestaion from the client, verify it and get measurements - let measurements = if attestation_verifier.has_remote_attestion() { - let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?; - - attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await? - } else { - None - }; - // Setup an HTTP server let http = hyper::server::conn::http2::Builder::new(TokioExecutor); @@ -348,63 +257,52 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { - // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots - let root_store = match remote_certificate { - Some(remote_certificate) => { - let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; - root_store - } - None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), - }; - - // Setup TLS client configuration, with or without client auth - let mut client_config = if let Some(ref cert_and_key) = cert_and_key { - ClientConfig::builder() - .with_root_certificates(root_store) - .with_client_auth_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth() - }; - - client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - Self::new_with_tls_config( - client_config.into(), - address, - server_name, + let attested_tls_client = AttestedTlsClient::new( + cert_and_key, attestation_generator, attestation_verifier, - cert_and_key.map(|c| c.cert_chain), + remote_certificate, ) - .await + .await?; + + Self::new_with_inner(address, attested_tls_client, &server_name).await } /// Create a new proxy client with given TLS configuration /// /// This is private as it allows dangerous configuration but is used in tests + #[cfg(test)] async fn new_with_tls_config( client_config: Arc, - local: impl ToSocketAddrs, + address: impl ToSocketAddrs, target_name: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - // Setup TCP server and TLS client - let listener = TcpListener::bind(local).await?; - let connector = TlsConnector::from(client_config.clone()); + let attested_tls_client = AttestedTlsClient::new_with_tls_config( + client_config, + attestation_generator, + attestation_verifier, + cert_chain, + ) + .await?; + + Self::new_with_inner(address, attested_tls_client, &target_name).await + } + + /// Create a new proxy client with given TLS configuration + /// + /// This is private as it allows dangerous configuration but is used in tests + async fn new_with_inner( + address: impl ToSocketAddrs, + attested_tls_client: AttestedTlsClient, + target_name: &str, + ) -> Result { + let listener = TcpListener::bind(address).await?; // Process the hostname / port provided by the user - let target = host_to_host_with_port(&target_name); + let target = host_to_host_with_port(target_name); // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( @@ -416,16 +314,9 @@ impl ProxyClient { // Connect to the proxy server and provide / verify attestation let (mut sender, mut measurements, mut remote_attestation_type) = - Self::setup_connection_with_backoff( - connector.clone(), - target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), - true, - ) - .await?; + Self::setup_connection_with_backoff(&target, &attested_tls_client, true).await?; + let attested_tls_client_clone = attested_tls_client.clone(); tokio::spawn(async move { // Read an incoming request from the channel (from the source client) while let Some((req, response_tx)) = requests_rx.recv().await { @@ -474,11 +365,8 @@ impl ProxyClient { // Reconnect to the server - retrying indefinately with a backoff (sender, measurements, remote_attestation_type) = Self::setup_connection_with_backoff( - connector.clone(), - target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), + &target, + &attested_tls_client_clone, false, ) .await @@ -548,26 +436,15 @@ impl ProxyClient { // Attempt connection and handshake with the proxy-server // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( - connector: TlsConnector, - target: String, - cert_chain: Option>>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, + target: &str, + attested_tls_client: &AttestedTlsClient, should_bail: bool, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection( - connector.clone(), - target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), - ) - .await - { + match Self::setup_connection(attested_tls_client, target).await { Ok(output) => { return Ok(output); } @@ -589,74 +466,12 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( - connector: TlsConnector, - target: String, - cert_chain: Option>>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, + inner: &AttestedTlsClient, + target: &str, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { - // Make a TCP client connection and TLS handshake - let out = TcpStream::connect(&target).await?; - let mut tls_stream = connector - .connect(server_name_from_host(&target)?, out) - .await?; + let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?; - let (_io, server_connection) = tls_stream.get_ref(); - - // Ensure that we agreed a protocol - let _negotiated_protocol = server_connection - .alpn_protocol() - .ok_or(ProxyError::AlpnFailed)?; - - // Compute an exporter unique to the channel - let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - // Get the TLS certificate chain of the server - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)? - .to_owned(); - - let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; - - // Read a length prefixed attestation from the proxy-server - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - let remote_attestation_type = remote_attestation_message.attestation_type; - - // Verify the remote attestation against our accepted measurements - let measurements = attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await?; - - // If we are in a CVM, provide an attestation - let attestation = if attestation_generator.attestation_type != AttestationType::None { - let local_input_data = compute_report_input(cert_chain.as_deref(), exporter)?; - attestation_generator - .generate_attestation(local_input_data) - .await? - .encode() - } else { - AttestationExchangeMessage::without_attestation().encode() - }; - - // Send our attestation (or zero bytes) prefixed with length - let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; - - // The attestation exchange is now complete - now setup an HTTP client + // The attestation exchange is now complete - setup an HTTP client let outbound_io = TokioIo::new(tls_stream); let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) @@ -685,110 +500,6 @@ impl ProxyClient { } } -/// Just get the attested remote certificate, with no client authentication -pub async fn get_tls_cert( - server_name: String, - attestation_verifier: AttestationVerifier, - remote_certificate: Option>, -) -> Result>, ProxyError> { - tracing::debug!("Getting remote TLS cert"); - // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots - let root_store = match remote_certificate { - Some(remote_certificate) => { - let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; - root_store - } - None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), - }; - - let mut client_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await -} - -async fn get_tls_cert_with_config( - server_name: String, - attestation_verifier: AttestationVerifier, - client_config: Arc, -) -> Result>, ProxyError> { - let connector = TlsConnector::from(client_config); - - let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?; - let mut tls_stream = connector - .connect(server_name_from_host(&server_name)?, out) - .await?; - - let (_io, server_connection) = tls_stream.get_ref(); - - let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)? - .to_owned(); - - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - - let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; - - let _measurements = attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await?; - - tls_stream.shutdown().await?; - - Ok(remote_cert_chain) -} - -/// Given a certificate chain and an exporter (session key material), build the quote input value -/// SHA256(pki) || exporter -pub fn compute_report_input( - cert_chain: Option<&[CertificateDer<'_>]>, - exporter: [u8; 32], -) -> Result<[u8; 64], AttestationError> { - let mut quote_input = [0u8; 64]; - if let Some(cert_chain) = cert_chain { - let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; - quote_input[..32].copy_from_slice(&pki_hash); - } - quote_input[32..].copy_from_slice(&exporter); - Ok(quote_input) -} - -/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate -fn get_pki_hash_from_certificate_chain( - cert_chain: &[CertificateDer<'_>], -) -> Result<[u8; 32], AttestationError> { - let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; - let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; - let public_key = &cert.tbs_certificate.subject_pki; - let key_bytes = public_key.subject_public_key.as_ref(); - - let mut hasher = Sha256::new(); - hasher.update(key_bytes); - Ok(hasher.finalize().into()) -} - /// An error when running a proxy client or server #[derive(Error, Debug)] pub enum ProxyError { @@ -814,10 +525,8 @@ pub enum ProxyError { OneShotRecv(#[from] oneshot::error::RecvError), #[error("Failed to send request, connection to proxy-server dropped")] MpscSend, - #[error("Serialization: {0}")] - Serialization(#[from] parity_scale_codec::Error), - #[error("Protocol negotiation failed - remote peer does not support this protocol")] - AlpnFailed, + #[error("Attested TLS: {0}")] + AttestedTls(#[from] AttestedTlsError), } impl From> for ProxyError { @@ -826,14 +535,8 @@ impl From> for ProxyError { } } -/// Given a byte array, encode its length as a 4 byte big endian u32 -fn length_prefix(input: &[u8]) -> [u8; 4] { - let len = input.len() as u32; - len.to_be_bytes() -} - /// If no port was provided, default to 443 -fn host_to_host_with_port(host: &str) -> String { +pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { host.to_string() } else { @@ -841,20 +544,6 @@ fn host_to_host_with_port(host: &str) -> String { } } -/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part -fn server_name_from_host( - host: &str, -) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { - // If host contains ':', try to split off the port. - let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); - - // If the host is an IPv6 literal in brackets like "[::1]:443", - // remove the brackets for SNI (SNI allows bare IPv6 too). - let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); - - ServerName::try_from(host_part.to_string()) -} - /// An Executor for hyper that uses the tokio runtime #[derive(Clone)] struct TokioExecutor; @@ -875,8 +564,11 @@ where mod tests { use std::collections::HashMap; - use crate::attestation::measurements::{ - DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, MultiMeasurements, + use crate::{ + attestation::measurements::{ + DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, MultiMeasurements, + }, + attested_tls::get_tls_cert_with_config, }; use super::*; @@ -1239,7 +931,7 @@ mod tests { }); let retrieved_chain = get_tls_cert_with_config( - proxy_server_addr.to_string(), + &proxy_server_addr.to_string(), AttestationVerifier::mock(), client_config, ) @@ -1287,7 +979,9 @@ mod tests { assert!(matches!( proxy_client_result.unwrap_err(), - ProxyError::Attestation(AttestationError::AttestationTypeNotAccepted) + ProxyError::AttestedTls(AttestedTlsError::Attestation( + AttestationError::AttestationTypeNotAccepted + )) )); } @@ -1346,7 +1040,9 @@ mod tests { assert!(matches!( proxy_client_result.unwrap_err(), - ProxyError::Attestation(AttestationError::MeasurementsNotAccepted) + ProxyError::AttestedTls(AttestedTlsError::Attestation( + AttestationError::MeasurementsNotAccepted + )) )); } } diff --git a/src/main.rs b/src/main.rs index d9cb5b2..9bdff8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,8 +8,9 @@ use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, attested_get::attested_get, + attested_tls::{get_tls_cert, TlsCertAndKey}, file_server::attested_file_server, - get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, + health_check, AttestationGenerator, ProxyClient, ProxyServer, }; #[derive(Parser, Debug, Clone)] diff --git a/src/test_helpers.rs b/src/test_helpers.rs index b783dff..c7df30e 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -13,7 +13,8 @@ use tokio_rustls::rustls::{ use crate::{ attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, - MEASUREMENT_HEADER, SUPPORTED_ALPN_PROTOCOL_VERSIONS, + attested_tls::SUPPORTED_ALPN_PROTOCOL_VERSIONS, + MEASUREMENT_HEADER, }; /// Helper to generate a self-signed certificate for testing