Skip to content

Commit 87fc7ef

Browse files
committed
Allow the server_name_resolver to be overriden
1 parent 2fde244 commit 87fc7ef

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

src/async_impl/client.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ struct Config {
185185
tls_built_in_certs_native: bool,
186186
#[cfg(feature = "__rustls")]
187187
crls: Vec<CertificateRevocationList>,
188+
#[cfg(feature = "__rustls")]
189+
server_name_resolver: Arc<dyn hyper_rustls::ResolveServerName + Send + Sync>,
188190
#[cfg(feature = "__tls")]
189191
min_tls_version: Option<tls::Version>,
190192
#[cfg(feature = "__tls")]
@@ -308,6 +310,8 @@ impl ClientBuilder {
308310
identity: None,
309311
#[cfg(feature = "__rustls")]
310312
crls: vec![],
313+
#[cfg(feature = "__rustls")]
314+
server_name_resolver: Arc::new(hyper_rustls::DefaultServerNameResolver::default()),
311315
#[cfg(feature = "__tls")]
312316
min_tls_version: None,
313317
#[cfg(feature = "__tls")]
@@ -657,6 +661,7 @@ impl ClientBuilder {
657661
config.interface.as_deref(),
658662
config.nodelay,
659663
config.tls_info,
664+
config.server_name_resolver,
660665
)
661666
}
662667
#[cfg(feature = "__rustls")]
@@ -862,6 +867,7 @@ impl ClientBuilder {
862867
config.interface.as_deref(),
863868
config.nodelay,
864869
config.tls_info,
870+
config.server_name_resolver,
865871
)
866872
}
867873
#[cfg(any(feature = "native-tls", feature = "__rustls",))]
@@ -1691,6 +1697,16 @@ impl ClientBuilder {
16911697
self
16921698
}
16931699

1700+
/// Sets the server name resolver. Defaults to `hyper_rustls::DefaultServerNameResolver`
1701+
///
1702+
/// This requires the `rustls-tls(-...)` Cargo feature enabled.
1703+
#[cfg(feature = "__rustls")]
1704+
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))]
1705+
pub fn with_server_name_resolver(mut self, server_name_resolver: Arc<dyn hyper_rustls::ResolveServerName + Send + Sync>) -> ClientBuilder {
1706+
self.config.server_name_resolver = server_name_resolver;
1707+
self
1708+
}
1709+
16941710
// TLS options
16951711

16961712
/// Add a custom root certificate.

src/connect.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ where {
330330
interface: Option<&str>,
331331
nodelay: bool,
332332
tls_info: bool,
333+
server_name_resolver: Arc<dyn hyper_rustls::ResolveServerName + Send + Sync>,
333334
) -> ConnectorBuilder
334335
where
335336
T: Into<Option<IpAddr>>,
@@ -367,6 +368,7 @@ where {
367368
http,
368369
tls,
369370
tls_proxy,
371+
server_name_resolver,
370372
},
371373
proxies,
372374
verbose: verbose::OFF,
@@ -458,6 +460,7 @@ enum Inner {
458460
http: HttpConnector,
459461
tls: Arc<rustls::ClientConfig>,
460462
tls_proxy: Arc<rustls::ClientConfig>,
463+
server_name_resolver: Arc<dyn hyper_rustls::ResolveServerName + Send + Sync>,
461464
},
462465
}
463466

@@ -579,7 +582,7 @@ impl ConnectorService {
579582
}
580583
}
581584
#[cfg(feature = "__rustls")]
582-
Inner::RustlsTls { http, tls, .. } => {
585+
Inner::RustlsTls { http, tls, server_name_resolver, .. } => {
583586
let mut http = http.clone();
584587

585588
// Disable Nagle's algorithm for TLS handshake
@@ -589,7 +592,7 @@ impl ConnectorService {
589592
http.set_nodelay(true);
590593
}
591594

592-
let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
595+
let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone(), server_name_resolver));
593596
let io = http.call(dst).await?;
594597

595598
if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
@@ -675,10 +678,9 @@ impl ConnectorService {
675678
http,
676679
tls,
677680
tls_proxy,
681+
server_name_resolver: name_resolver,
678682
} => {
679683
if dst.scheme() == Some(&Scheme::HTTPS) {
680-
use rustls_pki_types::ServerName;
681-
use std::convert::TryFrom;
682684
use tokio_rustls::TlsConnector as RustlsConnector;
683685

684686
log::trace!("tunneling HTTPS over proxy");
@@ -701,9 +703,8 @@ impl ConnectorService {
701703
// We don't wrap this again in an HttpsConnector since that uses Maybe,
702704
// and we know this is definitely HTTPS.
703705
let tunneled = tunnel.call(dst.clone()).await?;
704-
let host = dst.host().ok_or("no host in url")?.to_string();
705-
let server_name = ServerName::try_from(host.as_str().to_owned())
706-
.map_err(|_| "Invalid Server Name")?;
706+
707+
let server_name = name_resolver.resolve(&dst)?;
707708
let io = RustlsConnector::from(tls.clone())
708709
.connect(server_name, TokioIo::new(tunneled))
709710
.await?;

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ pub use self::error::{Error, Result};
287287
pub use self::into_url::IntoUrl;
288288
pub use self::response::ResponseBuilderExt;
289289

290+
#[cfg(feature = "__rustls")]
291+
#[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))]
292+
pub use hyper_rustls::ResolveServerName;
293+
290294
/// Shortcut method to quickly make a `GET` request.
291295
///
292296
/// See also the methods on the [`reqwest::Response`](./struct.Response.html)

0 commit comments

Comments
 (0)