Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge fallbacks with the rest of the router #3158

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions axum/src/docs/routing/route.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ documentation for more details.
It is not possible to create segments that only match some types like numbers or
regular expression. You must handle that manually in your handlers.

[`MatchedPath`](crate::extract::MatchedPath) can be used to extract the matched
path rather than the actual path.
[`MatchedPath`] can be used to extract the matched path rather than the actual path.

# Wildcards

Expand Down
87 changes: 53 additions & 34 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
#[cfg(feature = "matched-path")]
use crate::extract::MatchedPath;
use crate::{
body::{Body, HttpBody},
boxed::BoxedIntoRoute,
Expand All @@ -20,7 +22,8 @@ use std::{
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower::service_fn;
use tower_layer::{layer_fn, Layer};
use tower_service::Service;

pub mod future;
Expand Down Expand Up @@ -72,8 +75,7 @@ impl<S> Clone for Router<S> {
}

struct RouterInner<S> {
path_router: PathRouter<S, false>,
fallback_router: PathRouter<S, true>,
path_router: PathRouter<S>,
default_fallback: bool,
catch_all_fallback: Fallback<S>,
}
Expand All @@ -91,7 +93,6 @@ impl<S> fmt::Debug for Router<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.inner.path_router)
.field("fallback_router", &self.inner.fallback_router)
.field("default_fallback", &self.inner.default_fallback)
.field("catch_all_fallback", &self.inner.catch_all_fallback)
.finish()
Expand Down Expand Up @@ -141,7 +142,6 @@ where
Self {
inner: Arc::new(RouterInner {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}),
Expand All @@ -153,7 +153,6 @@ where
Ok(inner) => inner,
Err(arc) => RouterInner {
path_router: arc.path_router.clone(),
fallback_router: arc.fallback_router.clone(),
default_fallback: arc.default_fallback,
catch_all_fallback: arc.catch_all_fallback.clone(),
},
Expand Down Expand Up @@ -207,8 +206,7 @@ where

let RouterInner {
path_router,
fallback_router,
default_fallback,
default_fallback: _,
// we don't need to inherit the catch-all fallback. It is only used for CONNECT
// requests with an empty path. If we were to inherit the catch-all fallback
// it would end up matching `/{path}/*` which doesn't match empty paths.
Expand All @@ -217,10 +215,6 @@ where

tap_inner!(self, mut this => {
panic_on_err!(this.path_router.nest(path, path_router));

if !default_fallback {
panic_on_err!(this.fallback_router.nest(path, fallback_router));
}
})
}

Expand All @@ -247,43 +241,33 @@ where
where
R: Into<Router<S>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";

let other: Router<S> = other.into();
let RouterInner {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into_inner();

map_inner!(self, mut this => {
panic_on_err!(this.path_router.merge(path_router));

match (this.default_fallback, default_fallback) {
// both have the default fallback
// use the one from other
(true, true) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
(true, true) => {}
// this has default fallback, other has a custom fallback
(true, false) => {
this.fallback_router.merge(other_fallback).expect(PANIC_MSG);
this.default_fallback = false;
}
// this has a custom fallback, other has a default
(false, true) => {
let fallback_router = std::mem::take(&mut this.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
this.fallback_router = other_fallback;
}
// both have a custom fallback, not allowed
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
}
};

panic_on_err!(this.path_router.merge(path_router));

this.catch_all_fallback = this
.catch_all_fallback
.merge(catch_all_fallback)
Expand All @@ -304,7 +288,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.layer(layer.clone()),
fallback_router: this.fallback_router.layer(layer.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)),
})
Expand All @@ -322,7 +305,6 @@ where
{
map_inner!(self, this => RouterInner {
path_router: this.path_router.route_layer(layer),
fallback_router: this.fallback_router,
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback,
})
Expand Down Expand Up @@ -376,8 +358,51 @@ where
}

fn fallback_endpoint(self, endpoint: Endpoint<S>) -> Self {
// TODO make this better, get rid of the `unwrap`s.
// We need the returned `Service` to be `Clone` and the function inside `service_fn` to be
// `FnMut` so instead of just using the owned service, we do this trick with `Option`. We
// know this will be called just once so it's fine. We're doing that so that we avoid one
// clone inside `oneshot_inner` so that the `Router` and subsequently the `State` is not
// cloned too much.
tap_inner!(self, mut this => {
this.fallback_router.set_fallback(endpoint);
_ = this.path_router.route_endpoint(
"/",
endpoint.clone().layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

_ = this.path_router.route_endpoint(
FALLBACK_PARAM_PATH,
endpoint.layer(
layer_fn(
|service: Route| {
let mut service = Some(service);
service_fn(
#[cfg_attr(not(feature = "matched-path"), allow(unused_mut))]
move |mut request: Request| {
#[cfg(feature = "matched-path")]
request.extensions_mut().remove::<MatchedPath>();
service.take().unwrap().oneshot_inner_owned(request)
}
)
}
)
)
);

this.default_fallback = false;
})
}
Expand All @@ -386,7 +411,6 @@ where
pub fn with_state<S2>(self, state: S) -> Router<S2> {
map_inner!(self, this => RouterInner {
path_router: this.path_router.with_state(state.clone()),
fallback_router: this.fallback_router.with_state(state.clone()),
default_fallback: this.default_fallback,
catch_all_fallback: this.catch_all_fallback.with_state(state),
})
Expand All @@ -398,11 +422,6 @@ where
Err((req, state)) => (req, state),
};

let (req, state) = match self.inner.fallback_router.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => (req, state),
};

self.inner
.catch_all_fallback
.clone()
Expand Down
88 changes: 20 additions & 68 deletions axum/src/routing/path_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,17 @@ use tower_layer::Layer;
use tower_service::Service;

use super::{
future::RouteFuture, not_found::NotFound, strip_prefix::StripPrefix, url_params, Endpoint,
MethodRouter, Route, RouteId, FALLBACK_PARAM_PATH, NEST_TAIL_PARAM,
future::RouteFuture, strip_prefix::StripPrefix, url_params, Endpoint, MethodRouter, Route,
RouteId, NEST_TAIL_PARAM,
};

pub(super) struct PathRouter<S, const IS_FALLBACK: bool> {
pub(super) struct PathRouter<S> {
routes: HashMap<RouteId, Endpoint<S>>,
node: Arc<Node>,
prev_route_id: RouteId,
v7_checks: bool,
}

impl<S> PathRouter<S, true>
where
S: Clone + Send + Sync + 'static,
{
pub(super) fn new_fallback() -> Self {
let mut this = Self::default();
this.set_fallback(Endpoint::Route(Route::new(NotFound)));
this
}

pub(super) fn set_fallback(&mut self, endpoint: Endpoint<S>) {
self.replace_endpoint("/", endpoint.clone());
self.replace_endpoint(FALLBACK_PARAM_PATH, endpoint);
}
}

fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> {
if path.is_empty() {
return Err("Paths must start with a `/`. Use \"/\" for root routes");
Expand Down Expand Up @@ -72,7 +56,7 @@ fn validate_v07_paths(path: &str) -> Result<(), &'static str> {
.unwrap_or(Ok(()))
}

impl<S, const IS_FALLBACK: bool> PathRouter<S, IS_FALLBACK>
impl<S> PathRouter<S>
where
S: Clone + Send + Sync + 'static,
{
Expand Down Expand Up @@ -159,10 +143,7 @@ where
.map_err(|err| format!("Invalid route {path:?}: {err}"))
}

pub(super) fn merge(
&mut self,
other: PathRouter<S, IS_FALLBACK>,
) -> Result<(), Cow<'static, str>> {
pub(super) fn merge(&mut self, other: PathRouter<S>) -> Result<(), Cow<'static, str>> {
let PathRouter {
routes,
node,
Expand All @@ -179,24 +160,9 @@ where
.get(&id)
.expect("no path for route id. This is a bug in axum. Please file an issue");

if IS_FALLBACK && (&**path == "/" || &**path == FALLBACK_PARAM_PATH) {
// when merging two routers it doesn't matter if you do `a.merge(b)` or
// `b.merge(a)`. This must also be true for fallbacks.
//
// However all fallback routers will have routes for `/` and `/*` so when merging
// we have to ignore the top level fallbacks on one side otherwise we get
// conflicts.
//
// `Router::merge` makes sure that when merging fallbacks `other` always has the
// fallback we want to keep. It panics if both routers have a custom fallback. Thus
// it is always okay to ignore one fallback and `Router::merge` also makes sure the
// one we can ignore is that of `self`.
self.replace_endpoint(path, route);
} else {
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
match route {
Endpoint::MethodRouter(method_router) => self.route(path, method_router)?,
Endpoint::Route(route) => self.route_service(path, route)?,
}
}

Expand All @@ -206,7 +172,7 @@ where
pub(super) fn nest(
&mut self,
path_to_nest_at: &str,
router: PathRouter<S, IS_FALLBACK>,
router: PathRouter<S>,
) -> Result<(), Cow<'static, str>> {
let prefix = validate_nest_path(self.v7_checks, path_to_nest_at);

Expand Down Expand Up @@ -282,7 +248,7 @@ where
Ok(())
}

pub(super) fn layer<L>(self, layer: L) -> PathRouter<S, IS_FALLBACK>
pub(super) fn layer<L>(self, layer: L) -> PathRouter<S>
where
L: Layer<Route> + Clone + Send + Sync + 'static,
L::Service: Service<Request> + Clone + Send + Sync + 'static,
Expand Down Expand Up @@ -344,7 +310,7 @@ where
!self.routes.is_empty()
}

pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2, IS_FALLBACK> {
pub(super) fn with_state<S2>(self, state: S) -> PathRouter<S2> {
let routes = self
.routes
.into_iter()
Expand Down Expand Up @@ -388,14 +354,12 @@ where
Ok(match_) => {
let id = *match_.value;

if !IS_FALLBACK {
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);
}
#[cfg(feature = "matched-path")]
crate::extract::matched_path::set_matched_path_for_request(
id,
&self.node.route_id_to_path,
&mut parts.extensions,
);

url_params::insert_url_params(&mut parts.extensions, match_.params);

Expand All @@ -418,18 +382,6 @@ where
}
}

pub(super) fn replace_endpoint(&mut self, path: &str, endpoint: Endpoint<S>) {
match self.node.at(path) {
Ok(match_) => {
let id = *match_.value;
self.routes.insert(id, endpoint);
}
Err(_) => self
.route_endpoint(path, endpoint)
.expect("path wasn't matched so endpoint shouldn't exist"),
}
}

fn next_route_id(&mut self) -> RouteId {
let next_id = self
.prev_route_id
Expand All @@ -441,7 +393,7 @@ where
}
}

impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
impl<S> Default for PathRouter<S> {
fn default() -> Self {
Self {
routes: Default::default(),
Expand All @@ -452,7 +404,7 @@ impl<S, const IS_FALLBACK: bool> Default for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
impl<S> fmt::Debug for PathRouter<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PathRouter")
.field("routes", &self.routes)
Expand All @@ -461,7 +413,7 @@ impl<S, const IS_FALLBACK: bool> fmt::Debug for PathRouter<S, IS_FALLBACK> {
}
}

impl<S, const IS_FALLBACK: bool> Clone for PathRouter<S, IS_FALLBACK> {
impl<S> Clone for PathRouter<S> {
fn clone(&self) -> Self {
Self {
routes: self.routes.clone(),
Expand Down
Loading