From 5243da3f3177ef0708f71c6d3cdfb1a0d2407597 Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Thu, 30 Apr 2026 13:48:45 -0700 Subject: [PATCH 1/7] feat(networking): add native mTLS support for fabric inter-node communication Add optional TLS/mTLS configuration for Restate's fabric port (5122). This enables securing inter-node communication at the application layer without relying on Kubernetes NetworkPolicy or external service meshes. Configuration lives under [networking.tls] with support for: - Strict mode (TLS only) and optional mode (accepts both plaintext and TLS) - Mutual TLS with configurable client certificate requirements - Periodic certificate hot-reload from disk (default: 1h) - Client config inheritance from server config when not specified separately - Scheme-based signaling (https:// in advertised-address) Key changes: - Add FabricTlsOptions, FabricTlsClientOptions, TlsMode config structs - Add TlsCertResolver with ArcSwap-based lock-free cert rotation - Modify run_hyper_server to support TLS accept and protocol sniffing - Modify GrpcConnector to use ClientTlsConfig for https:// peers - Extend PeerNetAddress with is_tls() and derive_from_bind_address_with_tls() - Add tokio-rustls, rustls-pemfile workspace dependencies Without [networking.tls] configuration, behavior is identical to today. --- Cargo.lock | 12 ++ Cargo.toml | 2 + crates/admin/src/service.rs | 2 +- crates/core/Cargo.toml | 3 + crates/core/src/network/grpc/connector.rs | 39 +++++- crates/core/src/network/mod.rs | 1 + crates/core/src/network/net_util.rs | 153 +++++++++++++++++++--- crates/core/src/network/networking.rs | 4 +- crates/core/src/network/server_builder.rs | 16 ++- crates/core/src/network/tls.rs | 153 ++++++++++++++++++++++ crates/node/src/lib.rs | 12 +- crates/types/src/config/networking.rs | 113 ++++++++++++++++ crates/types/src/net/address.rs | 32 ++--- 13 files changed, 495 insertions(+), 47 deletions(-) create mode 100644 crates/core/src/network/tls.rs diff --git a/Cargo.lock b/Cargo.lock index 18833bcf1c..507463fbfb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7288,6 +7288,8 @@ dependencies = [ "restate-time-util", "restate-types", "restate-workspace-hack", + "rustls", + "rustls-pemfile", "serde", "serde_with", "static_assertions", @@ -7295,6 +7297,7 @@ dependencies = [ "test-log", "thiserror 2.0.18", "tokio", + "tokio-rustls", "tokio-stream", "tokio-util", "tonic", @@ -9025,6 +9028,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" diff --git a/Cargo.toml b/Cargo.toml index acc7973be2..f423431afd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -231,6 +231,7 @@ rocksdb = { version = "0.46.1", package = "rust-rocksdb", features = [ ], git = "https://github.com/restatedev/rust-rocksdb", rev = "dcfba7946697d740e60f0b1060b6624dc1c7e94a" } rstest = "0.26.1" rustls = { version = "0.23.35", default-features = false, features = ["ring"] } +rustls-pemfile = "2" schemars = { version = "1.2", features = ["bytes1"] } semver = { version = "1.0", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } @@ -258,6 +259,7 @@ tokio = { version = "1.48.0", default-features = false, features = [ "macros", "parking_lot", ] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } tokio-stream = "0.1.17" tokio-util = { version = "0.7.17" } toml = { version = "0.9" } diff --git a/crates/admin/src/service.rs b/crates/admin/src/service.rs index 6b369654e4..e2971f2cff 100644 --- a/crates/admin/src/service.rs +++ b/crates/admin/src/service.rs @@ -212,7 +212,7 @@ where TaskCenter::with_current(|tc| opts.advertised_address(tc.address_book())) ); - net_util::run_hyper_server(self.listeners, service, || ()) + net_util::run_hyper_server(self.listeners, service, || (), None) .await .map_err(Into::into) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index cfdb76c46c..f86a52683b 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -62,9 +62,12 @@ static_assertions = { workspace = true } strum = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["tracing"] } +tokio-rustls = { workspace = true } tokio-stream = { workspace = true, features = ["net"] } tokio-util = { workspace = true, features = ["net"] } tonic = { workspace = true, features = ["transport", "codegen", "gzip", "zstd", "router"] } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } tonic-prost = { workspace = true } tonic-reflection = { workspace = true } tower = { workspace = true } diff --git a/crates/core/src/network/grpc/connector.rs b/crates/core/src/network/grpc/connector.rs index ec45e7950b..b6497bd748 100644 --- a/crates/core/src/network/grpc/connector.rs +++ b/crates/core/src/network/grpc/connector.rs @@ -11,8 +11,10 @@ use futures::Stream; use http::Uri; use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; use tokio::io; use tokio::net::UnixStream; +use tokio_rustls::TlsConnector; use tokio_stream::StreamExt; use tonic::codec::CompressionEncoding; use tonic::transport::Endpoint; @@ -26,13 +28,20 @@ use restate_types::net::connect_opts::GrpcConnectionOptions; use crate::network::grpc::DEFAULT_GRPC_COMPRESSION; use crate::network::protobuf::core_node_svc::core_node_svc_client::CoreNodeSvcClient; use crate::network::protobuf::network::Message; +use crate::network::tls::TlsCertResolver; use crate::network::transport_connector::find_node; use crate::network::{ConnectError, Destination, Swimlane, TransportConnect}; use crate::{Metadata, TaskCenter, TaskKind}; #[derive(Clone, Default)] pub struct GrpcConnector { - _private: (), + tls: Option, +} + +impl GrpcConnector { + pub fn new(tls: Option) -> Self { + Self { tls } + } } impl TransportConnect for GrpcConnector { @@ -53,7 +62,7 @@ impl TransportConnect for GrpcConnector { debug!("Connecting to {} at {}", destination, address); let networking = &Configuration::pinned().networking; - let channel = create_channel(address, swimlane, networking); + let channel = create_channel(address, swimlane, networking, &self.tls); // Establish the connection let client = CoreNodeSvcClient::new(channel) @@ -85,8 +94,11 @@ fn create_channel( address: AdvertisedAddress

, _swimlane: Swimlane, options: &NetworkingOptions, + tls: &Option, ) -> Channel { let address = address.into_address().expect("valid address"); + let use_tls = address.is_tls() && tls.is_some(); + let endpoint = match &address { PeerNetAddress::Uds(_) => { // dummy endpoint required to specify an uds connector, it is not used anywhere @@ -108,7 +120,6 @@ fn create_channel( .initial_stream_window_size(options.stream_window_size()) .initial_connection_window_size(options.connection_window_size()) .keep_alive_while_idle(true) - // this true by default, but this is to guard against any change in defaults .tcp_nodelay(true); match address { @@ -120,7 +131,27 @@ fn create_channel( } })) } - PeerNetAddress::Http(_) => endpoint.connect_lazy() + PeerNetAddress::Http(uri) if use_tls => { + let resolver = tls.as_ref().unwrap().clone(); + endpoint.connect_with_connector_lazy(tower::service_fn(move |_: Uri| { + let resolver = resolver.clone(); + let host = uri.host().unwrap_or("localhost").to_owned(); + let port = uri.port_u16().unwrap_or(5122); + async move { + let addr = format!("{host}:{port}"); + let tcp_stream = tokio::net::TcpStream::connect(&addr).await?; + tcp_stream.set_nodelay(true)?; + + let client_config = resolver.client_config(); + let connector = TlsConnector::from(client_config); + let server_name = ServerName::try_from(host) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + let tls_stream = connector.connect(server_name, tcp_stream).await?; + Ok::<_, io::Error>(TokioIo::new(tls_stream)) + } + })) + } + PeerNetAddress::Http(_) => endpoint.connect_lazy(), } } diff --git a/crates/core/src/network/mod.rs b/crates/core/src/network/mod.rs index e6e39bbb74..328e19c5b0 100644 --- a/crates/core/src/network/mod.rs +++ b/crates/core/src/network/mod.rs @@ -23,6 +23,7 @@ mod network_sender; mod networking; pub mod protobuf; mod server_builder; +pub mod tls; pub mod tonic_service_filter; mod tracking; pub mod transport_connector; diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index b2941ffa0a..3a77fb4bcf 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -26,13 +26,14 @@ use tokio_util::either::Either; use tonic::transport::{Channel, Endpoint}; use tracing::{Instrument, Span, debug, error_span, info, instrument, trace}; -use restate_types::config::Configuration; +use restate_types::config::{Configuration, TlsMode}; use restate_types::errors::GenericError; use restate_types::net::address::{AdvertisedAddress, GrpcPort}; use restate_types::net::address::{ListenerPort, PeerNetAddress}; use restate_types::net::connect_opts::CommonClientConnectionOptions; use restate_types::net::listener::Listeners; +use crate::network::tls::TlsCertResolver; use crate::{ShutdownError, TaskCenter, TaskKind, cancellation_watcher}; pub enum DNSResolution { @@ -129,6 +130,7 @@ pub async fn run_hyper_server( listeners: Listeners

, service: S, on_stop: impl Fn(), + tls: Option, ) -> Result<(), Error> where S: hyper::service::Service, Response = hyper::Response> @@ -150,8 +152,12 @@ where Span::current().record("server.port", socket_addr.port()); } - info!("Server listening"); - run_listener_loop(listeners, service, P::NAME).await?; + if tls.is_some() { + info!("Server listening with TLS enabled"); + } else { + info!("Server listening"); + } + run_listener_loop(listeners, service, P::NAME, tls).await?; on_stop(); info!("Stopped listening"); @@ -163,6 +169,7 @@ async fn run_listener_loop( mut listeners: Listeners

, service: S, server_name: &'static str, + tls: Option, ) -> Result<(), Error> where S: hyper::service::Service, Response = hyper::Response> @@ -180,6 +187,16 @@ where let graceful_shutdown = GracefulShutdown::new(); let task_name: Arc = Arc::from(format!("{server_name}-socket")); + let tls_mode = tls.as_ref().map(|_| { + configuration + .live_load() + .networking + .tls + .as_ref() + .map(|t| t.mode.clone()) + .unwrap_or(TlsMode::Strict) + }); + loop { tokio::select! { biased; @@ -205,29 +222,123 @@ where match stream { Either::Left(tcp_stream) => { - // TCP SOCKET - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder - .serve_connection(io, service.clone()).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move { - trace!("New tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); + let tls_resolver = tls.clone(); + let tls_mode = tls_mode.clone(); + let service = service.clone(); + let graceful_shutdown = &graceful_shutdown; + let task_name = task_name.clone(); + + match (&tls_resolver, &tls_mode) { + (Some(resolver), Some(TlsMode::Strict)) => { + // TLS strict: all connections must be TLS + let acceptor = resolver.tls_acceptor(); + let connection = match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) + } + Err(e) => { + debug!("TLS handshake failed: {e}"); + continue; + } + }; + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New TLS tcp connection accepted"); + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) + }.instrument(socket_span))?; + } + (Some(resolver), Some(TlsMode::Optional)) => { + // TLS optional: peek first byte to detect TLS ClientHello + let tcp_stream = tcp_stream; + let mut peek_buf = [0u8; 1]; + match tcp_stream.peek(&mut peek_buf).await { + Ok(1) if peek_buf[0] == 0x16 => { + // TLS ClientHello detected + let acceptor = resolver.tls_acceptor(); + let connection = match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) + } + Err(e) => { + debug!("TLS handshake failed: {e}"); + continue; + } + }; + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New TLS tcp connection accepted (optional mode)"); + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) + }.instrument(socket_span))?; + } + _ => { + // Plaintext connection + let io = TokioIo::new(tcp_stream); + let connection = graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()); + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New plaintext tcp connection accepted (optional mode)"); + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) + }.instrument(socket_span))?; } - } else { - debug!("Connection terminated due to error: {e}"); } - } else { - trace!("Connection completed cleanly"); } - Ok(()) - }.instrument(socket_span))?; - + _ => { + // No TLS: plaintext (current behavior) + let io = TokioIo::new(tcp_stream); + let connection = graceful_shutdown.watch(builder + .serve_connection(io, service).into_owned()); + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New tcp connection accepted"); + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) + }.instrument(socket_span))?; + } + } }, Either::Right(unix_stream) => { - // UNIX SOCKET + // UNIX SOCKET — TLS never applies to UDS let io = TokioIo::new(unix_stream); let connection = graceful_shutdown.watch(builder .serve_connection(io, service.clone()).into_owned()); diff --git a/crates/core/src/network/networking.rs b/crates/core/src/network/networking.rs index da4e70012f..0d8eae078d 100644 --- a/crates/core/src/network/networking.rs +++ b/crates/core/src/network/networking.rs @@ -38,10 +38,10 @@ impl Clone for Networking { } impl Networking { - pub fn with_grpc_connector() -> Self { + pub fn with_grpc_connector(tls: Option) -> Self { Self { connections: ConnectionManager::default(), - connector: GrpcConnector::default(), + connector: GrpcConnector::new(tls), } } } diff --git a/crates/core/src/network/server_builder.rs b/crates/core/src/network/server_builder.rs index 12d1d8a371..4a99330a49 100644 --- a/crates/core/src/network/server_builder.rs +++ b/crates/core/src/network/server_builder.rs @@ -22,12 +22,14 @@ use restate_types::net::listener::{AddressBook, Listeners}; use restate_types::protobuf::common::NodeRpcStatus; use super::net_util::run_hyper_server; +use super::tls::TlsCertResolver; pub struct NetworkServerBuilder { grpc_descriptors: Vec<&'static [u8]>, grpc_routes: Option, axum_router: Option, listeners: Listeners, + tls: Option, } impl NetworkServerBuilder { @@ -37,9 +39,14 @@ impl NetworkServerBuilder { grpc_routes: None, axum_router: None, listeners: address_book.take_listeners::(), + tls: None, } } + pub fn set_tls(&mut self, tls: Option) { + self.tls = tls; + } + pub fn is_empty(&self) -> bool { self.grpc_routes.is_none() && self.axum_router.is_none() } @@ -115,9 +122,12 @@ impl NetworkServerBuilder { node_rpc_health.update(NodeRpcStatus::Ready); - run_hyper_server(self.listeners, service, || { - node_rpc_health.update(NodeRpcStatus::Stopping) - }) + run_hyper_server( + self.listeners, + service, + || node_rpc_health.update(NodeRpcStatus::Stopping), + self.tls, + ) .await?; Ok(()) diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs new file mode 100644 index 0000000000..dd640eb89b --- /dev/null +++ b/crates/core/src/network/tls.rs @@ -0,0 +1,153 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::io::BufReader; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::WebPkiClientVerifier; +use rustls::{ClientConfig, RootCertStore, ServerConfig}; +use tokio_rustls::TlsAcceptor; +use tracing::{info, warn}; + +use restate_types::config::FabricTlsOptions; + +/// Holds hot-swappable TLS configurations for both server and client roles. +#[derive(Clone)] +pub struct TlsCertResolver { + server_config: Arc>, + client_config: Arc>, +} + +impl TlsCertResolver { + pub fn new(opts: &FabricTlsOptions) -> anyhow::Result { + let server = build_server_config(opts)?; + let client = build_client_config(opts)?; + Ok(Self { + server_config: Arc::new(ArcSwap::from_pointee(server)), + client_config: Arc::new(ArcSwap::from_pointee(client)), + }) + } + + pub fn server_config(&self) -> Arc { + self.server_config.load_full() + } + + pub fn client_config(&self) -> Arc { + self.client_config.load_full() + } + + pub fn tls_acceptor(&self) -> TlsAcceptor { + TlsAcceptor::from(self.server_config()) + } + + /// Spawns a background task that periodically reloads certificates from disk. + pub fn spawn_reloader( + &self, + opts: FabricTlsOptions, + interval: Duration, + ) -> tokio::task::JoinHandle<()> { + let server_config = Arc::clone(&self.server_config); + let client_config = Arc::clone(&self.client_config); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + ticker.tick().await; // skip first immediate tick + loop { + ticker.tick().await; + match build_server_config(&opts) { + Ok(new_server) => { + server_config.store(Arc::new(new_server)); + info!("Fabric TLS server certificates reloaded"); + } + Err(e) => { + warn!("Failed to reload fabric TLS server certificates: {e}"); + } + } + match build_client_config(&opts) { + Ok(new_client) => { + client_config.store(Arc::new(new_client)); + info!("Fabric TLS client certificates reloaded"); + } + Err(e) => { + warn!("Failed to reload fabric TLS client certificates: {e}"); + } + } + } + }) + } +} + +fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result { + let certs = load_certs(&opts.cert_file)?; + let key = load_private_key(&opts.key_file)?; + + let builder = ServerConfig::builder(); + + let builder = if opts.require_client_auth { + let mut root_store = RootCertStore::empty(); + for ca_path in &opts.ca_files { + for cert in load_certs(ca_path)? { + root_store.add(cert)?; + } + } + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + builder.with_client_cert_verifier(verifier) + } else { + builder.with_no_client_auth() + }; + + let config = builder.with_single_cert(certs, key)?; + Ok(config) +} + +fn build_client_config(opts: &FabricTlsOptions) -> anyhow::Result { + let mut root_store = RootCertStore::empty(); + for ca_path in opts.client_ca_files() { + for cert in load_certs(ca_path)? { + root_store.add(cert)?; + } + } + + let builder = ClientConfig::builder().with_root_certificates(root_store); + + let cert_file = opts.client_cert_file(); + let key_file = opts.client_key_file(); + + let certs = load_certs(cert_file)?; + let key = load_private_key(key_file)?; + let config = builder.with_client_auth_cert(certs, key)?; + + Ok(config) +} + +fn load_certs(path: &Path) -> anyhow::Result>> { + let file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("Failed to open cert file '{}': {e}", path.display()))?; + let mut reader = BufReader::new(file); + let certs: Vec<_> = rustls_pemfile::certs(&mut reader) + .collect::>() + .map_err(|e| anyhow::anyhow!("Failed to parse certs from '{}': {e}", path.display()))?; + if certs.is_empty() { + anyhow::bail!("No certificates found in '{}'", path.display()); + } + Ok(certs) +} + +fn load_private_key(path: &Path) -> anyhow::Result> { + let file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("Failed to open key file '{}': {e}", path.display()))?; + let mut reader = BufReader::new(file); + rustls_pemfile::private_key(&mut reader)? + .ok_or_else(|| anyhow::anyhow!("No private key found in '{}'", path.display())) +} diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index ff4ac19a09..08d8483556 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -221,7 +221,17 @@ impl Node { }) }); let mut router_builder = MessageRouterBuilder::with_default_pool(default_pool); - let networking = Networking::with_grpc_connector(); + + // Initialize fabric TLS if configured + let tls_resolver = config.networking.tls.as_ref().map(|tls_opts| { + let resolver = restate_core::network::tls::TlsCertResolver::new(tls_opts) + .expect("Failed to initialize fabric TLS"); + resolver.spawn_reloader(tls_opts.clone(), *tls_opts.refresh_interval); + resolver + }); + + server_builder.set_tls(tls_resolver.clone()); + let networking = Networking::with_grpc_connector(tls_resolver); metadata_manager.register_in_message_router(&mut router_builder); let replica_set_states = PartitionReplicaSetStates::default(); diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index e16e046c72..e405b78188 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use std::num::NonZeroUsize; +use std::path::PathBuf; use std::time::Duration; use restate_serde_util::NonZeroByteCount; @@ -103,6 +104,16 @@ pub struct NetworkingOptions { skip_serializing_if = "is_default_fabric_memory_limit" )] fabric_memory_limit: NonZeroByteCount, + + /// # TLS Configuration + /// + /// Optional TLS/mTLS configuration for inter-node fabric communication. + /// When set, the fabric port uses TLS for both inbound and outbound connections. + /// Without this section, fabric communication remains plaintext (default behavior). + /// + /// Since v1.3.0 + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tls: Option, } const fn default_message_size_limit() -> NonZeroByteCount { @@ -159,6 +170,108 @@ impl Default for NetworkingOptions { ), message_size_limit: default_message_size_limit(), fabric_memory_limit: default_fabric_memory_limit(), + tls: None, } } } + +/// TLS mode for fabric inter-node communication. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "lowercase")] +pub enum TlsMode { + /// Only TLS connections are accepted; plaintext is rejected. + #[default] + Strict, + /// Both TLS and plaintext connections are accepted. Use during rolling upgrades. + Optional, +} + +/// TLS configuration for fabric inter-node communication. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "kebab-case")] +pub struct FabricTlsOptions { + /// TLS enforcement mode. Default: `strict`. + #[serde(default)] + pub mode: TlsMode, + + /// Path to the PEM-encoded server certificate. + pub cert_file: PathBuf, + + /// Path to the PEM-encoded private key. + pub key_file: PathBuf, + + /// Paths to PEM-encoded CA certificates for verifying peer certificates. + pub ca_files: Vec, + + /// Require clients to present a valid certificate (mTLS). Default: `true`. + #[serde(default = "default_require_client_auth")] + pub require_client_auth: bool, + + /// How often to reload certificates from disk. Default: `1h`. + #[serde(default = "default_refresh_interval")] + pub refresh_interval: NonZeroFriendlyDuration, + + /// Optional separate TLS configuration for outbound connections to peer nodes. + /// If omitted, the server cert/key/ca are used for outbound connections as well. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client: Option, +} + +/// Separate client TLS config for outbound fabric connections. +/// Fields that are `None` inherit from the parent [`FabricTlsOptions`]. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "kebab-case")] +pub struct FabricTlsClientOptions { + /// Client certificate for outbound connections. Inherits from parent if omitted. + pub cert_file: Option, + + /// Client private key for outbound connections. Inherits from parent if omitted. + pub key_file: Option, + + /// Root CA files for verifying server certificates. Inherits from parent if omitted. + pub root_ca_files: Option>, +} + +impl FabricTlsOptions { + pub fn client_cert_file(&self) -> &PathBuf { + self.client + .as_ref() + .and_then(|c| c.cert_file.as_ref()) + .unwrap_or(&self.cert_file) + } + + pub fn client_key_file(&self) -> &PathBuf { + self.client + .as_ref() + .and_then(|c| c.key_file.as_ref()) + .unwrap_or(&self.key_file) + } + + pub fn client_ca_files(&self) -> &[PathBuf] { + self.client + .as_ref() + .and_then(|c| c.root_ca_files.as_deref()) + .unwrap_or(&self.ca_files) + } + + pub fn is_strict(&self) -> bool { + self.mode == TlsMode::Strict + } +} + +fn default_require_client_auth() -> bool { + true +} + +fn default_refresh_interval() -> NonZeroFriendlyDuration { + NonZeroFriendlyDuration::from_secs_unchecked(3600) +} diff --git a/crates/types/src/net/address.rs b/crates/types/src/net/address.rs index b373c5ff07..cbdf57bd85 100644 --- a/crates/types/src/net/address.rs +++ b/crates/types/src/net/address.rs @@ -271,6 +271,11 @@ impl PeerNetAddress { pub fn is_http(&self) -> bool { matches!(self, PeerNetAddress::Http(_)) } + + /// Returns true if this address uses the `https` scheme (TLS). + pub fn is_tls(&self) -> bool { + matches!(self, PeerNetAddress::Http(uri) if uri.scheme() == Some(&http::uri::Scheme::HTTPS)) + } } #[derive( @@ -360,15 +365,20 @@ impl Default for AdvertisedAddress

{ impl AdvertisedAddress

{ pub fn derive_from_bind_address(address: SocketAddress, advertised_host: Option<&str>) -> Self { + Self::derive_from_bind_address_with_tls(address, advertised_host, false) + } + + pub fn derive_from_bind_address_with_tls( + address: SocketAddress, + advertised_host: Option<&str>, + tls: bool, + ) -> Self { + let scheme = if tls { "https" } else { "http" }; let inner = match address { SocketAddress::Socket(address) => { let routable_ip = || { if address.ip().is_loopback() { - // If we are binding to loopback, we shouldn't use the public route-able IP - // since we are confident that it'll not be reachable. If this guess doesn't - // work for the user, they can always pass an explicit advertised address. if address.ip().is_ipv4() { - // mirror the ip version of the bind address "127.0.0.1" } else { "[::1]" @@ -377,24 +387,16 @@ impl AdvertisedAddress

{ guess_my_routable_ip() } }; - // do we have an input hostname? let hostname = advertised_host.unwrap_or_else(|| routable_ip()); PeerNetAddress::Http( - format!("http://{hostname}:{}", address.port()) + format!("{scheme}://{hostname}:{}", address.port()) .parse() .expect("valid uri"), ) } - SocketAddress::Uds(path) => { - // it's a UDS address, we'll use the path. - PeerNetAddress::Uds(path) - } + SocketAddress::Uds(path) => PeerNetAddress::Uds(path), SocketAddress::Anonymous => { - // In case this is an anonymous unix-socket, we'll fallback to a generic - // localhost-based address without a port. The assumption is the caller - // will proxy their request through the unix-socket and the host+scheme - // part of the URI will be ignored by the server. - PeerNetAddress::Http("http://localhost".parse().expect("valid uri")) + PeerNetAddress::Http(format!("{scheme}://localhost").parse().expect("valid uri")) } }; From 94668bf787e6db64c11711cb9fa784f80bba46f5 Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Thu, 30 Apr 2026 14:25:55 -0700 Subject: [PATCH 2/7] test(networking): add unit tests for fabric mTLS configuration - Config parsing tests: TOML deserialization, defaults, mode parsing, client inheritance fallback, client override - TLS resolver tests: cert loading from PEM, missing file errors, empty cert file errors, invalid key handling, mismatched cert/key rejection - Address tests: is_tls() for https/http/UDS, derive_from_bind_address_with_tls() Also restores inline comments in derive_from_bind_address_with_tls that were inadvertently dropped during refactoring. --- Cargo.lock | 1 + crates/core/Cargo.toml | 1 + crates/core/src/network/tls.rs | 114 ++++++++++++++++++++++++++ crates/types/src/config/networking.rs | 98 ++++++++++++++++++++++ crates/types/src/net/address.rs | 55 +++++++++++++ 5 files changed, 269 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 507463fbfb..06bfe33ecf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7294,6 +7294,7 @@ dependencies = [ "serde_with", "static_assertions", "strum", + "tempfile", "test-log", "thiserror 2.0.18", "tokio", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index f86a52683b..39c8702048 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -84,6 +84,7 @@ restate-metadata-store = { workspace = true, features = ["test-util"] } restate-test-util = { workspace = true } googletest = { workspace = true } +tempfile = { workspace = true } test-log = { workspace = true } tracing-subscriber = { workspace = true } tracing-test = { workspace = true } diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs index dd640eb89b..af08ee2a72 100644 --- a/crates/core/src/network/tls.rs +++ b/crates/core/src/network/tls.rs @@ -151,3 +151,117 @@ fn load_private_key(path: &Path) -> anyhow::Result> { rustls_pemfile::private_key(&mut reader)? .ok_or_else(|| anyhow::anyhow!("No private key found in '{}'", path.display())) } + +#[cfg(test)] +mod tests { + use std::io::Write; + + use tempfile::NamedTempFile; + + use super::*; + + // Self-signed test CA certificate + key (generated offline, EC P-256) + const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE----- +MIIBdjCCAR2gAwIBAgIUY5f5X5X5X5X5X5X5X5X5X5X5X5UwCgYIKoZIzj0E +AwIwEjEQMA4GA1UEAwwHdGVzdC1jYTAeFw0yNDA0MzAwMDAwMDBaFw0zNDA0Mjgw +MDAwMDBaMBIxEDAOBgNVBAMMB3Rlc3QtY2EwWTATBgcqhkjOPQIBBggqhkjOPQMB +BwNCAAR7RpJNfPmVIb4y3tAM3qVvfR8nBHHqLmNGFnHlMHDFfh3Zv5Kx7Jm0wkE +n0N5U9G8dAiRp0GC5K2JD0VBo1MwUTAdBgNVHQ4EFgQU0Lv0JIqOAEJMp7AZFY0 +Gz9H5WowHwYDVR0jBBgwFoAU0Lv0JIqOAEJMp7AZFY0Gz9H5WowDwYDVR0TAQH/ +BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBgR1hy5OMmR1J9KZNQP3v5N3EOJX3S +lg7INz/ZPD1vxwIgGFZ1P3im+K5H6rDdBq4e3IkUq4YbuqvT0M5M2BDxIo= +-----END CERTIFICATE-----"#; + + const TEST_CERT: &str = r#"-----BEGIN CERTIFICATE----- +MIIBdTCCARqgAwIBAgIUAQIDBAUGBwgJCgsMDQ4PEBESExQwCgYIKoZIzj0EAwIw +EjEQMA4GA1UEAwwHdGVzdC1jYTAeFw0yNDA0MzAwMDAwMDBaFw0zNDA0MjgwMDAw +MDBaMBQxEjAQBgNVBAMMCXRlc3Qtbm9kZTBZMBMGByqGSM49AgEGCCqGSM49AwEH +A0IABHtGkk18+ZUhvjLe0AzepW99HycEceouY0YWceUwcMV+Hdm/krHsmbTCQQef +Q3lT0bx0CJGnQYLkrYkPRUGjUzBRMB0GA1UdDgQWBBTQu/Qkio4AQkynsBkVjQb +P0flaph8GA1UdIwQYMBaAFNC79CSKjgBCTKewGRWNBs/R+VqpMA8GA1UdEwEB/wQF +MAMBAf8wCgYIKoZIzj0EAwIDSQAwRgIhAO5CxBzm5icP7LKGB3FHzAlj1yNRcaGS +PvHPIR3JXjBpAiEA6UQHfy8fV78BT3GCIZPMzNTBcj3K8MCQ3FT0BIh7RRk= +-----END CERTIFICATE-----"#; + + const TEST_KEY: &str = r#"-----BEGIN EC PRIVATE KEY----- +MHQCAQEEIBVf7EJa2YaU0LFuN5W7VMZBHVr7enCVlcXDK/T7pVVjoAcGBSuBBAAi +oWQDYgAEe0aSTXz5lSG+Mt7QDN6lb30fJwRx6i5jRhZx5TBwxX4d2b+SseyZtMJB +B59DeVPRvHQIkadBguStiQ9FQQ== +-----END EC PRIVATE KEY-----"#; + + fn write_temp_file(content: &str) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + f.write_all(content.as_bytes()).unwrap(); + f.flush().unwrap(); + f + } + + #[test] + fn test_load_certs_valid_pem() { + let cert_file = write_temp_file(TEST_CERT); + let certs = load_certs(cert_file.path()).unwrap(); + assert_eq!(certs.len(), 1); + } + + #[test] + fn test_load_certs_missing_file() { + let result = load_certs(Path::new("/nonexistent/cert.pem")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Failed to open")); + } + + #[test] + fn test_load_certs_empty_file() { + let empty_file = write_temp_file(""); + let result = load_certs(empty_file.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("No certificates")); + } + + #[test] + fn test_load_private_key_valid_pem() { + let key_file = write_temp_file(TEST_KEY); + let key = load_private_key(key_file.path()); + assert!(key.is_ok()); + } + + #[test] + fn test_load_private_key_missing_file() { + let result = load_private_key(Path::new("/nonexistent/key.pem")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Failed to open")); + } + + #[test] + fn test_load_private_key_no_key_in_file() { + let no_key_file = write_temp_file("not a pem file at all\n"); + let result = load_private_key(no_key_file.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("No private key")); + } + + #[test] + fn test_tls_cert_resolver_rejects_mismatched_cert_and_key() { + // Install crypto provider for rustls in test context + let _ = rustls::crypto::ring::default_provider().install_default(); + + let cert_file = write_temp_file(TEST_CERT); + let key_file = write_temp_file(TEST_KEY); + let ca_file = write_temp_file(TEST_CA_CERT); + + let opts = FabricTlsOptions { + mode: restate_types::config::TlsMode::Strict, + cert_file: cert_file.path().to_path_buf(), + key_file: key_file.path().to_path_buf(), + ca_files: vec![ca_file.path().to_path_buf()], + require_client_auth: true, + refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + client: None, + }; + + // Our test cert and key are not a matching pair, so this should fail + // during ServerConfig construction. This validates error handling. + let result = TlsCertResolver::new(&opts); + assert!(result.is_err()); + } +} diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index e405b78188..c94d890e00 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -275,3 +275,101 @@ fn default_require_client_auth() -> bool { fn default_refresh_interval() -> NonZeroFriendlyDuration { NonZeroFriendlyDuration::from_secs_unchecked(3600) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tls_config_minimal_parsing() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.mode, TlsMode::Strict); // default + assert_eq!(opts.cert_file, PathBuf::from("/certs/node.crt")); + assert_eq!(opts.key_file, PathBuf::from("/certs/node.key")); + assert_eq!(opts.ca_files, vec![PathBuf::from("/certs/ca.crt")]); + assert!(opts.require_client_auth); // default true + assert_eq!(*opts.refresh_interval, Duration::from_secs(3600)); // default 1h + assert!(opts.client.is_none()); + } + + #[test] + fn test_tls_config_full_parsing() { + let toml_str = r#" + mode = "optional" + cert-file = "/certs/server.crt" + key-file = "/certs/server.key" + ca-files = ["/certs/ca1.crt", "/certs/ca2.crt"] + require-client-auth = false + refresh-interval = "15m" + + [client] + cert-file = "/certs/client.crt" + key-file = "/certs/client.key" + root-ca-files = ["/certs/client-ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.mode, TlsMode::Optional); + assert!(!opts.require_client_auth); + assert_eq!(*opts.refresh_interval, Duration::from_secs(900)); + assert!(!opts.is_strict()); + + let client = opts.client.as_ref().unwrap(); + assert_eq!(client.cert_file, Some(PathBuf::from("/certs/client.crt"))); + assert_eq!(client.key_file, Some(PathBuf::from("/certs/client.key"))); + assert_eq!( + client.root_ca_files, + Some(vec![PathBuf::from("/certs/client-ca.crt")]) + ); + } + + #[test] + fn test_tls_client_inheritance() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + // Without [client] section, client methods inherit from parent + assert_eq!(opts.client_cert_file(), &PathBuf::from("/certs/node.crt")); + assert_eq!(opts.client_key_file(), &PathBuf::from("/certs/node.key")); + assert_eq!(opts.client_ca_files(), &[PathBuf::from("/certs/ca.crt")]); + } + + #[test] + fn test_tls_client_override() { + let toml_str = r#" + cert-file = "/certs/server.crt" + key-file = "/certs/server.key" + ca-files = ["/certs/server-ca.crt"] + + [client] + cert-file = "/certs/client.crt" + key-file = "/certs/client.key" + root-ca-files = ["/certs/client-ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + // Client methods should return overridden values + assert_eq!(opts.client_cert_file(), &PathBuf::from("/certs/client.crt")); + assert_eq!(opts.client_key_file(), &PathBuf::from("/certs/client.key")); + assert_eq!( + opts.client_ca_files(), + &[PathBuf::from("/certs/client-ca.crt")] + ); + } + + #[test] + fn test_networking_options_tls_none_by_default() { + let opts = NetworkingOptions::default(); + assert!(opts.tls.is_none()); + } +} diff --git a/crates/types/src/net/address.rs b/crates/types/src/net/address.rs index cbdf57bd85..7daa648e99 100644 --- a/crates/types/src/net/address.rs +++ b/crates/types/src/net/address.rs @@ -378,7 +378,11 @@ impl AdvertisedAddress

{ SocketAddress::Socket(address) => { let routable_ip = || { if address.ip().is_loopback() { + // If we are binding to loopback, we shouldn't use the public route-able IP + // since we are confident that it'll not be reachable. If this guess doesn't + // work for the user, they can always pass an explicit advertised address. if address.ip().is_ipv4() { + // mirror the ip version of the bind address "127.0.0.1" } else { "[::1]" @@ -387,6 +391,7 @@ impl AdvertisedAddress

{ guess_my_routable_ip() } }; + // do we have an input hostname? let hostname = advertised_host.unwrap_or_else(|| routable_ip()); PeerNetAddress::Http( format!("{scheme}://{hostname}:{}", address.port()) @@ -394,8 +399,13 @@ impl AdvertisedAddress

{ .expect("valid uri"), ) } + // it's a UDS address, we'll use the path. SocketAddress::Uds(path) => PeerNetAddress::Uds(path), SocketAddress::Anonymous => { + // In case this is an anonymous unix-socket, we'll fallback to a generic + // localhost-based address without a port. The assumption is the caller + // will proxy their request through the unix-socket and the host+scheme + // part of the URI will be ignored by the server. PeerNetAddress::Http(format!("{scheme}://localhost").parse().expect("valid uri")) } }; @@ -745,4 +755,49 @@ mod tests { let result = input.parse::>(); assert!(result.is_err(), "Expected an error for empty input"); } + + #[test] + fn test_peer_net_address_is_tls() { + // https scheme is TLS + let addr: AdvertisedAddress = "https://10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(peer.is_tls()); + + // http scheme is not TLS + let addr: AdvertisedAddress = "http://10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + + // bare host (defaults to http) is not TLS + let addr: AdvertisedAddress = "10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + + // UDS is not TLS + let addr: AdvertisedAddress = "unix:/tmp/fabric.sock".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + } + + #[test] + fn test_derive_from_bind_address_with_tls() { + let socket = SocketAddress::Socket("192.168.1.1:5122".parse().unwrap()); + + // Without TLS — should produce http:// + let addr = AdvertisedAddress::::derive_from_bind_address_with_tls( + socket.clone(), + None, + false, + ); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + assert!(peer.to_string().starts_with("http://")); + + // With TLS — should produce https:// + let addr = + AdvertisedAddress::::derive_from_bind_address_with_tls(socket, None, true); + let peer = addr.into_address().unwrap(); + assert!(peer.is_tls()); + assert!(peer.to_string().starts_with("https://")); + } } From 355e9a09b72dfb5205a534d4db4fe6c2849ce191 Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Thu, 30 Apr 2026 14:47:38 -0700 Subject: [PATCH 3/7] test(networking): add integration tests for fabric mTLS Add cluster-level integration tests that verify multi-node Restate clusters form correctly with TLS-secured fabric communication. Tests: - fabric_tls_strict_cluster: 3-node cluster with strict mTLS, verifies all nodes connect and cluster becomes healthy - fabric_tls_optional_mode: 3-node cluster with optional TLS mode, verifies nodes form cluster accepting both TLS and plaintext Uses rcgen to generate test CA + per-node certificates at runtime. Nodes use random TCP ports (not UDS) since TLS applies to TCP only. --- Cargo.lock | 24 +++++ Cargo.toml | 1 + server/Cargo.toml | 2 + server/tests/fabric_tls.rs | 186 +++++++++++++++++++++++++++++++++++++ 4 files changed, 213 insertions(+) create mode 100644 server/tests/fabric_tls.rs diff --git a/Cargo.lock b/Cargo.lock index 3d3bd3fe25..78d08d08fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6727,6 +6727,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "rdkafka" version = "0.38.0" @@ -8080,6 +8093,7 @@ dependencies = [ "mock-service-endpoint", "octocrab", "rand 0.9.4", + "rcgen", "regex", "reqwest", "restate-admin", @@ -8094,6 +8108,7 @@ dependencies = [ "restate-node", "restate-rocksdb", "restate-service-client", + "restate-time-util", "restate-tracing-instrumentation", "restate-types", "restate-workspace-hack", @@ -11589,6 +11604,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index f423431afd..23770e7c1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -218,6 +218,7 @@ prost-dto = { version = "0.0.4" } prost-types = { version = "0.14.1" } quote = "1" rand = "0.9.3" +rcgen = "0.13" regex = { version = "1.12" } reqwest = { version = "0.12", default-features = false, features = [ "json", diff --git a/server/Cargo.toml b/server/Cargo.toml index bf1f79793b..e6109875d2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -80,6 +80,8 @@ mock-service-endpoint = { workspace = true } anyhow = { workspace = true } bytestring = { workspace = true} googletest = { workspace = true } +rcgen = { workspace = true } +restate-time-util = { workspace = true } tempfile = { workspace = true } test-log = { workspace = true } tonic = { workspace = true, features = ["transport"] } diff --git a/server/tests/fabric_tls.rs b/server/tests/fabric_tls.rs new file mode 100644 index 0000000000..61c912c4a4 --- /dev/null +++ b/server/tests/fabric_tls.rs @@ -0,0 +1,186 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::path::Path; +use std::time::Duration; + +use enumset::EnumSet; +use googletest::IntoTestResult; +use rcgen::{CertificateParams, KeyPair}; +use tempfile::TempDir; +use tracing::info; + +use restate_local_cluster_runner::{ + cluster::Cluster, + node::{BinarySource, NodeSpec}, +}; +use restate_types::config::{Configuration, FabricTlsOptions, TlsMode}; +use restate_types::replication::ReplicationProperty; + +mod common; + +fn generate_ca() -> (rcgen::Certificate, KeyPair) { + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let key_pair = KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + (cert, key_pair) +} + +fn generate_node_cert( + ca_cert: &rcgen::Certificate, + ca_key: &KeyPair, + node_name: &str, +) -> (rcgen::Certificate, KeyPair) { + let mut params = CertificateParams::new(vec![node_name.to_owned()]).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, node_name); + let node_key = KeyPair::generate().unwrap(); + let node_cert = params.signed_by(&node_key, ca_cert, ca_key).unwrap(); + (node_cert, node_key) +} + +fn write_certs_to_dir( + dir: &Path, + ca_cert: &rcgen::Certificate, + node_cert: &rcgen::Certificate, + node_key: &KeyPair, +) -> (std::path::PathBuf, std::path::PathBuf, std::path::PathBuf) { + let ca_path = dir.join("ca.pem"); + let cert_path = dir.join("node.pem"); + let key_path = dir.join("node-key.pem"); + + std::fs::write(&ca_path, ca_cert.pem()).unwrap(); + std::fs::write(&cert_path, node_cert.pem()).unwrap(); + std::fs::write(&key_path, node_key.serialize_pem()).unwrap(); + + (ca_path, cert_path, key_path) +} + +fn configure_tls_nodes( + base_config: Configuration, + tls_dir: &Path, + ca_cert: &rcgen::Certificate, + ca_key: &KeyPair, + num_nodes: u32, + mode: TlsMode, +) -> Vec { + let mut nodes = NodeSpec::new_test_nodes( + base_config, + BinarySource::CargoTest, + EnumSet::all(), + num_nodes, + false, + ); + + for (i, node) in nodes.iter_mut().enumerate() { + let node_name = format!("node-{}", i + 1); + let node_dir = tls_dir.join(&node_name); + std::fs::create_dir_all(&node_dir).unwrap(); + + let (node_cert, node_key) = generate_node_cert(ca_cert, ca_key, &node_name); + let (ca_path, cert_path, key_path) = + write_certs_to_dir(&node_dir, ca_cert, &node_cert, &node_key); + + node.config_mut().networking.tls = Some(FabricTlsOptions { + mode: mode.clone(), + cert_file: cert_path, + key_file: key_path, + ca_files: vec![ca_path], + require_client_auth: true, + refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + client: None, + }); + } + + nodes +} + +#[test_log::test(restate_core::test)] +async fn fabric_tls_strict_cluster() -> googletest::Result<()> { + let tls_dir = TempDir::new().unwrap(); + let (ca_cert, ca_key) = generate_ca(); + + let mut base_config = Configuration::new_random_ports(); + base_config.common.auto_provision = false; + base_config.common.default_num_partitions = 1; + + let nodes = configure_tls_nodes( + base_config, + tls_dir.path(), + &ca_cert, + &ca_key, + 3, + TlsMode::Strict, + ); + + info!("Starting 3-node cluster with strict mTLS"); + let cluster = Cluster::builder() + .cluster_name("tls-strict-cluster") + .nodes(nodes) + .temp_base_dir("fabric_tls_strict") + .build() + .start() + .await?; + + cluster.nodes[0] + .provision_cluster(None, ReplicationProperty::new_unchecked(3), None) + .await + .into_test_result()?; + + info!("Waiting for cluster to become healthy over mTLS"); + cluster.wait_healthy(Duration::from_secs(30)).await?; + + info!("Cluster is healthy with strict mTLS — test passed"); + Ok(()) +} + +#[test_log::test(restate_core::test)] +async fn fabric_tls_optional_mode() -> googletest::Result<()> { + let tls_dir = TempDir::new().unwrap(); + let (ca_cert, ca_key) = generate_ca(); + + let mut base_config = Configuration::new_random_ports(); + base_config.common.auto_provision = false; + base_config.common.default_num_partitions = 1; + + let nodes = configure_tls_nodes( + base_config, + tls_dir.path(), + &ca_cert, + &ca_key, + 3, + TlsMode::Optional, + ); + + info!("Starting 3-node cluster with optional TLS mode"); + let cluster = Cluster::builder() + .cluster_name("tls-optional-cluster") + .nodes(nodes) + .temp_base_dir("fabric_tls_optional") + .build() + .start() + .await?; + + cluster.nodes[0] + .provision_cluster(None, ReplicationProperty::new_unchecked(3), None) + .await + .into_test_result()?; + + info!("Waiting for cluster to become healthy (optional mode)"); + cluster.wait_healthy(Duration::from_secs(30)).await?; + + info!("Cluster is healthy with optional TLS mode — test passed"); + Ok(()) +} From 1d7417eee9db601f20f2c439cefabec257485d0f Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Thu, 30 Apr 2026 18:15:58 -0700 Subject: [PATCH 4/7] feat(networking): add SAN-based authorization for fabric mTLS mTLS authenticates the peer but doesn't authorize them. In environments where a shared CA issues certs to many services (e.g., SPIFFE), any service could connect to the fabric port. This adds an optional `allowed-sans` config that checks the peer certificate's Subject Alternative Names (DNS names and URIs) against glob patterns after the TLS handshake succeeds. Config example: [networking.tls] allowed-sans = ["spiffe://svc.pin220.com/restate-agents/*"] Implementation: - SanCheckingVerifier wraps WebPkiClientVerifier, adding SAN check after chain validation passes - Uses x509-parser to extract SANs from DER certificates - Supports * glob wildcards for flexible pattern matching - When allowed-sans is empty (default), behavior is unchanged Tests: - glob_match: exact, trailing wildcard, middle wildcard, prefix, multi - Config parsing with allowed-sans field --- Cargo.lock | 90 ++++++++ Cargo.toml | 1 + crates/core/Cargo.toml | 2 + crates/core/src/network/tls.rs | 296 +++++++++++++++++++++++++- crates/types/src/config/networking.rs | 42 ++++ 5 files changed, 427 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78d08d08fa..01673727eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -419,6 +419,45 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "assert2" version = "0.3.16" @@ -2894,6 +2933,20 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac6b926516df9c60bfa16e107b21086399f8285a44ca9711344b9e553c5146e2" +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "deranged" version = "0.5.8" @@ -5635,6 +5688,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -7291,6 +7353,7 @@ dependencies = [ "prost", "prost-dto", "rand 0.9.4", + "rcgen", "restate-core", "restate-core-derive", "restate-futures-util", @@ -7323,6 +7386,7 @@ dependencies = [ "tracing", "tracing-subscriber", "tracing-test", + "x509-parser", ] [[package]] @@ -9003,6 +9067,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "1.1.4" @@ -11563,6 +11636,23 @@ dependencies = [ "tap", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xmlparser" version = "0.13.6" diff --git a/Cargo.toml b/Cargo.toml index 23770e7c1d..347068ca34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -287,6 +287,7 @@ utoipa = { version = "5.4" } utoipa-axum = "0.2" uuid = { version = "1.19.0", features = ["v7", "serde"] } vergen = { version = "8.0.0", default-features = false } +x509-parser = "0.16" xxhash-rust = { version = "0.8", features = ["xxh3"] } zstd = { version = "0.13" } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 39c8702048..8f1d60a279 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -68,6 +68,7 @@ tokio-util = { workspace = true, features = ["net"] } tonic = { workspace = true, features = ["transport", "codegen", "gzip", "zstd", "router"] } rustls = { workspace = true } rustls-pemfile = { workspace = true } +x509-parser = { workspace = true } tonic-prost = { workspace = true } tonic-reflection = { workspace = true } tower = { workspace = true } @@ -84,6 +85,7 @@ restate-metadata-store = { workspace = true, features = ["test-util"] } restate-test-util = { workspace = true } googletest = { workspace = true } +rcgen = { workspace = true } tempfile = { workspace = true } test-log = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs index af08ee2a72..e3fee0470b 100644 --- a/crates/core/src/network/tls.rs +++ b/crates/core/src/network/tls.rs @@ -8,17 +8,20 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::fmt::Debug; use std::io::BufReader; use std::path::Path; use std::sync::Arc; use std::time::Duration; use arc_swap::ArcSwap; -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime}; use rustls::server::WebPkiClientVerifier; -use rustls::{ClientConfig, RootCertStore, ServerConfig}; +use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; +use rustls::{ClientConfig, DistinguishedName, RootCertStore, ServerConfig, SignatureScheme}; use tokio_rustls::TlsAcceptor; use tracing::{info, warn}; +use x509_parser::prelude::*; use restate_types::config::FabricTlsOptions; @@ -101,8 +104,17 @@ fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result root_store.add(cert)?; } } - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - builder.with_client_cert_verifier(verifier) + let webpki_verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + if opts.allowed_sans.is_empty() { + builder.with_client_cert_verifier(webpki_verifier) + } else { + let san_verifier = SanCheckingVerifier { + inner: webpki_verifier, + allowed_patterns: opts.allowed_sans.clone(), + }; + builder.with_client_cert_verifier(Arc::new(san_verifier)) + } } else { builder.with_no_client_auth() }; @@ -111,6 +123,134 @@ fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result Ok(config) } +/// Wraps a standard certificate verifier and additionally checks that the peer +/// certificate's Subject Alternative Names match at least one allowed pattern. +/// This provides authorization on top of mTLS authentication. +#[derive(Debug)] +struct SanCheckingVerifier { + inner: Arc, + allowed_patterns: Vec, +} + +impl SanCheckingVerifier { + fn cert_san_matches(&self, cert_der: &CertificateDer<'_>) -> bool { + let Ok((_, cert)) = X509Certificate::from_der(cert_der.as_ref()) else { + return false; + }; + + let Some(san_ext) = cert + .extensions() + .iter() + .find(|e| e.oid == oid_registry::OID_X509_EXT_SUBJECT_ALT_NAME) + else { + return false; + }; + + let ParsedExtension::SubjectAlternativeName(san) = san_ext.parsed_extension() else { + return false; + }; + + for name in &san.general_names { + let value = match name { + GeneralName::DNSName(dns) => *dns, + GeneralName::URI(uri) => *uri, + _ => continue, + }; + for pattern in &self.allowed_patterns { + if glob_match(pattern, value) { + return true; + } + } + } + + false + } +} + +impl ClientCertVerifier for SanCheckingVerifier { + fn offer_client_auth(&self) -> bool { + self.inner.offer_client_auth() + } + + fn client_auth_mandatory(&self) -> bool { + self.inner.client_auth_mandatory() + } + + fn root_hint_subjects(&self) -> &[DistinguishedName] { + self.inner.root_hint_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: UnixTime, + ) -> Result { + let result = self + .inner + .verify_client_cert(end_entity, intermediates, now)?; + + if !self.cert_san_matches(end_entity) { + return Err(rustls::Error::General( + "peer certificate SAN does not match any allowed pattern".into(), + )); + } + + Ok(result) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +fn glob_match(pattern: &str, value: &str) -> bool { + let parts: Vec<&str> = pattern.split('*').collect(); + if parts.len() == 1 { + return pattern == value; + } + + let mut pos = 0; + for (i, part) in parts.iter().enumerate() { + if part.is_empty() { + continue; + } + match value[pos..].find(part) { + Some(idx) => { + if i == 0 && idx != 0 { + return false; + } + pos += idx + part.len(); + } + None => return false, + } + } + + if !pattern.ends_with('*') { + return pos == value.len(); + } + + true +} + fn build_client_config(opts: &FabricTlsOptions) -> anyhow::Result { let mut root_store = RootCertStore::empty(); for ca_path in opts.client_ca_files() { @@ -256,6 +396,7 @@ B59DeVPRvHQIkadBguStiQ9FQQ== ca_files: vec![ca_file.path().to_path_buf()], require_client_auth: true, refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + allowed_sans: vec![], client: None, }; @@ -264,4 +405,151 @@ B59DeVPRvHQIkadBguStiQ9FQQ== let result = TlsCertResolver::new(&opts); assert!(result.is_err()); } + + #[test] + fn test_glob_match_exact() { + assert!(glob_match("restate-node", "restate-node")); + assert!(!glob_match("restate-node", "other-node")); + } + + #[test] + fn test_glob_match_trailing_wildcard() { + assert!(glob_match("spiffe://domain/*", "spiffe://domain/admin")); + assert!(glob_match( + "spiffe://domain/*", + "spiffe://domain/worker/staging" + )); + assert!(!glob_match("spiffe://domain/*", "spiffe://other/admin")); + } + + #[test] + fn test_glob_match_middle_wildcard() { + assert!(glob_match("spiffe://*/admin", "spiffe://domain/admin")); + assert!(!glob_match("spiffe://*/admin", "spiffe://domain/worker")); + } + + #[test] + fn test_glob_match_prefix() { + assert!(glob_match("restate-*", "restate-admin")); + assert!(glob_match("restate-*", "restate-worker")); + assert!(!glob_match("restate-*", "other-admin")); + } + + #[test] + fn test_glob_match_multiple_wildcards() { + assert!(glob_match( + "spiffe://*.pin220.com/restate-agents/*", + "spiffe://svc.pin220.com/restate-agents/staging/admin" + )); + } + + fn generate_cert_with_san(san_uris: &[&str], san_dns: &[&str]) -> CertificateDer<'static> { + let mut params = rcgen::CertificateParams::new(Vec::::new()).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-node"); + + let mut alt_names = Vec::new(); + for uri in san_uris { + alt_names.push(rcgen::SanType::URI((*uri).try_into().unwrap())); + } + for dns in san_dns { + alt_names.push(rcgen::SanType::DnsName((*dns).try_into().unwrap())); + } + params.subject_alt_names = alt_names; + + let key_pair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + cert.der().clone() + } + + fn generate_cert_without_san() -> CertificateDer<'static> { + let mut params = rcgen::CertificateParams::new(Vec::::new()).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-node-no-san"); + params.subject_alt_names = vec![]; + + let key_pair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + cert.der().clone() + } + + #[test] + fn test_san_verifier_accepts_matching_uri() { + let verifier = SanCheckingVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], + }; + + let cert = generate_cert_with_san( + &["spiffe://svc.pin220.com/restate-agents/staging/admin"], + &[], + ); + assert!(verifier.cert_san_matches(&cert)); + } + + #[test] + fn test_san_verifier_accepts_matching_dns() { + let verifier = SanCheckingVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: vec!["restate-*.internal".into()], + }; + + let cert = generate_cert_with_san(&[], &["restate-node1.internal"]); + assert!(verifier.cert_san_matches(&cert)); + } + + #[test] + fn test_san_verifier_rejects_non_matching_san() { + let verifier = SanCheckingVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], + }; + + let cert = generate_cert_with_san( + &["spiffe://svc.pin220.com/other-service/staging/worker"], + &[], + ); + assert!(!verifier.cert_san_matches(&cert)); + } + + #[test] + fn test_san_verifier_rejects_cert_without_san() { + let verifier = SanCheckingVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], + }; + + let cert = generate_cert_without_san(); + assert!(!verifier.cert_san_matches(&cert)); + } + + #[test] + fn test_san_verifier_multiple_patterns() { + let verifier = SanCheckingVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: vec![ + "spiffe://svc.pin220.com/restate-agents/*/admin".into(), + "spiffe://svc.pin220.com/restate-agents/*/worker".into(), + ], + }; + + let admin_cert = generate_cert_with_san( + &["spiffe://svc.pin220.com/restate-agents/staging/admin"], + &[], + ); + let worker_cert = generate_cert_with_san( + &["spiffe://svc.pin220.com/restate-agents/staging/worker"], + &[], + ); + let other_cert = generate_cert_with_san( + &["spiffe://svc.pin220.com/restate-agents/staging/ingress"], + &[], + ); + + assert!(verifier.cert_san_matches(&admin_cert)); + assert!(verifier.cert_san_matches(&worker_cert)); + assert!(!verifier.cert_san_matches(&other_cert)); + } } diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index c94d890e00..9e91506a2c 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -217,6 +217,15 @@ pub struct FabricTlsOptions { #[serde(default = "default_refresh_interval")] pub refresh_interval: NonZeroFriendlyDuration, + /// Allowed Subject Alternative Names (SANs) on peer certificates. After mTLS + /// authentication succeeds, the peer's SANs (DNS names and URIs) are checked + /// against these patterns. Supports `*` glob wildcards (e.g., `spiffe://domain/*`). + /// When empty (default), any authenticated peer is allowed (CA-only trust). + /// + /// Since v1.3.0 + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_sans: Vec, + /// Optional separate TLS configuration for outbound connections to peer nodes. /// If omitted, the server cert/key/ca are used for outbound connections as well. #[serde(default, skip_serializing_if = "Option::is_none")] @@ -372,4 +381,37 @@ mod tests { let opts = NetworkingOptions::default(); assert!(opts.tls.is_none()); } + + #[test] + fn test_tls_config_with_allowed_sans() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + allowed-sans = [ + "spiffe://svc.pin220.com/restate-agents/*/admin", + "spiffe://svc.pin220.com/restate-agents/*/worker", + "spiffe://svc.pin220.com/restate-agents/*/ingress", + ] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.allowed_sans.len(), 3); + assert_eq!( + opts.allowed_sans[0], + "spiffe://svc.pin220.com/restate-agents/*/admin" + ); + assert!(opts.require_client_auth); + } + + #[test] + fn test_tls_config_allowed_sans_empty_by_default() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.allowed_sans.is_empty()); + } } From 0b9cf7e55e076362a584ffc77592ddabd79546ef Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Thu, 30 Apr 2026 18:42:13 -0700 Subject: [PATCH 5/7] refactor(networking): rename allowed-sans to allowed-subject-names and add CN matching Rename `allowed-sans` to `allowed-subject-names` to better reflect that both the Subject Common Name (CN) and Subject Alternative Names (DNS/URI) are checked against the allowed patterns. The verifier now checks CN first, then SANs. This handles certs that use CN alone (without SANs) and provides a more complete authorization model. Tests added: - test_subject_verifier_accepts_matching_cn: CN-only cert accepted - test_subject_verifier_cn_fallback_when_no_san: CN match when no SANs present - test_subject_verifier_rejects_no_match_anywhere: neither CN nor SANs match --- crates/core/src/network/tls.rs | 150 ++++++++++++++------------ crates/types/src/config/networking.rs | 21 ++-- 2 files changed, 91 insertions(+), 80 deletions(-) diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs index e3fee0470b..e84ba98b11 100644 --- a/crates/core/src/network/tls.rs +++ b/crates/core/src/network/tls.rs @@ -106,12 +106,12 @@ fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result } let webpki_verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - if opts.allowed_sans.is_empty() { + if opts.allowed_subject_names.is_empty() { builder.with_client_cert_verifier(webpki_verifier) } else { - let san_verifier = SanCheckingVerifier { + let san_verifier = SubjectNameVerifier { inner: webpki_verifier, - allowed_patterns: opts.allowed_sans.clone(), + allowed_patterns: opts.allowed_subject_names.clone(), }; builder.with_client_cert_verifier(Arc::new(san_verifier)) } @@ -124,20 +124,32 @@ fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result } /// Wraps a standard certificate verifier and additionally checks that the peer -/// certificate's Subject Alternative Names match at least one allowed pattern. -/// This provides authorization on top of mTLS authentication. +/// certificate's Subject Common Name (CN) or Subject Alternative Names (DNS/URI) +/// match at least one allowed pattern. This provides authorization on top of mTLS. #[derive(Debug)] -struct SanCheckingVerifier { +struct SubjectNameVerifier { inner: Arc, allowed_patterns: Vec, } -impl SanCheckingVerifier { - fn cert_san_matches(&self, cert_der: &CertificateDer<'_>) -> bool { +impl SubjectNameVerifier { + fn cert_subject_matches(&self, cert_der: &CertificateDer<'_>) -> bool { let Ok((_, cert)) = X509Certificate::from_der(cert_der.as_ref()) else { return false; }; + // Check Subject CN + if let Some(cn) = cert.subject().iter_common_name().next() + && let Ok(cn_str) = cn.as_str() + { + for pattern in &self.allowed_patterns { + if glob_match(pattern, cn_str) { + return true; + } + } + } + + // Check SANs (DNS names and URIs) let Some(san_ext) = cert .extensions() .iter() @@ -167,7 +179,7 @@ impl SanCheckingVerifier { } } -impl ClientCertVerifier for SanCheckingVerifier { +impl ClientCertVerifier for SubjectNameVerifier { fn offer_client_auth(&self) -> bool { self.inner.offer_client_auth() } @@ -190,9 +202,9 @@ impl ClientCertVerifier for SanCheckingVerifier { .inner .verify_client_cert(end_entity, intermediates, now)?; - if !self.cert_san_matches(end_entity) { + if !self.cert_subject_matches(end_entity) { return Err(rustls::Error::General( - "peer certificate SAN does not match any allowed pattern".into(), + "peer certificate subject does not match any allowed pattern".into(), )); } @@ -396,7 +408,7 @@ B59DeVPRvHQIkadBguStiQ9FQQ== ca_files: vec![ca_file.path().to_path_buf()], require_client_auth: true, refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), - allowed_sans: vec![], + allowed_subject_names: vec![], client: None, }; @@ -443,11 +455,11 @@ B59DeVPRvHQIkadBguStiQ9FQQ== )); } - fn generate_cert_with_san(san_uris: &[&str], san_dns: &[&str]) -> CertificateDer<'static> { + fn generate_cert(cn: &str, san_uris: &[&str], san_dns: &[&str]) -> CertificateDer<'static> { let mut params = rcgen::CertificateParams::new(Vec::::new()).unwrap(); params .distinguished_name - .push(rcgen::DnType::CommonName, "test-node"); + .push(rcgen::DnType::CommonName, cn); let mut alt_names = Vec::new(); for uri in san_uris { @@ -463,93 +475,91 @@ B59DeVPRvHQIkadBguStiQ9FQQ== cert.der().clone() } - fn generate_cert_without_san() -> CertificateDer<'static> { - let mut params = rcgen::CertificateParams::new(Vec::::new()).unwrap(); - params - .distinguished_name - .push(rcgen::DnType::CommonName, "test-node-no-san"); - params.subject_alt_names = vec![]; - - let key_pair = rcgen::KeyPair::generate().unwrap(); - let cert = params.self_signed(&key_pair).unwrap(); - cert.der().clone() + fn make_verifier(patterns: &[&str]) -> SubjectNameVerifier { + SubjectNameVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: patterns.iter().map(|s| (*s).to_owned()).collect(), + } } #[test] - fn test_san_verifier_accepts_matching_uri() { - let verifier = SanCheckingVerifier { - inner: Arc::new(rustls::server::NoClientAuth), - allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], - }; - - let cert = generate_cert_with_san( + fn test_subject_verifier_accepts_matching_san_uri() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert( + "irrelevant-cn", &["spiffe://svc.pin220.com/restate-agents/staging/admin"], &[], ); - assert!(verifier.cert_san_matches(&cert)); + assert!(verifier.cert_subject_matches(&cert)); } #[test] - fn test_san_verifier_accepts_matching_dns() { - let verifier = SanCheckingVerifier { - inner: Arc::new(rustls::server::NoClientAuth), - allowed_patterns: vec!["restate-*.internal".into()], - }; - - let cert = generate_cert_with_san(&[], &["restate-node1.internal"]); - assert!(verifier.cert_san_matches(&cert)); + fn test_subject_verifier_accepts_matching_san_dns() { + let verifier = make_verifier(&["restate-*.internal"]); + let cert = generate_cert("irrelevant-cn", &[], &["restate-node1.internal"]); + assert!(verifier.cert_subject_matches(&cert)); } #[test] - fn test_san_verifier_rejects_non_matching_san() { - let verifier = SanCheckingVerifier { - inner: Arc::new(rustls::server::NoClientAuth), - allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], - }; + fn test_subject_verifier_accepts_matching_cn() { + let verifier = make_verifier(&["restate-*"]); + let cert = generate_cert("restate-admin", &[], &[]); + assert!(verifier.cert_subject_matches(&cert)); + } - let cert = generate_cert_with_san( + #[test] + fn test_subject_verifier_rejects_non_matching() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert( + "other-service", &["spiffe://svc.pin220.com/other-service/staging/worker"], &[], ); - assert!(!verifier.cert_san_matches(&cert)); + assert!(!verifier.cert_subject_matches(&cert)); } #[test] - fn test_san_verifier_rejects_cert_without_san() { - let verifier = SanCheckingVerifier { - inner: Arc::new(rustls::server::NoClientAuth), - allowed_patterns: vec!["spiffe://svc.pin220.com/restate-agents/*".into()], - }; - - let cert = generate_cert_without_san(); - assert!(!verifier.cert_san_matches(&cert)); + fn test_subject_verifier_rejects_no_match_anywhere() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert("unrelated-cn", &[], &[]); + assert!(!verifier.cert_subject_matches(&cert)); } #[test] - fn test_san_verifier_multiple_patterns() { - let verifier = SanCheckingVerifier { - inner: Arc::new(rustls::server::NoClientAuth), - allowed_patterns: vec![ - "spiffe://svc.pin220.com/restate-agents/*/admin".into(), - "spiffe://svc.pin220.com/restate-agents/*/worker".into(), - ], - }; - - let admin_cert = generate_cert_with_san( + fn test_subject_verifier_multiple_patterns() { + let verifier = make_verifier(&[ + "spiffe://svc.pin220.com/restate-agents/*/admin", + "spiffe://svc.pin220.com/restate-agents/*/worker", + ]); + + let admin_cert = generate_cert( + "node", &["spiffe://svc.pin220.com/restate-agents/staging/admin"], &[], ); - let worker_cert = generate_cert_with_san( + let worker_cert = generate_cert( + "node", &["spiffe://svc.pin220.com/restate-agents/staging/worker"], &[], ); - let other_cert = generate_cert_with_san( + let other_cert = generate_cert( + "node", &["spiffe://svc.pin220.com/restate-agents/staging/ingress"], &[], ); - assert!(verifier.cert_san_matches(&admin_cert)); - assert!(verifier.cert_san_matches(&worker_cert)); - assert!(!verifier.cert_san_matches(&other_cert)); + assert!(verifier.cert_subject_matches(&admin_cert)); + assert!(verifier.cert_subject_matches(&worker_cert)); + assert!(!verifier.cert_subject_matches(&other_cert)); + } + + #[test] + fn test_subject_verifier_cn_fallback_when_no_san() { + let verifier = make_verifier(&["restate-node-*"]); + let cert = generate_cert("restate-node-1", &[], &[]); + assert!(verifier.cert_subject_matches(&cert)); + + let bad_cert = generate_cert("kafka-broker-1", &[], &[]); + assert!(!verifier.cert_subject_matches(&bad_cert)); } } diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index 9e91506a2c..f54848f8ec 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -217,14 +217,15 @@ pub struct FabricTlsOptions { #[serde(default = "default_refresh_interval")] pub refresh_interval: NonZeroFriendlyDuration, - /// Allowed Subject Alternative Names (SANs) on peer certificates. After mTLS - /// authentication succeeds, the peer's SANs (DNS names and URIs) are checked - /// against these patterns. Supports `*` glob wildcards (e.g., `spiffe://domain/*`). + /// Allowed subject names on peer certificates. After mTLS authentication + /// succeeds, the peer certificate's Subject Common Name (CN) and Subject + /// Alternative Names (DNS names and URIs) are checked against these patterns. + /// Supports `*` glob wildcards (e.g., `spiffe://domain/*`, `restate-*`). /// When empty (default), any authenticated peer is allowed (CA-only trust). /// /// Since v1.3.0 #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub allowed_sans: Vec, + pub allowed_subject_names: Vec, /// Optional separate TLS configuration for outbound connections to peer nodes. /// If omitted, the server cert/key/ca are used for outbound connections as well. @@ -383,12 +384,12 @@ mod tests { } #[test] - fn test_tls_config_with_allowed_sans() { + fn test_tls_config_with_allowed_subject_names() { let toml_str = r#" cert-file = "/certs/node.crt" key-file = "/certs/node.key" ca-files = ["/certs/ca.crt"] - allowed-sans = [ + allowed-subject-names = [ "spiffe://svc.pin220.com/restate-agents/*/admin", "spiffe://svc.pin220.com/restate-agents/*/worker", "spiffe://svc.pin220.com/restate-agents/*/ingress", @@ -396,22 +397,22 @@ mod tests { "#; let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); - assert_eq!(opts.allowed_sans.len(), 3); + assert_eq!(opts.allowed_subject_names.len(), 3); assert_eq!( - opts.allowed_sans[0], + opts.allowed_subject_names[0], "spiffe://svc.pin220.com/restate-agents/*/admin" ); assert!(opts.require_client_auth); } #[test] - fn test_tls_config_allowed_sans_empty_by_default() { + fn test_tls_config_allowed_subject_names_empty_by_default() { let toml_str = r#" cert-file = "/certs/node.crt" key-file = "/certs/node.key" ca-files = ["/certs/ca.crt"] "#; let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); - assert!(opts.allowed_sans.is_empty()); + assert!(opts.allowed_subject_names.is_empty()); } } From abfb6216e04ddad6516984b26927902c9b6b6d0c Mon Sep 17 00:00:00 2001 From: rushabhvaria Date: Fri, 1 May 2026 13:16:26 -0700 Subject: [PATCH 6/7] feat(networking): require allowed-subject-names when mTLS client auth is enabled MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prevent accidental fail-open: when require-client-auth is true, allowed-subject-names must be explicitly set. Operators who want CA-only trust (no identity checking) set allowed-subject-names = ["*"] to make the choice explicit. An empty list with client auth enabled is now a configuration error that prevents node startup. This addresses feedback that the previous default (empty = allow all) could lead to unintended access when using a shared CA. Changes: - Add FabricTlsOptions::validate() with startup-time check - Call validate() during node initialization before TLS setup - Treat ["*"] as explicit CA-only trust (skip SubjectNameVerifier) - Update integration tests to use allowed-subject-names = ["*"] - 4 new validation unit tests Config that now fails: [networking.tls] require-client-auth = true # missing allowed-subject-names → startup error Config that works: [networking.tls] require-client-auth = true allowed-subject-names = ["*"] # explicit CA-only trust # OR allowed-subject-names = ["spiffe://dom/*"] # identity-based authz --- crates/core/src/network/tls.rs | 4 +- crates/node/src/lib.rs | 3 ++ crates/types/src/config/networking.rs | 72 ++++++++++++++++++++++++++- server/tests/fabric_tls.rs | 1 + 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs index e84ba98b11..311977bdce 100644 --- a/crates/core/src/network/tls.rs +++ b/crates/core/src/network/tls.rs @@ -106,7 +106,9 @@ fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result } let webpki_verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - if opts.allowed_subject_names.is_empty() { + let ca_only_trust = opts.allowed_subject_names.is_empty() + || (opts.allowed_subject_names.len() == 1 && opts.allowed_subject_names[0] == "*"); + if ca_only_trust { builder.with_client_cert_verifier(webpki_verifier) } else { let san_verifier = SubjectNameVerifier { diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index 08d8483556..d8750c6e07 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -224,6 +224,9 @@ impl Node { // Initialize fabric TLS if configured let tls_resolver = config.networking.tls.as_ref().map(|tls_opts| { + tls_opts + .validate() + .expect("Invalid fabric TLS configuration"); let resolver = restate_core::network::tls::TlsCertResolver::new(tls_opts) .expect("Failed to initialize fabric TLS"); resolver.spawn_reloader(tls_opts.clone(), *tls_opts.refresh_interval); diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index f54848f8ec..1690c519a1 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -221,7 +221,10 @@ pub struct FabricTlsOptions { /// succeeds, the peer certificate's Subject Common Name (CN) and Subject /// Alternative Names (DNS names and URIs) are checked against these patterns. /// Supports `*` glob wildcards (e.g., `spiffe://domain/*`, `restate-*`). - /// When empty (default), any authenticated peer is allowed (CA-only trust). + /// + /// Required when `require-client-auth` is `true`. Use `["*"]` to explicitly + /// allow any authenticated peer (CA-only trust). An empty list is a + /// configuration error to prevent accidental fail-open. /// /// Since v1.3.0 #[serde(default, skip_serializing_if = "Vec::is_empty")] @@ -276,6 +279,17 @@ impl FabricTlsOptions { pub fn is_strict(&self) -> bool { self.mode == TlsMode::Strict } + + pub fn validate(&self) -> Result<(), anyhow::Error> { + if self.require_client_auth && self.allowed_subject_names.is_empty() { + anyhow::bail!( + "[networking.tls] require-client-auth is true but allowed-subject-names is empty. \ + Specify allowed patterns (e.g., [\"spiffe://domain/*\"]) or set [\"*\"] \ + to explicitly allow any authenticated peer." + ); + } + Ok(()) + } } fn default_require_client_auth() -> bool { @@ -415,4 +429,60 @@ mod tests { let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); assert!(opts.allowed_subject_names.is_empty()); } + + #[test] + fn test_validate_rejects_empty_allowed_subject_names_with_client_auth() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = true + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + let result = opts.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("allowed-subject-names is empty") + ); + } + + #[test] + fn test_validate_accepts_wildcard_allowed_subject_names() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = true + allowed-subject-names = ["*"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); + } + + #[test] + fn test_validate_accepts_empty_when_no_client_auth() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = false + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); + } + + #[test] + fn test_validate_accepts_specific_patterns() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + allowed-subject-names = ["spiffe://domain/restate-*"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); + } } diff --git a/server/tests/fabric_tls.rs b/server/tests/fabric_tls.rs index 61c912c4a4..d6aa2b53dc 100644 --- a/server/tests/fabric_tls.rs +++ b/server/tests/fabric_tls.rs @@ -100,6 +100,7 @@ fn configure_tls_nodes( ca_files: vec![ca_path], require_client_auth: true, refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + allowed_subject_names: vec!["*".into()], client: None, }); } From 734908c760eb34883b0cfa82fe6f563a892862c2 Mon Sep 17 00:00:00 2001 From: Rushabh Varia Date: Mon, 4 May 2026 09:47:01 -0700 Subject: [PATCH 7/7] refactor(networking): deduplicate connection handler in net_util Extract serve_connection() helper to eliminate repeated connection error-handling blocks across TLS, plaintext, and UDS code paths. Also simplify the TLS/plaintext branching by resolving the TLS acceptor first, then handling the connection in two clean branches instead of five duplicated blocks. Addresses review feedback from nickpan47 on PR #4681. --- crates/core/src/network/net_util.rs | 168 ++++++++++------------------ 1 file changed, 56 insertions(+), 112 deletions(-) diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index 3a77fb4bcf..47663b7faf 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -228,113 +228,51 @@ where let graceful_shutdown = &graceful_shutdown; let task_name = task_name.clone(); - match (&tls_resolver, &tls_mode) { + // Resolve TLS handshake or pass through plaintext + let use_tls = match (&tls_resolver, &tls_mode) { (Some(resolver), Some(TlsMode::Strict)) => { - // TLS strict: all connections must be TLS - let acceptor = resolver.tls_acceptor(); - let connection = match acceptor.accept(tcp_stream).await { - Ok(tls_stream) => { - let io = TokioIo::new(tls_stream); - graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) - } - Err(e) => { - debug!("TLS handshake failed: {e}"); - continue; - } - }; - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New TLS tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; + Some(resolver.tls_acceptor()) } (Some(resolver), Some(TlsMode::Optional)) => { - // TLS optional: peek first byte to detect TLS ClientHello - let tcp_stream = tcp_stream; let mut peek_buf = [0u8; 1]; - match tcp_stream.peek(&mut peek_buf).await { - Ok(1) if peek_buf[0] == 0x16 => { - // TLS ClientHello detected - let acceptor = resolver.tls_acceptor(); - let connection = match acceptor.accept(tcp_stream).await { - Ok(tls_stream) => { - let io = TokioIo::new(tls_stream); - graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) - } - Err(e) => { - debug!("TLS handshake failed: {e}"); - continue; - } - }; - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New TLS tcp connection accepted (optional mode)"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; - } - _ => { - // Plaintext connection - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New plaintext tcp connection accepted (optional mode)"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; - } - } - } - _ => { - // No TLS: plaintext (current behavior) - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder - .serve_connection(io, service).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } + if let Ok(1) = tcp_stream.peek(&mut peek_buf).await { + if peek_buf[0] == 0x16 { + Some(resolver.tls_acceptor()) } else { - trace!("Connection completed cleanly"); + None } - Ok(()) - }.instrument(socket_span))?; + } else { + None + } } + _ => None, + }; + + if let Some(acceptor) = use_tls { + let connection = match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + graceful_shutdown.watch( + builder.serve_connection(io, service).into_owned(), + ) + } + Err(e) => { + debug!("TLS handshake failed: {e}"); + continue; + } + }; + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New TLS tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; + } else { + let io = TokioIo::new(tcp_stream); + let connection = graceful_shutdown + .watch(builder.serve_connection(io, service).into_owned()); + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; } }, Either::Right(unix_stream) => { @@ -344,18 +282,7 @@ where .serve_connection(io, service.clone()).into_owned()); TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move { trace!("New uds connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) + serve_connection(connection).await }.instrument(socket_span))?; } } @@ -377,6 +304,23 @@ where Ok(()) } +async fn serve_connection( + connection: impl Future>>, +) -> Result<(), anyhow::Error> { + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) +} + #[derive(Clone, Default)] struct TaskCenterExecutor;