diff --git a/Cargo.toml b/Cargo.toml index 3379f58..34aecd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ chrono = "0.4.13" colored = "2.0.0" flate2 = "1.0.16" futures = "0.3.5" -hyper = {version = "0.14", features = ["server", "http1", "http2", "tcp"]} +hyper = {version = "0.14", features = ["server", "stream", "http1", "http2", "tcp"]} log = "0.4.11" regex = "1.3.9" serde = "1.0.114" @@ -34,3 +34,8 @@ tokio = {version = "1", features = ["full"]} handlebars = {version = "3.5.1", optional = true} multer = {version = "1.2.2", optional = true} rust-s3 = {version = "0.26.4", optional = true} +tokio-util = { version = "0.7.1", features = ["codec"] } + +[dev-dependencies] +serde = { version = "1.0", features = ["derive"] } +uuid = { version = "0.8", features = ["serde", "v4"] } diff --git a/examples/hello.rs b/examples/hello.rs new file mode 100644 index 0000000..6cc0074 --- /dev/null +++ b/examples/hello.rs @@ -0,0 +1,63 @@ +use std::{ + net::{Ipv4Addr, SocketAddr}, + pin::Pin, +}; + +use eve_rs::{ + listen, + App, + Context, + DefaultContext, + DefaultMiddleware, + Error, + NextHandler, + Request, +}; +use futures::Future; + +async fn say_hello( + mut context: DefaultContext, + _: NextHandler, +) -> Result> { + let val = "Hello 👋"; + context.body(val); + context.content_type("application/html"); + Ok(context) +} + +#[tokio::main] +async fn main() { + println!("Starting server ..."); + + let mut app = App::, (), ()>::create( + |request: Request, _state: &()| DefaultContext::new(request), + (), + ); + + app.get( + "/hello", + [DefaultMiddleware::new(|context, next| { + Box::pin(async move { say_hello(context, next).await }) + })], + ); + + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 8080)); + println!("Listening for connection on port {bind_addr}"); + + listen( + app, + bind_addr, + None::>>>, + ) + .await; +} + +#[allow(dead_code)] +async fn shutdown_signal() { + tokio::signal::ctrl_c() + .await + .expect("Expected installing Ctrl+C signal handler"); + + // Handle cleanup tasks + println!("Shutting down the server !!!"); +} diff --git a/examples/static_server.rs b/examples/static_server.rs new file mode 100644 index 0000000..5a13537 --- /dev/null +++ b/examples/static_server.rs @@ -0,0 +1,48 @@ +use std::{ + env, + net::{Ipv4Addr, SocketAddr}, + pin::Pin, +}; + +use eve_rs::{ + default_middlewares::static_file_server::StaticFileServer, + listen, + App, + DefaultContext, + DefaultMiddleware, + Request, +}; +use futures::Future; + +#[tokio::main] +async fn main() { + println!("Starting server ..."); + + let mut app = App::, (), ()>::create( + |request: Request, _state: &()| DefaultContext::new(request), + (), + ); + + app.get( + "**", + [DefaultMiddleware::new(move |context, next| { + Box::pin(async move { + StaticFileServer::create( + env::current_dir().unwrap().to_str().unwrap(), + ) + .serve(context, next) + .await + }) + })], + ); + + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 8080)); + println!("Listening for connection on port {bind_addr}"); + + listen( + app, + bind_addr, + None::>>>, + ) + .await; +} diff --git a/examples/todos/app.rs b/examples/todos/app.rs new file mode 100644 index 0000000..f0fd113 --- /dev/null +++ b/examples/todos/app.rs @@ -0,0 +1,118 @@ +use std::{ + collections::HashMap, + pin::Pin, + sync::{Arc, RwLock}, +}; + +use eve_rs::{Context, Error, Middleware, NextHandler, Request, Response}; +use futures::Future; +use serde::Serialize; +use uuid::Uuid; + +// Middleware used by Eve +#[derive(Debug, Clone)] +pub struct Middler { + handler: MiddlewareHandler, +} + +type MiddlewareHandler = + fn( + Ctx, + NextHandler, + ) -> Pin>> + Send>>; + +impl Middler { + pub fn new(handler: MiddlewareHandler) -> Self { + Middler { handler } + } +} + +#[async_trait::async_trait] +impl Middleware for Middler { + async fn run_middleware( + &self, + context: Ctx, + next: NextHandler, + ) -> Result> { + (self.handler)(context, next).await + } +} + +// Context used by every middleware +#[derive(Debug)] +pub struct Ctx { + // basic data + req: Request, + res: Response, + + // buffering streaming data + buffered_request_body: Option>, + + // additional data + pub state: AppState, +} + +impl Ctx { + pub fn new(req: Request, state: &AppState) -> Self { + Ctx { + req, + res: Default::default(), + state: state.clone(), + buffered_request_body: None, + } + } +} + +#[async_trait::async_trait] +impl Context for Ctx { + fn get_request(&self) -> &Request { + &self.req + } + + fn get_request_mut(&mut self) -> &mut Request { + &mut self.req + } + + fn get_response(&self) -> &Response { + &self.res + } + + fn take_response(self) -> Response { + self.res + } + + fn get_response_mut(&mut self) -> &mut Response { + &mut self.res + } + + type ResBodyBuffer = Vec; + async fn get_buffered_request_body(&mut self) -> &Self::ResBodyBuffer { + match self.buffered_request_body { + Some(ref body) => body, + None => { + let body = hyper::body::to_bytes(self.req.take_body()) + .await + .unwrap() + .to_vec(); + self.buffered_request_body.get_or_insert(body) + } + } + } +} + +// AppState for sharing and persistance +#[derive(Debug, Default, Clone)] +pub struct AppState { + // Since state is cloned for every request, + // size of struct fields should be small + pub db: Db, +} + +type Db = Arc>>; + +#[derive(Debug, Clone, Serialize)] +pub struct Todo { + pub id: Uuid, + pub text: String, + pub completed: bool, +} diff --git a/examples/todos/main.rs b/examples/todos/main.rs new file mode 100644 index 0000000..074ea21 --- /dev/null +++ b/examples/todos/main.rs @@ -0,0 +1,52 @@ +mod app; +mod routes; + +use std::{ + net::{Ipv4Addr, SocketAddr}, + pin::Pin, +}; + +use eve_rs::{listen, App}; +use futures::Future; + +use crate::{ + app::{AppState, Ctx, Middler}, + routes::{create_todos, delete_todos, get_todos, update_todos}, +}; + +macro_rules! middleware { + ( $ ($func: ident),+ ) => { + [ + $ ( + Middler::new(|ctx, next| { + Box::pin(async move { $func(ctx, next).await }) + }) + ),+ + ] + }; +} + +#[tokio::main] +async fn main() { + println!("Starting server ..."); + + let mut app = App::::create( + Ctx::new, + AppState::default(), + ); + + app.get("/todos", middleware![get_todos]); + app.post("/todos", middleware![create_todos]); + app.patch("/todos/:id", middleware![update_todos]); + app.delete("/todos/:id", middleware![delete_todos]); + + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 8080)); + println!("Listening on {bind_addr}"); + + listen( + app, + bind_addr, + None::>>>, + ) + .await; +} diff --git a/examples/todos/routes.rs b/examples/todos/routes.rs new file mode 100644 index 0000000..0f6016f --- /dev/null +++ b/examples/todos/routes.rs @@ -0,0 +1,118 @@ +use std::str::FromStr; + +use eve_rs::{Context, Error, NextHandler}; +use uuid::Uuid; + +use crate::app::{Ctx, Todo}; + +pub async fn get_todos( + mut ctx: Ctx, + _next: NextHandler, +) -> Result> { + let todos = ctx + .state + .db + .read() + .expect("Expected to read DB data") + .values() + .cloned() + .collect::>(); + + ctx.body( + serde_json::to_string_pretty(&todos) + .expect("Expected to serialize the todos to json string"), + ); + ctx.content_type("application/json"); + ctx.status(200); + + Ok(ctx) +} + +pub async fn create_todos( + mut ctx: Ctx, + _next: NextHandler, +) -> Result> { + let text = ctx + .get_request() + .get_query() + .get("text") + .cloned() + .expect("Expected text in query params"); + + let todo = Todo { + id: Uuid::new_v4(), + text, + completed: false, + }; + + ctx.state + .db + .write() + .expect("Expected to write todo in DB") + .insert(todo.id, todo.clone()); + + ctx.body( + serde_json::to_string_pretty(&todo) + .expect("Expected to serialize the todos to json string"), + ); + ctx.content_type("application/json"); + ctx.status(201); + + Ok(ctx) +} + +pub async fn update_todos( + mut ctx: Ctx, + _next: NextHandler, +) -> Result> { + let id = ctx + .get_request() + .get_params() + .get("id") + .map(|id| Uuid::from_str(id).unwrap()) + .expect("Expected id in path param"); + + let mut todo = ctx.state.db.read().unwrap().get(&id).cloned().unwrap(); + + if let Some(text) = ctx.get_request().get_query().get("text").cloned() { + todo.text = text; + }; + + if let Some(completed) = + ctx.get_request().get_query().get("completed").cloned() + { + todo.completed = bool::from_str(&completed).unwrap(); + }; + + ctx.state + .db + .write() + .expect("Expected to write todo in DB") + .insert(todo.id, todo.clone()); + + ctx.body( + serde_json::to_string_pretty(&todo) + .expect("Expected to serialize the todos to json string"), + ); + ctx.content_type("application/json"); + ctx.status(201); + + Ok(ctx) +} + +pub async fn delete_todos( + mut ctx: Ctx, + _next: NextHandler, +) -> Result> { + let id = ctx + .get_request() + .get_params() + .get("id") + .map(|id| Uuid::from_str(id).unwrap()) + .expect("Expected id in path param"); + + ctx.state.db.write().unwrap().remove(&id); + ctx.status(200); + + Ok(ctx) +} diff --git a/src/app.rs b/src/app.rs index d204834..de52b33 100644 --- a/src/app.rs +++ b/src/app.rs @@ -64,7 +64,7 @@ where let path = context.get_path(); context .status(404) - .body(&format!("Cannot {} route {}", method, path)); + .body(format!("Cannot {} route {}", method, path)); Ok(context) } }) diff --git a/src/context.rs b/src/context.rs index 15a41c7..af5d515 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,13 +1,13 @@ -use std::{ - net::IpAddr, - str::{self, Utf8Error}, -}; +use std::{net::IpAddr, str}; +use async_trait::async_trait; +use hyper::Body; use serde::Serialize; use serde_json::Value; use crate::{cookie::Cookie, request::Request, response::Response, HttpMethod}; +#[async_trait] pub trait Context { fn get_request(&self) -> &Request; fn get_request_mut(&mut self) -> &mut Request; @@ -15,24 +15,22 @@ pub trait Context { fn take_response(self) -> Response; fn get_response_mut(&mut self) -> &mut Response; - fn get_body(&self) -> Result { - self.get_request().get_body() - } + // Buffered response body type + type ResBodyBuffer; + async fn get_buffered_request_body(&mut self) -> &Self::ResBodyBuffer; + fn json(&mut self, body: TBody) -> &mut Self where TBody: Serialize, { self.content_type("application/json").body( - &serde_json::to_string(&body) + serde_json::to_string(&body) .expect("unable to serialize body into JSON"), ) } - fn body(&mut self, string: &str) -> &mut Self { - self.get_response_mut().set_body(string); - self - } - fn body_bytes(&mut self, bytes: &[u8]) -> &mut Self { - self.get_response_mut().set_body_bytes(bytes); + + fn body>(&mut self, data: T) -> &mut Self { + self.get_response_mut().set_body(data); self } @@ -164,27 +162,30 @@ pub trait Context { pub struct DefaultContext { request: Request, response: Response, - body: Option, + buffered_request_body: Option>, + parsed_json: Option, } impl DefaultContext { pub fn get_body_object(&self) -> Option<&Value> { - self.body.as_ref() + self.parsed_json.as_ref() } pub fn set_body_object(&mut self, body: Value) { - self.body = Some(body); + self.parsed_json = Some(body); } pub fn new(request: Request) -> Self { DefaultContext { request, response: Default::default(), - body: None, + buffered_request_body: None, + parsed_json: None, } } } +#[async_trait] impl Context for DefaultContext { fn get_request(&self) -> &Request { &self.request @@ -205,6 +206,20 @@ impl Context for DefaultContext { fn get_response_mut(&mut self) -> &mut Response { &mut self.response } + + type ResBodyBuffer = Vec; + async fn get_buffered_request_body(&mut self) -> &Self::ResBodyBuffer { + match self.buffered_request_body { + Some(ref body) => body, + None => { + let body = hyper::body::to_bytes(self.request.take_body()) + .await + .unwrap() + .to_vec(); + self.buffered_request_body.get_or_insert(body) + } + } + } } pub fn default_context_generator( diff --git a/src/default_middlewares/compression.rs b/src/default_middlewares/compression.rs index 514bed2..79a54a1 100644 --- a/src/default_middlewares/compression.rs +++ b/src/default_middlewares/compression.rs @@ -20,7 +20,8 @@ impl CompressionHandler { } } - pub fn compress(&mut self, context: &mut TContext) + // TODO: Need to use async-compression crate for compressing on-demand + pub async fn compress(&mut self, context: &mut TContext) where TContext: Context + Debug + Send + Sync, { @@ -34,27 +35,35 @@ impl CompressionHandler { .collect::>(); if allowed_encodings.contains(&"gzip") { - let data = context.get_response().get_body(); + let data = + hyper::body::to_bytes(context.get_response_mut().take_body()) + .await + .unwrap(); let mut output = vec![]; let result = GzBuilder::new() .buf_read(data.as_ref(), self.compression_level) .read_to_end(&mut output); if result.is_ok() { context - .body_bytes(&output) - .header("Content-Encoding", "gzip"); + .header("Content-Encoding", "gzip") + .get_response_mut() + .set_body(output); } } else if allowed_encodings.contains(&"deflate") { - let data = context.get_response().get_body(); - let mut output = []; + let data = + hyper::body::to_bytes(context.get_response_mut().take_body()) + .await + .unwrap(); + let mut output = vec![]; if let Ok(Status::Ok) = self.zlib_compressor.compress( - data, + data.as_ref(), &mut output, FlushCompress::None, ) { context - .body_bytes(&output) - .header("Content-Encoding", "deflate"); + .header("Content-Encoding", "deflate") + .get_response_mut() + .set_body(output); } } } @@ -78,7 +87,7 @@ where context = next(context).await?; - compressor.compress(&mut context); + compressor.compress(&mut context).await; Ok(context) }) diff --git a/src/default_middlewares/json.rs b/src/default_middlewares/json.rs index a4b6055..aefeeb6 100644 --- a/src/default_middlewares/json.rs +++ b/src/default_middlewares/json.rs @@ -4,19 +4,17 @@ use serde_json::Value; use crate::{AsError, Context, DefaultMiddleware, Error}; -pub fn parser( - context: &TContext, +pub async fn parser( + context: &mut TContext, ) -> Result, Error> where TContext: 'static + Context + Debug + Send + Sync, TErrorData: Default + Send + Sync, + TContext::ResBodyBuffer: AsRef<[u8]>, { if context.is(&["application/json"]) { - let body = context - .get_request() - .get_body() - .unwrap_or_else(|_| "None".to_string()); - let value = serde_json::from_str(&body) + let body = context.get_buffered_request_body().await; + let value = serde_json::from_slice(body.as_ref()) .status(400) .body("Bad request")?; Ok(Some(value)) @@ -31,7 +29,7 @@ where { DefaultMiddleware::new(|mut context, next| { Box::pin(async move { - let json = parser(&context)?; + let json = parser(&mut context).await?; if let Some(json) = json { context.set_body_object(json); diff --git a/src/default_middlewares/static_file_server.rs b/src/default_middlewares/static_file_server.rs index 4653feb..0cb68d8 100644 --- a/src/default_middlewares/static_file_server.rs +++ b/src/default_middlewares/static_file_server.rs @@ -1,8 +1,11 @@ use std::fmt::Debug; -use tokio::fs; +use futures::TryFutureExt; +use hyper::Body; +use tokio::fs::{self, File}; +use tokio_util::codec::{BytesCodec, FramedRead}; -use crate::{AsError, Context, Error, Middleware, NextHandler}; +use crate::{Context, Error, Middleware, NextHandler}; #[derive(Clone)] pub struct StaticFileServer { @@ -33,11 +36,14 @@ impl StaticFileServer { let file_location = format!("{}{}", self.folder_path, context.get_path()); if is_file(&file_location).await { - let content = fs::read(file_location) - .await - .status(500) - .body("Internal server error")?; - context.body_bytes(&content); + let stream = File::open(file_location) + .map_ok(|file| FramedRead::new(file, BytesCodec::new())) + .await.expect("File open should succeed as it has already met precondition"); + + context + .get_response_mut() + .set_body(Body::wrap_stream(stream)); + Ok(context) } else { next(context).await diff --git a/src/default_middlewares/url_encoded.rs b/src/default_middlewares/url_encoded.rs index e937811..8d367e0 100644 --- a/src/default_middlewares/url_encoded.rs +++ b/src/default_middlewares/url_encoded.rs @@ -4,20 +4,18 @@ use serde_json::Value; use crate::{AsError, Context, DefaultMiddleware, Error}; -pub fn parser( - context: &TContext, +pub async fn parser( + context: &mut TContext, ) -> Result, Error> where TContext: 'static + Context + Debug + Send + Sync, TErrorData: Default + Send + Sync, + TContext::ResBodyBuffer: AsRef<[u8]> { if context.is(&["application/x-www-form-urlencoded"]) { - let body = context - .get_request() - .get_body() - .unwrap_or_else(|_| "None".to_string()); + let body = context.get_buffered_request_body().await; Ok(Some( - serde_urlencoded::from_bytes(body.as_bytes()) + serde_urlencoded::from_bytes(body.as_ref()) .status(500) .body("Internal server error")?, )) @@ -32,7 +30,7 @@ where { DefaultMiddleware::new(|mut context, next| { Box::pin(async move { - let json = parser(&context)?; + let json = parser(&mut context).await?; if let Some(json) = json { context.set_body_object(json); diff --git a/src/lib.rs b/src/lib.rs index fdbec09..4728d24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,7 +70,7 @@ pub async fn listen< let app = app.clone(); async move { let request = - Request::from_hyper(remote_addr, req).await; + Request::from_hyper(remote_addr, req); let mut context = app.generate_context(request); context.header("Server", "Eve"); diff --git a/src/middleware_handler.rs b/src/middleware_handler.rs index 5e52a8b..f8c24f5 100644 --- a/src/middleware_handler.rs +++ b/src/middleware_handler.rs @@ -82,8 +82,8 @@ where .replace('^', "\\^") .replace('$', "\\$") .replace('.', "\\.") // Specifically, match the dot. This ain't a regex character - .replace('*', "([^\\/].)+") // Match anything that's not a /, but at least 1 character - .replace("**", "(.)+"); //Match anything + .replace("**", "(.)+") //Match anything [ NOTE: first replace `**` and then replace remaining `*` ] + .replace('*', "([^/].)+"); // Match anything that's not a /, but at least 1 character // Make a variable out of anything that begins with a : and has a-z, // A-Z, 0-9, '_' diff --git a/src/renderer.rs b/src/renderer.rs index abbe8e6..e02f367 100644 --- a/src/renderer.rs +++ b/src/renderer.rs @@ -18,7 +18,7 @@ pub trait RenderEngine: Context { TParams: Serialize, { let rendered = self.get_register().render(template_name, data)?; - self.content_type("text/html").body(&rendered); + self.content_type("text/html").body(rendered); Ok(self) } } diff --git a/src/request.rs b/src/request.rs index d2a651d..9754e5a 100644 --- a/src/request.rs +++ b/src/request.rs @@ -2,10 +2,10 @@ use std::{ collections::HashMap, fmt::{Debug, Formatter, Result as FmtResult}, net::{IpAddr, SocketAddr}, - str::{self, Utf8Error}, + str, }; -use hyper::{body, Body, Request as HyperRequestInternal, Uri, Version}; +use hyper::{Body, Request as HyperRequestInternal, Uri, Version}; use crate::{cookie::Cookie, HttpMethod}; @@ -13,7 +13,7 @@ pub type HyperRequest = HyperRequestInternal; pub struct Request { pub(crate) socket_addr: SocketAddr, - pub(crate) body: Vec, + pub(crate) body: Option, pub(crate) method: HttpMethod, pub(crate) uri: Uri, pub(crate) version: (u8, u8), @@ -21,11 +21,10 @@ pub struct Request { pub(crate) query: HashMap, pub(crate) params: HashMap, pub(crate) cookies: Vec, - pub(crate) hyper_request: HyperRequest, } impl Request { - pub async fn from_hyper( + pub fn from_hyper( socket_addr: SocketAddr, req: HyperRequest, ) -> Self { @@ -46,10 +45,9 @@ impl Request { headers.insert(key.to_string(), vec![value]); } }); - let body = body::to_bytes(hyper_body).await.unwrap().to_vec(); Request { socket_addr, - body: body.clone(), + body: Some(hyper_body), method: HttpMethod::from(parts.method.clone()), uri: parts.uri.clone(), version: match parts.version { @@ -69,33 +67,34 @@ impl Request { }, params: HashMap::new(), cookies: vec![], - hyper_request: HyperRequest::from_parts(parts, Body::from(body)), } } - pub fn get_body_bytes(&self) -> &[u8] { - &self.body - } - - pub fn get_body(&self) -> Result { - Ok(str::from_utf8(&self.body)?.to_string()) + /// Returns the original body stream sent from client + /// + /// For inspecting the body contents inside different middlewares, + /// it needs to be buffered somewhere in the context + /// + /// + /// # Panics + /// + /// This method should be called only once during the entire lifecycle of a Request. + /// If this method is called more than once, then it will panic + pub fn take_body(&mut self) -> Body { + if let Some(body) = self.body.take() { + body + } else { + panic!( + "Body stream is already extracted from Request. \n\ + If the body has to be consumed inside more than one middleware, it needs to be buffered somewhere in the context." + ) + } } pub fn get_method(&self) -> &HttpMethod { &self.method } - pub fn get_length(&self) -> u128 { - if let Some(length) = self.headers.get("Content-Length") { - if let Some(value) = length.get(0) { - if let Ok(value) = value.parse::() { - return value; - } - } - } - self.body.len() as u128 - } - pub fn get_path(&self) -> String { self.uri.path().to_string() } @@ -227,14 +226,6 @@ impl Request { pub fn get_cookie(&self, name: &str) -> Option<&Cookie> { self.cookies.iter().find(|cookie| cookie.key == name) } - - pub fn get_hyper_request(&self) -> &HyperRequest { - &self.hyper_request - } - - pub fn get_hyper_request_mut(&mut self) -> &mut HyperRequest { - &mut self.hyper_request - } } #[cfg(debug_assertions)] diff --git a/src/response.rs b/src/response.rs index c515c3a..1d03fce 100644 --- a/src/response.rs +++ b/src/response.rs @@ -3,13 +3,12 @@ use std::{ fmt::{Debug, Formatter, Result as FmtResult}, }; -use chrono::Local; +use hyper::Body; use crate::Cookie; -#[derive(Clone)] pub struct Response { - pub(crate) body: Vec, + pub(crate) body: Body, pub(crate) status: u16, pub(crate) headers: HashMap>, } @@ -17,7 +16,7 @@ pub struct Response { impl Response { pub fn new() -> Self { Response { - body: vec![], + body: Body::empty(), status: 200, headers: HashMap::new(), } @@ -167,16 +166,12 @@ impl Response { self.set_header("ETag", etag); } - pub fn get_body(&self) -> &Vec { - &self.body + pub fn take_body(&mut self) -> Body { + std::mem::take(&mut self.body) } - pub fn set_body(&mut self, data: &str) { - self.set_body_bytes(data.as_bytes()); - } - pub fn set_body_bytes(&mut self, data: &[u8]) { - self.body = data.to_vec(); - self.set_content_length(data.len()); - self.set_header("date", &Local::now().to_rfc2822()); + + pub fn set_body>(&mut self, data: T) { + self.body = data.into(); } pub fn set_cookie(&mut self, cookie: Cookie) { @@ -201,7 +196,7 @@ impl Debug for Response { impl Default for Response { fn default() -> Self { Response { - body: vec![], + body: Body::default(), status: 200, headers: HashMap::new(), }