Skip to content

Commit f8a96bf

Browse files
authored
Merge pull request #1138 from DefGuard/dev
Merge dev -> main (1.3.0 3rd merge)
2 parents 7e5f3c2 + e67e1cb commit f8a96bf

15 files changed

+325
-80
lines changed

.sqlx/query-d347ad2fe71dd67c0f07a5ae4114637bde1cef092ea64a1d535341a531247dc5.json renamed to .sqlx/query-3491725f35609e9b219c4d613cffd28a14cf37e546dfcabdfd78889dc1ef247f.json

Lines changed: 15 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.sqlx/query-5008e41c4dae86fe8825731a4c4202f5ddff52ef296af1471ea7226d52f85a6a.json renamed to .sqlx/query-7ddef79c85c3e85b979d5a8a5e50660bcae531c2b8342ae2feffea7454450f10.json

Lines changed: 19 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ALTER TABLE settings DROP COLUMN openid_username_handling;
2+
DROP TYPE openid_username_handling;
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
CREATE TYPE openid_username_handling AS ENUM (
2+
'remove_forbidden',
3+
'replace_forbidden',
4+
'prune_email_domain'
5+
);
6+
ALTER TABLE settings ADD COLUMN openid_username_handling openid_username_handling NOT NULL DEFAULT 'remove_forbidden';

src/db/models/settings.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ pub enum SmtpEncryption {
4848
ImplicitTls,
4949
}
5050

51+
#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Type, Debug, Default, Copy)]
52+
#[sqlx(type_name = "openid_username_handling", rename_all = "snake_case")]
53+
pub enum OpenidUsernameHandling {
54+
#[default]
55+
/// Removes all forbidden characters
56+
RemoveForbidden,
57+
/// Replaces all forbidden characters with `_`
58+
ReplaceForbidden,
59+
/// Removes the email domain, replaces all other forbidden characters with `_`
60+
PruneEmailDomain,
61+
}
62+
5163
#[derive(Clone, Debug, Deserialize, PartialEq, Patch, Serialize, Default)]
5264
#[patch(attribute(derive(Deserialize, Serialize, Debug)))]
5365
pub struct Settings {
@@ -107,6 +119,7 @@ pub struct Settings {
107119
pub ldap_sync_groups: Vec<String>,
108120
// Whether to create a new account when users try to log in with external OpenID
109121
pub openid_create_account: bool,
122+
pub openid_username_handling: OpenidUsernameHandling,
110123
pub license: Option<String>,
111124
// Gateway disconnect notifications
112125
pub gateway_disconnect_notifications_enabled: bool,
@@ -138,7 +151,8 @@ impl Settings {
138151
ldap_sync_status \"ldap_sync_status: SyncStatus\", \
139152
ldap_enabled, ldap_sync_enabled, ldap_is_authoritative, \
140153
ldap_sync_interval, ldap_user_auxiliary_obj_classes, ldap_uses_ad, \
141-
ldap_user_rdn_attr, ldap_sync_groups \
154+
ldap_user_rdn_attr, ldap_sync_groups, \
155+
openid_username_handling \"openid_username_handling: OpenidUsernameHandling\" \
142156
FROM \"settings\" WHERE id = 1",
143157
)
144158
.fetch_optional(executor)
@@ -209,7 +223,8 @@ impl Settings {
209223
ldap_user_auxiliary_obj_classes = $44, \
210224
ldap_uses_ad = $45, \
211225
ldap_user_rdn_attr = $46, \
212-
ldap_sync_groups = $47 \
226+
ldap_sync_groups = $47, \
227+
openid_username_handling = $48 \
213228
WHERE id = 1",
214229
self.openid_enabled,
215230
self.wireguard_enabled,
@@ -258,6 +273,7 @@ impl Settings {
258273
self.ldap_uses_ad,
259274
self.ldap_user_rdn_attr,
260275
&self.ldap_sync_groups as &Vec<String>,
276+
&self.openid_username_handling as &OpenidUsernameHandling,
261277
)
262278
.execute(executor)
263279
.await?;

src/enterprise/handlers/openid_login.rs

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,65 @@ static NONCE_COOKIE_NAME: &str = "nonce";
2525
use super::LicenseInfo;
2626
use crate::{
2727
appstate::AppState,
28-
db::{Id, Settings, User},
28+
db::{models::settings::OpenidUsernameHandling, Id, Settings, User},
2929
enterprise::{
3030
db::models::openid_provider::OpenIdProvider,
3131
directory_sync::sync_user_groups_if_configured, ldap::utils::ldap_update_user_state,
3232
limits::update_counts,
3333
},
3434
error::WebError,
3535
handlers::{
36-
auth::create_session,
37-
user::{check_username, prune_username},
38-
ApiResponse, AuthResponse, SESSION_COOKIE_NAME, SIGN_IN_COOKIE_NAME,
36+
auth::create_session, user::check_username, ApiResponse, AuthResponse, SESSION_COOKIE_NAME,
37+
SIGN_IN_COOKIE_NAME,
3938
},
4039
server_config,
4140
};
4241

42+
/// Prune the given username from illegal characters in accordance with the following rules:
43+
///
44+
/// To enable LDAP sync usernames need to avoid reserved characters.
45+
/// Username requirements:
46+
/// - 64 characters long
47+
/// - only lowercase or uppercase latin alphabet letters (A-Z, a-z) and digits (0-9)
48+
/// - starts with non-special character
49+
/// - only special characters allowed: . - _
50+
/// - no whitespaces
51+
pub fn prune_username(username: &str, handling: OpenidUsernameHandling) -> String {
52+
let mut result = username.to_string();
53+
54+
// Go through the string and remove any non-alphanumeric characters at the beginning
55+
result = result
56+
.trim_start_matches(|c: char| !c.is_ascii_alphanumeric())
57+
.to_string();
58+
59+
let is_char_valid = |c: char| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_';
60+
61+
match handling {
62+
OpenidUsernameHandling::RemoveForbidden => {
63+
result.retain(&is_char_valid);
64+
}
65+
OpenidUsernameHandling::ReplaceForbidden => {
66+
result = result
67+
.chars()
68+
.map(|c| if is_char_valid(c) { c } else { '_' })
69+
.collect();
70+
}
71+
OpenidUsernameHandling::PruneEmailDomain => {
72+
if let Some(at_index) = result.find('@') {
73+
result.truncate(at_index);
74+
}
75+
result = result
76+
.chars()
77+
.map(|c| if is_char_valid(c) { c } else { '_' })
78+
.collect();
79+
}
80+
}
81+
82+
result.truncate(64);
83+
84+
result
85+
}
86+
4387
/// Create HTTP client and prevent following redirects
4488
async fn get_async_http_client() -> Result<reqwest::Client, WebError> {
4589
reqwest::Client::builder()
@@ -207,15 +251,16 @@ pub(crate) async fn user_from_claims(
207251
debug!("Username extracted from email ({email:?}): {username})");
208252
username
209253
};
210-
let username = prune_username(username);
254+
let settings = Settings::get_current_settings();
255+
256+
let username = prune_username(username, settings.openid_username_handling);
211257
// Check if the username is valid just in case, not everything can be handled by the pruning.
212258
check_username(&username)?;
213259

214260
// Get the *sub* claim from the token.
215261
let sub = token_claims.subject().to_string();
216262

217263
// Handle logging in or creating user.
218-
let settings = Settings::get_current_settings();
219264
let user = match User::find_by_sub(pool, &sub)
220265
.await
221266
.map_err(|err| WebError::Authorization(err.to_string()))?
@@ -557,3 +602,75 @@ pub(crate) async fn auth_callback(
557602
unimplemented!("Impossible to get here");
558603
}
559604
}
605+
606+
#[cfg(test)]
607+
mod test {
608+
use super::*;
609+
610+
#[test]
611+
fn test_prune_username() {
612+
// Test RemoveForbidden handling
613+
let handling_remove = OpenidUsernameHandling::RemoveForbidden;
614+
assert_eq!(prune_username("zenek", handling_remove), "zenek");
615+
assert_eq!(prune_username("zenek34", handling_remove), "zenek34");
616+
assert_eq!(prune_username("zenek@34", handling_remove), "zenek34");
617+
assert_eq!(prune_username("first.last", handling_remove), "first.last");
618+
assert_eq!(prune_username("__zenek__", handling_remove), "zenek__");
619+
assert_eq!(prune_username("zenek?", handling_remove), "zenek");
620+
assert_eq!(prune_username("zenek!", handling_remove), "zenek");
621+
assert_eq!(
622+
prune_username(
623+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
624+
handling_remove
625+
),
626+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
627+
);
628+
assert_eq!(prune_username("", handling_remove), "");
629+
assert_eq!(prune_username("!@#$%^&*()", handling_remove), "");
630+
assert_eq!(prune_username("!zenek", handling_remove), "zenek");
631+
assert_eq!(prune_username("...zenek", handling_remove), "zenek");
632+
633+
// Test ReplaceForbidden handling
634+
let handling_replace = OpenidUsernameHandling::ReplaceForbidden;
635+
assert_eq!(prune_username("zenek", handling_replace), "zenek");
636+
assert_eq!(prune_username("zenek34", handling_replace), "zenek34");
637+
assert_eq!(prune_username("zenek@34", handling_replace), "zenek_34");
638+
assert_eq!(prune_username("first.last", handling_replace), "first.last");
639+
assert_eq!(prune_username("__zenek__", handling_replace), "zenek__");
640+
assert_eq!(prune_username("zenek?", handling_replace), "zenek_");
641+
assert_eq!(prune_username("zenek!", handling_replace), "zenek_");
642+
assert_eq!(
643+
prune_username(
644+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee",
645+
handling_replace
646+
),
647+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
648+
);
649+
650+
// Test PruneEmailDomain handling
651+
let handling_prune_email = OpenidUsernameHandling::PruneEmailDomain;
652+
assert_eq!(
653+
prune_username("[email protected]", handling_prune_email),
654+
"zenek"
655+
);
656+
assert_eq!(
657+
prune_username("[email protected]", handling_prune_email),
658+
"user.name"
659+
);
660+
assert_eq!(
661+
prune_username("[email protected]", handling_prune_email),
662+
"invalid_chars_"
663+
);
664+
assert_eq!(
665+
prune_username("multiple@[email protected]", handling_prune_email),
666+
"multiple"
667+
);
668+
assert_eq!(
669+
prune_username(
670+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee@domain.com",
671+
handling_prune_email
672+
),
673+
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
674+
);
675+
}
676+
}

src/enterprise/handlers/openid_providers.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ use super::LicenseInfo;
1010
use crate::{
1111
appstate::AppState,
1212
auth::{AdminRole, SessionInfo},
13-
db::{models::settings::update_current_settings, Settings},
13+
db::{
14+
models::settings::{update_current_settings, OpenidUsernameHandling},
15+
Settings,
16+
},
1417
enterprise::{
1518
db::models::openid_provider::OpenIdProvider, directory_sync::test_directory_sync_connection,
1619
},
@@ -36,6 +39,7 @@ pub struct AddProviderData {
3639
pub okta_private_jwk: Option<String>,
3740
pub okta_dirsync_client_id: Option<String>,
3841
pub directory_sync_group_match: Option<String>,
42+
pub username_handling: OpenidUsernameHandling,
3943
}
4044

4145
#[derive(Debug, Deserialize, Serialize)]
@@ -107,6 +111,7 @@ pub async fn add_openid_provider(
107111

108112
let mut settings = Settings::get_current_settings();
109113
settings.openid_create_account = provider_data.create_account;
114+
settings.openid_username_handling = provider_data.username_handling;
110115
update_current_settings(&appstate.pool, settings).await?;
111116

112117
let group_match = if let Some(group_match) = provider_data.directory_sync_group_match {
@@ -173,7 +178,7 @@ pub async fn get_current_openid_provider(
173178
Ok(ApiResponse {
174179
json: json!({
175180
"provider": json!(provider),
176-
"settings": json!({ "create_account": create_account }),
181+
"settings": json!({ "create_account": create_account, "username_handling": settings.openid_username_handling}),
177182
}),
178183
status: StatusCode::OK,
179184
})

src/handlers/user.rs

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -74,32 +74,6 @@ pub fn check_username(username: &str) -> Result<(), WebError> {
7474
Ok(())
7575
}
7676

77-
/// Prune the given username from illegal characters in accordance with the following rules:
78-
///
79-
/// To enable LDAP sync usernames need to avoid reserved characters.
80-
/// Username requirements:
81-
/// - 64 characters long
82-
/// - only lowercase or uppercase latin alphabet letters (A-Z, a-z) and digits (0-9)
83-
/// - starts with non-special character
84-
/// - only special characters allowed: . - _
85-
/// - no whitespaces
86-
pub fn prune_username(username: &str) -> String {
87-
let mut result = username.to_string();
88-
89-
if result.len() > 64 {
90-
result.truncate(64);
91-
}
92-
93-
// Go through the string and remove any non-alphanumeric characters at the beginning
94-
result = result
95-
.trim_start_matches(|c: char| !c.is_ascii_alphanumeric())
96-
.to_string();
97-
98-
result.retain(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
99-
100-
result
101-
}
102-
10377
pub(crate) fn check_password_strength(password: &str) -> Result<(), WebError> {
10478
if !(8..=128).contains(&password.len()) {
10579
return Err(WebError::Serialization("Incorrect password length".into()));
@@ -1248,23 +1222,6 @@ mod test {
12481222

12491223
use super::*;
12501224

1251-
#[test]
1252-
fn test_username_prune() {
1253-
assert_eq!(prune_username("zenek"), "zenek");
1254-
assert_eq!(prune_username("zenek34"), "zenek34");
1255-
assert_eq!(prune_username("zenek@34"), "zenek34");
1256-
assert_eq!(prune_username("first.last"), "first.last");
1257-
assert_eq!(prune_username("__zenek__"), "zenek__");
1258-
assert_eq!(prune_username("zenek?"), "zenek");
1259-
assert_eq!(prune_username("zenek!"), "zenek");
1260-
assert_eq!(
1261-
prune_username(
1262-
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
1263-
),
1264-
"averylongnameeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"
1265-
);
1266-
}
1267-
12681225
#[test]
12691226
fn test_username_validation() {
12701227
// valid usernames

0 commit comments

Comments
 (0)