From 25f51b34836f47cbbc240cb713354be23874070c Mon Sep 17 00:00:00 2001
From: Daniel Schneider <daniel.schneider@eramux.com>
Date: Tue, 14 Jan 2025 20:21:19 +0100
Subject: [PATCH 1/4] feat: impl OptionalFromRequestParts for Host extractor

---
 axum-extra/src/extract/host.rs | 62 ++++++++++++++++++++++++++++++++--
 1 file changed, 59 insertions(+), 3 deletions(-)

diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs
index a6828d3004..d1b4f78b25 100644
--- a/axum-extra/src/extract/host.rs
+++ b/axum-extra/src/extract/host.rs
@@ -1,10 +1,11 @@
 use super::rejection::{FailedToResolveHost, HostRejection};
-use axum::extract::FromRequestParts;
+use axum::extract::{FromRequestParts, OptionalFromRequestParts};
 use http::{
     header::{HeaderMap, FORWARDED},
     request::Parts,
     uri::Authority,
 };
+use std::convert::Infallible;
 
 const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
 
@@ -59,6 +60,24 @@ where
     }
 }
 
+impl<S> OptionalFromRequestParts<S> for Host
+where
+    S: Send + Sync,
+{
+    type Rejection = Infallible;
+
+    async fn from_request_parts(
+        parts: &mut Parts,
+        _state: &S,
+    ) -> Result<Option<Self>, Self::Rejection> {
+        Ok(
+            <Self as FromRequestParts<S>>::from_request_parts(parts, _state)
+                .await
+                .ok(),
+        )
+    }
+}
+
 #[allow(warnings)]
 fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
     // if there are multiple `Forwarded` `HeaderMap::get` will return the first one
@@ -148,7 +167,10 @@ mod tests {
     async fn ip4_uri_host() {
         let mut parts = Request::new(()).into_parts().0;
         parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
-        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
+        let host =
+            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
+                .await
+                .unwrap();
         assert_eq!(host.0, "127.0.0.1:1234");
     }
 
@@ -156,10 +178,44 @@ mod tests {
     async fn ip6_uri_host() {
         let mut parts = Request::new(()).into_parts().0;
         parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
-        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
+        let host =
+            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
+                .await
+                .unwrap();
         assert_eq!(host.0, "[::1]:456");
     }
 
+    #[crate::test]
+    async fn missing_host() {
+        let mut parts = Request::new(()).into_parts().0;
+        let host =
+            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
+                .await
+                .unwrap_err();
+        assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
+    }
+
+    #[crate::test]
+    async fn optional_extractor() {
+        let mut parts = Request::new(()).into_parts().0;
+        parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
+        let host = Option::<Host>::from_request_parts(&mut parts, &())
+            .await
+            .unwrap();
+
+        assert!(matches!(host, Some(Host(_))));
+    }
+
+    #[crate::test]
+    async fn optional_extractor_none() {
+        let mut parts = Request::new(()).into_parts().0;
+        let host = Option::<Host>::from_request_parts(&mut parts, &())
+            .await
+            .unwrap();
+
+        assert!(matches!(host, None));
+    }
+
     #[test]
     fn forwarded_parsing() {
         // the basic case

From 3b7a812cc311b6c8511629a96089cbed1f032c5b Mon Sep 17 00:00:00 2001
From: Daniel Schneider <daniel.schneider@eramux.com>
Date: Tue, 14 Jan 2025 20:35:24 +0100
Subject: [PATCH 2/4] fix: linting issues

---
 axum-extra/src/extract/host.rs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs
index d1b4f78b25..366a189299 100644
--- a/axum-extra/src/extract/host.rs
+++ b/axum-extra/src/extract/host.rs
@@ -203,7 +203,7 @@ mod tests {
             .await
             .unwrap();
 
-        assert!(matches!(host, Some(Host(_))));
+        assert!(host.is_some());
     }
 
     #[crate::test]
@@ -213,7 +213,7 @@ mod tests {
             .await
             .unwrap();
 
-        assert!(matches!(host, None));
+        assert!(host.is_none());
     }
 
     #[test]

From c3c1735faef6062359dedff0023d5d723eddd297 Mon Sep 17 00:00:00 2001
From: Daniel Schneider <daniel.schneider@eramux.com>
Date: Wed, 15 Jan 2025 08:06:46 +0100
Subject: [PATCH 3/4] feat: implement feedback

---
 axum-extra/src/extract/host.rs | 53 +++++++++++++++++++---------------
 1 file changed, 29 insertions(+), 24 deletions(-)

diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs
index 366a189299..bc9c19e508 100644
--- a/axum-extra/src/extract/host.rs
+++ b/axum-extra/src/extract/host.rs
@@ -1,5 +1,8 @@
 use super::rejection::{FailedToResolveHost, HostRejection};
-use axum::extract::{FromRequestParts, OptionalFromRequestParts};
+use axum::{
+    extract::{FromRequestParts, OptionalFromRequestParts},
+    RequestPartsExt,
+};
 use http::{
     header::{HeaderMap, FORWARDED},
     request::Parts,
@@ -32,8 +35,27 @@ where
     type Rejection = HostRejection;
 
     async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
+        parts
+            .extract::<Option<Host>>()
+            .await
+            .ok()
+            .flatten()
+            .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost))
+    }
+}
+
+impl<S> OptionalFromRequestParts<S> for Host
+where
+    S: Send + Sync,
+{
+    type Rejection = Infallible;
+
+    async fn from_request_parts(
+        parts: &mut Parts,
+        _state: &S,
+    ) -> Result<Option<Self>, Self::Rejection> {
         if let Some(host) = parse_forwarded(&parts.headers) {
-            return Ok(Host(host.to_owned()));
+            return Ok(Some(Host(host.to_owned())));
         }
 
         if let Some(host) = parts
@@ -41,7 +63,7 @@ where
             .get(X_FORWARDED_HOST_HEADER_KEY)
             .and_then(|host| host.to_str().ok())
         {
-            return Ok(Host(host.to_owned()));
+            return Ok(Some(Host(host.to_owned())));
         }
 
         if let Some(host) = parts
@@ -49,32 +71,14 @@ where
             .get(http::header::HOST)
             .and_then(|host| host.to_str().ok())
         {
-            return Ok(Host(host.to_owned()));
+            return Ok(Some(Host(host.to_owned())));
         }
 
         if let Some(authority) = parts.uri.authority() {
-            return Ok(Host(parse_authority(authority).to_owned()));
+            return Ok(Some(Host(parse_authority(authority).to_owned())));
         }
 
-        Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
-    }
-}
-
-impl<S> OptionalFromRequestParts<S> for Host
-where
-    S: Send + Sync,
-{
-    type Rejection = Infallible;
-
-    async fn from_request_parts(
-        parts: &mut Parts,
-        _state: &S,
-    ) -> Result<Option<Self>, Self::Rejection> {
-        Ok(
-            <Self as FromRequestParts<S>>::from_request_parts(parts, _state)
-                .await
-                .ok(),
-        )
+        Ok(None)
     }
 }
 
@@ -182,6 +186,7 @@ mod tests {
             <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
                 .await
                 .unwrap();
+
         assert_eq!(host.0, "[::1]:456");
     }
 

From cb02d2deaecae9eb405f2677766f929328988fef Mon Sep 17 00:00:00 2001
From: Daniel Schneider <daniel.schneider@eramux.com>
Date: Wed, 15 Jan 2025 08:24:21 +0100
Subject: [PATCH 4/4] tests: cleanup/simplify tests

---
 axum-extra/src/extract/host.rs | 26 +++++---------------------
 1 file changed, 5 insertions(+), 21 deletions(-)

diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs
index bc9c19e508..e9eb91c5be 100644
--- a/axum-extra/src/extract/host.rs
+++ b/axum-extra/src/extract/host.rs
@@ -171,10 +171,7 @@ mod tests {
     async fn ip4_uri_host() {
         let mut parts = Request::new(()).into_parts().0;
         parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
-        let host =
-            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
-                .await
-                .unwrap();
+        let host = parts.extract::<Host>().await.unwrap();
         assert_eq!(host.0, "127.0.0.1:1234");
     }
 
@@ -182,21 +179,14 @@ mod tests {
     async fn ip6_uri_host() {
         let mut parts = Request::new(()).into_parts().0;
         parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap();
-        let host =
-            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
-                .await
-                .unwrap();
-
+        let host = parts.extract::<Host>().await.unwrap();
         assert_eq!(host.0, "[::1]:456");
     }
 
     #[crate::test]
     async fn missing_host() {
         let mut parts = Request::new(()).into_parts().0;
-        let host =
-            <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &())
-                .await
-                .unwrap_err();
+        let host = parts.extract::<Host>().await.unwrap_err();
         assert!(matches!(host, HostRejection::FailedToResolveHost(_)));
     }
 
@@ -204,20 +194,14 @@ mod tests {
     async fn optional_extractor() {
         let mut parts = Request::new(()).into_parts().0;
         parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap();
-        let host = Option::<Host>::from_request_parts(&mut parts, &())
-            .await
-            .unwrap();
-
+        let host = parts.extract::<Option<Host>>().await.unwrap();
         assert!(host.is_some());
     }
 
     #[crate::test]
     async fn optional_extractor_none() {
         let mut parts = Request::new(()).into_parts().0;
-        let host = Option::<Host>::from_request_parts(&mut parts, &())
-            .await
-            .unwrap();
-
+        let host = parts.extract::<Option<Host>>().await.unwrap();
         assert!(host.is_none());
     }