Skip to content

Commit 83f322e

Browse files
committed
Add new features runtime-{runtime}-notls to avoid tls dependency
1 parent 76ae286 commit 83f322e

File tree

6 files changed

+109
-33
lines changed

6 files changed

+109
-33
lines changed

Cargo.toml

+12
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ runtime-tokio-rustls = [
100100
"_rt-tokio",
101101
]
102102

103+
runtime-actix-notls = ["runtime-tokio-notls"]
104+
runtime-async-std-notls = [
105+
"sqlx-core/runtime-async-std-notls",
106+
"sqlx-macros/runtime-async-std-notls",
107+
"_rt-async-std",
108+
]
109+
runtime-tokio-notls = [
110+
"sqlx-core/runtime-tokio-notls",
111+
"sqlx-macros/runtime-tokio-notls",
112+
"_rt-tokio",
113+
]
114+
103115
# for conditional compilation
104116
_rt-async-std = []
105117
_rt-tokio = []

sqlx-core/Cargo.toml

+15
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,25 @@ runtime-tokio-rustls = [
9393
"_rt-tokio"
9494
]
9595

96+
runtime-actix-notls = ['runtime-tokio-notls']
97+
runtime-async-std-notls = [
98+
"sqlx-rt/runtime-async-std-notls",
99+
"sqlx/runtime-async-std-notls",
100+
"_tls-notls",
101+
"_rt-async-std",
102+
]
103+
runtime-tokio-notls = [
104+
"sqlx-rt/runtime-tokio-notls",
105+
"sqlx/runtime-tokio-notls",
106+
"_tls-notls",
107+
"_rt-tokio"
108+
]
109+
96110
# for conditional compilation
97111
_rt-async-std = []
98112
_rt-tokio = ["tokio-stream"]
99113
_tls-native-tls = []
114+
_tls-notls = []
100115
_tls-rustls = ["rustls", "rustls-pemfile", "webpki-roots"]
101116

102117
# support offline/decoupled building (enables serialization of `Describe`)

sqlx-core/src/net/tls/mod.rs

+56-31
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ use std::path::PathBuf;
66
use std::pin::Pin;
77
use std::task::{Context, Poll};
88

9-
use sqlx_rt::{AsyncRead, AsyncWrite, TlsStream};
9+
#[cfg(not(feature = "_tls-notls"))]
10+
use sqlx_rt::TlsStream;
11+
use sqlx_rt::{AsyncRead, AsyncWrite};
1012

1113
use crate::error::Error;
1214
use std::mem::replace;
@@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput {
5658
#[cfg(feature = "_tls-rustls")]
5759
mod rustls;
5860

61+
#[cfg(feature = "_tls-notls")]
62+
pub struct MaybeTlsStream<S>(S);
63+
#[cfg(not(feature = "_tls-notls"))]
5964
pub enum MaybeTlsStream<S>
6065
where
6166
S: AsyncRead + AsyncWrite + Unpin,
@@ -69,11 +74,28 @@ impl<S> MaybeTlsStream<S>
6974
where
7075
S: AsyncRead + AsyncWrite + Unpin,
7176
{
77+
#[cfg(feature = "_tls-notls")]
78+
#[inline]
79+
pub fn is_tls(&self) -> bool {
80+
false
81+
}
82+
#[cfg(not(feature = "_tls-notls"))]
7283
#[inline]
7384
pub fn is_tls(&self) -> bool {
7485
matches!(self, Self::Tls(_))
7586
}
7687

88+
#[cfg(feature = "_tls-notls")]
89+
pub async fn upgrade(
90+
&mut self,
91+
host: &str,
92+
accept_invalid_certs: bool,
93+
accept_invalid_hostnames: bool,
94+
root_cert_path: Option<&CertificateInput>,
95+
) -> Result<(), Error> {
96+
Ok(())
97+
}
98+
#[cfg(not(feature = "_tls-notls"))]
7799
pub async fn upgrade(
78100
&mut self,
79101
host: &str,
@@ -112,6 +134,24 @@ where
112134
}
113135
}
114136

137+
#[cfg(feature = "_tls-notls")]
138+
macro_rules! exec_on_stream {
139+
($stream:ident, $fn_name:ident, $($arg:ident),*) => (
140+
Pin::new(&mut $stream.0).$fn_name($($arg,)*)
141+
)
142+
}
143+
#[cfg(not(feature = "_tls-notls"))]
144+
macro_rules! exec_on_stream {
145+
($stream:ident, $fn_name:ident, $($arg:ident),*) => (
146+
match &mut *$stream {
147+
MaybeTlsStream::Raw(s) => Pin::new(s).$fn_name($($arg,)*),
148+
MaybeTlsStream::Tls(s) => Pin::new(s).$fn_name($($arg,)*),
149+
150+
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
151+
}
152+
)
153+
}
154+
115155
#[cfg(feature = "_tls-native-tls")]
116156
async fn configure_tls_connector(
117157
accept_invalid_certs: bool,
@@ -155,12 +195,7 @@ where
155195
cx: &mut Context<'_>,
156196
buf: &mut super::PollReadBuf<'_>,
157197
) -> Poll<io::Result<super::PollReadOut>> {
158-
match &mut *self {
159-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
160-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
161-
162-
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
163-
}
198+
exec_on_stream!(self, poll_read, cx, buf)
164199
}
165200
}
166201

@@ -173,41 +208,21 @@ where
173208
cx: &mut Context<'_>,
174209
buf: &[u8],
175210
) -> Poll<io::Result<usize>> {
176-
match &mut *self {
177-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
178-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
179-
180-
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
181-
}
211+
exec_on_stream!(self, poll_write, cx, buf)
182212
}
183213

184214
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185-
match &mut *self {
186-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
187-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
188-
189-
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
190-
}
215+
exec_on_stream!(self, poll_flush, cx)
191216
}
192217

193218
#[cfg(feature = "_rt-tokio")]
194219
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
195-
match &mut *self {
196-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
197-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
198-
199-
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
200-
}
220+
exec_on_stream!(self, poll_shutdown, cx)
201221
}
202222

203223
#[cfg(feature = "_rt-async-std")]
204224
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
205-
match &mut *self {
206-
MaybeTlsStream::Raw(s) => Pin::new(s).poll_close(cx),
207-
MaybeTlsStream::Tls(s) => Pin::new(s).poll_close(cx),
208-
209-
MaybeTlsStream::Upgrading => Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into())),
210-
}
225+
exec_on_stream!(self, poll_close, cx)
211226
}
212227
}
213228

@@ -218,6 +233,11 @@ where
218233
type Target = S;
219234

220235
fn deref(&self) -> &Self::Target {
236+
#[cfg(feature = "_tls-notls")]
237+
{
238+
&self.0
239+
}
240+
#[cfg(not(feature = "_tls-notls"))]
221241
match self {
222242
MaybeTlsStream::Raw(s) => s,
223243

@@ -242,6 +262,11 @@ where
242262
S: Unpin + AsyncWrite + AsyncRead,
243263
{
244264
fn deref_mut(&mut self) -> &mut Self::Target {
265+
#[cfg(feature = "_tls-notls")]
266+
{
267+
&mut self.0
268+
}
269+
#[cfg(not(feature = "_tls-notls"))]
245270
match self {
246271
MaybeTlsStream::Raw(s) => s,
247272

sqlx-macros/Cargo.toml

+12
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ runtime-tokio-rustls = [
4444
"_rt-tokio",
4545
]
4646

47+
runtime-actix-notls = ["runtime-tokio-notls"]
48+
runtime-async-std-notls = [
49+
"sqlx-core/runtime-async-std-notls",
50+
"sqlx-rt/runtime-async-std-notls",
51+
"_rt-async-std",
52+
]
53+
runtime-tokio-notls = [
54+
"sqlx-core/runtime-tokio-notls",
55+
"sqlx-rt/runtime-tokio-notls",
56+
"_rt-tokio",
57+
]
58+
4759
# for conditional compilation
4860
_rt-async-std = []
4961
_rt-tokio = []

sqlx-rt/Cargo.toml

+5
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@ runtime-actix-rustls = ["runtime-tokio-rustls"]
2323
runtime-async-std-rustls = ["_rt-async-std", "_tls-rustls", "futures-rustls"]
2424
runtime-tokio-rustls = ["_rt-tokio", "_tls-rustls", "tokio-rustls"]
2525

26+
runtime-actix-notls = ["runtime-tokio-notls"]
27+
runtime-async-std-notls = ["_rt-async-std", "_tls-notls"]
28+
runtime-tokio-notls = ["_rt-tokio", "_tls-notls"]
29+
2630
# Not used directly and not re-exported from sqlx
2731
_rt-async-std = ["async-std"]
2832
_rt-tokio = ["tokio", "once_cell"]
2933
_tls-native-tls = ["native-tls"]
34+
_tls-notls = []
3035
_tls-rustls = []
3136

3237
[dependencies]

sqlx-rt/src/lib.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,30 @@
77
feature = "runtime-actix-rustls",
88
feature = "runtime-async-std-rustls",
99
feature = "runtime-tokio-rustls",
10+
feature = "runtime-actix-notls",
11+
feature = "runtime-async-std-notls",
12+
feature = "runtime-tokio-notls",
1013
)))]
1114
compile_error!(
1215
"one of the features ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
1316
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
14-
'runtime-tokio-rustls'] must be enabled"
17+
'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \
18+
'runtime-tokio-notls'] must be enabled"
1519
);
1620

1721
#[cfg(any(
1822
all(feature = "_rt-actix", feature = "_rt-async-std"),
1923
all(feature = "_rt-actix", feature = "_rt-tokio"),
2024
all(feature = "_rt-async-std", feature = "_rt-tokio"),
2125
all(feature = "_tls-native-tls", feature = "_tls-rustls"),
26+
all(feature = "_tls-native-tls", feature = "_tls-notls"),
27+
all(feature = "_tls-rustls", feature = "_tls-notls"),
2228
))]
2329
compile_error!(
2430
"only one of ['runtime-actix-native-tls', 'runtime-async-std-native-tls', \
2531
'runtime-tokio-native-tls', 'runtime-actix-rustls', 'runtime-async-std-rustls', \
26-
'runtime-tokio-rustls'] can be enabled"
32+
'runtime-tokio-rustls', 'runtime-actix-notls', 'runtime-async-std-notls', \
33+
'runtime-tokio-notls'] can be enabled"
2734
);
2835

2936
#[cfg(feature = "_rt-async-std")]

0 commit comments

Comments
 (0)