Skip to content

Commit

Permalink
tls feature improvements (#228)
Browse files Browse the repository at this point in the history
* tls improvements

reexport tokio_rustls since TlsAcceptor type must come from same version

breaking change: when tls not enabled still require None to be passed,
thus code written without tls feature won't be broken when tls enabled

* remove Arc

* Update src/tokio/server.rs

---------

Co-authored-by: Ning Sun <[email protected]>
  • Loading branch information
serprex and sunng87 authored Dec 18, 2024
1 parent b179036 commit 5374828
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/scram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pub async fn main() {
});

let server_addr = "127.0.0.1:5432";
let tls_acceptor = Arc::new(setup_tls().unwrap());
let tls_acceptor = setup_tls().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
Expand Down
2 changes: 1 addition & 1 deletion examples/secure_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub async fn main() {
});

let server_addr = "127.0.0.1:5433";
let tls_acceptor = Arc::new(setup_tls().unwrap());
let tls_acceptor = setup_tls().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();

println!("Listening to {}", server_addr);
Expand Down
7 changes: 7 additions & 0 deletions src/tokio/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
mod server;

pub use server::process_socket;
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
pub use tokio_rustls;
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
pub type TlsAcceptor = tokio_rustls::TlsAcceptor;

#[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))]
pub enum TlsAcceptor {}
27 changes: 12 additions & 15 deletions src/tokio/server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::io::Error as IOError;
use std::io;
use std::sync::Arc;

use bytes::Buf;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tokio_rustls::server::TlsStream;
use tokio_util::codec::{Decoder, Encoder, Framed};

use crate::api::auth::StartupHandler;
Expand Down Expand Up @@ -66,7 +66,7 @@ impl<S> Decoder for PgWireMessageServerCodec<S> {
}

impl<S> Encoder<PgWireBackendMessage> for PgWireMessageServerCodec<S> {
type Error = IOError;
type Error = io::Error;

fn encode(
&mut self,
Expand Down Expand Up @@ -237,7 +237,7 @@ async fn process_error<S, ST>(
socket: &mut Framed<S, PgWireMessageServerCodec<ST>>,
error: PgWireError,
wait_for_sync: bool,
) -> Result<(), IOError>
) -> Result<(), io::Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
{
Expand Down Expand Up @@ -289,7 +289,7 @@ enum SslNegotiationType {
None,
}

async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, IOError> {
async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, io::Error> {
let mut buf = [0u8; 1];
let n = tcp_socket.peek(&mut buf).await?;

Expand All @@ -299,7 +299,7 @@ async fn check_ssl_direct_negotiation(tcp_socket: &TcpStream) -> Result<bool, IO
async fn peek_for_sslrequest<ST>(
socket: &mut Framed<TcpStream, PgWireMessageServerCodec<ST>>,
ssl_supported: bool,
) -> Result<SslNegotiationType, IOError> {
) -> Result<SslNegotiationType, io::Error> {
if check_ssl_direct_negotiation(socket.get_ref()).await? {
Ok(SslNegotiationType::Direct)
} else if let Some(Ok(PgWireFrontendMessage::SslRequest(Some(_)))) = socket.next().await {
Expand All @@ -326,7 +326,7 @@ async fn do_process_socket<S, A, Q, EQ, C, E>(
extended_query_handler: Arc<EQ>,
copy_handler: Arc<C>,
error_handler: Arc<E>,
) -> Result<(), IOError>
) -> Result<(), io::Error>
where
S: AsyncRead + AsyncWrite + Unpin + Send + Sync,
A: StartupHandler,
Expand Down Expand Up @@ -359,7 +359,7 @@ where
}

#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), IOError> {
fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), io::Error> {
let (_, the_conn) = tls_socket.get_ref();
let mut accept = false;

Expand All @@ -370,8 +370,8 @@ fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), IOErr
}

if !accept {
Err(IOError::new(
std::io::ErrorKind::InvalidData,
Err(io::Error::new(
io::ErrorKind::InvalidData,
"received direct SSL connection request without ALPN protocol negotiation extension",
))
} else {
Expand All @@ -381,9 +381,9 @@ fn check_alpn_for_direct_ssl<IO>(tls_socket: &TlsStream<IO>) -> Result<(), IOErr

pub async fn process_socket<H>(
tcp_socket: TcpStream,
#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))] tls_acceptor: Option<Arc<TlsAcceptor>>,
tls_acceptor: Option<crate::tokio::TlsAcceptor>,
handlers: H,
) -> Result<(), IOError>
) -> Result<(), io::Error>
where
H: PgWireServerHandlers,
{
Expand All @@ -393,10 +393,7 @@ where
let client_info = DefaultClient::new(addr, false);
let mut tcp_socket = Framed::new(tcp_socket, PgWireMessageServerCodec::new(client_info));

#[cfg(any(feature = "_ring", feature = "_aws-lc-rs"))]
let ssl = peek_for_sslrequest(&mut tcp_socket, tls_acceptor.is_some()).await?;
#[cfg(not(any(feature = "_ring", feature = "_aws-lc-rs")))]
let ssl = peek_for_sslrequest(&mut tcp_socket, false).await?;

let startup_handler = handlers.startup_handler();
let simple_query_handler = handlers.simple_query_handler();
Expand Down
2 changes: 1 addition & 1 deletion tests-integration/test-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ pub async fn main() {
let factory = Arc::new(DummyDatabaseFactory(Arc::new(DummyDatabase::default())));

let server_addr = "127.0.0.1:5432";
let tls_acceptor = Arc::new(setup_tls().unwrap());
let tls_acceptor = setup_tls().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
Expand Down

0 comments on commit 5374828

Please sign in to comment.