diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 73455310..8bc33804 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -32,6 +32,8 @@ class CORSMiddleware: simple_headers = {} if "*" in allow_origins: simple_headers["Access-Control-Allow-Origin"] = "*" + else: + simple_headers["Vary"] = "Origin" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" if expose_headers: @@ -74,7 +76,7 @@ class CORSMiddleware: return self.preflight_response(request_headers=headers) else: return functools.partial( - self.simple_response, scope=scope, origin=origin + self.simple_response, scope=scope, request_headers=headers ) return self.app(scope) @@ -130,22 +132,31 @@ class CORSMiddleware: return PlainTextResponse("OK", status_code=200, headers=headers) - async def simple_response(self, receive, send, scope=None, origin=None): + async def simple_response(self, receive, send, scope=None, request_headers=None): inner = self.app(scope) - send = functools.partial(self.send, send=send, origin=origin) + send = functools.partial(self.send, send=send, request_headers=request_headers) await inner(receive, send) - async def send(self, message, send=None, origin=None): + async def send(self, message, send=None, request_headers=None): if message["type"] != "http.response.start": await send(message) return message.setdefault("headers", []) headers = MutableHeaders(message["headers"]) + origin = request_headers["Origin"] + has_cookie = "cookie" in request_headers + + # If request includes any cookie headers, then we must respond + # with the specific origin instead of '*'. + if self.allow_all_origins and has_cookie: + self.simple_headers["Access-Control-Allow-Origin"] = origin # If we only allow specific origins, then we have to mirror back # the Origin header in the response. - if not self.allow_all_origins and self.is_allowed_origin(origin=origin): + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): headers["Access-Control-Allow-Origin"] = origin + if "vary" in headers: + self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin" headers.update(self.simple_headers) await send(message) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 1b42a2ff..ab3a0772 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -206,3 +206,60 @@ def test_cors_allow_origin_regex(): assert response.status_code == 400 assert response.text == "Disallowed CORS origin" assert "access-control-allow-origin" not in response.headers + + +def test_cors_credentialed_requests_return_specific_origin(): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["*"]) + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = TestClient(app) + + # Test credentialed request + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + + +def test_cors_vary_header_defaults_to_origin(): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + + headers = {"Origin": "https://example.org"} + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = TestClient(app) + + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.headers["vary"] == "Origin" + + +def test_cors_vary_header_is_properly_set(): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + + headers = {"Origin": "https://example.org"} + + @app.route("/") + def homepage(request): + return PlainTextResponse( + "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} + ) + + client = TestClient(app) + + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding, Origin"