@@ -6,7 +6,9 @@ use std::path::PathBuf;
66use std:: pin:: Pin ;
77use std:: task:: { Context , Poll } ;
88
9- use sqlx_rt:: { AsyncRead , AsyncWrite , TlsStream } ;
9+ #[ cfg( not( feature = "_tls-notls" ) ) ]
10+ use sqlx_rt:: TlsStream ;
11+ use sqlx_rt:: { AsyncRead , AsyncWrite } ;
1012
1113use crate :: error:: Error ;
1214use std:: mem:: replace;
@@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput {
5658#[ cfg( feature = "_tls-rustls" ) ]
5759mod rustls;
5860
61+ #[ cfg( feature = "_tls-notls" ) ]
62+ pub struct MaybeTlsStream < S > ( S ) ;
63+ #[ cfg( not( feature = "_tls-notls" ) ) ]
5964pub enum MaybeTlsStream < S >
6065where
6166 S : AsyncRead + AsyncWrite + Unpin ,
@@ -69,11 +74,28 @@ impl<S> MaybeTlsStream<S>
6974where
7075 S : AsyncRead + AsyncWrite + Unpin ,
7176{
77+ #[ cfg( feature = "_tls-notls" ) ]
78+ #[ inline]
79+ pub fn is_tls ( & self ) -> bool {
80+ false
81+ }
82+ #[ cfg( not( feature = "_tls-notls" ) ) ]
7283 #[ inline]
7384 pub fn is_tls ( & self ) -> bool {
7485 matches ! ( self , Self :: Tls ( _) )
7586 }
7687
88+ #[ cfg( feature = "_tls-notls" ) ]
89+ pub async fn upgrade (
90+ & mut self ,
91+ host : & str ,
92+ accept_invalid_certs : bool ,
93+ accept_invalid_hostnames : bool ,
94+ root_cert_path : Option < & CertificateInput > ,
95+ ) -> Result < ( ) , Error > {
96+ Ok ( ( ) )
97+ }
98+ #[ cfg( not( feature = "_tls-notls" ) ) ]
7799 pub async fn upgrade (
78100 & mut self ,
79101 host : & str ,
@@ -112,6 +134,24 @@ where
112134 }
113135}
114136
137+ #[ cfg( feature = "_tls-notls" ) ]
138+ macro_rules! exec_on_stream {
139+ ( $stream: ident, $fn_name: ident, $( $arg: ident) ,* ) => (
140+ Pin :: new( & mut $stream. 0 ) . $fn_name( $( $arg, ) * )
141+ )
142+ }
143+ #[ cfg( not( feature = "_tls-notls" ) ) ]
144+ macro_rules! exec_on_stream {
145+ ( $stream: ident, $fn_name: ident, $( $arg: ident) ,* ) => (
146+ match & mut * $stream {
147+ MaybeTlsStream :: Raw ( s) => Pin :: new( s) . $fn_name( $( $arg, ) * ) ,
148+ MaybeTlsStream :: Tls ( s) => Pin :: new( s) . $fn_name( $( $arg, ) * ) ,
149+
150+ MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into( ) ) ) ,
151+ }
152+ )
153+ }
154+
115155#[ cfg( feature = "_tls-native-tls" ) ]
116156async fn configure_tls_connector (
117157 accept_invalid_certs : bool ,
@@ -155,12 +195,7 @@ where
155195 cx : & mut Context < ' _ > ,
156196 buf : & mut super :: PollReadBuf < ' _ > ,
157197 ) -> Poll < io:: Result < super :: PollReadOut > > {
158- match & mut * self {
159- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
160- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_read ( cx, buf) ,
161-
162- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
163- }
198+ exec_on_stream ! ( self , poll_read, cx, buf)
164199 }
165200}
166201
@@ -173,41 +208,21 @@ where
173208 cx : & mut Context < ' _ > ,
174209 buf : & [ u8 ] ,
175210 ) -> Poll < io:: Result < usize > > {
176- match & mut * self {
177- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
178- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_write ( cx, buf) ,
179-
180- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
181- }
211+ exec_on_stream ! ( self , poll_write, cx, buf)
182212 }
183213
184214 fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
185- match & mut * self {
186- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_flush ( cx) ,
187- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_flush ( cx) ,
188-
189- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
190- }
215+ exec_on_stream ! ( self , poll_flush, cx)
191216 }
192217
193218 #[ cfg( feature = "_rt-tokio" ) ]
194219 fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
195- match & mut * self {
196- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
197- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_shutdown ( cx) ,
198-
199- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
200- }
220+ exec_on_stream ! ( self , poll_shutdown, cx)
201221 }
202222
203223 #[ cfg( feature = "_rt-async-std" ) ]
204224 fn poll_close ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
205- match & mut * self {
206- MaybeTlsStream :: Raw ( s) => Pin :: new ( s) . poll_close ( cx) ,
207- MaybeTlsStream :: Tls ( s) => Pin :: new ( s) . poll_close ( cx) ,
208-
209- MaybeTlsStream :: Upgrading => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) ,
210- }
225+ exec_on_stream ! ( self , poll_close, cx)
211226 }
212227}
213228
@@ -218,6 +233,11 @@ where
218233 type Target = S ;
219234
220235 fn deref ( & self ) -> & Self :: Target {
236+ #[ cfg( feature = "_tls-notls" ) ]
237+ {
238+ & self . 0
239+ }
240+ #[ cfg( not( feature = "_tls-notls" ) ) ]
221241 match self {
222242 MaybeTlsStream :: Raw ( s) => s,
223243
@@ -242,6 +262,11 @@ where
242262 S : Unpin + AsyncWrite + AsyncRead ,
243263{
244264 fn deref_mut ( & mut self ) -> & mut Self :: Target {
265+ #[ cfg( feature = "_tls-notls" ) ]
266+ {
267+ & mut self . 0
268+ }
269+ #[ cfg( not( feature = "_tls-notls" ) ) ]
245270 match self {
246271 MaybeTlsStream :: Raw ( s) => s,
247272
0 commit comments