Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE settings DROP COLUMN openid_username_handling;
DROP TYPE openid_username_handling;
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE TYPE openid_username_handling AS ENUM (
'remove_forbidden',
'replace_forbidden',
'prune_email_domain'
);
ALTER TABLE settings ADD COLUMN openid_username_handling openid_username_handling NOT NULL DEFAULT 'remove_forbidden';
20 changes: 18 additions & 2 deletions src/db/models/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ pub enum SmtpEncryption {
ImplicitTls,
}

#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Type, Debug, Default, Copy)]
#[sqlx(type_name = "openid_username_handling", rename_all = "snake_case")]
pub enum OpenidUsernameHandling {
#[default]
/// Removes all forbidden characters
RemoveForbidden,
/// Replaces all forbidden characters with `_`
ReplaceForbidden,
/// Removes the email domain, replaces all other forbidden characters with `_`
PruneEmailDomain,
}

#[derive(Clone, Debug, Deserialize, PartialEq, Patch, Serialize, Default)]
#[patch(attribute(derive(Deserialize, Serialize, Debug)))]
pub struct Settings {
Expand Down Expand Up @@ -107,6 +119,7 @@ pub struct Settings {
pub ldap_sync_groups: Vec<String>,
// Whether to create a new account when users try to log in with external OpenID
pub openid_create_account: bool,
pub openid_username_handling: OpenidUsernameHandling,
pub license: Option<String>,
// Gateway disconnect notifications
pub gateway_disconnect_notifications_enabled: bool,
Expand Down Expand Up @@ -138,7 +151,8 @@ impl Settings {
ldap_sync_status \"ldap_sync_status: SyncStatus\", \
ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, \
ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, \
ldap_user_rdn_attr, ldap_sync_groups \
ldap_user_rdn_attr, ldap_sync_groups, \
openid_username_handling \"openid_username_handling: OpenidUsernameHandling\" \
FROM \"settings\" WHERE id = 1",
)
.fetch_optional(executor)
Expand Down Expand Up @@ -209,7 +223,8 @@ impl Settings {
ldap_user_auxiliary_obj_classes = $44, \
ldap_uses_ad = $45, \
ldap_user_rdn_attr = $46, \
ldap_sync_groups = $47 \
ldap_sync_groups = $47, \
openid_username_handling = $48 \
WHERE id = 1",
self.openid_enabled,
self.wireguard_enabled,
Expand Down Expand Up @@ -258,6 +273,7 @@ impl Settings {
self.ldap_uses_ad,
self.ldap_user_rdn_attr,
&self.ldap_sync_groups as &Vec<String>,
&self.openid_username_handling as &OpenidUsernameHandling,
)
.execute(executor)
.await?;
Expand Down
129 changes: 123 additions & 6 deletions src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,65 @@ static NONCE_COOKIE_NAME: &str = "nonce";
use super::LicenseInfo;
use crate::{
appstate::AppState,
db::{Id, Settings, User},
db::{models::settings::OpenidUsernameHandling, Id, Settings, User},
enterprise::{
db::models::openid_provider::OpenIdProvider,
directory_sync::sync_user_groups_if_configured, ldap::utils::ldap_update_user_state,
limits::update_counts,
},
error::WebError,
handlers::{
auth::create_session,
user::{check_username, prune_username},
ApiResponse, AuthResponse, SESSION_COOKIE_NAME, SIGN_IN_COOKIE_NAME,
auth::create_session, user::check_username, ApiResponse, AuthResponse, SESSION_COOKIE_NAME,
SIGN_IN_COOKIE_NAME,
},
server_config,
};

/// Prune the given username from illegal characters in accordance with the following rules:
///
/// To enable LDAP sync usernames need to avoid reserved characters.
/// Username requirements:
/// - 64 characters long
/// - only lowercase or uppercase latin alphabet letters (A-Z, a-z) and digits (0-9)
/// - starts with non-special character
/// - only special characters allowed: . - _
/// - no whitespaces
pub fn prune_username(username: &str, handling: OpenidUsernameHandling) -> String {
let mut result = username.to_string();

// Go through the string and remove any non-alphanumeric characters at the beginning
result = result
.trim_start_matches(|c: char| !c.is_ascii_alphanumeric())
.to_string();

let is_char_valid = |c: char| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_';

match handling {
OpenidUsernameHandling::RemoveForbidden => {
result.retain(&is_char_valid);
}
OpenidUsernameHandling::ReplaceForbidden => {
result = result
.chars()
.map(|c| if is_char_valid(c) { c } else { '_' })
.collect();
}
OpenidUsernameHandling::PruneEmailDomain => {
if let Some(at_index) = result.find('@') {
result.truncate(at_index);
}
result = result
.chars()
.map(|c| if is_char_valid(c) { c } else { '_' })
.collect();
}
}

result.truncate(64);

result
}

/// Create HTTP client and prevent following redirects
async fn get_async_http_client() -> Result<reqwest::Client, WebError> {
reqwest::Client::builder()
Expand Down Expand Up @@ -207,15 +251,16 @@ pub(crate) async fn user_from_claims(
debug!("Username extracted from email ({email:?}): {username})");
username
};
let username = prune_username(username);
let settings = Settings::get_current_settings();

let username = prune_username(username, settings.openid_username_handling);
// Check if the username is valid just in case, not everything can be handled by the pruning.
check_username(&username)?;

// Get the *sub* claim from the token.
let sub = token_claims.subject().to_string();

// Handle logging in or creating user.
let settings = Settings::get_current_settings();
let user = match User::find_by_sub(pool, &sub)
.await
.map_err(|err| WebError::Authorization(err.to_string()))?
Expand Down Expand Up @@ -557,3 +602,75 @@ pub(crate) async fn auth_callback(
unimplemented!("Impossible to get here");
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_prune_username() {
// Test RemoveForbidden handling
let handling_remove = OpenidUsernameHandling::RemoveForbidden;
assert_eq!(prune_username("zenek", handling_remove), "zenek");
assert_eq!(prune_username("zenek34", handling_remove), "zenek34");
assert_eq!(prune_username("zenek@34", handling_remove), "zenek34");
assert_eq!(prune_username("first.last", handling_remove), "first.last");
assert_eq!(prune_username("__zenek__", handling_remove), "zenek__");
assert_eq!(prune_username("zenek?", handling_remove), "zenek");
assert_eq!(prune_username("zenek!", handling_remove), "zenek");
assert_eq!(
prune_username(
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
handling_remove
),
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
);
assert_eq!(prune_username("", handling_remove), "");
assert_eq!(prune_username("!@#$%^&*()", handling_remove), "");
assert_eq!(prune_username("!zenek", handling_remove), "zenek");
assert_eq!(prune_username("...zenek", handling_remove), "zenek");

// Test ReplaceForbidden handling
let handling_replace = OpenidUsernameHandling::ReplaceForbidden;
assert_eq!(prune_username("zenek", handling_replace), "zenek");
assert_eq!(prune_username("zenek34", handling_replace), "zenek34");
assert_eq!(prune_username("zenek@34", handling_replace), "zenek_34");
assert_eq!(prune_username("first.last", handling_replace), "first.last");
assert_eq!(prune_username("__zenek__", handling_replace), "zenek__");
assert_eq!(prune_username("zenek?", handling_replace), "zenek_");
assert_eq!(prune_username("zenek!", handling_replace), "zenek_");
assert_eq!(
prune_username(
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
handling_replace
),
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
);

// Test PruneEmailDomain handling
let handling_prune_email = OpenidUsernameHandling::PruneEmailDomain;
assert_eq!(
prune_username("[email protected]", handling_prune_email),
"zenek"
);
assert_eq!(
prune_username("[email protected]", handling_prune_email),
"user.name"
);
assert_eq!(
prune_username("[email protected]", handling_prune_email),
"invalid_chars_"
);
assert_eq!(
prune_username("multiple@[email protected]", handling_prune_email),
"multiple"
);
assert_eq!(
prune_username(
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee@domain.com",
handling_prune_email
),
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
);
}
}
9 changes: 7 additions & 2 deletions src/enterprise/handlers/openid_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use super::LicenseInfo;
use crate::{
appstate::AppState,
auth::{AdminRole, SessionInfo},
db::{models::settings::update_current_settings, Settings},
db::{
models::settings::{update_current_settings, OpenidUsernameHandling},
Settings,
},
enterprise::{
db::models::openid_provider::OpenIdProvider, directory_sync::test_directory_sync_connection,
},
Expand All @@ -36,6 +39,7 @@ pub struct AddProviderData {
pub okta_private_jwk: Option<String>,
pub okta_dirsync_client_id: Option<String>,
pub directory_sync_group_match: Option<String>,
pub username_handling: OpenidUsernameHandling,
}

#[derive(Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -107,6 +111,7 @@ pub async fn add_openid_provider(

let mut settings = Settings::get_current_settings();
settings.openid_create_account = provider_data.create_account;
settings.openid_username_handling = provider_data.username_handling;
update_current_settings(&appstate.pool, settings).await?;

let group_match = if let Some(group_match) = provider_data.directory_sync_group_match {
Expand Down Expand Up @@ -173,7 +178,7 @@ pub async fn get_current_openid_provider(
Ok(ApiResponse {
json: json!({
"provider": json!(provider),
"settings": json!({ "create_account": create_account }),
"settings": json!({ "create_account": create_account, "username_handling": settings.openid_username_handling}),
}),
status: StatusCode::OK,
})
Expand Down
43 changes: 0 additions & 43 deletions src/handlers/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,32 +74,6 @@ pub fn check_username(username: &str) -> Result<(), WebError> {
Ok(())
}

/// Prune the given username from illegal characters in accordance with the following rules:
///
/// To enable LDAP sync usernames need to avoid reserved characters.
/// Username requirements:
/// - 64 characters long
/// - only lowercase or uppercase latin alphabet letters (A-Z, a-z) and digits (0-9)
/// - starts with non-special character
/// - only special characters allowed: . - _
/// - no whitespaces
pub fn prune_username(username: &str) -> String {
let mut result = username.to_string();

if result.len() > 64 {
result.truncate(64);
}

// Go through the string and remove any non-alphanumeric characters at the beginning
result = result
.trim_start_matches(|c: char| !c.is_ascii_alphanumeric())
.to_string();

result.retain(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');

result
}

pub(crate) fn check_password_strength(password: &str) -> Result<(), WebError> {
if !(8..=128).contains(&password.len()) {
return Err(WebError::Serialization("Incorrect password length".into()));
Expand Down Expand Up @@ -1248,23 +1222,6 @@ mod test {

use super::*;

#[test]
fn test_username_prune() {
assert_eq!(prune_username("zenek"), "zenek");
assert_eq!(prune_username("zenek34"), "zenek34");
assert_eq!(prune_username("zenek@34"), "zenek34");
assert_eq!(prune_username("first.last"), "first.last");
assert_eq!(prune_username("__zenek__"), "zenek__");
assert_eq!(prune_username("zenek?"), "zenek");
assert_eq!(prune_username("zenek!"), "zenek");
assert_eq!(
prune_username(
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
),
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
);
}

#[test]
fn test_username_validation() {
// valid usernames
Expand Down
Loading
Loading