diff --git a/app/core/config.py b/app/core/config.py index f4bede6..f0e49fd 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -42,9 +42,10 @@ class Settings(BaseSettings): @computed_field # type: ignore[prop-decorator] @property def all_cors_origins(self) -> list[str]: - return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + [ - self.FRONTEND_HOST - ] + origins = [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + if self.FRONTEND_HOST: + origins.append(str(self.FRONTEND_HOST).rstrip("/")) + return origins PROJECT_NAME: str SENTRY_DSN: str | None = None diff --git a/app/main.py b/app/main.py index 2b78964..7f943cf 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,7 @@ import sentry_sdk -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from fastapi.routing import APIRoute from slowapi.errors import RateLimitExceeded from starlette.exceptions import HTTPException @@ -14,6 +15,7 @@ ) from app.api.main import api_router from app.core.config import settings +from app.core.messages.error_message import ErrorMessages from app.core.rate_limit import limiter @@ -38,6 +40,40 @@ def custom_generate_unique_id(route: APIRoute) -> str: app.add_exception_handler(RateLimitExceeded, rate_limit_exception_handler) app.state.limiter = limiter + +@app.middleware("http") +async def origin_check_middleware(request: Request, call_next): + """ + Strict origin check. Returns 404 for unauthorized origins. + Allows same-origin requests (from the API's own origin). + Hides error details and ensures always 404 for these cases. + """ + origin = request.headers.get("origin") + if origin: + origin = origin.rstrip("/") + + # Allow same-origin requests + if origin: + # Reconstruct the request's own origin from scheme and host header + scheme = request.url.scheme + host = request.headers.get("host", "").rstrip("/") + request_origin = f"{scheme}://{host}".rstrip("/") + + # Allow if origins match (same-origin request) + if origin == request_origin: + return await call_next(request) + + allowed_origins = settings.all_cors_origins + + if origin and allowed_origins and "*" not in allowed_origins: + if origin not in allowed_origins: + return JSONResponse( + status_code=404, + content={"success": False, "error": ErrorMessages.RESOURCE_NOT_FOUND}, + ) + return await call_next(request) + + # Set all CORS enabled origins if settings.all_cors_origins: app.add_middleware( diff --git a/app/tests/test_origin_restriction.py b/app/tests/test_origin_restriction.py new file mode 100644 index 0000000..45a1c1f --- /dev/null +++ b/app/tests/test_origin_restriction.py @@ -0,0 +1,81 @@ +import pytest +from httpx import AsyncClient + +from app.core.config import settings + + +@pytest.mark.asyncio +async def test_allowed_origin(client: AsyncClient): + # Use an origin from settings.all_cors_origins + # settings.all_cors_origins includes FRONTEND_HOST + allowed_origin = settings.FRONTEND_HOST + response = await client.get( + "/users/me", # Correct path + headers={"Origin": allowed_origin}, + ) + # It should not be 404 due to origin (might be 401 if not logged in, but not 404) + assert response.status_code != 404 + + +@pytest.mark.asyncio +async def test_unauthorized_origin(client: AsyncClient): + response = await client.get( + "/auth/login", # Use an endpoint that exists + headers={"Origin": "http://malicious.com"}, + ) + assert response.status_code == 404 + assert response.json() == {"success": False, "error": "RESOURCE_NOT_FOUND"} + + +@pytest.mark.asyncio +async def test_no_origin(client: AsyncClient): + response = await client.get("/auth/login") + # Should not be 404 due to missing origin + assert response.status_code != 404 + # If /health doesn't exist, this might fail, but let's check a known endpoint + pass + + +@pytest.mark.asyncio +async def test_same_origin_post(client: AsyncClient): + """Test that same-origin POST requests are allowed (not blocked with 404).""" + # The client fixture uses http://test as the base origin (from conftest base_url) + same_origin = "http://test" + + response = await client.post( + "/auth/logout", # Correct path (base_url already includes /api/v1) + headers={"Origin": same_origin}, + ) + + # Should NOT be 404 due to origin check (may fail with 401 if not authenticated, but not 404) + assert response.status_code != 404, ( + "Same-origin POST request should not be blocked with 404" + ) + + +@pytest.mark.asyncio +async def test_same_origin_put(client: AsyncClient): + """Test that same-origin PUT requests are allowed (not blocked with 404).""" + same_origin = "http://test" + + response = await client.put( + "/users/me", + headers={"Origin": same_origin}, + ) + + # Should NOT be 404 due to origin check (may fail with 401/422, but not 404) + assert response.status_code != 404, ( + "Same-origin PUT request should not be blocked with 404" + ) + + +@pytest.mark.asyncio +async def test_cross_origin_post_blocked(client: AsyncClient): + """Test that cross-origin POST requests are still blocked (404).""" + response = await client.post( + "/auth/logout", + headers={"Origin": "http://malicious.com"}, + ) + + assert response.status_code == 404 + assert response.json() == {"success": False, "error": "RESOURCE_NOT_FOUND"}