Skip to content

Commit

Permalink
Don't leak a tokio task when using serve without graceful shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Dec 29, 2024
1 parent f8f3a03 commit 6aab90a
Showing 1 changed file with 105 additions and 80 deletions.
185 changes: 105 additions & 80 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,16 @@ where
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
self.with_graceful_shutdown(std::future::pending())
.into_future()
let Self {
listener,
make_service,
_marker,
} = self;
private::ServeFuture(Box::pin(do_serve(
listener,
make_service,
None::<std::future::Pending<_>>,
)))
}
}

Expand Down Expand Up @@ -256,98 +264,115 @@ where

fn into_future(self) -> Self::IntoFuture {
let Self {
mut listener,
mut make_service,
listener,
make_service,
signal,
_marker: _,
_marker,
} = self;

private::ServeFuture(Box::pin(async move {
let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});
private::ServeFuture(Box::pin(do_serve(listener, make_service, Some(signal))))
}
}

async fn do_serve<L, M, S>(
mut listener: L,
mut make_service: M,
signal: Option<impl Future<Output = ()> + Send + 'static>,
) -> io::Result<()>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
{
let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
if let Some(signal) = signal {
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});
}

let (close_tx, close_rx) = watch::channel(());

loop {
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};

let io = TokioIo::new(io);

trace!("connection {remote_addr:?} accepted");

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
io: &io,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));

let hyper_service = TowerToHyperService::new(tower_service);

let signal_tx = Arc::clone(&signal_tx);

let close_rx = close_rx.clone();

let (close_tx, close_rx) = watch::channel(());
tokio::spawn(async move {
#[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn);

let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);

loop {
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
};

let io = TokioIo::new(io);

trace!("connection {remote_addr:?} accepted");

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
io: &io,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));

let hyper_service = TowerToHyperService::new(tower_service);

let signal_tx = Arc::clone(&signal_tx);

let close_rx = close_rx.clone();

tokio::spawn(async move {
#[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
// CONNECT protocol needed for HTTP/2 websockets
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn);

let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);

loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}

drop(close_rx);
});
}
}

drop(close_rx);
drop(listener);
});
}

trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;
drop(close_rx);
drop(listener);

Ok(())
}))
}
trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;

Ok(())
}

/// An incoming stream.
Expand Down

0 comments on commit 6aab90a

Please sign in to comment.