@@ -6,7 +6,9 @@ use std::path::PathBuf;
66use  std:: pin:: Pin ; 
77use  std:: task:: { Context ,  Poll } ; 
88
9- use  sqlx_rt:: { AsyncRead ,  AsyncWrite ,  TlsStream } ; 
9+ #[ cfg( not( feature = "_tls-notls" ) ) ]  
10+ use  sqlx_rt:: TlsStream ; 
11+ use  sqlx_rt:: { AsyncRead ,  AsyncWrite } ; 
1012
1113use  crate :: error:: Error ; 
1214use  std:: mem:: replace; 
@@ -56,6 +58,9 @@ impl std::fmt::Display for CertificateInput {
5658#[ cfg( feature = "_tls-rustls" ) ]  
5759mod  rustls; 
5860
61+ #[ cfg( feature = "_tls-notls" ) ]  
62+ pub  struct  MaybeTlsStream < S > ( S ) ; 
63+ #[ cfg( not( feature = "_tls-notls" ) ) ]  
5964pub  enum  MaybeTlsStream < S > 
6065where 
6166    S :  AsyncRead  + AsyncWrite  + Unpin , 
@@ -69,11 +74,28 @@ impl<S> MaybeTlsStream<S>
6974where 
7075    S :  AsyncRead  + AsyncWrite  + Unpin , 
7176{ 
77+     #[ cfg( feature = "_tls-notls" ) ]  
78+     #[ inline]  
79+     pub  fn  is_tls ( & self )  -> bool  { 
80+         false 
81+     } 
82+     #[ cfg( not( feature = "_tls-notls" ) ) ]  
7283    #[ inline]  
7384    pub  fn  is_tls ( & self )  -> bool  { 
7485        matches ! ( self ,  Self :: Tls ( _) ) 
7586    } 
7687
88+     #[ cfg( feature = "_tls-notls" ) ]  
89+     pub  async  fn  upgrade ( 
90+         & mut  self , 
91+         host :  & str , 
92+         accept_invalid_certs :  bool , 
93+         accept_invalid_hostnames :  bool , 
94+         root_cert_path :  Option < & CertificateInput > , 
95+     )  -> Result < ( ) ,  Error >  { 
96+         Ok ( ( ) ) 
97+     } 
98+     #[ cfg( not( feature = "_tls-notls" ) ) ]  
7799    pub  async  fn  upgrade ( 
78100        & mut  self , 
79101        host :  & str , 
@@ -112,6 +134,24 @@ where
112134    } 
113135} 
114136
137+ #[ cfg( feature = "_tls-notls" ) ]  
138+ macro_rules!  exec_on_stream { 
139+     ( $stream: ident,  $fn_name: ident,  $( $arg: ident) ,* )  => ( 
140+         Pin :: new( & mut  $stream. 0 ) . $fn_name( $( $arg, ) * ) 
141+     ) 
142+ } 
143+ #[ cfg( not( feature = "_tls-notls" ) ) ]  
144+ macro_rules!  exec_on_stream { 
145+     ( $stream: ident,  $fn_name: ident,  $( $arg: ident) ,* )  => ( 
146+         match  & mut  * $stream { 
147+             MaybeTlsStream :: Raw ( s)  => Pin :: new( s) . $fn_name( $( $arg, ) * ) , 
148+             MaybeTlsStream :: Tls ( s)  => Pin :: new( s) . $fn_name( $( $arg, ) * ) , 
149+ 
150+             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into( ) ) ) , 
151+         } 
152+     ) 
153+ } 
154+ 
115155#[ cfg( feature = "_tls-native-tls" ) ]  
116156async  fn  configure_tls_connector ( 
117157    accept_invalid_certs :  bool , 
@@ -155,12 +195,7 @@ where
155195        cx :  & mut  Context < ' _ > , 
156196        buf :  & mut  super :: PollReadBuf < ' _ > , 
157197    )  -> Poll < io:: Result < super :: PollReadOut > >  { 
158-         match  & mut  * self  { 
159-             MaybeTlsStream :: Raw ( s)  => Pin :: new ( s) . poll_read ( cx,  buf) , 
160-             MaybeTlsStream :: Tls ( s)  => Pin :: new ( s) . poll_read ( cx,  buf) , 
161- 
162-             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) , 
163-         } 
198+         exec_on_stream ! ( self ,  poll_read,  cx,  buf) 
164199    } 
165200} 
166201
@@ -173,41 +208,21 @@ where
173208        cx :  & mut  Context < ' _ > , 
174209        buf :  & [ u8 ] , 
175210    )  -> Poll < io:: Result < usize > >  { 
176-         match  & mut  * self  { 
177-             MaybeTlsStream :: Raw ( s)  => Pin :: new ( s) . poll_write ( cx,  buf) , 
178-             MaybeTlsStream :: Tls ( s)  => Pin :: new ( s) . poll_write ( cx,  buf) , 
179- 
180-             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) , 
181-         } 
211+         exec_on_stream ! ( self ,  poll_write,  cx,  buf) 
182212    } 
183213
184214    fn  poll_flush ( mut  self :  Pin < & mut  Self > ,  cx :  & mut  Context < ' _ > )  -> Poll < io:: Result < ( ) > >  { 
185-         match  & mut  * self  { 
186-             MaybeTlsStream :: Raw ( s)  => Pin :: new ( s) . poll_flush ( cx) , 
187-             MaybeTlsStream :: Tls ( s)  => Pin :: new ( s) . poll_flush ( cx) , 
188- 
189-             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) , 
190-         } 
215+         exec_on_stream ! ( self ,  poll_flush,  cx) 
191216    } 
192217
193218    #[ cfg( feature = "_rt-tokio" ) ]  
194219    fn  poll_shutdown ( mut  self :  Pin < & mut  Self > ,  cx :  & mut  Context < ' _ > )  -> Poll < io:: Result < ( ) > >  { 
195-         match  & mut  * self  { 
196-             MaybeTlsStream :: Raw ( s)  => Pin :: new ( s) . poll_shutdown ( cx) , 
197-             MaybeTlsStream :: Tls ( s)  => Pin :: new ( s) . poll_shutdown ( cx) , 
198- 
199-             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) , 
200-         } 
220+         exec_on_stream ! ( self ,  poll_shutdown,  cx) 
201221    } 
202222
203223    #[ cfg( feature = "_rt-async-std" ) ]  
204224    fn  poll_close ( mut  self :  Pin < & mut  Self > ,  cx :  & mut  Context < ' _ > )  -> Poll < io:: Result < ( ) > >  { 
205-         match  & mut  * self  { 
206-             MaybeTlsStream :: Raw ( s)  => Pin :: new ( s) . poll_close ( cx) , 
207-             MaybeTlsStream :: Tls ( s)  => Pin :: new ( s) . poll_close ( cx) , 
208- 
209-             MaybeTlsStream :: Upgrading  => Poll :: Ready ( Err ( io:: ErrorKind :: ConnectionAborted . into ( ) ) ) , 
210-         } 
225+         exec_on_stream ! ( self ,  poll_close,  cx) 
211226    } 
212227} 
213228
@@ -218,6 +233,11 @@ where
218233    type  Target  = S ; 
219234
220235    fn  deref ( & self )  -> & Self :: Target  { 
236+         #[ cfg( feature = "_tls-notls" ) ]  
237+         { 
238+             & self . 0 
239+         } 
240+         #[ cfg( not( feature = "_tls-notls" ) ) ]  
221241        match  self  { 
222242            MaybeTlsStream :: Raw ( s)  => s, 
223243
@@ -242,6 +262,11 @@ where
242262    S :  Unpin  + AsyncWrite  + AsyncRead , 
243263{ 
244264    fn  deref_mut ( & mut  self )  -> & mut  Self :: Target  { 
265+         #[ cfg( feature = "_tls-notls" ) ]  
266+         { 
267+             & mut  self . 0 
268+         } 
269+         #[ cfg( not( feature = "_tls-notls" ) ) ]  
245270        match  self  { 
246271            MaybeTlsStream :: Raw ( s)  => s, 
247272
0 commit comments