Skip to content

Commit bd78086

Browse files
committed
Rust 1.48.0 + sentry Middlewares
1 parent 6bed25c commit bd78086

File tree

9 files changed

+192
-124
lines changed

9 files changed

+192
-124
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust-toolchain

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.45.2
1+
1.48.0

sentry/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ edition = "2018"
88
# Futures
99
futures = "0.3.1"
1010
async-std = "1.4.0"
11+
async-trait = "^0.1"
1112
# Primitives
1213
primitives = { path = "../primitives", features = ["postgres"] }
1314
adapter = { version = "0.1", path = "../adapter" }

sentry/src/chain.rs

-40
This file was deleted.

sentry/src/event_aggregator.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ impl EventAggregator {
8585
// fetch channel
8686
let channel = get_channel_by_id(&app.pool, &channel_id)
8787
.await?
88-
.ok_or_else(|| ResponseError::NotFound)?;
88+
.ok_or(ResponseError::NotFound)?;
8989

9090
let withdraw_period_start = channel.spec.withdraw_period_start;
9191
let channel_id = channel.id;

sentry/src/lib.rs

+46-76
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#![deny(clippy::all)]
22
#![deny(rust_2018_idioms)]
33

4-
use crate::chain::chain;
54
use crate::db::DbPool;
65
use crate::event_aggregator::EventAggregator;
7-
use crate::middleware::auth;
8-
use crate::middleware::channel::{channel_load, get_channel_id};
9-
use crate::middleware::cors::{cors, Cors};
106
use crate::routes::channel::channel_status;
117
use crate::routes::event_aggregate::list_channel_event_aggregates;
128
use crate::routes::validator_message::{extract_params, list_validator_messages};
139
use chrono::Utc;
14-
use futures::future::{BoxFuture, FutureExt};
1510
use hyper::{Body, Method, Request, Response, StatusCode};
1611
use lazy_static::lazy_static;
12+
use middleware::{
13+
auth::{AuthRequired, Authenticate},
14+
channel::{ChannelLoad, GetChannelId},
15+
cors::{cors, Cors},
16+
};
17+
use middleware::{Chain, Middleware};
1718
use primitives::adapter::Adapter;
1819
use primitives::sentry::ValidationErrorResponse;
1920
use primitives::{Config, ValidatorId};
@@ -25,15 +26,10 @@ use routes::channel::{
2526
channel_list, channel_validate, create_channel, create_validator_messages, insert_events,
2627
last_approved,
2728
};
28-
use slog::{error, Logger};
29+
use slog::Logger;
2930
use std::collections::HashMap;
3031

31-
pub mod middleware {
32-
pub mod auth;
33-
pub mod channel;
34-
pub mod cors;
35-
}
36-
32+
pub mod middleware;
3733
pub mod routes {
3834
pub mod analytics;
3935
pub mod cfg;
@@ -44,7 +40,6 @@ pub mod routes {
4440

4541
pub mod access;
4642
pub mod analytics_recorder;
47-
mod chain;
4843
pub mod db;
4944
pub mod event_aggregator;
5045
pub mod event_reducer;
@@ -64,20 +59,6 @@ lazy_static! {
6459
static ref CREATE_EVENTS_BY_CHANNEL_ID: Regex = Regex::new(r"^/channel/0x([a-zA-Z0-9]{64})/events/?$").expect("The regex should be valid");
6560
}
6661

67-
fn auth_required_middleware<'a, A: Adapter>(
68-
req: Request<Body>,
69-
_: &Application<A>,
70-
) -> BoxFuture<'a, Result<Request<Body>, ResponseError>> {
71-
async move {
72-
if req.extensions().get::<Session>().is_some() {
73-
Ok(req)
74-
} else {
75-
Err(ResponseError::Unauthorized)
76-
}
77-
}
78-
.boxed()
79-
}
80-
8162
#[derive(Debug)]
8263
pub struct RouteParams(Vec<String>);
8364

@@ -127,12 +108,9 @@ impl<A: Adapter + 'static> Application<A> {
127108
None => Default::default(),
128109
};
129110

130-
let req = match auth::for_request(req, &self.adapter, self.redis.clone()).await {
111+
let req = match Authenticate.call(req, &self).await {
131112
Ok(req) => req,
132-
Err(error) => {
133-
error!(&self.logger, "{}", &error; "module" => "middleware-auth");
134-
return map_response_error(ResponseError::Unauthorized);
135-
}
113+
Err(error) => return map_response_error(error),
136114
};
137115

138116
let mut response = match (req.uri().path(), req.method()) {
@@ -143,7 +121,7 @@ impl<A: Adapter + 'static> Application<A> {
143121

144122
("/analytics", &Method::GET) => analytics(req, &self).await,
145123
("/analytics/advanced", &Method::GET) => {
146-
let req = match chain(req, &self, vec![Box::new(auth_required_middleware)]).await {
124+
let req = match AuthRequired.call(req, &self).await {
147125
Ok(req) => req,
148126
Err(error) => {
149127
return map_response_error(error);
@@ -153,7 +131,7 @@ impl<A: Adapter + 'static> Application<A> {
153131
advanced_analytics(req, &self).await
154132
}
155133
("/analytics/for-advertiser", &Method::GET) => {
156-
let req = match chain(req, &self, vec![Box::new(auth_required_middleware)]).await {
134+
let req = match AuthRequired.call(req, &self).await {
157135
Ok(req) => req,
158136
Err(error) => {
159137
return map_response_error(error);
@@ -162,7 +140,7 @@ impl<A: Adapter + 'static> Application<A> {
162140
advertiser_analytics(req, &self).await
163141
}
164142
("/analytics/for-publisher", &Method::GET) => {
165-
let req = match chain(req, &self, vec![Box::new(auth_required_middleware)]).await {
143+
let req = match AuthRequired.call(req, &self).await {
166144
Ok(req) => req,
167145
Err(error) => {
168146
return map_response_error(error);
@@ -199,12 +177,12 @@ async fn analytics_router<A: Adapter + 'static>(
199177
.map_or("".to_string(), |m| m.as_str().to_string())]);
200178
req.extensions_mut().insert(param);
201179

202-
let req = chain(
203-
req,
204-
app,
205-
vec![Box::new(channel_load), Box::new(get_channel_id)],
206-
)
207-
.await?;
180+
// apply middlewares
181+
req = Chain::new()
182+
.chain(ChannelLoad)
183+
.chain(GetChannelId)
184+
.apply(req, app)
185+
.await?;
208186

209187
analytics(req, app).await
210188
} else if let Some(caps) = ADVERTISER_ANALYTICS_BY_CHANNEL_ID.captures(route) {
@@ -213,12 +191,12 @@ async fn analytics_router<A: Adapter + 'static>(
213191
.map_or("".to_string(), |m| m.as_str().to_string())]);
214192
req.extensions_mut().insert(param);
215193

216-
let req = chain(
217-
req,
218-
app,
219-
vec![Box::new(auth_required_middleware), Box::new(get_channel_id)],
220-
)
221-
.await?;
194+
// apply middlewares
195+
req = Chain::new()
196+
.chain(AuthRequired)
197+
.chain(GetChannelId)
198+
.apply(req, app)
199+
.await?;
222200

223201
advertiser_analytics(req, app).await
224202
} else if let Some(caps) = PUBLISHER_ANALYTICS_BY_CHANNEL_ID.captures(route) {
@@ -227,12 +205,12 @@ async fn analytics_router<A: Adapter + 'static>(
227205
.map_or("".to_string(), |m| m.as_str().to_string())]);
228206
req.extensions_mut().insert(param);
229207

230-
let req = chain(
231-
req,
232-
app,
233-
vec![Box::new(auth_required_middleware), Box::new(get_channel_id)],
234-
)
235-
.await?;
208+
// apply middlewares
209+
req = Chain::new()
210+
.chain(AuthRequired)
211+
.chain(GetChannelId)
212+
.apply(req, app)
213+
.await?;
236214

237215
publisher_analytics(req, app).await
238216
} else {
@@ -274,7 +252,7 @@ async fn channels_router<A: Adapter + 'static>(
274252
.map_or("".to_string(), |m| m.as_str().to_string())]);
275253
req.extensions_mut().insert(param);
276254

277-
let req = channel_load(req, app).await?;
255+
req = ChannelLoad.call(req, app).await?;
278256
channel_status(req, app).await
279257
} else if let (Some(caps), &Method::GET) = (CHANNEL_VALIDATOR_MESSAGES.captures(&path), method)
280258
{
@@ -284,12 +262,7 @@ async fn channels_router<A: Adapter + 'static>(
284262

285263
req.extensions_mut().insert(param);
286264

287-
let req = match chain(req, app, vec![Box::new(channel_load)]).await {
288-
Ok(req) => req,
289-
Err(error) => {
290-
return Err(error);
291-
}
292-
};
265+
req = ChannelLoad.call(req, app).await?;
293266

294267
// @TODO: Move this to a middleware?!
295268
let extract_params = match extract_params(caps.get(2).map_or("", |m| m.as_str())) {
@@ -308,24 +281,15 @@ async fn channels_router<A: Adapter + 'static>(
308281

309282
req.extensions_mut().insert(param);
310283

311-
let req = match chain(
312-
req,
313-
app,
314-
vec![Box::new(auth_required_middleware), Box::new(channel_load)],
315-
)
316-
.await
317-
{
318-
Ok(req) => req,
319-
Err(error) => {
320-
return Err(error);
321-
}
322-
};
284+
let req = Chain::new()
285+
.chain(AuthRequired)
286+
.chain(ChannelLoad)
287+
.apply(req, app)
288+
.await?;
323289

324290
create_validator_messages(req, &app).await
325291
} else if let (Some(caps), &Method::GET) = (CHANNEL_EVENTS_AGGREGATES.captures(&path), method) {
326-
if req.extensions().get::<Session>().is_none() {
327-
return Err(ResponseError::Unauthorized);
328-
}
292+
req = AuthRequired.call(req, app).await?;
329293

330294
let param = RouteParams(vec![
331295
caps.get(1)
@@ -335,7 +299,7 @@ async fn channels_router<A: Adapter + 'static>(
335299
]);
336300
req.extensions_mut().insert(param);
337301

338-
let req = chain(req, app, vec![Box::new(channel_load)]).await?;
302+
req = ChannelLoad.call(req, app).await?;
339303

340304
list_channel_event_aggregates(req, app).await
341305
} else {
@@ -365,6 +329,12 @@ where
365329
}
366330
}
367331

332+
impl Into<Response<Body>> for ResponseError {
333+
fn into(self) -> Response<Body> {
334+
map_response_error(self)
335+
}
336+
}
337+
368338
pub fn map_response_error(error: ResponseError) -> Response<Body> {
369339
match error {
370340
ResponseError::NotFound => not_found(),

sentry/src/middleware.rs

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
use std::fmt::Debug;
2+
3+
use crate::{Application, ResponseError};
4+
use hyper::{Body, Request};
5+
use primitives::adapter::Adapter;
6+
7+
use async_trait::async_trait;
8+
9+
pub mod auth;
10+
pub mod channel;
11+
pub mod cors;
12+
13+
#[async_trait]
14+
pub trait Middleware<A: Adapter + 'static>: Send + Sync + Debug {
15+
async fn call<'a>(
16+
&self,
17+
request: Request<Body>,
18+
application: &'a Application<A>,
19+
) -> Result<Request<Body>, ResponseError>;
20+
}
21+
22+
#[derive(Debug, Default)]
23+
/// `Chain` allows chaining multiple middleware to be applied on the Request of the application
24+
/// Chained middlewares are applied in the order they were chained
25+
pub struct Chain<A: Adapter + 'static>(Vec<Box<dyn Middleware<A>>>);
26+
27+
impl<A: Adapter + 'static> Chain<A> {
28+
pub fn new() -> Self {
29+
Chain(vec![])
30+
}
31+
32+
pub fn chain<M: Middleware<A> + 'static>(mut self, middleware: M) -> Self {
33+
self.0.push(Box::new(middleware));
34+
35+
self
36+
}
37+
38+
/// Applies chained middlewares in the order they were chained
39+
pub async fn apply<'a>(
40+
&self,
41+
mut request: Request<Body>,
42+
application: &'a Application<A>,
43+
) -> Result<Request<Body>, ResponseError> {
44+
for middleware in self.0.iter() {
45+
request = middleware.call(request, application).await?;
46+
}
47+
48+
Ok(request)
49+
}
50+
}

0 commit comments

Comments
 (0)