Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 91 additions & 26 deletions src/infrastructure/services/login_lockout_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,28 @@ impl LoginLockoutService {
}
}

/// Check whether the account is currently locked.
/// Build the cache key from the (lowercased) username and the client IP.
///
/// The IP is part of the key so that an attacker flooding bad passwords
/// from one address cannot lock a legitimate user out of the same account
/// from a different address (issue #323). When the caller cannot resolve
/// a real IP — e.g. `OXICLOUD_TRUST_PROXY_HEADERS=false` and the peer
/// address isn't available — `client_ip` should be a non-empty constant
/// like `"unknown"`; in that pathological case we fall back to
/// account-scoped lockout, which is no worse than the previous
/// behaviour.
fn key(username: &str, client_ip: &str) -> String {
// `|` is not valid in either a username or an IP literal so it makes
// the username/ip boundary unambiguous.
format!("{}|{}", username.to_lowercase(), client_ip)
}

/// Check whether the (account, IP) pair is currently locked.
///
/// Returns `Ok(())` if the user may attempt login, or
/// `Err(remaining_secs)` with the *approximate* remaining lockout time.
pub fn check(&self, username: &str) -> Result<(), u64> {
if let Some(rec) = self.cache.get(&username.to_lowercase())
pub fn check(&self, username: &str, client_ip: &str) -> Result<(), u64> {
if let Some(rec) = self.cache.get(&Self::key(username, client_ip))
&& rec.count >= self.max_failures
{
// The entry exists and is over the threshold. Because moka
Expand All @@ -71,27 +87,30 @@ impl LoginLockoutService {
}

/// Record a failed login attempt. Returns the new failure count.
pub fn record_failure(&self, username: &str) -> u32 {
let key = username.to_lowercase();
pub fn record_failure(&self, username: &str, client_ip: &str) -> u32 {
let key = Self::key(username, client_ip);
let new_count = self.cache.get(&key).map(|r| r.count + 1).unwrap_or(1);
self.cache
.insert(key.clone(), FailureRecord { count: new_count });

if new_count >= self.max_failures {
tracing::warn!(
username = %username,
client_ip = %client_ip,
attempts = new_count,
lockout_secs = self.lockout_secs,
"Account temporarily locked after {} consecutive failed login attempts",
"Account temporarily locked after {} consecutive failed login attempts from this IP",
new_count,
);
}
new_count
}

/// Record a successful login — resets the failure counter.
pub fn record_success(&self, username: &str) {
self.cache.invalidate(&username.to_lowercase());
/// Record a successful login — resets the failure counter for this
/// (account, IP) pair so the user isn't penalised for stray earlier
/// failures from the same address.
pub fn record_success(&self, username: &str, client_ip: &str) {
self.cache.invalidate(&Self::key(username, client_ip));
}

/// Maximum failures before lockout (used to inform callers / error messages).
Expand All @@ -109,42 +128,88 @@ impl LoginLockoutService {
mod tests {
use super::*;

const IP1: &str = "1.1.1.1";
const IP2: &str = "2.2.2.2";

#[test]
fn allows_login_under_threshold() {
let svc = LoginLockoutService::new(3, 60, 100);
assert!(svc.check("alice").is_ok());
svc.record_failure("alice");
svc.record_failure("alice");
assert!(svc.check("alice", IP1).is_ok());
svc.record_failure("alice", IP1);
svc.record_failure("alice", IP1);
// 2 failures — still under threshold
assert!(svc.check("alice").is_ok());
assert!(svc.check("alice", IP1).is_ok());
}

#[test]
fn locks_after_threshold() {
let svc = LoginLockoutService::new(3, 60, 100);
svc.record_failure("bob");
svc.record_failure("bob");
svc.record_failure("bob");
assert!(svc.check("bob").is_err());
svc.record_failure("bob", IP1);
svc.record_failure("bob", IP1);
svc.record_failure("bob", IP1);
assert!(svc.check("bob", IP1).is_err());
}

#[test]
fn resets_on_success() {
let svc = LoginLockoutService::new(3, 60, 100);
svc.record_failure("carol");
svc.record_failure("carol");
svc.record_success("carol");
svc.record_failure("carol", IP1);
svc.record_failure("carol", IP1);
svc.record_success("carol", IP1);
// Counter reset — should be allowed again
assert!(svc.check("carol").is_ok());
svc.record_failure("carol"); // starts over at 1
assert!(svc.check("carol").is_ok());
assert!(svc.check("carol", IP1).is_ok());
svc.record_failure("carol", IP1); // starts over at 1
assert!(svc.check("carol", IP1).is_ok());
}

#[test]
fn case_insensitive() {
let svc = LoginLockoutService::new(2, 60, 100);
svc.record_failure("Dave");
svc.record_failure("dave");
assert!(svc.check("DAVE").is_err());
svc.record_failure("Dave", IP1);
svc.record_failure("dave", IP1);
assert!(svc.check("DAVE", IP1).is_err());
}

/// Regression test for #323: flooding bad passwords from one IP must
/// NOT lock the account out for legitimate users coming from a
/// different IP.
#[test]
fn does_not_lock_out_other_ips_for_same_account() {
let svc = LoginLockoutService::new(3, 60, 100);

// Attacker hammers the account from IP1 until it locks for that IP.
for _ in 0..3 {
svc.record_failure("admin", IP1);
}
assert!(
svc.check("admin", IP1).is_err(),
"attacker IP must be locked"
);

// A legitimate user coming from IP2 must still be allowed to try.
assert!(
svc.check("admin", IP2).is_ok(),
"second IP must not inherit the lockout — that's the #323 DOS"
);
}

/// A successful login on one IP must clear *that* IP's counter only —
/// it should NOT silently absolve a separate, ongoing brute-force from
/// a different IP against the same account.
#[test]
fn success_resets_only_the_acting_ip() {
let svc = LoginLockoutService::new(3, 60, 100);

for _ in 0..3 {
svc.record_failure("admin", IP1);
}
// Genuine login from IP2 succeeds; should reset IP2 counter (which
// is already 0 here) but leave IP1's lockout intact.
svc.record_success("admin", IP2);

assert!(
svc.check("admin", IP1).is_err(),
"IP1 must remain locked after IP2's success"
);
}
}
25 changes: 19 additions & 6 deletions src/interfaces/api/handlers/auth_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::common::di::AppState;
use crate::interfaces::api::cookie_auth;
use crate::interfaces::errors::AppError;
use crate::interfaces::middleware::auth::CurrentUserId;
use crate::interfaces::middleware::rate_limit::extract_client_ip_from_parts;

/// Public auth routes — no authentication required.
pub fn auth_public_routes() -> Router<Arc<AppState>> {
Expand Down Expand Up @@ -142,13 +143,21 @@ async fn login(
};

// ── Account lockout check ──────────────────────────────────────────
// Reject immediately if the account has too many consecutive failures.
// This runs BEFORE Argon2 to save CPU under brute-force attacks.
if let Err(lockout_secs) = auth_service.login_lockout.check(&dto.username) {
// Reject immediately if (this account, this IP) has too many consecutive
// failures. The IP is part of the key so an attacker flooding bad
// passwords from one address cannot lock a legitimate user out of the
// same account from a different address (issue #323). The check runs
// BEFORE Argon2 to save CPU under brute-force attacks.
let client_ip = extract_client_ip_from_parts(&headers, None);
if let Err(lockout_secs) = auth_service
.login_lockout
.check(&dto.username, &client_ip)
{
tracing::warn!(
username = %dto.username,
client_ip = %client_ip,
lockout_secs = lockout_secs,
"Login rejected — account temporarily locked"
"Login rejected — account temporarily locked for this IP"
);
return Err(AppError::new(
StatusCode::TOO_MANY_REQUESTS,
Expand Down Expand Up @@ -178,7 +187,9 @@ async fn login(
{
Ok(auth_response) => {
// ── Successful login — reset lockout counter ──
auth_service.login_lockout.record_success(&dto.username);
auth_service
.login_lockout
.record_success(&dto.username, &client_ip);

tracing::info!("Login successful for user: {}", dto.username);
// Log the response structure for debugging
Expand Down Expand Up @@ -228,7 +239,9 @@ async fn login(
}
Err(err) => {
// ── Record failed attempt for lockout tracking ──
auth_service.login_lockout.record_failure(&dto.username);
auth_service
.login_lockout
.record_failure(&dto.username, &client_ip);
tracing::error!("Login failed for user {}: {}", dto.username, err);
Err(err.into())
}
Expand Down
22 changes: 18 additions & 4 deletions src/interfaces/middleware/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,28 @@ impl RateLimiter {
/// proxy in front of the app, an attacker can spoof these headers to bypass
/// rate limiting.
pub fn extract_client_ip<B>(req: &Request<B>) -> String {
extract_client_ip_from_parts(
req.headers(),
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|c| &c.0),
)
}

/// Same as [`extract_client_ip`], but operates on already-extracted axum parts
/// (a `&HeaderMap` plus an optional TCP peer). Handlers that don't take a full
/// `Request<B>` (e.g. those that consume the body via `Json<…>`) can still
/// derive a stable client identifier with this entry point.
pub fn extract_client_ip_from_parts(
headers: &axum::http::HeaderMap,
peer: Option<&SocketAddr>,
) -> String {
let trust_proxy = *TRUST_PROXY.get_or_init(|| {
std::env::var("OXICLOUD_TRUST_PROXY_HEADERS")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
});

let headers = req.headers();

if trust_proxy {
// 1. X-Forwarded-For (first entry — closest to the client)
if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok())
Expand All @@ -128,8 +142,8 @@ pub fn extract_client_ip<B>(req: &Request<B>) -> String {
}

// 3. TCP peer (ConnectInfo extension set by axum::serve)
if let Some(addr) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
return addr.0.ip().to_string();
if let Some(addr) = peer {
return addr.ip().to_string();
}

// Fallback — should never happen behind axum::serve
Expand Down
14 changes: 9 additions & 5 deletions src/interfaces/nextcloud/basic_auth_middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,18 @@ pub async fn basic_auth_middleware(
let (username, password) =
parse_basic_auth(auth_header).ok_or(NextcloudAuthError::Unauthorized)?;

// Check account lockout before attempting password verification (saves CPU)
// Check account lockout before attempting password verification (saves CPU).
// The lockout is per (account, IP) — see #323 for rationale.
let client_ip =
crate::interfaces::middleware::rate_limit::extract_client_ip(&request);
if let Some(auth_svc) = state.auth_service.as_ref()
&& let Err(secs) = auth_svc.login_lockout.check(&username)
&& let Err(secs) = auth_svc.login_lockout.check(&username, &client_ip)
{
tracing::warn!(
username = %username,
client_ip = %client_ip,
lockout_remaining_secs = secs,
"[NC] Account locked — too many failed attempts"
"[NC] Account locked — too many failed attempts from this IP"
);
return Err(NextcloudAuthError::Unauthorized);
}
Expand All @@ -87,7 +91,7 @@ pub async fn basic_auth_middleware(
Ok((user_id, uname, email, role)) => {
// Reset lockout counter on success
if let Some(auth_svc) = state.auth_service.as_ref() {
auth_svc.login_lockout.record_success(&username);
auth_svc.login_lockout.record_success(&username, &client_ip);
}
request.extensions_mut().insert(Arc::new(CurrentUser {
id: user_id,
Expand All @@ -100,7 +104,7 @@ pub async fn basic_auth_middleware(
Err(_) => {
// Record failed attempt for lockout tracking
if let Some(auth_svc) = state.auth_service.as_ref() {
auth_svc.login_lockout.record_failure(&username);
auth_svc.login_lockout.record_failure(&username, &client_ip);
}
Err(NextcloudAuthError::Unauthorized)
}
Expand Down