|
| 1 | +use crate::{ |
| 2 | + middleware::{Middleware, Next}, |
| 3 | + response::IntoResponse, |
| 4 | + Context, Response, |
| 5 | +}; |
| 6 | +use futures::future::{BoxFuture, FutureExt, TryFutureExt}; |
| 7 | +use http::status::StatusCode; |
| 8 | +use std::any::Any; |
| 9 | + |
| 10 | +/// A [`Middleware`] that will catch any panics from later middleware or handlers and return a |
| 11 | +/// response to the client. |
| 12 | +pub struct CatchUnwind { |
| 13 | + f: Box<dyn Fn(Box<dyn Any + Send + 'static>) -> Response + Send + Sync>, |
| 14 | +} |
| 15 | + |
| 16 | +impl CatchUnwind { |
| 17 | + /// Create a [`CatchUnwind`] which will respond with [`StatusCode::INTERNAL_SERVER_ERROR`] when |
| 18 | + /// any panic is caught. |
| 19 | + pub fn new() -> Self { |
| 20 | + Self::with_response(|_| { |
| 21 | + "Internal server error" |
| 22 | + .with_status(StatusCode::INTERNAL_SERVER_ERROR) |
| 23 | + .into_response() |
| 24 | + }) |
| 25 | + } |
| 26 | + |
| 27 | + /// Create a [`CatchUnwind`] with a custom function to generate the response, the function will |
| 28 | + /// be passed the caught panic. |
| 29 | + pub fn with_response( |
| 30 | + response: impl Fn(Box<dyn Any + Send + 'static>) -> Response + Send + Sync + 'static, |
| 31 | + ) -> Self { |
| 32 | + Self { |
| 33 | + f: Box::new(response), |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +impl<State: 'static> Middleware<State> for CatchUnwind { |
| 39 | + fn handle<'a>(&'a self, cx: Context<State>, next: Next<'a, State>) -> BoxFuture<'a, Response> { |
| 40 | + std::panic::AssertUnwindSafe(next.run(cx)) |
| 41 | + .catch_unwind() |
| 42 | + .unwrap_or_else(move |err| (self.f)(err)) |
| 43 | + .boxed() |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +impl Default for CatchUnwind { |
| 48 | + fn default() -> Self { |
| 49 | + Self::new() |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +impl std::fmt::Debug for CatchUnwind { |
| 54 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 55 | + f.debug_struct("CatchUnwind").finish() |
| 56 | + } |
| 57 | +} |
0 commit comments