Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 56 additions & 112 deletions crates/core/src/network/net_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,113 +228,51 @@ where
let graceful_shutdown = &graceful_shutdown;
let task_name = task_name.clone();

match (&tls_resolver, &tls_mode) {
// Resolve TLS handshake or pass through plaintext
let use_tls = match (&tls_resolver, &tls_mode) {
(Some(resolver), Some(TlsMode::Strict)) => {
// TLS strict: all connections must be TLS
let acceptor = resolver.tls_acceptor();
let connection = match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
graceful_shutdown.watch(builder.serve_connection(io, service).into_owned())
}
Err(e) => {
debug!("TLS handshake failed: {e}");
continue;
}
};
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New TLS tcp connection accepted");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
} else {
trace!("Connection completed cleanly");
}
Ok(())
}.instrument(socket_span))?;
Some(resolver.tls_acceptor())
}
(Some(resolver), Some(TlsMode::Optional)) => {
// TLS optional: peek first byte to detect TLS ClientHello
let tcp_stream = tcp_stream;
let mut peek_buf = [0u8; 1];
match tcp_stream.peek(&mut peek_buf).await {
Ok(1) if peek_buf[0] == 0x16 => {
// TLS ClientHello detected
let acceptor = resolver.tls_acceptor();
let connection = match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
graceful_shutdown.watch(builder.serve_connection(io, service).into_owned())
}
Err(e) => {
debug!("TLS handshake failed: {e}");
continue;
}
};
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New TLS tcp connection accepted (optional mode)");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
} else {
trace!("Connection completed cleanly");
}
Ok(())
}.instrument(socket_span))?;
}
_ => {
// Plaintext connection
let io = TokioIo::new(tcp_stream);
let connection = graceful_shutdown.watch(builder.serve_connection(io, service).into_owned());
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New plaintext tcp connection accepted (optional mode)");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
} else {
trace!("Connection completed cleanly");
}
Ok(())
}.instrument(socket_span))?;
}
}
}
_ => {
// No TLS: plaintext (current behavior)
let io = TokioIo::new(tcp_stream);
let connection = graceful_shutdown.watch(builder
.serve_connection(io, service).into_owned());
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New tcp connection accepted");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
if let Ok(1) = tcp_stream.peek(&mut peek_buf).await {
if peek_buf[0] == 0x16 {
Some(resolver.tls_acceptor())
} else {
trace!("Connection completed cleanly");
None
}
Ok(())
}.instrument(socket_span))?;
} else {
None
}
}
_ => None,
};

if let Some(acceptor) = use_tls {
let connection = match acceptor.accept(tcp_stream).await {
Ok(tls_stream) => {
let io = TokioIo::new(tls_stream);
graceful_shutdown.watch(
builder.serve_connection(io, service).into_owned(),
)
}
Err(e) => {
debug!("TLS handshake failed: {e}");
continue;
}
};
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New TLS tcp connection accepted");
serve_connection(connection).await
}.instrument(socket_span))?;
} else {
let io = TokioIo::new(tcp_stream);
let connection = graceful_shutdown
.watch(builder.serve_connection(io, service).into_owned());
TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move {
trace!("New tcp connection accepted");
serve_connection(connection).await
}.instrument(socket_span))?;
}
},
Either::Right(unix_stream) => {
Expand All @@ -344,18 +282,7 @@ where
.serve_connection(io, service.clone()).into_owned());
TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move {
trace!("New uds connection accepted");
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
} else {
trace!("Connection completed cleanly");
}
Ok(())
serve_connection(connection).await
}.instrument(socket_span))?;
}
}
Expand All @@ -377,6 +304,23 @@ where
Ok(())
}

async fn serve_connection(
connection: impl Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>,
) -> Result<(), anyhow::Error> {
if let Err(e) = connection.await {
if let Some(hyper_error) = e.downcast_ref::<hyper::Error>() {
if hyper_error.is_incomplete_message() {
debug!("Connection closed before request completed");
}
} else {
debug!("Connection terminated due to error: {e}");
}
} else {
trace!("Connection completed cleanly");
}
Ok(())
}

#[derive(Clone, Default)]
struct TaskCenterExecutor;

Expand Down
Loading