diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 61502691a..f0ea120f5 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -152,9 +152,9 @@ async def send(self, message: Message, send: Send, request_headers: Headers) -> headers = MutableHeaders(scope=message) headers.update(self.simple_headers) origin = request_headers["Origin"] - has_cookie = "cookie" in request_headers + has_cookie = "cookie" in request_headers or "set-cookie" in headers - # If request includes any cookie headers, then we must respond + # If request or response includes any cookie headers, then we must respond # with the specific origin instead of '*'. if self.allow_all_origins and has_cookie: self.allow_explicit_origin(headers, origin) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 0d987263e..d3a2bdee6 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -425,6 +425,29 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-credentials" not in response.headers +def test_cors_credentialed_requests_return_specific_origin_without_initial_cookie( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: + response = PlainTextResponse("Homepage", status_code=200) + response.set_cookie("mycookie", "myvalue", path=None) + return response + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=["*"])], + ) + client = test_client_factory(app) + + # Test credentialed request + headers = {"Origin": "https://example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" in response.headers + + def test_cors_vary_header_defaults_to_origin( test_client_factory: TestClientFactory, ) -> None: