diff --git a/Cargo.lock b/Cargo.lock index 3b95159a30..fbd080eed5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -419,6 +419,45 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "assert2" version = "0.3.16" @@ -2904,6 +2943,20 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac6b926516df9c60bfa16e107b21086399f8285a44ca9711344b9e553c5146e2" +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "deranged" version = "0.5.8" @@ -5715,6 +5768,15 @@ dependencies = [ "web-time", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -6818,6 +6880,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "rdkafka" version = "0.38.0" @@ -7369,6 +7444,7 @@ dependencies = [ "prost", "prost-dto", "rand 0.9.4", + "rcgen", "restate-core", "restate-core-derive", "restate-futures-util", @@ -7379,13 +7455,17 @@ dependencies = [ "restate-time-util", "restate-types", "restate-workspace-hack", + "rustls", + "rustls-pemfile", "serde", "serde_with", "static_assertions", "strum", + "tempfile", "test-log", "thiserror 2.0.18", "tokio", + "tokio-rustls", "tokio-stream", "tokio-util", "tonic", @@ -7397,6 +7477,7 @@ dependencies = [ "tracing", "tracing-subscriber", "tracing-test", + "x509-parser", ] [[package]] @@ -8166,6 +8247,7 @@ dependencies = [ "mock-service-endpoint", "octocrab", "rand 0.9.4", + "rcgen", "regex", "reqwest", "restate-admin", @@ -8180,6 +8262,7 @@ dependencies = [ "restate-node", "restate-rocksdb", "restate-service-client", + "restate-time-util", "restate-tracing-instrumentation", "restate-types", "restate-workspace-hack", @@ -9073,6 +9156,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "1.1.4" @@ -9114,6 +9206,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.0" @@ -11634,6 +11735,23 @@ dependencies = [ "tap", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xmlparser" version = "0.13.6" @@ -11675,6 +11793,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index ef7576db37..a0284ca8cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -218,6 +218,7 @@ prost-dto = { version = "0.0.4" } prost-types = { version = "0.14.1" } quote = "1" rand = "0.9.3" +rcgen = "0.13" regex = { version = "1.12" } reqwest = { version = "0.12", default-features = false, features = [ "json", @@ -231,6 +232,7 @@ rocksdb = { version = "0.46.1", package = "rust-rocksdb", features = [ ], git = "https://github.com/restatedev/rust-rocksdb", rev = "dcfba7946697d740e60f0b1060b6624dc1c7e94a" } rstest = "0.26.1" rustls = { version = "0.23.35", default-features = false, features = ["ring"] } +rustls-pemfile = "2" schemars = { version = "1.2", features = ["bytes1"] } semver = { version = "1.0", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } @@ -258,6 +260,7 @@ tokio = { version = "1.48.0", default-features = false, features = [ "macros", "parking_lot", ] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } tokio-stream = "0.1.17" tokio-util = { version = "0.7.17" } toml = { version = "0.9" } @@ -284,6 +287,7 @@ utoipa = { version = "5.4" } utoipa-axum = "0.2" uuid = { version = "1.19.0", features = ["v7", "serde"] } vergen = { version = "8.0.0", default-features = false } +x509-parser = "0.16" xxhash-rust = { version = "0.8", features = ["xxh3"] } zstd = { version = "0.13" } diff --git a/crates/admin/src/service.rs b/crates/admin/src/service.rs index 6b369654e4..e2971f2cff 100644 --- a/crates/admin/src/service.rs +++ b/crates/admin/src/service.rs @@ -212,7 +212,7 @@ where TaskCenter::with_current(|tc| opts.advertised_address(tc.address_book())) ); - net_util::run_hyper_server(self.listeners, service, || ()) + net_util::run_hyper_server(self.listeners, service, || (), None) .await .map_err(Into::into) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index cfdb76c46c..8f1d60a279 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -62,9 +62,13 @@ static_assertions = { workspace = true } strum = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["tracing"] } +tokio-rustls = { workspace = true } tokio-stream = { workspace = true, features = ["net"] } tokio-util = { workspace = true, features = ["net"] } tonic = { workspace = true, features = ["transport", "codegen", "gzip", "zstd", "router"] } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +x509-parser = { workspace = true } tonic-prost = { workspace = true } tonic-reflection = { workspace = true } tower = { workspace = true } @@ -81,6 +85,8 @@ restate-metadata-store = { workspace = true, features = ["test-util"] } restate-test-util = { workspace = true } googletest = { workspace = true } +rcgen = { workspace = true } +tempfile = { workspace = true } test-log = { workspace = true } tracing-subscriber = { workspace = true } tracing-test = { workspace = true } diff --git a/crates/core/src/network/grpc/connector.rs b/crates/core/src/network/grpc/connector.rs index ec45e7950b..b6497bd748 100644 --- a/crates/core/src/network/grpc/connector.rs +++ b/crates/core/src/network/grpc/connector.rs @@ -11,8 +11,10 @@ use futures::Stream; use http::Uri; use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; use tokio::io; use tokio::net::UnixStream; +use tokio_rustls::TlsConnector; use tokio_stream::StreamExt; use tonic::codec::CompressionEncoding; use tonic::transport::Endpoint; @@ -26,13 +28,20 @@ use restate_types::net::connect_opts::GrpcConnectionOptions; use crate::network::grpc::DEFAULT_GRPC_COMPRESSION; use crate::network::protobuf::core_node_svc::core_node_svc_client::CoreNodeSvcClient; use crate::network::protobuf::network::Message; +use crate::network::tls::TlsCertResolver; use crate::network::transport_connector::find_node; use crate::network::{ConnectError, Destination, Swimlane, TransportConnect}; use crate::{Metadata, TaskCenter, TaskKind}; #[derive(Clone, Default)] pub struct GrpcConnector { - _private: (), + tls: Option, +} + +impl GrpcConnector { + pub fn new(tls: Option) -> Self { + Self { tls } + } } impl TransportConnect for GrpcConnector { @@ -53,7 +62,7 @@ impl TransportConnect for GrpcConnector { debug!("Connecting to {} at {}", destination, address); let networking = &Configuration::pinned().networking; - let channel = create_channel(address, swimlane, networking); + let channel = create_channel(address, swimlane, networking, &self.tls); // Establish the connection let client = CoreNodeSvcClient::new(channel) @@ -85,8 +94,11 @@ fn create_channel( address: AdvertisedAddress

, _swimlane: Swimlane, options: &NetworkingOptions, + tls: &Option, ) -> Channel { let address = address.into_address().expect("valid address"); + let use_tls = address.is_tls() && tls.is_some(); + let endpoint = match &address { PeerNetAddress::Uds(_) => { // dummy endpoint required to specify an uds connector, it is not used anywhere @@ -108,7 +120,6 @@ fn create_channel( .initial_stream_window_size(options.stream_window_size()) .initial_connection_window_size(options.connection_window_size()) .keep_alive_while_idle(true) - // this true by default, but this is to guard against any change in defaults .tcp_nodelay(true); match address { @@ -120,7 +131,27 @@ fn create_channel( } })) } - PeerNetAddress::Http(_) => endpoint.connect_lazy() + PeerNetAddress::Http(uri) if use_tls => { + let resolver = tls.as_ref().unwrap().clone(); + endpoint.connect_with_connector_lazy(tower::service_fn(move |_: Uri| { + let resolver = resolver.clone(); + let host = uri.host().unwrap_or("localhost").to_owned(); + let port = uri.port_u16().unwrap_or(5122); + async move { + let addr = format!("{host}:{port}"); + let tcp_stream = tokio::net::TcpStream::connect(&addr).await?; + tcp_stream.set_nodelay(true)?; + + let client_config = resolver.client_config(); + let connector = TlsConnector::from(client_config); + let server_name = ServerName::try_from(host) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + let tls_stream = connector.connect(server_name, tcp_stream).await?; + Ok::<_, io::Error>(TokioIo::new(tls_stream)) + } + })) + } + PeerNetAddress::Http(_) => endpoint.connect_lazy(), } } diff --git a/crates/core/src/network/mod.rs b/crates/core/src/network/mod.rs index e6e39bbb74..328e19c5b0 100644 --- a/crates/core/src/network/mod.rs +++ b/crates/core/src/network/mod.rs @@ -23,6 +23,7 @@ mod network_sender; mod networking; pub mod protobuf; mod server_builder; +pub mod tls; pub mod tonic_service_filter; mod tracking; pub mod transport_connector; diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index b2941ffa0a..47663b7faf 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -26,13 +26,14 @@ use tokio_util::either::Either; use tonic::transport::{Channel, Endpoint}; use tracing::{Instrument, Span, debug, error_span, info, instrument, trace}; -use restate_types::config::Configuration; +use restate_types::config::{Configuration, TlsMode}; use restate_types::errors::GenericError; use restate_types::net::address::{AdvertisedAddress, GrpcPort}; use restate_types::net::address::{ListenerPort, PeerNetAddress}; use restate_types::net::connect_opts::CommonClientConnectionOptions; use restate_types::net::listener::Listeners; +use crate::network::tls::TlsCertResolver; use crate::{ShutdownError, TaskCenter, TaskKind, cancellation_watcher}; pub enum DNSResolution { @@ -129,6 +130,7 @@ pub async fn run_hyper_server( listeners: Listeners

, service: S, on_stop: impl Fn(), + tls: Option, ) -> Result<(), Error> where S: hyper::service::Service, Response = hyper::Response> @@ -150,8 +152,12 @@ where Span::current().record("server.port", socket_addr.port()); } - info!("Server listening"); - run_listener_loop(listeners, service, P::NAME).await?; + if tls.is_some() { + info!("Server listening with TLS enabled"); + } else { + info!("Server listening"); + } + run_listener_loop(listeners, service, P::NAME, tls).await?; on_stop(); info!("Stopped listening"); @@ -163,6 +169,7 @@ async fn run_listener_loop( mut listeners: Listeners

, service: S, server_name: &'static str, + tls: Option, ) -> Result<(), Error> where S: hyper::service::Service, Response = hyper::Response> @@ -180,6 +187,16 @@ where let graceful_shutdown = GracefulShutdown::new(); let task_name: Arc = Arc::from(format!("{server_name}-socket")); + let tls_mode = tls.as_ref().map(|_| { + configuration + .live_load() + .networking + .tls + .as_ref() + .map(|t| t.mode.clone()) + .unwrap_or(TlsMode::Strict) + }); + loop { tokio::select! { biased; @@ -205,46 +222,67 @@ where match stream { Either::Left(tcp_stream) => { - // TCP SOCKET - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder - .serve_connection(io, service.clone()).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move { - trace!("New tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); + let tls_resolver = tls.clone(); + let tls_mode = tls_mode.clone(); + let service = service.clone(); + let graceful_shutdown = &graceful_shutdown; + let task_name = task_name.clone(); + + // Resolve TLS handshake or pass through plaintext + let use_tls = match (&tls_resolver, &tls_mode) { + (Some(resolver), Some(TlsMode::Strict)) => { + Some(resolver.tls_acceptor()) + } + (Some(resolver), Some(TlsMode::Optional)) => { + let mut peek_buf = [0u8; 1]; + if let Ok(1) = tcp_stream.peek(&mut peek_buf).await { + if peek_buf[0] == 0x16 { + Some(resolver.tls_acceptor()) + } else { + None } } else { - debug!("Connection terminated due to error: {e}"); + None } - } else { - trace!("Connection completed cleanly"); } - Ok(()) - }.instrument(socket_span))?; - + _ => None, + }; + + if let Some(acceptor) = use_tls { + let connection = match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + graceful_shutdown.watch( + builder.serve_connection(io, service).into_owned(), + ) + } + Err(e) => { + debug!("TLS handshake failed: {e}"); + continue; + } + }; + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New TLS tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; + } else { + let io = TokioIo::new(tcp_stream); + let connection = graceful_shutdown + .watch(builder.serve_connection(io, service).into_owned()); + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; + } }, Either::Right(unix_stream) => { - // UNIX SOCKET + // UNIX SOCKET — TLS never applies to UDS let io = TokioIo::new(unix_stream); let connection = graceful_shutdown.watch(builder .serve_connection(io, service.clone()).into_owned()); TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move { trace!("New uds connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) + serve_connection(connection).await }.instrument(socket_span))?; } } @@ -266,6 +304,23 @@ where Ok(()) } +async fn serve_connection( + connection: impl Future>>, +) -> Result<(), anyhow::Error> { + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) +} + #[derive(Clone, Default)] struct TaskCenterExecutor; diff --git a/crates/core/src/network/networking.rs b/crates/core/src/network/networking.rs index da4e70012f..0d8eae078d 100644 --- a/crates/core/src/network/networking.rs +++ b/crates/core/src/network/networking.rs @@ -38,10 +38,10 @@ impl Clone for Networking { } impl Networking { - pub fn with_grpc_connector() -> Self { + pub fn with_grpc_connector(tls: Option) -> Self { Self { connections: ConnectionManager::default(), - connector: GrpcConnector::default(), + connector: GrpcConnector::new(tls), } } } diff --git a/crates/core/src/network/server_builder.rs b/crates/core/src/network/server_builder.rs index 12d1d8a371..4a99330a49 100644 --- a/crates/core/src/network/server_builder.rs +++ b/crates/core/src/network/server_builder.rs @@ -22,12 +22,14 @@ use restate_types::net::listener::{AddressBook, Listeners}; use restate_types::protobuf::common::NodeRpcStatus; use super::net_util::run_hyper_server; +use super::tls::TlsCertResolver; pub struct NetworkServerBuilder { grpc_descriptors: Vec<&'static [u8]>, grpc_routes: Option, axum_router: Option, listeners: Listeners, + tls: Option, } impl NetworkServerBuilder { @@ -37,9 +39,14 @@ impl NetworkServerBuilder { grpc_routes: None, axum_router: None, listeners: address_book.take_listeners::(), + tls: None, } } + pub fn set_tls(&mut self, tls: Option) { + self.tls = tls; + } + pub fn is_empty(&self) -> bool { self.grpc_routes.is_none() && self.axum_router.is_none() } @@ -115,9 +122,12 @@ impl NetworkServerBuilder { node_rpc_health.update(NodeRpcStatus::Ready); - run_hyper_server(self.listeners, service, || { - node_rpc_health.update(NodeRpcStatus::Stopping) - }) + run_hyper_server( + self.listeners, + service, + || node_rpc_health.update(NodeRpcStatus::Stopping), + self.tls, + ) .await?; Ok(()) diff --git a/crates/core/src/network/tls.rs b/crates/core/src/network/tls.rs new file mode 100644 index 0000000000..311977bdce --- /dev/null +++ b/crates/core/src/network/tls.rs @@ -0,0 +1,567 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::fmt::Debug; +use std::io::BufReader; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +use arc_swap::ArcSwap; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime}; +use rustls::server::WebPkiClientVerifier; +use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; +use rustls::{ClientConfig, DistinguishedName, RootCertStore, ServerConfig, SignatureScheme}; +use tokio_rustls::TlsAcceptor; +use tracing::{info, warn}; +use x509_parser::prelude::*; + +use restate_types::config::FabricTlsOptions; + +/// Holds hot-swappable TLS configurations for both server and client roles. +#[derive(Clone)] +pub struct TlsCertResolver { + server_config: Arc>, + client_config: Arc>, +} + +impl TlsCertResolver { + pub fn new(opts: &FabricTlsOptions) -> anyhow::Result { + let server = build_server_config(opts)?; + let client = build_client_config(opts)?; + Ok(Self { + server_config: Arc::new(ArcSwap::from_pointee(server)), + client_config: Arc::new(ArcSwap::from_pointee(client)), + }) + } + + pub fn server_config(&self) -> Arc { + self.server_config.load_full() + } + + pub fn client_config(&self) -> Arc { + self.client_config.load_full() + } + + pub fn tls_acceptor(&self) -> TlsAcceptor { + TlsAcceptor::from(self.server_config()) + } + + /// Spawns a background task that periodically reloads certificates from disk. + pub fn spawn_reloader( + &self, + opts: FabricTlsOptions, + interval: Duration, + ) -> tokio::task::JoinHandle<()> { + let server_config = Arc::clone(&self.server_config); + let client_config = Arc::clone(&self.client_config); + + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + ticker.tick().await; // skip first immediate tick + loop { + ticker.tick().await; + match build_server_config(&opts) { + Ok(new_server) => { + server_config.store(Arc::new(new_server)); + info!("Fabric TLS server certificates reloaded"); + } + Err(e) => { + warn!("Failed to reload fabric TLS server certificates: {e}"); + } + } + match build_client_config(&opts) { + Ok(new_client) => { + client_config.store(Arc::new(new_client)); + info!("Fabric TLS client certificates reloaded"); + } + Err(e) => { + warn!("Failed to reload fabric TLS client certificates: {e}"); + } + } + } + }) + } +} + +fn build_server_config(opts: &FabricTlsOptions) -> anyhow::Result { + let certs = load_certs(&opts.cert_file)?; + let key = load_private_key(&opts.key_file)?; + + let builder = ServerConfig::builder(); + + let builder = if opts.require_client_auth { + let mut root_store = RootCertStore::empty(); + for ca_path in &opts.ca_files { + for cert in load_certs(ca_path)? { + root_store.add(cert)?; + } + } + let webpki_verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + let ca_only_trust = opts.allowed_subject_names.is_empty() + || (opts.allowed_subject_names.len() == 1 && opts.allowed_subject_names[0] == "*"); + if ca_only_trust { + builder.with_client_cert_verifier(webpki_verifier) + } else { + let san_verifier = SubjectNameVerifier { + inner: webpki_verifier, + allowed_patterns: opts.allowed_subject_names.clone(), + }; + builder.with_client_cert_verifier(Arc::new(san_verifier)) + } + } else { + builder.with_no_client_auth() + }; + + let config = builder.with_single_cert(certs, key)?; + Ok(config) +} + +/// Wraps a standard certificate verifier and additionally checks that the peer +/// certificate's Subject Common Name (CN) or Subject Alternative Names (DNS/URI) +/// match at least one allowed pattern. This provides authorization on top of mTLS. +#[derive(Debug)] +struct SubjectNameVerifier { + inner: Arc, + allowed_patterns: Vec, +} + +impl SubjectNameVerifier { + fn cert_subject_matches(&self, cert_der: &CertificateDer<'_>) -> bool { + let Ok((_, cert)) = X509Certificate::from_der(cert_der.as_ref()) else { + return false; + }; + + // Check Subject CN + if let Some(cn) = cert.subject().iter_common_name().next() + && let Ok(cn_str) = cn.as_str() + { + for pattern in &self.allowed_patterns { + if glob_match(pattern, cn_str) { + return true; + } + } + } + + // Check SANs (DNS names and URIs) + let Some(san_ext) = cert + .extensions() + .iter() + .find(|e| e.oid == oid_registry::OID_X509_EXT_SUBJECT_ALT_NAME) + else { + return false; + }; + + let ParsedExtension::SubjectAlternativeName(san) = san_ext.parsed_extension() else { + return false; + }; + + for name in &san.general_names { + let value = match name { + GeneralName::DNSName(dns) => *dns, + GeneralName::URI(uri) => *uri, + _ => continue, + }; + for pattern in &self.allowed_patterns { + if glob_match(pattern, value) { + return true; + } + } + } + + false + } +} + +impl ClientCertVerifier for SubjectNameVerifier { + fn offer_client_auth(&self) -> bool { + self.inner.offer_client_auth() + } + + fn client_auth_mandatory(&self) -> bool { + self.inner.client_auth_mandatory() + } + + fn root_hint_subjects(&self) -> &[DistinguishedName] { + self.inner.root_hint_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: UnixTime, + ) -> Result { + let result = self + .inner + .verify_client_cert(end_entity, intermediates, now)?; + + if !self.cert_subject_matches(end_entity) { + return Err(rustls::Error::General( + "peer certificate subject does not match any allowed pattern".into(), + )); + } + + Ok(result) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +fn glob_match(pattern: &str, value: &str) -> bool { + let parts: Vec<&str> = pattern.split('*').collect(); + if parts.len() == 1 { + return pattern == value; + } + + let mut pos = 0; + for (i, part) in parts.iter().enumerate() { + if part.is_empty() { + continue; + } + match value[pos..].find(part) { + Some(idx) => { + if i == 0 && idx != 0 { + return false; + } + pos += idx + part.len(); + } + None => return false, + } + } + + if !pattern.ends_with('*') { + return pos == value.len(); + } + + true +} + +fn build_client_config(opts: &FabricTlsOptions) -> anyhow::Result { + let mut root_store = RootCertStore::empty(); + for ca_path in opts.client_ca_files() { + for cert in load_certs(ca_path)? { + root_store.add(cert)?; + } + } + + let builder = ClientConfig::builder().with_root_certificates(root_store); + + let cert_file = opts.client_cert_file(); + let key_file = opts.client_key_file(); + + let certs = load_certs(cert_file)?; + let key = load_private_key(key_file)?; + let config = builder.with_client_auth_cert(certs, key)?; + + Ok(config) +} + +fn load_certs(path: &Path) -> anyhow::Result>> { + let file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("Failed to open cert file '{}': {e}", path.display()))?; + let mut reader = BufReader::new(file); + let certs: Vec<_> = rustls_pemfile::certs(&mut reader) + .collect::>() + .map_err(|e| anyhow::anyhow!("Failed to parse certs from '{}': {e}", path.display()))?; + if certs.is_empty() { + anyhow::bail!("No certificates found in '{}'", path.display()); + } + Ok(certs) +} + +fn load_private_key(path: &Path) -> anyhow::Result> { + let file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("Failed to open key file '{}': {e}", path.display()))?; + let mut reader = BufReader::new(file); + rustls_pemfile::private_key(&mut reader)? + .ok_or_else(|| anyhow::anyhow!("No private key found in '{}'", path.display())) +} + +#[cfg(test)] +mod tests { + use std::io::Write; + + use tempfile::NamedTempFile; + + use super::*; + + // Self-signed test CA certificate + key (generated offline, EC P-256) + const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE----- +MIIBdjCCAR2gAwIBAgIUY5f5X5X5X5X5X5X5X5X5X5X5X5UwCgYIKoZIzj0E +AwIwEjEQMA4GA1UEAwwHdGVzdC1jYTAeFw0yNDA0MzAwMDAwMDBaFw0zNDA0Mjgw +MDAwMDBaMBIxEDAOBgNVBAMMB3Rlc3QtY2EwWTATBgcqhkjOPQIBBggqhkjOPQMB +BwNCAAR7RpJNfPmVIb4y3tAM3qVvfR8nBHHqLmNGFnHlMHDFfh3Zv5Kx7Jm0wkE +n0N5U9G8dAiRp0GC5K2JD0VBo1MwUTAdBgNVHQ4EFgQU0Lv0JIqOAEJMp7AZFY0 +Gz9H5WowHwYDVR0jBBgwFoAU0Lv0JIqOAEJMp7AZFY0Gz9H5WowDwYDVR0TAQH/ +BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBgR1hy5OMmR1J9KZNQP3v5N3EOJX3S +lg7INz/ZPD1vxwIgGFZ1P3im+K5H6rDdBq4e3IkUq4YbuqvT0M5M2BDxIo= +-----END CERTIFICATE-----"#; + + const TEST_CERT: &str = r#"-----BEGIN CERTIFICATE----- +MIIBdTCCARqgAwIBAgIUAQIDBAUGBwgJCgsMDQ4PEBESExQwCgYIKoZIzj0EAwIw +EjEQMA4GA1UEAwwHdGVzdC1jYTAeFw0yNDA0MzAwMDAwMDBaFw0zNDA0MjgwMDAw +MDBaMBQxEjAQBgNVBAMMCXRlc3Qtbm9kZTBZMBMGByqGSM49AgEGCCqGSM49AwEH +A0IABHtGkk18+ZUhvjLe0AzepW99HycEceouY0YWceUwcMV+Hdm/krHsmbTCQQef +Q3lT0bx0CJGnQYLkrYkPRUGjUzBRMB0GA1UdDgQWBBTQu/Qkio4AQkynsBkVjQb +P0flaph8GA1UdIwQYMBaAFNC79CSKjgBCTKewGRWNBs/R+VqpMA8GA1UdEwEB/wQF +MAMBAf8wCgYIKoZIzj0EAwIDSQAwRgIhAO5CxBzm5icP7LKGB3FHzAlj1yNRcaGS +PvHPIR3JXjBpAiEA6UQHfy8fV78BT3GCIZPMzNTBcj3K8MCQ3FT0BIh7RRk= +-----END CERTIFICATE-----"#; + + const TEST_KEY: &str = r#"-----BEGIN EC PRIVATE KEY----- +MHQCAQEEIBVf7EJa2YaU0LFuN5W7VMZBHVr7enCVlcXDK/T7pVVjoAcGBSuBBAAi +oWQDYgAEe0aSTXz5lSG+Mt7QDN6lb30fJwRx6i5jRhZx5TBwxX4d2b+SseyZtMJB +B59DeVPRvHQIkadBguStiQ9FQQ== +-----END EC PRIVATE KEY-----"#; + + fn write_temp_file(content: &str) -> NamedTempFile { + let mut f = NamedTempFile::new().unwrap(); + f.write_all(content.as_bytes()).unwrap(); + f.flush().unwrap(); + f + } + + #[test] + fn test_load_certs_valid_pem() { + let cert_file = write_temp_file(TEST_CERT); + let certs = load_certs(cert_file.path()).unwrap(); + assert_eq!(certs.len(), 1); + } + + #[test] + fn test_load_certs_missing_file() { + let result = load_certs(Path::new("/nonexistent/cert.pem")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Failed to open")); + } + + #[test] + fn test_load_certs_empty_file() { + let empty_file = write_temp_file(""); + let result = load_certs(empty_file.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("No certificates")); + } + + #[test] + fn test_load_private_key_valid_pem() { + let key_file = write_temp_file(TEST_KEY); + let key = load_private_key(key_file.path()); + assert!(key.is_ok()); + } + + #[test] + fn test_load_private_key_missing_file() { + let result = load_private_key(Path::new("/nonexistent/key.pem")); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Failed to open")); + } + + #[test] + fn test_load_private_key_no_key_in_file() { + let no_key_file = write_temp_file("not a pem file at all\n"); + let result = load_private_key(no_key_file.path()); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("No private key")); + } + + #[test] + fn test_tls_cert_resolver_rejects_mismatched_cert_and_key() { + // Install crypto provider for rustls in test context + let _ = rustls::crypto::ring::default_provider().install_default(); + + let cert_file = write_temp_file(TEST_CERT); + let key_file = write_temp_file(TEST_KEY); + let ca_file = write_temp_file(TEST_CA_CERT); + + let opts = FabricTlsOptions { + mode: restate_types::config::TlsMode::Strict, + cert_file: cert_file.path().to_path_buf(), + key_file: key_file.path().to_path_buf(), + ca_files: vec![ca_file.path().to_path_buf()], + require_client_auth: true, + refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + allowed_subject_names: vec![], + client: None, + }; + + // Our test cert and key are not a matching pair, so this should fail + // during ServerConfig construction. This validates error handling. + let result = TlsCertResolver::new(&opts); + assert!(result.is_err()); + } + + #[test] + fn test_glob_match_exact() { + assert!(glob_match("restate-node", "restate-node")); + assert!(!glob_match("restate-node", "other-node")); + } + + #[test] + fn test_glob_match_trailing_wildcard() { + assert!(glob_match("spiffe://domain/*", "spiffe://domain/admin")); + assert!(glob_match( + "spiffe://domain/*", + "spiffe://domain/worker/staging" + )); + assert!(!glob_match("spiffe://domain/*", "spiffe://other/admin")); + } + + #[test] + fn test_glob_match_middle_wildcard() { + assert!(glob_match("spiffe://*/admin", "spiffe://domain/admin")); + assert!(!glob_match("spiffe://*/admin", "spiffe://domain/worker")); + } + + #[test] + fn test_glob_match_prefix() { + assert!(glob_match("restate-*", "restate-admin")); + assert!(glob_match("restate-*", "restate-worker")); + assert!(!glob_match("restate-*", "other-admin")); + } + + #[test] + fn test_glob_match_multiple_wildcards() { + assert!(glob_match( + "spiffe://*.pin220.com/restate-agents/*", + "spiffe://svc.pin220.com/restate-agents/staging/admin" + )); + } + + fn generate_cert(cn: &str, san_uris: &[&str], san_dns: &[&str]) -> CertificateDer<'static> { + let mut params = rcgen::CertificateParams::new(Vec::::new()).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, cn); + + let mut alt_names = Vec::new(); + for uri in san_uris { + alt_names.push(rcgen::SanType::URI((*uri).try_into().unwrap())); + } + for dns in san_dns { + alt_names.push(rcgen::SanType::DnsName((*dns).try_into().unwrap())); + } + params.subject_alt_names = alt_names; + + let key_pair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + cert.der().clone() + } + + fn make_verifier(patterns: &[&str]) -> SubjectNameVerifier { + SubjectNameVerifier { + inner: Arc::new(rustls::server::NoClientAuth), + allowed_patterns: patterns.iter().map(|s| (*s).to_owned()).collect(), + } + } + + #[test] + fn test_subject_verifier_accepts_matching_san_uri() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert( + "irrelevant-cn", + &["spiffe://svc.pin220.com/restate-agents/staging/admin"], + &[], + ); + assert!(verifier.cert_subject_matches(&cert)); + } + + #[test] + fn test_subject_verifier_accepts_matching_san_dns() { + let verifier = make_verifier(&["restate-*.internal"]); + let cert = generate_cert("irrelevant-cn", &[], &["restate-node1.internal"]); + assert!(verifier.cert_subject_matches(&cert)); + } + + #[test] + fn test_subject_verifier_accepts_matching_cn() { + let verifier = make_verifier(&["restate-*"]); + let cert = generate_cert("restate-admin", &[], &[]); + assert!(verifier.cert_subject_matches(&cert)); + } + + #[test] + fn test_subject_verifier_rejects_non_matching() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert( + "other-service", + &["spiffe://svc.pin220.com/other-service/staging/worker"], + &[], + ); + assert!(!verifier.cert_subject_matches(&cert)); + } + + #[test] + fn test_subject_verifier_rejects_no_match_anywhere() { + let verifier = make_verifier(&["spiffe://svc.pin220.com/restate-agents/*"]); + let cert = generate_cert("unrelated-cn", &[], &[]); + assert!(!verifier.cert_subject_matches(&cert)); + } + + #[test] + fn test_subject_verifier_multiple_patterns() { + let verifier = make_verifier(&[ + "spiffe://svc.pin220.com/restate-agents/*/admin", + "spiffe://svc.pin220.com/restate-agents/*/worker", + ]); + + let admin_cert = generate_cert( + "node", + &["spiffe://svc.pin220.com/restate-agents/staging/admin"], + &[], + ); + let worker_cert = generate_cert( + "node", + &["spiffe://svc.pin220.com/restate-agents/staging/worker"], + &[], + ); + let other_cert = generate_cert( + "node", + &["spiffe://svc.pin220.com/restate-agents/staging/ingress"], + &[], + ); + + assert!(verifier.cert_subject_matches(&admin_cert)); + assert!(verifier.cert_subject_matches(&worker_cert)); + assert!(!verifier.cert_subject_matches(&other_cert)); + } + + #[test] + fn test_subject_verifier_cn_fallback_when_no_san() { + let verifier = make_verifier(&["restate-node-*"]); + let cert = generate_cert("restate-node-1", &[], &[]); + assert!(verifier.cert_subject_matches(&cert)); + + let bad_cert = generate_cert("kafka-broker-1", &[], &[]); + assert!(!verifier.cert_subject_matches(&bad_cert)); + } +} diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index ff4ac19a09..d8750c6e07 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -221,7 +221,20 @@ impl Node { }) }); let mut router_builder = MessageRouterBuilder::with_default_pool(default_pool); - let networking = Networking::with_grpc_connector(); + + // Initialize fabric TLS if configured + let tls_resolver = config.networking.tls.as_ref().map(|tls_opts| { + tls_opts + .validate() + .expect("Invalid fabric TLS configuration"); + let resolver = restate_core::network::tls::TlsCertResolver::new(tls_opts) + .expect("Failed to initialize fabric TLS"); + resolver.spawn_reloader(tls_opts.clone(), *tls_opts.refresh_interval); + resolver + }); + + server_builder.set_tls(tls_resolver.clone()); + let networking = Networking::with_grpc_connector(tls_resolver); metadata_manager.register_in_message_router(&mut router_builder); let replica_set_states = PartitionReplicaSetStates::default(); diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index e16e046c72..1690c519a1 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use std::num::NonZeroUsize; +use std::path::PathBuf; use std::time::Duration; use restate_serde_util::NonZeroByteCount; @@ -103,6 +104,16 @@ pub struct NetworkingOptions { skip_serializing_if = "is_default_fabric_memory_limit" )] fabric_memory_limit: NonZeroByteCount, + + /// # TLS Configuration + /// + /// Optional TLS/mTLS configuration for inter-node fabric communication. + /// When set, the fabric port uses TLS for both inbound and outbound connections. + /// Without this section, fabric communication remains plaintext (default behavior). + /// + /// Since v1.3.0 + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tls: Option, } const fn default_message_size_limit() -> NonZeroByteCount { @@ -159,6 +170,319 @@ impl Default for NetworkingOptions { ), message_size_limit: default_message_size_limit(), fabric_memory_limit: default_fabric_memory_limit(), + tls: None, + } + } +} + +/// TLS mode for fabric inter-node communication. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "lowercase")] +pub enum TlsMode { + /// Only TLS connections are accepted; plaintext is rejected. + #[default] + Strict, + /// Both TLS and plaintext connections are accepted. Use during rolling upgrades. + Optional, +} + +/// TLS configuration for fabric inter-node communication. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "kebab-case")] +pub struct FabricTlsOptions { + /// TLS enforcement mode. Default: `strict`. + #[serde(default)] + pub mode: TlsMode, + + /// Path to the PEM-encoded server certificate. + pub cert_file: PathBuf, + + /// Path to the PEM-encoded private key. + pub key_file: PathBuf, + + /// Paths to PEM-encoded CA certificates for verifying peer certificates. + pub ca_files: Vec, + + /// Require clients to present a valid certificate (mTLS). Default: `true`. + #[serde(default = "default_require_client_auth")] + pub require_client_auth: bool, + + /// How often to reload certificates from disk. Default: `1h`. + #[serde(default = "default_refresh_interval")] + pub refresh_interval: NonZeroFriendlyDuration, + + /// Allowed subject names on peer certificates. After mTLS authentication + /// succeeds, the peer certificate's Subject Common Name (CN) and Subject + /// Alternative Names (DNS names and URIs) are checked against these patterns. + /// Supports `*` glob wildcards (e.g., `spiffe://domain/*`, `restate-*`). + /// + /// Required when `require-client-auth` is `true`. Use `["*"]` to explicitly + /// allow any authenticated peer (CA-only trust). An empty list is a + /// configuration error to prevent accidental fail-open. + /// + /// Since v1.3.0 + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub allowed_subject_names: Vec, + + /// Optional separate TLS configuration for outbound connections to peer nodes. + /// If omitted, the server cert/key/ca are used for outbound connections as well. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client: Option, +} + +/// Separate client TLS config for outbound fabric connections. +/// Fields that are `None` inherit from the parent [`FabricTlsOptions`]. +/// +/// Since v1.3.0 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[serde(rename_all = "kebab-case")] +pub struct FabricTlsClientOptions { + /// Client certificate for outbound connections. Inherits from parent if omitted. + pub cert_file: Option, + + /// Client private key for outbound connections. Inherits from parent if omitted. + pub key_file: Option, + + /// Root CA files for verifying server certificates. Inherits from parent if omitted. + pub root_ca_files: Option>, +} + +impl FabricTlsOptions { + pub fn client_cert_file(&self) -> &PathBuf { + self.client + .as_ref() + .and_then(|c| c.cert_file.as_ref()) + .unwrap_or(&self.cert_file) + } + + pub fn client_key_file(&self) -> &PathBuf { + self.client + .as_ref() + .and_then(|c| c.key_file.as_ref()) + .unwrap_or(&self.key_file) + } + + pub fn client_ca_files(&self) -> &[PathBuf] { + self.client + .as_ref() + .and_then(|c| c.root_ca_files.as_deref()) + .unwrap_or(&self.ca_files) + } + + pub fn is_strict(&self) -> bool { + self.mode == TlsMode::Strict + } + + pub fn validate(&self) -> Result<(), anyhow::Error> { + if self.require_client_auth && self.allowed_subject_names.is_empty() { + anyhow::bail!( + "[networking.tls] require-client-auth is true but allowed-subject-names is empty. \ + Specify allowed patterns (e.g., [\"spiffe://domain/*\"]) or set [\"*\"] \ + to explicitly allow any authenticated peer." + ); } + Ok(()) + } +} + +fn default_require_client_auth() -> bool { + true +} + +fn default_refresh_interval() -> NonZeroFriendlyDuration { + NonZeroFriendlyDuration::from_secs_unchecked(3600) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tls_config_minimal_parsing() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.mode, TlsMode::Strict); // default + assert_eq!(opts.cert_file, PathBuf::from("/certs/node.crt")); + assert_eq!(opts.key_file, PathBuf::from("/certs/node.key")); + assert_eq!(opts.ca_files, vec![PathBuf::from("/certs/ca.crt")]); + assert!(opts.require_client_auth); // default true + assert_eq!(*opts.refresh_interval, Duration::from_secs(3600)); // default 1h + assert!(opts.client.is_none()); + } + + #[test] + fn test_tls_config_full_parsing() { + let toml_str = r#" + mode = "optional" + cert-file = "/certs/server.crt" + key-file = "/certs/server.key" + ca-files = ["/certs/ca1.crt", "/certs/ca2.crt"] + require-client-auth = false + refresh-interval = "15m" + + [client] + cert-file = "/certs/client.crt" + key-file = "/certs/client.key" + root-ca-files = ["/certs/client-ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.mode, TlsMode::Optional); + assert!(!opts.require_client_auth); + assert_eq!(*opts.refresh_interval, Duration::from_secs(900)); + assert!(!opts.is_strict()); + + let client = opts.client.as_ref().unwrap(); + assert_eq!(client.cert_file, Some(PathBuf::from("/certs/client.crt"))); + assert_eq!(client.key_file, Some(PathBuf::from("/certs/client.key"))); + assert_eq!( + client.root_ca_files, + Some(vec![PathBuf::from("/certs/client-ca.crt")]) + ); + } + + #[test] + fn test_tls_client_inheritance() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + // Without [client] section, client methods inherit from parent + assert_eq!(opts.client_cert_file(), &PathBuf::from("/certs/node.crt")); + assert_eq!(opts.client_key_file(), &PathBuf::from("/certs/node.key")); + assert_eq!(opts.client_ca_files(), &[PathBuf::from("/certs/ca.crt")]); + } + + #[test] + fn test_tls_client_override() { + let toml_str = r#" + cert-file = "/certs/server.crt" + key-file = "/certs/server.key" + ca-files = ["/certs/server-ca.crt"] + + [client] + cert-file = "/certs/client.crt" + key-file = "/certs/client.key" + root-ca-files = ["/certs/client-ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + // Client methods should return overridden values + assert_eq!(opts.client_cert_file(), &PathBuf::from("/certs/client.crt")); + assert_eq!(opts.client_key_file(), &PathBuf::from("/certs/client.key")); + assert_eq!( + opts.client_ca_files(), + &[PathBuf::from("/certs/client-ca.crt")] + ); + } + + #[test] + fn test_networking_options_tls_none_by_default() { + let opts = NetworkingOptions::default(); + assert!(opts.tls.is_none()); + } + + #[test] + fn test_tls_config_with_allowed_subject_names() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + allowed-subject-names = [ + "spiffe://svc.pin220.com/restate-agents/*/admin", + "spiffe://svc.pin220.com/restate-agents/*/worker", + "spiffe://svc.pin220.com/restate-agents/*/ingress", + ] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + + assert_eq!(opts.allowed_subject_names.len(), 3); + assert_eq!( + opts.allowed_subject_names[0], + "spiffe://svc.pin220.com/restate-agents/*/admin" + ); + assert!(opts.require_client_auth); + } + + #[test] + fn test_tls_config_allowed_subject_names_empty_by_default() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.allowed_subject_names.is_empty()); + } + + #[test] + fn test_validate_rejects_empty_allowed_subject_names_with_client_auth() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = true + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + let result = opts.validate(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("allowed-subject-names is empty") + ); + } + + #[test] + fn test_validate_accepts_wildcard_allowed_subject_names() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = true + allowed-subject-names = ["*"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); + } + + #[test] + fn test_validate_accepts_empty_when_no_client_auth() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + require-client-auth = false + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); + } + + #[test] + fn test_validate_accepts_specific_patterns() { + let toml_str = r#" + cert-file = "/certs/node.crt" + key-file = "/certs/node.key" + ca-files = ["/certs/ca.crt"] + allowed-subject-names = ["spiffe://domain/restate-*"] + "#; + let opts: FabricTlsOptions = toml::from_str(toml_str).unwrap(); + assert!(opts.validate().is_ok()); } } diff --git a/crates/types/src/net/address.rs b/crates/types/src/net/address.rs index b373c5ff07..7daa648e99 100644 --- a/crates/types/src/net/address.rs +++ b/crates/types/src/net/address.rs @@ -271,6 +271,11 @@ impl PeerNetAddress { pub fn is_http(&self) -> bool { matches!(self, PeerNetAddress::Http(_)) } + + /// Returns true if this address uses the `https` scheme (TLS). + pub fn is_tls(&self) -> bool { + matches!(self, PeerNetAddress::Http(uri) if uri.scheme() == Some(&http::uri::Scheme::HTTPS)) + } } #[derive( @@ -360,6 +365,15 @@ impl Default for AdvertisedAddress

{ impl AdvertisedAddress

{ pub fn derive_from_bind_address(address: SocketAddress, advertised_host: Option<&str>) -> Self { + Self::derive_from_bind_address_with_tls(address, advertised_host, false) + } + + pub fn derive_from_bind_address_with_tls( + address: SocketAddress, + advertised_host: Option<&str>, + tls: bool, + ) -> Self { + let scheme = if tls { "https" } else { "http" }; let inner = match address { SocketAddress::Socket(address) => { let routable_ip = || { @@ -380,21 +394,19 @@ impl AdvertisedAddress

{ // do we have an input hostname? let hostname = advertised_host.unwrap_or_else(|| routable_ip()); PeerNetAddress::Http( - format!("http://{hostname}:{}", address.port()) + format!("{scheme}://{hostname}:{}", address.port()) .parse() .expect("valid uri"), ) } - SocketAddress::Uds(path) => { - // it's a UDS address, we'll use the path. - PeerNetAddress::Uds(path) - } + // it's a UDS address, we'll use the path. + SocketAddress::Uds(path) => PeerNetAddress::Uds(path), SocketAddress::Anonymous => { // In case this is an anonymous unix-socket, we'll fallback to a generic // localhost-based address without a port. The assumption is the caller // will proxy their request through the unix-socket and the host+scheme // part of the URI will be ignored by the server. - PeerNetAddress::Http("http://localhost".parse().expect("valid uri")) + PeerNetAddress::Http(format!("{scheme}://localhost").parse().expect("valid uri")) } }; @@ -743,4 +755,49 @@ mod tests { let result = input.parse::>(); assert!(result.is_err(), "Expected an error for empty input"); } + + #[test] + fn test_peer_net_address_is_tls() { + // https scheme is TLS + let addr: AdvertisedAddress = "https://10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(peer.is_tls()); + + // http scheme is not TLS + let addr: AdvertisedAddress = "http://10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + + // bare host (defaults to http) is not TLS + let addr: AdvertisedAddress = "10.0.0.1:5122".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + + // UDS is not TLS + let addr: AdvertisedAddress = "unix:/tmp/fabric.sock".parse().unwrap(); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + } + + #[test] + fn test_derive_from_bind_address_with_tls() { + let socket = SocketAddress::Socket("192.168.1.1:5122".parse().unwrap()); + + // Without TLS — should produce http:// + let addr = AdvertisedAddress::::derive_from_bind_address_with_tls( + socket.clone(), + None, + false, + ); + let peer = addr.into_address().unwrap(); + assert!(!peer.is_tls()); + assert!(peer.to_string().starts_with("http://")); + + // With TLS — should produce https:// + let addr = + AdvertisedAddress::::derive_from_bind_address_with_tls(socket, None, true); + let peer = addr.into_address().unwrap(); + assert!(peer.is_tls()); + assert!(peer.to_string().starts_with("https://")); + } } diff --git a/server/Cargo.toml b/server/Cargo.toml index bf1f79793b..e6109875d2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -80,6 +80,8 @@ mock-service-endpoint = { workspace = true } anyhow = { workspace = true } bytestring = { workspace = true} googletest = { workspace = true } +rcgen = { workspace = true } +restate-time-util = { workspace = true } tempfile = { workspace = true } test-log = { workspace = true } tonic = { workspace = true, features = ["transport"] } diff --git a/server/tests/fabric_tls.rs b/server/tests/fabric_tls.rs new file mode 100644 index 0000000000..d6aa2b53dc --- /dev/null +++ b/server/tests/fabric_tls.rs @@ -0,0 +1,187 @@ +// Copyright (c) 2023 - 2026 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::path::Path; +use std::time::Duration; + +use enumset::EnumSet; +use googletest::IntoTestResult; +use rcgen::{CertificateParams, KeyPair}; +use tempfile::TempDir; +use tracing::info; + +use restate_local_cluster_runner::{ + cluster::Cluster, + node::{BinarySource, NodeSpec}, +}; +use restate_types::config::{Configuration, FabricTlsOptions, TlsMode}; +use restate_types::replication::ReplicationProperty; + +mod common; + +fn generate_ca() -> (rcgen::Certificate, KeyPair) { + let mut params = CertificateParams::new(Vec::::new()).unwrap(); + params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "test-ca"); + let key_pair = KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + (cert, key_pair) +} + +fn generate_node_cert( + ca_cert: &rcgen::Certificate, + ca_key: &KeyPair, + node_name: &str, +) -> (rcgen::Certificate, KeyPair) { + let mut params = CertificateParams::new(vec![node_name.to_owned()]).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, node_name); + let node_key = KeyPair::generate().unwrap(); + let node_cert = params.signed_by(&node_key, ca_cert, ca_key).unwrap(); + (node_cert, node_key) +} + +fn write_certs_to_dir( + dir: &Path, + ca_cert: &rcgen::Certificate, + node_cert: &rcgen::Certificate, + node_key: &KeyPair, +) -> (std::path::PathBuf, std::path::PathBuf, std::path::PathBuf) { + let ca_path = dir.join("ca.pem"); + let cert_path = dir.join("node.pem"); + let key_path = dir.join("node-key.pem"); + + std::fs::write(&ca_path, ca_cert.pem()).unwrap(); + std::fs::write(&cert_path, node_cert.pem()).unwrap(); + std::fs::write(&key_path, node_key.serialize_pem()).unwrap(); + + (ca_path, cert_path, key_path) +} + +fn configure_tls_nodes( + base_config: Configuration, + tls_dir: &Path, + ca_cert: &rcgen::Certificate, + ca_key: &KeyPair, + num_nodes: u32, + mode: TlsMode, +) -> Vec { + let mut nodes = NodeSpec::new_test_nodes( + base_config, + BinarySource::CargoTest, + EnumSet::all(), + num_nodes, + false, + ); + + for (i, node) in nodes.iter_mut().enumerate() { + let node_name = format!("node-{}", i + 1); + let node_dir = tls_dir.join(&node_name); + std::fs::create_dir_all(&node_dir).unwrap(); + + let (node_cert, node_key) = generate_node_cert(ca_cert, ca_key, &node_name); + let (ca_path, cert_path, key_path) = + write_certs_to_dir(&node_dir, ca_cert, &node_cert, &node_key); + + node.config_mut().networking.tls = Some(FabricTlsOptions { + mode: mode.clone(), + cert_file: cert_path, + key_file: key_path, + ca_files: vec![ca_path], + require_client_auth: true, + refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600), + allowed_subject_names: vec!["*".into()], + client: None, + }); + } + + nodes +} + +#[test_log::test(restate_core::test)] +async fn fabric_tls_strict_cluster() -> googletest::Result<()> { + let tls_dir = TempDir::new().unwrap(); + let (ca_cert, ca_key) = generate_ca(); + + let mut base_config = Configuration::new_random_ports(); + base_config.common.auto_provision = false; + base_config.common.default_num_partitions = 1; + + let nodes = configure_tls_nodes( + base_config, + tls_dir.path(), + &ca_cert, + &ca_key, + 3, + TlsMode::Strict, + ); + + info!("Starting 3-node cluster with strict mTLS"); + let cluster = Cluster::builder() + .cluster_name("tls-strict-cluster") + .nodes(nodes) + .temp_base_dir("fabric_tls_strict") + .build() + .start() + .await?; + + cluster.nodes[0] + .provision_cluster(None, ReplicationProperty::new_unchecked(3), None) + .await + .into_test_result()?; + + info!("Waiting for cluster to become healthy over mTLS"); + cluster.wait_healthy(Duration::from_secs(30)).await?; + + info!("Cluster is healthy with strict mTLS — test passed"); + Ok(()) +} + +#[test_log::test(restate_core::test)] +async fn fabric_tls_optional_mode() -> googletest::Result<()> { + let tls_dir = TempDir::new().unwrap(); + let (ca_cert, ca_key) = generate_ca(); + + let mut base_config = Configuration::new_random_ports(); + base_config.common.auto_provision = false; + base_config.common.default_num_partitions = 1; + + let nodes = configure_tls_nodes( + base_config, + tls_dir.path(), + &ca_cert, + &ca_key, + 3, + TlsMode::Optional, + ); + + info!("Starting 3-node cluster with optional TLS mode"); + let cluster = Cluster::builder() + .cluster_name("tls-optional-cluster") + .nodes(nodes) + .temp_base_dir("fabric_tls_optional") + .build() + .start() + .await?; + + cluster.nodes[0] + .provision_cluster(None, ReplicationProperty::new_unchecked(3), None) + .await + .into_test_result()?; + + info!("Waiting for cluster to become healthy (optional mode)"); + cluster.wait_healthy(Duration::from_secs(30)).await?; + + info!("Cluster is healthy with optional TLS mode — test passed"); + Ok(()) +}