mirror of https://github.com/encode/starlette.git
Respond to credentialed requests with specific origin (#105)
* Respond with specific origin instead of wildcard for credentialed requests * Add test case for credentialed standard request * Add tests for setting vary header
This commit is contained in:
parent
cc09042c1c
commit
7a0f89abb8
|
@ -32,6 +32,8 @@ class CORSMiddleware:
|
|||
simple_headers = {}
|
||||
if "*" in allow_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
else:
|
||||
simple_headers["Vary"] = "Origin"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
|
@ -74,7 +76,7 @@ class CORSMiddleware:
|
|||
return self.preflight_response(request_headers=headers)
|
||||
else:
|
||||
return functools.partial(
|
||||
self.simple_response, scope=scope, origin=origin
|
||||
self.simple_response, scope=scope, request_headers=headers
|
||||
)
|
||||
|
||||
return self.app(scope)
|
||||
|
@ -130,22 +132,31 @@ class CORSMiddleware:
|
|||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(self, receive, send, scope=None, origin=None):
|
||||
async def simple_response(self, receive, send, scope=None, request_headers=None):
|
||||
inner = self.app(scope)
|
||||
send = functools.partial(self.send, send=send, origin=origin)
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await inner(receive, send)
|
||||
|
||||
async def send(self, message, send=None, origin=None):
|
||||
async def send(self, message, send=None, request_headers=None):
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(message["headers"])
|
||||
origin = request_headers["Origin"]
|
||||
has_cookie = "cookie" in request_headers
|
||||
|
||||
# If request includes any cookie headers, then we must respond
|
||||
# with the specific origin instead of '*'.
|
||||
if self.allow_all_origins and has_cookie:
|
||||
self.simple_headers["Access-Control-Allow-Origin"] = origin
|
||||
|
||||
# If we only allow specific origins, then we have to mirror back
|
||||
# the Origin header in the response.
|
||||
if not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
if "vary" in headers:
|
||||
self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
|
||||
headers.update(self.simple_headers)
|
||||
await send(message)
|
||||
|
|
|
@ -206,3 +206,60 @@ def test_cors_allow_origin_regex():
|
|||
assert response.status_code == 400
|
||||
assert response.text == "Disallowed CORS origin"
|
||||
assert "access-control-allow-origin" not in response.headers
|
||||
|
||||
|
||||
def test_cors_credentialed_requests_return_specific_origin():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"])
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("Homepage", status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test credentialed request
|
||||
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
|
||||
response = client.get("/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Homepage"
|
||||
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
||||
|
||||
|
||||
def test_cors_vary_header_defaults_to_origin():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
||||
|
||||
headers = {"Origin": "https://example.org"}
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("Homepage", status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["vary"] == "Origin"
|
||||
|
||||
|
||||
def test_cors_vary_header_is_properly_set():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
||||
|
||||
headers = {"Origin": "https://example.org"}
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse(
|
||||
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/", headers=headers)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["vary"] == "Accept-Encoding, Origin"
|
||||
|
|
Loading…
Reference in New Issue