Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ serde_json = "1.0.149"
futures = "0.3.32"
async-stream = "0.3.6"
mime_guess = "2.0.5"
uuid = { version = "1.23.0", features = ["v4", "serde"] }
uuid = { version = "1.23.0", features = ["v4", "v7", "serde"] }
thiserror = "2.0.18"

mockall = { version = "0.14.0", optional = true }
Expand Down
13 changes: 13 additions & 0 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,16 @@ MIMALLOC_PURGE_DELAY=0
# When enabled with Linux Transparent Huge Pages (THP), partially-used 2 MiB
# pages inflate the reported RSS by up to 20-30 MiB.
MIMALLOC_ALLOW_LARGE_OS_PAGES=0

# -----------------------------------------------------------------------------
# PROXY
# -----------------------------------------------------------------------------

# Use this section if you are running OxiCloud behind a proxy

# Trusted Proxy IPs. Format: coma separated list of CIDR
# (default not defined = server without proxy)
# if defined and proxy's IPs match, client_ip will be defined from
# `X-Forwarded-For` / `X-Real-Ip`
#OXICLOUD_TRUST_PROXY_CIDR=192.168.0.1/32,10.1.2.0/24

6 changes: 5 additions & 1 deletion src/interfaces/api/handlers/favorites_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ pub async fn get_favorites(

match favorites_service.get_favorites(user_id).await {
Ok(favorites) => {
info!("Retrieved {} favorites for user", favorites.len());
info!(
"Retrieved {} favorites for user {}",
favorites.len(),
auth_user.id
);
(StatusCode::OK, Json(serde_json::json!(favorites))).into_response()
}
Err(err) => {
Expand Down
3 changes: 3 additions & 0 deletions src/interfaces/middleware/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pub async fn auth_middleware(
role: claims.role,
});
request.extensions_mut().insert(current_user);
tracing::Span::current().record("user_id", user_id.to_string());
return Ok(next.run(request).await);
}
Err(e) => {
Expand Down Expand Up @@ -234,6 +235,7 @@ pub async fn auth_middleware(
role,
});
request.extensions_mut().insert(current_user);
tracing::Span::current().record("user_id", user_id.to_string());
return Ok(next.run(request).await);
}
Err(e) => {
Expand Down Expand Up @@ -291,6 +293,7 @@ pub async fn auth_middleware(
});
request.extensions_mut().insert(current_user);
request.extensions_mut().insert(CookieAuthenticated);
tracing::Span::current().record("user_id", user_id.to_string());
return Ok(next.run(request).await);
}
Err(e) => {
Expand Down
2 changes: 2 additions & 0 deletions src/interfaces/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod auth;
pub mod csrf;
pub mod rate_limit;
pub mod trace_span;
pub mod trusted_proxy;
57 changes: 7 additions & 50 deletions src/interfaces/middleware/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,21 @@
//! counts per client IP. Each protected endpoint group gets its own
//! [`RateLimiter`] instance with independently tuneable limits.
//!
//! The middleware extracts the client IP from (in order):
//! 1. `X-Forwarded-For` header (first entry — set by reverse proxies)
//! 2. `X-Real-Ip` header
//! 3. The TCP peer address from the connection info
//! Client IP resolution is delegated to [`super::trusted_proxy::client_ip`],
//! which honours `OXICLOUD_TRUST_PROXY_CIDR` for proxy-header forwarding.
//!
//! When the limit is exceeded a `429 Too Many Requests` response is returned
//! with a `Retry-After` header indicating how many seconds to wait.

use axum::{
extract::ConnectInfo,
http::{HeaderValue, Request, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use moka::sync::Cache;
use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;
use std::time::Duration;

/// Cached value of `OXICLOUD_TRUST_PROXY_HEADERS` env var.
/// Read once on first access, never again — avoids a syscall per request.
static TRUST_PROXY: OnceLock<bool> = OnceLock::new();

/// A simple sliding-window counter keyed by IP address.
///
/// Each key lives for `window` seconds; every request increments the counter.
Expand Down Expand Up @@ -94,46 +86,11 @@ impl RateLimiter {

/// Extract the most-likely real client IP from headers / connection info.
///
/// Proxy headers (`X-Forwarded-For`, `X-Real-Ip`) are only trusted when
/// `OXICLOUD_TRUST_PROXY_HEADERS=true` is set. Without a trusted reverse
/// proxy in front of the app, an attacker can spoof these headers to bypass
/// rate limiting.
/// Proxy headers (`X-Forwarded-For`, `X-Real-Ip`) are only trusted when the
/// TCP peer address falls within `OXICLOUD_TRUST_PROXY_CIDR`. Without a
/// configured CIDR list an attacker could spoof headers to bypass rate limiting.
pub fn extract_client_ip<B>(req: &Request<B>) -> String {
let trust_proxy = *TRUST_PROXY.get_or_init(|| {
std::env::var("OXICLOUD_TRUST_PROXY_HEADERS")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
});

let headers = req.headers();

if trust_proxy {
// 1. X-Forwarded-For (first entry — closest to the client)
if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok())
&& let Some(first) = xff.split(',').next()
{
let ip = first.trim();
if !ip.is_empty() {
return ip.to_string();
}
}

// 2. X-Real-Ip
if let Some(xri) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
let ip = xri.trim();
if !ip.is_empty() {
return ip.to_string();
}
}
}

// 3. TCP peer (ConnectInfo extension set by axum::serve)
if let Some(addr) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
return addr.0.ip().to_string();
}

// Fallback — should never happen behind axum::serve
"unknown".to_string()
super::trusted_proxy::client_ip(req, false)
}

/// Build a rate-limit response with the standard `Retry-After` header.
Expand Down
81 changes: 81 additions & 0 deletions src/interfaces/middleware/trace_span.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//! Custom [`MakeSpan`], [`OnResponse`], and [`MakeRequestId`] for request tracing.
//!
//! [`UuidRequestId`] — generates a UUID v4 per request for `SetRequestIdLayer`.
//!
//! [`ClientIpMakeSpan`] — records `request_id`, `client_ip`, `method`, `uri`,
//! and a placeholder `user_id` (filled by auth middleware) on every request span.
//!
//! [`LogBadRequest`] — emits a WARN for every HTTP 400 response, inheriting
//! all span fields so the log line includes request ID, IP, user, method, URI.

use axum::http::{HeaderValue, Request, Response, StatusCode};
use std::time::Duration;
use tower_http::request_id::{MakeRequestId, RequestId};
use tower_http::trace::{MakeSpan, OnResponse};
use tracing::Span;
use uuid::Uuid;

// ─── Request ID generator ────────────────────────────────────────────────────

/// Generates a UUID v7 (fast, timed, sortable) for each request.
///
/// Used with [`tower_http::request_id::SetRequestIdLayer`]:
/// ```ignore
/// .layer(SetRequestIdLayer::x_request_id(UuidRequestId))
/// ```
#[derive(Clone, Debug, Default)]
pub struct UuidRequestId;

impl MakeRequestId for UuidRequestId {
fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
let id = Uuid::now_v7().to_string();
HeaderValue::from_str(&id).ok().map(RequestId::new)
}
}

// ─── Span factory ────────────────────────────────────────────────────────────

/// Implements [`MakeSpan`] so that every HTTP request span carries
/// `request_id`, `client_ip`, `method`, `uri`, and a deferred `user_id`.
///
/// `request_id` is read from the `x-request-id` header set by
/// [`tower_http::request_id::SetRequestIdLayer`] (which must wrap this layer).
#[derive(Clone, Debug, Default)]
pub struct ClientIpMakeSpan;

impl<B> MakeSpan<B> for ClientIpMakeSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
let ip = super::trusted_proxy::client_ip(request, true);
let request_id = request
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
tracing::info_span!(
"req",
request_id = request_id,
client_ip = %ip,
method = %request.method(),
uri = %request.uri().path(),
user_id = tracing::field::Empty,
)
}
}

// ─── Response observer ───────────────────────────────────────────────────────

/// Implements [`OnResponse`]: emits a WARN log for every HTTP 400 response.
#[derive(Clone, Debug, Default)]
pub struct LogBadRequest;

impl<B> OnResponse<B> for LogBadRequest {
fn on_response(self, response: &Response<B>, latency: Duration, _span: &Span) {
if response.status() == StatusCode::BAD_REQUEST {
tracing::warn!(
status = 400,
latency_ms = latency.as_millis(),
"bad request",
);
}
}
}
Loading
Loading