diff --git a/Cargo.toml b/Cargo.toml index 5b55757f1d..09ff762d11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,18 @@ runtime-tokio-rustls = [ "_rt-tokio", ] +runtime-actix-notls = ["runtime-tokio-notls"] +runtime-async-std-notls = [ + "sqlx-core/runtime-async-std-notls", + "sqlx-macros/runtime-async-std-notls", + "_rt-async-std", +] +runtime-tokio-notls = [ + "sqlx-core/runtime-tokio-notls", + "sqlx-macros/runtime-tokio-notls", + "_rt-tokio", +] + # for conditional compilation _rt-async-std = [] _rt-tokio = [] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 70f93ef4ae..d2a6b413cb 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -93,10 +93,25 @@ runtime-tokio-rustls = [ "_rt-tokio" ] +runtime-actix-notls = ['runtime-tokio-notls'] +runtime-async-std-notls = [ + "sqlx-rt/runtime-async-std-notls", + "sqlx/runtime-async-std-notls", + "_tls-notls", + "_rt-async-std", +] +runtime-tokio-notls = [ + "sqlx-rt/runtime-tokio-notls", + "sqlx/runtime-tokio-notls", + "_tls-notls", + "_rt-tokio" +] + # for conditional compilation _rt-async-std = [] _rt-tokio = ["tokio-stream"] _tls-native-tls = [] +_tls-notls = [] _tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"] # support offline/decoupled building (enables serialization of `Describe`) diff --git a/sqlx-core/src/net/tls/mod.rs b/sqlx-core/src/net/tls/mod.rs index 85e5dda7c1..f73e70a1e1 100644 --- a/sqlx-core/src/net/tls/mod.rs +++ b/sqlx-core/src/net/tls/mod.rs @@ -6,7 +6,9 @@ use std::path::PathBuf; use std::pin::Pin; use std::task::{Context, Poll}; -use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream}; +#[cfg(not(feature = "_tls-notls"))] +use sqlx_rt::TlsStream; +use sqlx_rt::{AsyncRead, AsyncWrite}; use crate::error::Error; use std::mem::replace; @@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput { #[cfg(feature = "_tls-rustls")] mod rustls; +#[cfg(feature = "_tls-notls")] +pub struct MaybeTlsStream(S); +#[cfg(not(feature = "_tls-notls"))] pub enum MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin, @@ -69,11 +74,28 @@ impl MaybeTlsStream where S: AsyncRead + AsyncWrite + Unpin, { + #[cfg(feature = "_tls-notls")] + #[inline] + pub fn is_tls(&self) -> bool { + false + } + #[cfg(not(feature = "_tls-notls"))] #[inline] pub fn is_tls(&self) -> bool { matches!(self, Self::Tls(_)) } + #[cfg(feature = "_tls-notls")] + pub async fn upgrade( + &mut self, + host: &str, + accept_invalid_certs: bool, + accept_invalid_hostnames: bool, + root_cert_path: Option<&CertificateInput>, + ) -> Result<(), Error> { + Ok(()) + } + #[cfg(not(feature = "_tls-notls"))] pub async fn upgrade( &mut self, host: &str, @@ -112,6 +134,24 @@ where } } +#[cfg(feature = "_tls-notls")] +macro_rules! exec_on_stream { + ($stream:ident, $fn_name:ident, $($arg:ident),*) => ( + Pin::new(&mut $stream.0).$fn_name($($arg,)*) + ) +} +#[cfg(not(feature = "_tls-notls"))] +macro_rules! exec_on_stream { + ($stream:ident, $fn_name:ident, $($arg:ident),*) => ( + match &mut *$stream { + MaybeTlsStream::Raw(s) => Pin::new(s).$fn_name($($arg,)*), + MaybeTlsStream::Tls(s) => Pin::new(s).$fn_name($($arg,)*), + + MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), + } + ) +} + #[cfg(feature = "_tls-native-tls")] async fn configure_tls_connector( accept_invalid_certs: bool, @@ -155,12 +195,7 @@ where cx: &mut Context<'_>, buf: &mut super::PollReadBuf<'_>, ) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + exec_on_stream!(self, poll_read, cx, buf) } } @@ -173,41 +208,21 @@ where cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + exec_on_stream!(self, poll_write, cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + exec_on_stream!(self, poll_flush, cx) } #[cfg(feature = "_rt-tokio")] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + exec_on_stream!(self, poll_shutdown, cx) } #[cfg(feature = "_rt-async-std")] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx), - MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx), - - MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())), - } + exec_on_stream!(self, poll_close, cx) } } @@ -218,6 +233,11 @@ where type Target = S; fn deref(&self) -> &Self::Target { + #[cfg(feature = "_tls-notls")] + { + &self.0 + } + #[cfg(not(feature = "_tls-notls"))] match self { MaybeTlsStream::Raw(s) => s, @@ -242,6 +262,11 @@ where S: Unpin + AsyncWrite + AsyncRead, { fn deref_mut(&mut self) -> &mut Self::Target { + #[cfg(feature = "_tls-notls")] + { + &mut self.0 + } + #[cfg(not(feature = "_tls-notls"))] match self { MaybeTlsStream::Raw(s) => s, diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 14ecbc7735..8175ffb70d 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -44,6 +44,18 @@ runtime-tokio-rustls = [ "_rt-tokio", ] +runtime-actix-notls = ["runtime-tokio-notls"] +runtime-async-std-notls = [ + "sqlx-core/runtime-async-std-notls", + "sqlx-rt/runtime-async-std-notls", + "_rt-async-std", +] +runtime-tokio-notls = [ + "sqlx-core/runtime-tokio-notls", + "sqlx-rt/runtime-tokio-notls", + "_rt-tokio", +] + # for conditional compilation _rt-async-std = [] _rt-tokio = [] diff --git a/sqlx-rt/Cargo.toml b/sqlx-rt/Cargo.toml index 5db022a190..d8d1b65c7b 100644 --- a/sqlx-rt/Cargo.toml +++ b/sqlx-rt/Cargo.toml @@ -23,10 +23,15 @@ runtime-actix-rustls = ["runtime-tokio-rustls"] runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "futures-rustls"] runtime-tokio-rustls = ["_rt-tokio", "_tls-rustls", "tokio-rustls"] +runtime-actix-notls = ["runtime-tokio-notls"] +runtime-async-std-notls = ["_rt-async-std", "_tls-notls"] +runtime-tokio-notls = ["_rt-tokio", "_tls-notls"] + # Not used directly and not re-exported from sqlx _rt-async-std = ["async-std"] _rt-tokio = ["tokio", "once_cell"] _tls-native-tls = ["native-tls"] +_tls-notls = [] _tls-rustls = [] [dependencies] diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index a0aac5b8ea..a977cfd6d4 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -7,11 +7,15 @@ feature = "runtime-actix-rustls", feature = "runtime-async-std-rustls", feature = "runtime-tokio-rustls", + feature = "runtime-actix-notls", + feature = "runtime-async-std-notls", + feature = "runtime-tokio-notls", )))] compile_error!( "one of the features ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \ 'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \ - 'runtime-tokio-rustls'] must be enabled" + 'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \ + 'runtime-tokio-notls'] must be enabled" ); #[cfg(any( @@ -19,11 +23,14 @@ compile_error!( all(feature = "_rt-actix", feature = "_rt-tokio"), all(feature = "_rt-async-std", feature = "_rt-tokio"), all(feature = "_tls-native-tls", feature = "_tls-rustls"), + all(feature = "_tls-native-tls", feature = "_tls-notls"), + all(feature = "_tls-rustls", feature = "_tls-notls"), ))] compile_error!( "only one of ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \ 'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \ - 'runtime-tokio-rustls'] can be enabled" + 'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \ + 'runtime-tokio-notls'] can be enabled" ); #[cfg(feature = "_rt-async-std")]