Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/cli-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mod os_env;
pub mod ui;

pub use context::CliContext;
pub use opts::CommonOpts;
pub use opts::{CommonOpts, NetworkOpts};
pub use os_env::OsEnv;

// Re-export comfy-table for console c_* macros (used internally by macros)
Expand Down
21 changes: 21 additions & 0 deletions crates/cli-util/src/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ pub struct NetworkOpts {
/// Sets the maximum size of a network messages.
#[arg[long, default_value_t = DEFAULT_MESSAGE_SIZE_LIMIT, global = true, hide = true]]
pub message_size_limit: NonZeroUsize,

/// Path to PEM-encoded CA certificate for verifying the server's TLS certificate.
#[arg(long, global = true, env = "RESTATECTL_TLS_CA")]
pub tls_ca: Option<std::path::PathBuf>,

/// Path to PEM-encoded client certificate for mTLS authentication.
#[arg(long, global = true, env = "RESTATECTL_TLS_CERT", requires = "tls_key")]
pub tls_cert: Option<std::path::PathBuf>,

/// Path to PEM-encoded client private key for mTLS authentication.
#[arg(long, global = true, env = "RESTATECTL_TLS_KEY", requires = "tls_cert")]
pub tls_key: Option<std::path::PathBuf>,
}

impl Default for NetworkOpts {
Expand All @@ -107,10 +119,19 @@ impl Default for NetworkOpts {
request_timeout: DEFAULT_REQUEST_TIMEOUT,
insecure_skip_tls_verify: false,
message_size_limit: DEFAULT_MESSAGE_SIZE_LIMIT,
tls_ca: None,
tls_cert: None,
tls_key: None,
}
}
}

impl NetworkOpts {
pub fn tls_enabled(&self) -> bool {
self.tls_ca.is_some() || self.tls_cert.is_some()
}
}

impl GrpcConnectionOptions for NetworkOpts {
fn message_size_limit(&self) -> NonZeroUsize {
self.message_size_limit
Expand Down
6 changes: 5 additions & 1 deletion tools/restatectl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ diff = "0.1.13"
futures = { workspace = true }
itertools = { workspace = true }
json-patch = "2.0.0"
http = { workspace = true }
humantime = { workspace = true }
prost-types = { workspace = true }
rand = { workspace = true }
Expand All @@ -73,9 +74,12 @@ serde = { workspace = true }
serde_json = { workspace = true }
test-log = { workspace = true }
thiserror = { workspace = true }
toml = { workspace = true }
tokio = { workspace = true }
tokio-rustls = { workspace = true }
toml = { workspace = true }
tonic = { workspace = true, features = ["transport", "zstd", "gzip"] }
hyper-util = { workspace = true }
tower = { workspace = true }
tracing = { workspace = true }

[build-dependencies]
Expand Down
66 changes: 64 additions & 2 deletions tools/restatectl/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,81 @@
use std::{
fmt::{self, Display},
str::FromStr,
sync::Arc,
};

use rustls::pki_types::ServerName;
use tokio::io;
use tokio_rustls::TlsConnector;
use tonic::transport::Channel;

use restate_cli_util::CliContext;
use restate_core::network::net_util::{DNSResolution, create_tonic_channel};
use restate_core::network::tls::TlsCertResolver;
use restate_types::{
config::FabricTlsOptions,
logs::metadata::ProviderConfiguration,
net::address::{AdvertisedAddress, GrpcPort, ListenerPort},
net::address::{AdvertisedAddress, GrpcPort, ListenerPort, PeerNetAddress},
};

pub fn grpc_channel<P: ListenerPort + GrpcPort>(address: AdvertisedAddress<P>) -> Channel {
let ctx = CliContext::get();
create_tonic_channel(address, &ctx.network, DNSResolution::Gai)

if ctx.network.tls_enabled() {
grpc_channel_with_tls(address, &ctx.network)
} else {
create_tonic_channel(address, &ctx.network, DNSResolution::Gai)
}
}

fn grpc_channel_with_tls<P: ListenerPort + GrpcPort>(
address: AdvertisedAddress<P>,
opts: &restate_cli_util::NetworkOpts,
) -> Channel {
use hyper_util::rt::TokioIo;
use tonic::transport::Endpoint;

let address = address.into_address().expect("valid address");
let PeerNetAddress::Http(uri) = &address else {
panic!("TLS is not supported for Unix domain sockets");
};

let endpoint = Endpoint::from(uri.clone())
.connect_timeout(std::time::Duration::from_millis(opts.connect_timeout))
.keep_alive_while_idle(true)
.tcp_nodelay(true);

let tls_opts = FabricTlsOptions {
mode: restate_types::config::TlsMode::Strict,
cert_file: opts.tls_cert.clone().unwrap_or_default(),
key_file: opts.tls_key.clone().unwrap_or_default(),
ca_files: opts.tls_ca.iter().cloned().collect(),
require_client_auth: false,
refresh_interval: restate_time_util::NonZeroFriendlyDuration::from_secs_unchecked(3600),
allowed_subject_names: vec![],
client: None,
};

let resolver =
TlsCertResolver::new(&tls_opts).expect("Failed to initialize TLS for restatectl");
let client_config = resolver.client_config();

let host = uri.host().unwrap_or("localhost").to_owned();
let port = uri.port_u16().unwrap_or(P::DEFAULT_PORT);

endpoint.connect_with_connector_lazy(tower::service_fn(move |_: http::Uri| {
let client_config = Arc::clone(&client_config);
let host = host.clone();
async move {
let tcp_stream = tokio::net::TcpStream::connect(format!("{host}:{port}")).await?;
tcp_stream.set_nodelay(true)?;
let connector = TlsConnector::from(client_config);
let server_name = ServerName::try_from(host)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let tls_stream = connector.connect(server_name, tcp_stream).await?;
Ok::<_, io::Error>(TokioIo::new(tls_stream))
}
}))
}

pub fn write_default_provider<W: fmt::Write>(
Expand Down
Loading