diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index 3a77fb4bcf..47663b7faf 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -228,113 +228,51 @@ where let graceful_shutdown = &graceful_shutdown; let task_name = task_name.clone(); - match (&tls_resolver, &tls_mode) { + // Resolve TLS handshake or pass through plaintext + let use_tls = match (&tls_resolver, &tls_mode) { (Some(resolver), Some(TlsMode::Strict)) => { - // TLS strict: all connections must be TLS - let acceptor = resolver.tls_acceptor(); - let connection = match acceptor.accept(tcp_stream).await { - Ok(tls_stream) => { - let io = TokioIo::new(tls_stream); - graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) - } - Err(e) => { - debug!("TLS handshake failed: {e}"); - continue; - } - }; - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New TLS tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; + Some(resolver.tls_acceptor()) } (Some(resolver), Some(TlsMode::Optional)) => { - // TLS optional: peek first byte to detect TLS ClientHello - let tcp_stream = tcp_stream; let mut peek_buf = [0u8; 1]; - match tcp_stream.peek(&mut peek_buf).await { - Ok(1) if peek_buf[0] == 0x16 => { - // TLS ClientHello detected - let acceptor = resolver.tls_acceptor(); - let connection = match acceptor.accept(tcp_stream).await { - Ok(tls_stream) => { - let io = TokioIo::new(tls_stream); - graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()) - } - Err(e) => { - debug!("TLS handshake failed: {e}"); - continue; - } - }; - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New TLS tcp connection accepted (optional mode)"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; - } - _ => { - // Plaintext connection - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder.serve_connection(io, service).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New plaintext tcp connection accepted (optional mode)"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) - }.instrument(socket_span))?; - } - } - } - _ => { - // No TLS: plaintext (current behavior) - let io = TokioIo::new(tcp_stream); - let connection = graceful_shutdown.watch(builder - .serve_connection(io, service).into_owned()); - TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { - trace!("New tcp connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } + if let Ok(1) = tcp_stream.peek(&mut peek_buf).await { + if peek_buf[0] == 0x16 { + Some(resolver.tls_acceptor()) } else { - trace!("Connection completed cleanly"); + None } - Ok(()) - }.instrument(socket_span))?; + } else { + None + } } + _ => None, + }; + + if let Some(acceptor) = use_tls { + let connection = match acceptor.accept(tcp_stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + graceful_shutdown.watch( + builder.serve_connection(io, service).into_owned(), + ) + } + Err(e) => { + debug!("TLS handshake failed: {e}"); + continue; + } + }; + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New TLS tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; + } else { + let io = TokioIo::new(tcp_stream); + let connection = graceful_shutdown + .watch(builder.serve_connection(io, service).into_owned()); + TaskCenter::spawn(TaskKind::SocketHandler, task_name, async move { + trace!("New tcp connection accepted"); + serve_connection(connection).await + }.instrument(socket_span))?; } }, Either::Right(unix_stream) => { @@ -344,18 +282,7 @@ where .serve_connection(io, service.clone()).into_owned()); TaskCenter::spawn(TaskKind::SocketHandler, task_name.clone(), async move { trace!("New uds connection accepted"); - if let Err(e) = connection.await { - if let Some(hyper_error) = e.downcast_ref::() { - if hyper_error.is_incomplete_message() { - debug!("Connection closed before request completed"); - } - } else { - debug!("Connection terminated due to error: {e}"); - } - } else { - trace!("Connection completed cleanly"); - } - Ok(()) + serve_connection(connection).await }.instrument(socket_span))?; } } @@ -377,6 +304,23 @@ where Ok(()) } +async fn serve_connection( + connection: impl Future>>, +) -> Result<(), anyhow::Error> { + if let Err(e) = connection.await { + if let Some(hyper_error) = e.downcast_ref::() { + if hyper_error.is_incomplete_message() { + debug!("Connection closed before request completed"); + } + } else { + debug!("Connection terminated due to error: {e}"); + } + } else { + trace!("Connection completed cleanly"); + } + Ok(()) +} + #[derive(Clone, Default)] struct TaskCenterExecutor;