From 5374828b0bbc8de959be5b3381b5b4a7d5c3d6f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 18 Dec 2024 03:11:04 +0000 Subject: [PATCH] tls feature improvements (#228) * tls improvements reexport tokio_rustls since TlsAcceptor type must come from same version breaking change: when tls not enabled still require None to be passed, thus code written without tls feature won't be broken when tls enabled * remove Arc * Update src/tokio/server.rs --------- Co-authored-by: Ning Sun --- examples/scram.rs | 2 +- examples/secure_server.rs | 2 +- src/tokio/mod.rs | 7 ++++++ src/tokio/server.rs | 27 ++++++++++------------- tests-integration/test-server/src/main.rs | 2 +- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/examples/scram.rs b/examples/scram.rs index ea10950..1e47903 100644 --- a/examples/scram.rs +++ b/examples/scram.rs @@ -124,7 +124,7 @@ pub async fn main() { }); let server_addr = "127.0.0.1:5432"; - let tls_acceptor = Arc::new(setup_tls().unwrap()); + let tls_acceptor = setup_tls().unwrap(); let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop { diff --git a/examples/secure_server.rs b/examples/secure_server.rs index 07292dc..656638f 100644 --- a/examples/secure_server.rs +++ b/examples/secure_server.rs @@ -120,7 +120,7 @@ pub async fn main() { }); let server_addr = "127.0.0.1:5433"; - let tls_acceptor = Arc::new(setup_tls().unwrap()); + let tls_acceptor = setup_tls().unwrap(); let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); diff --git a/src/tokio/mod.rs b/src/tokio/mod.rs index 2092d7a..9f6176b 100644 --- a/src/tokio/mod.rs +++ b/src/tokio/mod.rs @@ -1,3 +1,10 @@ mod server; pub use server::process_socket; +#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] +pub use tokio_rustls; +#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] +pub type TlsAcceptor = tokio_rustls::TlsAcceptor; + +#[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))] +pub enum TlsAcceptor {} diff --git a/src/tokio/server.rs b/src/tokio/server.rs index 078a6ec..6483068 100644 --- a/src/tokio/server.rs +++ b/src/tokio/server.rs @@ -1,4 +1,4 @@ -use std::io::Error as IOError; +use std::io; use std::sync::Arc; use bytes::Buf; @@ -6,7 +6,7 @@ use futures::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] -use tokio_rustls::{server::TlsStream, TlsAcceptor}; +use tokio_rustls::server::TlsStream; use tokio_util::codec::{Decoder, Encoder, Framed}; use crate::api::auth::StartupHandler; @@ -66,7 +66,7 @@ impl Decoder for PgWireMessageServerCodec { } impl Encoder for PgWireMessageServerCodec { - type Error = IOError; + type Error = io::Error; fn encode( &mut self, @@ -237,7 +237,7 @@ async fn process_error( socket: &mut Framed>, error: PgWireError, wait_for_sync: bool, -) -> Result<(), IOError> +) -> Result<(), io::Error> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync, { @@ -289,7 +289,7 @@ enum SslNegotiationType { None, } -async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result { +async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result { let mut buf = [0u8; 1]; let n = tcp_socket.peek(&mut buf).await?; @@ -299,7 +299,7 @@ async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result( socket: &mut Framed>, ssl_supported: bool, -) -> Result { +) -> Result { if check_ssl_direct_negotiation(socket.get_ref()).await? { Ok(SslNegotiationType::Direct) } else if let Some(Ok(PgWireFrontendMessage::SslRequest(Some(_)))) = socket.next().await { @@ -326,7 +326,7 @@ async fn do_process_socket( extended_query_handler: Arc, copy_handler: Arc, error_handler: Arc, -) -> Result<(), IOError> +) -> Result<(), io::Error> where S: AsyncRead + AsyncWrite + Unpin + Send + Sync, A: StartupHandler, @@ -359,7 +359,7 @@ where } #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] -fn check_alpn_for_direct_ssl(tls_socket: &TlsStream) -> Result<(), IOError> { +fn check_alpn_for_direct_ssl(tls_socket: &TlsStream) -> Result<(), io::Error> { let (_, the_conn) = tls_socket.get_ref(); let mut accept = false; @@ -370,8 +370,8 @@ fn check_alpn_for_direct_ssl(tls_socket: &TlsStream) -> Result<(), IOErr } if !accept { - Err(IOError::new( - std::io::ErrorKind::InvalidData, + Err(io::Error::new( + io::ErrorKind::InvalidData, "received direct SSL connection request without ALPN protocol negotiation extension", )) } else { @@ -381,9 +381,9 @@ fn check_alpn_for_direct_ssl(tls_socket: &TlsStream) -> Result<(), IOErr pub async fn process_socket( tcp_socket: TcpStream, - #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] tls_acceptor: Option>, + tls_acceptor: Option, handlers: H, -) -> Result<(), IOError> +) -> Result<(), io::Error> where H: PgWireServerHandlers, { @@ -393,10 +393,7 @@ where let client_info = DefaultClient::new(addr, false); let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info)); - #[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?; - #[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))] - let ssl = peek_for_sslrequest(&mut tcp_socket, false).await?; let startup_handler = handlers.startup_handler(); let simple_query_handler = handlers.simple_query_handler(); diff --git a/tests-integration/test-server/src/main.rs b/tests-integration/test-server/src/main.rs index 41d3f2a..c85a90a 100644 --- a/tests-integration/test-server/src/main.rs +++ b/tests-integration/test-server/src/main.rs @@ -280,7 +280,7 @@ pub async fn main() { let factory = Arc::new(DummyDatabaseFactory(Arc::new(DummyDatabase::default()))); let server_addr = "127.0.0.1:5432"; - let tls_acceptor = Arc::new(setup_tls().unwrap()); + let tls_acceptor = setup_tls().unwrap(); let listener = TcpListener::bind(server_addr).await.unwrap(); println!("Listening to {}", server_addr); loop {