Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ class Settings(BaseSettings):
@computed_field # type: ignore[prop-decorator]
@property
def all_cors_origins(self) -> list[str]:
return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + [
self.FRONTEND_HOST
]
origins = [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS]
if self.FRONTEND_HOST:
origins.append(str(self.FRONTEND_HOST).rstrip("/"))
return origins

PROJECT_NAME: str
SENTRY_DSN: str | None = None
Expand Down
38 changes: 37 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sentry_sdk
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from slowapi.errors import RateLimitExceeded
from starlette.exceptions import HTTPException
Expand All @@ -14,6 +15,7 @@
)
from app.api.main import api_router
from app.core.config import settings
from app.core.messages.error_message import ErrorMessages
from app.core.rate_limit import limiter


Expand All @@ -38,6 +40,40 @@ def custom_generate_unique_id(route: APIRoute) -> str:
app.add_exception_handler(RateLimitExceeded, rate_limit_exception_handler)
app.state.limiter = limiter


@app.middleware("http")
async def origin_check_middleware(request: Request, call_next):
"""
Strict origin check. Returns 404 for unauthorized origins.
Allows same-origin requests (from the API's own origin).
Hides error details and ensures always 404 for these cases.
"""
origin = request.headers.get("origin")
if origin:
origin = origin.rstrip("/")

# Allow same-origin requests
if origin:
# Reconstruct the request's own origin from scheme and host header
scheme = request.url.scheme
host = request.headers.get("host", "").rstrip("/")
request_origin = f"{scheme}://{host}".rstrip("/")

# Allow if origins match (same-origin request)
if origin == request_origin:
return await call_next(request)

allowed_origins = settings.all_cors_origins

if origin and allowed_origins and "*" not in allowed_origins:
if origin not in allowed_origins:
return JSONResponse(
status_code=404,
content={"success": False, "error": ErrorMessages.RESOURCE_NOT_FOUND},
)
return await call_next(request)
Comment thread
cursor[bot] marked this conversation as resolved.


# Set all CORS enabled origins
if settings.all_cors_origins:
app.add_middleware(
Expand Down
81 changes: 81 additions & 0 deletions app/tests/test_origin_restriction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
from httpx import AsyncClient

from app.core.config import settings


@pytest.mark.asyncio
async def test_allowed_origin(client: AsyncClient):
# Use an origin from settings.all_cors_origins
# settings.all_cors_origins includes FRONTEND_HOST
allowed_origin = settings.FRONTEND_HOST
response = await client.get(
"/users/me", # Correct path
headers={"Origin": allowed_origin},
)
# It should not be 404 due to origin (might be 401 if not logged in, but not 404)
assert response.status_code != 404


@pytest.mark.asyncio
async def test_unauthorized_origin(client: AsyncClient):
response = await client.get(
"/auth/login", # Use an endpoint that exists
headers={"Origin": "http://malicious.com"},
)
assert response.status_code == 404
assert response.json() == {"success": False, "error": "RESOURCE_NOT_FOUND"}


@pytest.mark.asyncio
async def test_no_origin(client: AsyncClient):
response = await client.get("/auth/login")
# Should not be 404 due to missing origin
assert response.status_code != 404
# If /health doesn't exist, this might fail, but let's check a known endpoint
pass


@pytest.mark.asyncio
async def test_same_origin_post(client: AsyncClient):
"""Test that same-origin POST requests are allowed (not blocked with 404)."""
# The client fixture uses http://test as the base origin (from conftest base_url)
same_origin = "http://test"

response = await client.post(
"/auth/logout", # Correct path (base_url already includes /api/v1)
headers={"Origin": same_origin},
)

# Should NOT be 404 due to origin check (may fail with 401 if not authenticated, but not 404)
assert response.status_code != 404, (
"Same-origin POST request should not be blocked with 404"
)


@pytest.mark.asyncio
async def test_same_origin_put(client: AsyncClient):
"""Test that same-origin PUT requests are allowed (not blocked with 404)."""
same_origin = "http://test"

response = await client.put(
"/users/me",
headers={"Origin": same_origin},
)

# Should NOT be 404 due to origin check (may fail with 401/422, but not 404)
assert response.status_code != 404, (
"Same-origin PUT request should not be blocked with 404"
)


@pytest.mark.asyncio
async def test_cross_origin_post_blocked(client: AsyncClient):
"""Test that cross-origin POST requests are still blocked (404)."""
response = await client.post(
"/auth/logout",
headers={"Origin": "http://malicious.com"},
)

assert response.status_code == 404
assert response.json() == {"success": False, "error": "RESOURCE_NOT_FOUND"}
Loading