From 25f51b34836f47cbbc240cb713354be23874070c Mon Sep 17 00:00:00 2001 From: Daniel Schneider <daniel.schneider@eramux.com> Date: Tue, 14 Jan 2025 20:21:19 +0100 Subject: [PATCH 1/4] feat: impl OptionalFromRequestParts for Host extractor --- axum-extra/src/extract/host.rs | 62 ++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index a6828d3004..d1b4f78b25 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -1,10 +1,11 @@ use super::rejection::{FailedToResolveHost, HostRejection}; -use axum::extract::FromRequestParts; +use axum::extract::{FromRequestParts, OptionalFromRequestParts}; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, uri::Authority, }; +use std::convert::Infallible; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; @@ -59,6 +60,24 @@ where } } +impl<S> OptionalFromRequestParts<S> for Host +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Option<Self>, Self::Rejection> { + Ok( + <Self as FromRequestParts<S>>::from_request_parts(parts, _state) + .await + .ok(), + ) + } +} + #[allow(warnings)] fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { // if there are multiple `Forwarded` `HeaderMap::get` will return the first one @@ -148,7 +167,10 @@ mod tests { async fn ip4_uri_host() { let mut parts = Request::new(()).into_parts().0; parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); - let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + let host = + <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) + .await + .unwrap(); assert_eq!(host.0, "127.0.0.1:1234"); } @@ -156,10 +178,44 @@ mod tests { async fn ip6_uri_host() { let mut parts = Request::new(()).into_parts().0; parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap(); - let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + let host = + <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) + .await + .unwrap(); assert_eq!(host.0, "[::1]:456"); } + #[crate::test] + async fn missing_host() { + let mut parts = Request::new(()).into_parts().0; + let host = + <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) + .await + .unwrap_err(); + assert!(matches!(host, HostRejection::FailedToResolveHost(_))); + } + + #[crate::test] + async fn optional_extractor() { + let mut parts = Request::new(()).into_parts().0; + parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); + let host = Option::<Host>::from_request_parts(&mut parts, &()) + .await + .unwrap(); + + assert!(matches!(host, Some(Host(_)))); + } + + #[crate::test] + async fn optional_extractor_none() { + let mut parts = Request::new(()).into_parts().0; + let host = Option::<Host>::from_request_parts(&mut parts, &()) + .await + .unwrap(); + + assert!(matches!(host, None)); + } + #[test] fn forwarded_parsing() { // the basic case From 3b7a812cc311b6c8511629a96089cbed1f032c5b Mon Sep 17 00:00:00 2001 From: Daniel Schneider <daniel.schneider@eramux.com> Date: Tue, 14 Jan 2025 20:35:24 +0100 Subject: [PATCH 2/4] fix: linting issues --- axum-extra/src/extract/host.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index d1b4f78b25..366a189299 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -203,7 +203,7 @@ mod tests { .await .unwrap(); - assert!(matches!(host, Some(Host(_)))); + assert!(host.is_some()); } #[crate::test] @@ -213,7 +213,7 @@ mod tests { .await .unwrap(); - assert!(matches!(host, None)); + assert!(host.is_none()); } #[test] From c3c1735faef6062359dedff0023d5d723eddd297 Mon Sep 17 00:00:00 2001 From: Daniel Schneider <daniel.schneider@eramux.com> Date: Wed, 15 Jan 2025 08:06:46 +0100 Subject: [PATCH 3/4] feat: implement feedback --- axum-extra/src/extract/host.rs | 53 +++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index 366a189299..bc9c19e508 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -1,5 +1,8 @@ use super::rejection::{FailedToResolveHost, HostRejection}; -use axum::extract::{FromRequestParts, OptionalFromRequestParts}; +use axum::{ + extract::{FromRequestParts, OptionalFromRequestParts}, + RequestPartsExt, +}; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, @@ -32,8 +35,27 @@ where type Rejection = HostRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { + parts + .extract::<Option<Host>>() + .await + .ok() + .flatten() + .ok_or(HostRejection::FailedToResolveHost(FailedToResolveHost)) + } +} + +impl<S> OptionalFromRequestParts<S> for Host +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts( + parts: &mut Parts, + _state: &S, + ) -> Result<Option<Self>, Self::Rejection> { if let Some(host) = parse_forwarded(&parts.headers) { - return Ok(Host(host.to_owned())); + return Ok(Some(Host(host.to_owned()))); } if let Some(host) = parts @@ -41,7 +63,7 @@ where .get(X_FORWARDED_HOST_HEADER_KEY) .and_then(|host| host.to_str().ok()) { - return Ok(Host(host.to_owned())); + return Ok(Some(Host(host.to_owned()))); } if let Some(host) = parts @@ -49,32 +71,14 @@ where .get(http::header::HOST) .and_then(|host| host.to_str().ok()) { - return Ok(Host(host.to_owned())); + return Ok(Some(Host(host.to_owned()))); } if let Some(authority) = parts.uri.authority() { - return Ok(Host(parse_authority(authority).to_owned())); + return Ok(Some(Host(parse_authority(authority).to_owned()))); } - Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) - } -} - -impl<S> OptionalFromRequestParts<S> for Host -where - S: Send + Sync, -{ - type Rejection = Infallible; - - async fn from_request_parts( - parts: &mut Parts, - _state: &S, - ) -> Result<Option<Self>, Self::Rejection> { - Ok( - <Self as FromRequestParts<S>>::from_request_parts(parts, _state) - .await - .ok(), - ) + Ok(None) } } @@ -182,6 +186,7 @@ mod tests { <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) .await .unwrap(); + assert_eq!(host.0, "[::1]:456"); } From cb02d2deaecae9eb405f2677766f929328988fef Mon Sep 17 00:00:00 2001 From: Daniel Schneider <daniel.schneider@eramux.com> Date: Wed, 15 Jan 2025 08:24:21 +0100 Subject: [PATCH 4/4] tests: cleanup/simplify tests --- axum-extra/src/extract/host.rs | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/axum-extra/src/extract/host.rs b/axum-extra/src/extract/host.rs index bc9c19e508..e9eb91c5be 100644 --- a/axum-extra/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -171,10 +171,7 @@ mod tests { async fn ip4_uri_host() { let mut parts = Request::new(()).into_parts().0; parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); - let host = - <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) - .await - .unwrap(); + let host = parts.extract::<Host>().await.unwrap(); assert_eq!(host.0, "127.0.0.1:1234"); } @@ -182,21 +179,14 @@ mod tests { async fn ip6_uri_host() { let mut parts = Request::new(()).into_parts().0; parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap(); - let host = - <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) - .await - .unwrap(); - + let host = parts.extract::<Host>().await.unwrap(); assert_eq!(host.0, "[::1]:456"); } #[crate::test] async fn missing_host() { let mut parts = Request::new(()).into_parts().0; - let host = - <Host as axum::extract::FromRequestParts<_>>::from_request_parts(&mut parts, &()) - .await - .unwrap_err(); + let host = parts.extract::<Host>().await.unwrap_err(); assert!(matches!(host, HostRejection::FailedToResolveHost(_))); } @@ -204,20 +194,14 @@ mod tests { async fn optional_extractor() { let mut parts = Request::new(()).into_parts().0; parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); - let host = Option::<Host>::from_request_parts(&mut parts, &()) - .await - .unwrap(); - + let host = parts.extract::<Option<Host>>().await.unwrap(); assert!(host.is_some()); } #[crate::test] async fn optional_extractor_none() { let mut parts = Request::new(()).into_parts().0; - let host = Option::<Host>::from_request_parts(&mut parts, &()) - .await - .unwrap(); - + let host = parts.extract::<Option<Host>>().await.unwrap(); assert!(host.is_none()); }