@@ -6,7 +6,9 @@ use std::path::PathBuf;
6
6
use std:: pin:: Pin ;
7
7
use std:: task:: { Context , Poll } ;
8
8
9
- use sqlx_rt:: { AsyncRead , AsyncWrite , TlsStream } ;
9
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
10
+ use sqlx_rt:: TlsStream ;
11
+ use sqlx_rt:: { AsyncRead , AsyncWrite } ;
10
12
11
13
use crate :: error:: Error ;
12
14
use std:: mem:: replace;
@@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput {
56
58
#[ cfg( feature = "_tls-rustls" ) ]
57
59
mod rustls;
58
60
61
+ #[ cfg( feature = "_tls-notls" ) ]
62
+ pub struct MaybeTlsStream < S > ( S ) ;
63
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
59
64
pub enum MaybeTlsStream < S >
60
65
where
61
66
S : AsyncRead + AsyncWrite + Unpin ,
@@ -69,11 +74,28 @@ impl<S> MaybeTlsStream<S>
69
74
where
70
75
S : AsyncRead + AsyncWrite + Unpin ,
71
76
{
77
+ #[ cfg( feature = "_tls-notls" ) ]
78
+ #[ inline]
79
+ pub fn is_tls ( & self ) -> bool {
80
+ false
81
+ }
82
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
72
83
#[ inline]
73
84
pub fn is_tls ( & self ) -> bool {
74
85
matches ! ( self , Self :: Tls ( _) )
75
86
}
76
87
88
+ #[ cfg( feature = "_tls-notls" ) ]
89
+ pub async fn upgrade (
90
+ & mut self ,
91
+ host : & str ,
92
+ accept_invalid_certs : bool ,
93
+ accept_invalid_hostnames : bool ,
94
+ root_cert_path : Option < & CertificateInput > ,
95
+ ) -> Result < ( ) , Error > {
96
+ Ok ( ( ) )
97
+ }
98
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
77
99
pub async fn upgrade (
78
100
& mut self ,
79
101
host : & str ,
@@ -112,6 +134,24 @@ where
112
134
}
113
135
}
114
136
137
+ #[ cfg( feature = "_tls-notls" ) ]
138
+ macro_rules! exec_on_stream {
139
+ ( $stream: ident, $fn_name: ident, $( $arg: ident) ,* ) => (
140
+ Pin :: new( & mut $stream. 0 ) . $fn_name( $( $arg, ) * )
141
+ )
142
+ }
143
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
144
+ macro_rules! exec_on_stream {
145
+ ( $stream: ident, $fn_name: ident, $( $arg: ident) ,* ) => (
146
+ match & mut * $stream {
147
+ MaybeTlsStream :: Raw ( s) => Pin :: new( s) . $fn_name( $( $arg, ) * ) ,
148
+ MaybeTlsStream :: Tls ( s) => Pin :: new( s) . $fn_name( $( $arg, ) * ) ,
149
+
150
+ MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into( ) ) ) ,
151
+ }
152
+ )
153
+ }
154
+
115
155
#[ cfg( feature = "_tls-native-tls" ) ]
116
156
async fn configure_tls_connector (
117
157
accept_invalid_certs : bool ,
@@ -155,12 +195,7 @@ where
155
195
cx : & mut Context < ' _ > ,
156
196
buf : & mut super :: PollReadBuf < ' _ > ,
157
197
) -> Poll < io:: Result < super :: PollReadOut > > {
158
- match & mut * self {
159
- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
160
- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
161
-
162
- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
163
- }
198
+ exec_on_stream ! ( self , poll_read, cx, buf)
164
199
}
165
200
}
166
201
@@ -173,41 +208,21 @@ where
173
208
cx : & mut Context < ' _ > ,
174
209
buf : & [ u8 ] ,
175
210
) -> Poll < io:: Result < usize > > {
176
- match & mut * self {
177
- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
178
- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
179
-
180
- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
181
- }
211
+ exec_on_stream ! ( self , poll_write, cx, buf)
182
212
}
183
213
184
214
fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
185
- match & mut * self {
186
- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_flush ( cx) ,
187
- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_flush ( cx) ,
188
-
189
- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
190
- }
215
+ exec_on_stream ! ( self , poll_flush, cx)
191
216
}
192
217
193
218
#[ cfg( feature = "_rt-tokio" ) ]
194
219
fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
195
- match & mut * self {
196
- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
197
- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
198
-
199
- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
200
- }
220
+ exec_on_stream ! ( self , poll_shutdown, cx)
201
221
}
202
222
203
223
#[ cfg( feature = "_rt-async-std" ) ]
204
224
fn poll_close ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
205
- match & mut * self {
206
- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_close ( cx) ,
207
- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_close ( cx) ,
208
-
209
- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
210
- }
225
+ exec_on_stream ! ( self , poll_close, cx)
211
226
}
212
227
}
213
228
@@ -218,6 +233,11 @@ where
218
233
type Target = S ;
219
234
220
235
fn deref ( & self ) -> & Self :: Target {
236
+ #[ cfg( feature = "_tls-notls" ) ]
237
+ {
238
+ & self . 0
239
+ }
240
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
221
241
match self {
222
242
MaybeTlsStream :: Raw ( s) => s,
223
243
@@ -242,6 +262,11 @@ where
242
262
S : Unpin + AsyncWrite + AsyncRead ,
243
263
{
244
264
fn deref_mut ( & mut self ) -> & mut Self :: Target {
265
+ #[ cfg( feature = "_tls-notls" ) ]
266
+ {
267
+ & mut self . 0
268
+ }
269
+ #[ cfg( not( feature = "_tls-notls" ) ) ]
245
270
match self {
246
271
MaybeTlsStream :: Raw ( s) => s,
247
272
0 commit comments