-
Notifications
You must be signed in to change notification settings - Fork 0
[Auth] Cover the auth middleware with tests #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,34 @@ | ||
use std::sync::Arc; | ||
|
||
use axum::extract::State; | ||
use axum::extract::{FromRef, State}; | ||
use axum::http::{self, HeaderMap, HeaderValue, Request}; | ||
use axum::middleware::Next; | ||
use axum::response::IntoResponse; | ||
use lib::prelude::*; | ||
|
||
use super::auth::{AuthError, SecretApiKey}; | ||
use super::auth::{AuthError, Authenticator, SecretApiKey}; | ||
use super::errors::ApiError; | ||
use super::AppState; | ||
|
||
const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of"; | ||
|
||
// Partial state from the main app state to facilitate writing tests for the | ||
// middleware. | ||
#[derive(Clone)] | ||
pub struct AuthenticationState { | ||
authenticator: Authenticator, | ||
config: super::config::ApiSvcConfig, | ||
} | ||
|
||
impl FromRef<Arc<AppState>> for AuthenticationState { | ||
fn from_ref(input: &Arc<AppState>) -> Self { | ||
Self { | ||
authenticator: input.authenticator.clone(), | ||
config: input.context.service_config(), | ||
} | ||
} | ||
} | ||
|
||
enum AuthenticationStatus { | ||
Unauthenticated, | ||
Authenticated(ValidShardedId<ProjectId>), | ||
|
@@ -61,14 +78,14 @@ fn get_auth_key( | |
} | ||
|
||
async fn get_auth_status<B>( | ||
state: &AppState, | ||
state: &AuthenticationState, | ||
req: &Request<B>, | ||
) -> Result<AuthenticationStatus, ApiError> { | ||
let auth_key = get_auth_key(req.headers())?; | ||
let Some(auth_key) = auth_key else { | ||
return Ok(AuthenticationStatus::Unauthenticated); | ||
}; | ||
let config = state.context.service_config(); | ||
let config = &state.config; | ||
let admin_keys = &config.admin_api_keys; | ||
if admin_keys.contains(&auth_key) { | ||
let project: Option<ValidShardedId<ProjectId>> = req | ||
|
@@ -98,7 +115,10 @@ async fn get_auth_status<B>( | |
return Ok(AuthenticationStatus::Unauthenticated); | ||
}; | ||
|
||
let project = state.authenicator.authenticate(&user_provided_secret).await; | ||
let project = state | ||
.authenticator | ||
.authenticate(&user_provided_secret) | ||
.await; | ||
match project { | ||
| Ok(project_id) => Ok(AuthenticationStatus::Authenticated(project_id)), | ||
| Err(AuthError::AuthFailed(_)) => { | ||
|
@@ -178,11 +198,11 @@ pub async fn ensure_admin<B>( | |
/// of the other "ensure_*" middlewares in this module to enforce the expected | ||
/// AuthenticationStatus for a certain route. | ||
pub async fn authenticate<B>( | ||
State(state): State<Arc<AppState>>, | ||
State(state): State<AuthenticationState>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very happy that you started to use partial states 🚀 |
||
mut req: Request<B>, | ||
next: Next<B>, | ||
) -> Result<impl IntoResponse, ApiError> { | ||
let auth_status = get_auth_status(state.as_ref(), &req).await?; | ||
let auth_status = get_auth_status(&state, &req).await?; | ||
|
||
let project_id = auth_status.project_id(); | ||
req.extensions_mut().insert(auth_status); | ||
|
@@ -200,3 +220,260 @@ pub async fn authenticate<B>( | |
|
||
Ok(resp) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
|
||
use std::collections::HashSet; | ||
use std::fmt::Debug; | ||
|
||
use axum::routing::get; | ||
use axum::{middleware, Router}; | ||
use cronback_api_model::admin::CreateAPIkeyRequest; | ||
use hyper::{Body, StatusCode}; | ||
use tower::ServiceExt; | ||
|
||
use super::*; | ||
use crate::api::auth_store::AuthStore; | ||
use crate::api::config::ApiSvcConfig; | ||
use crate::api::ApiService; | ||
|
||
async fn make_state() -> AuthenticationState { | ||
let mut set = HashSet::new(); | ||
set.insert("adminkey1".to_string()); | ||
set.insert("adminkey2".to_string()); | ||
|
||
let config = ApiSvcConfig { | ||
address: String::new(), | ||
port: 123, | ||
database_uri: String::new(), | ||
admin_api_keys: set, | ||
log_request_body: false, | ||
log_response_body: false, | ||
}; | ||
|
||
let db = ApiService::in_memory_database().await.unwrap(); | ||
let auth_store = AuthStore::new(db); | ||
let authenticator = Authenticator::new(auth_store); | ||
|
||
AuthenticationState { | ||
authenticator, | ||
config, | ||
} | ||
} | ||
|
||
struct TestInput { | ||
app: Router, | ||
auth_header: Option<String>, | ||
on_behalf_on_header: Option<String>, | ||
expected_status: StatusCode, | ||
} | ||
|
||
impl Debug for TestInput { | ||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
f.debug_struct("TestInput") | ||
.field("auth_header", &self.auth_header) | ||
.field("on_behalf_on_header", &self.on_behalf_on_header) | ||
.field("expected_status", &self.expected_status) | ||
.finish() | ||
} | ||
} | ||
|
||
struct TestExpectations { | ||
unauthenticated: StatusCode, | ||
authenticated: StatusCode, | ||
admin_no_project: StatusCode, | ||
admin_with_project: StatusCode, | ||
unknown_secret_key: StatusCode, | ||
} | ||
|
||
async fn run_tests( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TIL. But it doesn't work with |
||
app: Router, | ||
state: AuthenticationState, | ||
expectations: TestExpectations, | ||
) -> anyhow::Result<()> { | ||
// Define one project and generate a key for it. | ||
let prj1 = ProjectId::generate(); | ||
let key = state | ||
.authenticator | ||
.gen_key( | ||
CreateAPIkeyRequest { | ||
key_name: "test".to_string(), | ||
metadata: Default::default(), | ||
}, | ||
&prj1, | ||
) | ||
.await?; | ||
|
||
let inputs = vec![ | ||
// Unauthenticated user | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: None, | ||
on_behalf_on_header: None, | ||
expected_status: expectations.unauthenticated, | ||
}, | ||
// Authenticated user | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some(format!("Bearer {}", key.unsafe_to_string())), | ||
on_behalf_on_header: None, | ||
expected_status: expectations.authenticated, | ||
}, | ||
// Admin without project | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some("Bearer adminkey1".to_string()), | ||
on_behalf_on_header: None, | ||
expected_status: expectations.admin_no_project, | ||
}, | ||
// Admin with project | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some("Bearer adminkey1".to_string()), | ||
on_behalf_on_header: Some(prj1.to_string()), | ||
expected_status: expectations.admin_with_project, | ||
}, | ||
// Unknown secret key | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some(format!( | ||
"Bearer {}", | ||
SecretApiKey::generate().unsafe_to_string() | ||
)), | ||
on_behalf_on_header: Some(prj1.to_string()), | ||
expected_status: expectations.unknown_secret_key, | ||
}, | ||
// Malformed secret key should be treated as an unknown secret key | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some("Bearer wrong key".to_string()), | ||
on_behalf_on_header: Some("wrong_project".to_string()), | ||
expected_status: expectations.unknown_secret_key, | ||
}, | ||
// Malformed authorization header | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some(format!("Token {}", key.unsafe_to_string())), | ||
on_behalf_on_header: Some(prj1.to_string()), | ||
expected_status: StatusCode::BAD_REQUEST, | ||
}, | ||
// Malformed on-behalf-on project id | ||
TestInput { | ||
app: app.clone(), | ||
auth_header: Some("Bearer adminkey1".to_string()), | ||
on_behalf_on_header: Some("wrong_project".to_string()), | ||
expected_status: StatusCode::BAD_REQUEST, | ||
}, | ||
]; | ||
|
||
for input in inputs { | ||
let input_str = format!("{:?}", input); | ||
|
||
let mut req = Request::builder(); | ||
if let Some(v) = input.auth_header { | ||
req = req.header("Authorization", v); | ||
} | ||
if let Some(v) = input.on_behalf_on_header { | ||
req = req.header(ON_BEHALF_OF_HEADER_NAME, v); | ||
} | ||
|
||
let resp = input | ||
.app | ||
.oneshot(req.uri("/").body(Body::empty()).unwrap()) | ||
.await?; | ||
|
||
assert_eq!( | ||
resp.status(), | ||
input.expected_status, | ||
"Input: {}", | ||
input_str | ||
); | ||
} | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_ensure_authenticated() -> anyhow::Result<()> { | ||
let state = make_state().await; | ||
|
||
let app = Router::new() | ||
.route("/", get(|| async { "Hello, World!" })) | ||
.layer(middleware::from_fn(super::ensure_authenticated)) | ||
.layer(middleware::from_fn_with_state( | ||
state.clone(), | ||
super::authenticate, | ||
)); | ||
|
||
run_tests( | ||
app, | ||
state, | ||
TestExpectations { | ||
unauthenticated: StatusCode::UNAUTHORIZED, | ||
authenticated: StatusCode::OK, | ||
admin_no_project: StatusCode::BAD_REQUEST, | ||
admin_with_project: StatusCode::OK, | ||
unknown_secret_key: StatusCode::UNAUTHORIZED, | ||
}, | ||
) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_ensure_admin() -> anyhow::Result<()> { | ||
let state = make_state().await; | ||
|
||
let app = Router::new() | ||
.route("/", get(|| async { "Hello, World!" })) | ||
.layer(middleware::from_fn(super::ensure_admin)) | ||
.layer(middleware::from_fn_with_state( | ||
state.clone(), | ||
super::authenticate, | ||
)); | ||
|
||
run_tests( | ||
app, | ||
state, | ||
TestExpectations { | ||
unauthenticated: StatusCode::UNAUTHORIZED, | ||
authenticated: StatusCode::FORBIDDEN, | ||
admin_no_project: StatusCode::OK, | ||
admin_with_project: StatusCode::OK, | ||
unknown_secret_key: StatusCode::UNAUTHORIZED, | ||
}, | ||
) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_ensure_admin_for_project() -> anyhow::Result<()> { | ||
let state = make_state().await; | ||
|
||
let app = Router::new() | ||
.route("/", get(|| async { "Hello, World!" })) | ||
.layer(middleware::from_fn(super::ensure_admin_for_project)) | ||
.layer(middleware::from_fn_with_state( | ||
state.clone(), | ||
super::authenticate, | ||
)); | ||
|
||
run_tests( | ||
app, | ||
state, | ||
TestExpectations { | ||
unauthenticated: StatusCode::UNAUTHORIZED, | ||
authenticated: StatusCode::FORBIDDEN, | ||
admin_no_project: StatusCode::BAD_REQUEST, | ||
admin_with_project: StatusCode::OK, | ||
unknown_secret_key: StatusCode::UNAUTHORIZED, | ||
}, | ||
) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought FromRef was implemented for all
T: Clone
already, no? I wonder why this wasn't the case forAuthenticationState
.Also, did you try to
#[derive(FromRef, Clone)]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FromRef<T> for T
is implemented for all Ts whereT: Clone
but here I'm defining how to get anAuthenticationState
(custom struct) from anAppState
, there's no way the compiler can infer this on its own or even with a derive macro, right?I think the derive macro might work only if I'm getting a state member from the state struct, but here it's a completely custom new struct.