|
| 1 | +use futures::future::{BoxFuture, FutureExt, TryFutureExt}; |
| 2 | +use http::status::StatusCode; |
| 3 | +use std::{ |
| 4 | + any::Any, |
| 5 | + panic::{AssertUnwindSafe, RefUnwindSafe}, |
| 6 | +}; |
| 7 | +use tide_core::{ |
| 8 | + middleware::{Middleware, Next}, |
| 9 | + response::IntoResponse, |
| 10 | + Context, Response, |
| 11 | +}; |
| 12 | + |
| 13 | +/// A [`Middleware`] that will catch any panics from later middleware or handlers and return a |
| 14 | +/// response to the client. |
| 15 | +/// |
| 16 | +/// It is **not** recommended to use this middleware for a general try/catch mechanism. The |
| 17 | +/// [`Result`] type is more appropriate to use for middleware/handlers that can fail on a regular |
| 18 | +/// basis. Additionally, this middleware is not guaranteed to catch all panics, see the "Notes" |
| 19 | +/// section in the [`std::panic::catch_unwind`] docs. |
| 20 | +pub struct CatchUnwind { |
| 21 | + f: Box<dyn Fn(Box<dyn Any + Send + 'static>) -> Response + Send + Sync>, |
| 22 | +} |
| 23 | + |
| 24 | +impl CatchUnwind { |
| 25 | + /// Create a [`CatchUnwind`] which will respond with [`StatusCode::INTERNAL_SERVER_ERROR`] when |
| 26 | + /// any panic is caught. |
| 27 | + pub fn new() -> Self { |
| 28 | + Self::with_response(|_| { |
| 29 | + "Internal server error" |
| 30 | + .with_status(StatusCode::INTERNAL_SERVER_ERROR) |
| 31 | + .into_response() |
| 32 | + }) |
| 33 | + } |
| 34 | + |
| 35 | + /// Create a [`CatchUnwind`] with a custom function to generate the response, the function will |
| 36 | + /// be passed the caught panic. |
| 37 | + pub fn with_response( |
| 38 | + response: impl Fn(Box<dyn Any + Send + 'static>) -> Response + Send + Sync + 'static, |
| 39 | + ) -> Self { |
| 40 | + Self { |
| 41 | + f: Box::new(response), |
| 42 | + } |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +impl<State: RefUnwindSafe + 'static> Middleware<State> for CatchUnwind { |
| 47 | + fn handle<'a>(&'a self, cx: Context<State>, next: Next<'a, State>) -> BoxFuture<'a, Response> { |
| 48 | + AssertUnwindSafe(next.run(cx)) |
| 49 | + .catch_unwind() |
| 50 | + .unwrap_or_else(move |err| (self.f)(err)) |
| 51 | + .boxed() |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +impl Default for CatchUnwind { |
| 56 | + fn default() -> Self { |
| 57 | + Self::new() |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +impl std::fmt::Debug for CatchUnwind { |
| 62 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 63 | + f.debug_struct("CatchUnwind").finish() |
| 64 | + } |
| 65 | +} |
0 commit comments