Respond to credentialed requests with specific origin (#105)

* Respond with specific origin instead of wildcard for credentialed requests

* Add test case for credentialed standard request

* Add tests for setting vary header
This commit is contained in:
Alex Botello 2018-10-12 03:46:10 -05:00 committed by Tom Christie
parent cc09042c1c
commit 7a0f89abb8
2 changed files with 73 additions and 5 deletions

View File

@ -32,6 +32,8 @@ class CORSMiddleware:
simple_headers = {} simple_headers = {}
if "*" in allow_origins: if "*" in allow_origins:
simple_headers["Access-Control-Allow-Origin"] = "*" simple_headers["Access-Control-Allow-Origin"] = "*"
else:
simple_headers["Vary"] = "Origin"
if allow_credentials: if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true" simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers: if expose_headers:
@ -74,7 +76,7 @@ class CORSMiddleware:
return self.preflight_response(request_headers=headers) return self.preflight_response(request_headers=headers)
else: else:
return functools.partial( return functools.partial(
self.simple_response, scope=scope, origin=origin self.simple_response, scope=scope, request_headers=headers
) )
return self.app(scope) return self.app(scope)
@ -130,22 +132,31 @@ class CORSMiddleware:
return PlainTextResponse("OK", status_code=200, headers=headers) return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(self, receive, send, scope=None, origin=None): async def simple_response(self, receive, send, scope=None, request_headers=None):
inner = self.app(scope) inner = self.app(scope)
send = functools.partial(self.send, send=send, origin=origin) send = functools.partial(self.send, send=send, request_headers=request_headers)
await inner(receive, send) await inner(receive, send)
async def send(self, message, send=None, origin=None): async def send(self, message, send=None, request_headers=None):
if message["type"] != "http.response.start": if message["type"] != "http.response.start":
await send(message) await send(message)
return return
message.setdefault("headers", []) message.setdefault("headers", [])
headers = MutableHeaders(message["headers"]) headers = MutableHeaders(message["headers"])
origin = request_headers["Origin"]
has_cookie = "cookie" in request_headers
# If request includes any cookie headers, then we must respond
# with the specific origin instead of '*'.
if self.allow_all_origins and has_cookie:
self.simple_headers["Access-Control-Allow-Origin"] = origin
# If we only allow specific origins, then we have to mirror back # If we only allow specific origins, then we have to mirror back
# the Origin header in the response. # the Origin header in the response.
if not self.allow_all_origins and self.is_allowed_origin(origin=origin): elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
headers["Access-Control-Allow-Origin"] = origin headers["Access-Control-Allow-Origin"] = origin
if "vary" in headers:
self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
headers.update(self.simple_headers) headers.update(self.simple_headers)
await send(message) await send(message)

View File

@ -206,3 +206,60 @@ def test_cors_allow_origin_regex():
assert response.status_code == 400 assert response.status_code == 400
assert response.text == "Disallowed CORS origin" assert response.text == "Disallowed CORS origin"
assert "access-control-allow-origin" not in response.headers assert "access-control-allow-origin" not in response.headers
def test_cors_credentialed_requests_return_specific_origin():
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["*"])
@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
client = TestClient(app)
# Test credentialed request
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
def test_cors_vary_header_defaults_to_origin():
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
headers = {"Origin": "https://example.org"}
@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
client = TestClient(app)
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.headers["vary"] == "Origin"
def test_cors_vary_header_is_properly_set():
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
headers = {"Origin": "https://example.org"}
@app.route("/")
def homepage(request):
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
client = TestClient(app)
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"