mirror of https://github.com/encode/starlette.git
Respond to credentialed requests with specific origin (#105)
* Respond with specific origin instead of wildcard for credentialed requests * Add test case for credentialed standard request * Add tests for setting vary header
This commit is contained in:
parent
cc09042c1c
commit
7a0f89abb8
|
@ -32,6 +32,8 @@ class CORSMiddleware:
|
||||||
simple_headers = {}
|
simple_headers = {}
|
||||||
if "*" in allow_origins:
|
if "*" in allow_origins:
|
||||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||||
|
else:
|
||||||
|
simple_headers["Vary"] = "Origin"
|
||||||
if allow_credentials:
|
if allow_credentials:
|
||||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
if expose_headers:
|
if expose_headers:
|
||||||
|
@ -74,7 +76,7 @@ class CORSMiddleware:
|
||||||
return self.preflight_response(request_headers=headers)
|
return self.preflight_response(request_headers=headers)
|
||||||
else:
|
else:
|
||||||
return functools.partial(
|
return functools.partial(
|
||||||
self.simple_response, scope=scope, origin=origin
|
self.simple_response, scope=scope, request_headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.app(scope)
|
return self.app(scope)
|
||||||
|
@ -130,22 +132,31 @@ class CORSMiddleware:
|
||||||
|
|
||||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||||
|
|
||||||
async def simple_response(self, receive, send, scope=None, origin=None):
|
async def simple_response(self, receive, send, scope=None, request_headers=None):
|
||||||
inner = self.app(scope)
|
inner = self.app(scope)
|
||||||
send = functools.partial(self.send, send=send, origin=origin)
|
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||||
await inner(receive, send)
|
await inner(receive, send)
|
||||||
|
|
||||||
async def send(self, message, send=None, origin=None):
|
async def send(self, message, send=None, request_headers=None):
|
||||||
if message["type"] != "http.response.start":
|
if message["type"] != "http.response.start":
|
||||||
await send(message)
|
await send(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
message.setdefault("headers", [])
|
message.setdefault("headers", [])
|
||||||
headers = MutableHeaders(message["headers"])
|
headers = MutableHeaders(message["headers"])
|
||||||
|
origin = request_headers["Origin"]
|
||||||
|
has_cookie = "cookie" in request_headers
|
||||||
|
|
||||||
|
# If request includes any cookie headers, then we must respond
|
||||||
|
# with the specific origin instead of '*'.
|
||||||
|
if self.allow_all_origins and has_cookie:
|
||||||
|
self.simple_headers["Access-Control-Allow-Origin"] = origin
|
||||||
|
|
||||||
# If we only allow specific origins, then we have to mirror back
|
# If we only allow specific origins, then we have to mirror back
|
||||||
# the Origin header in the response.
|
# the Origin header in the response.
|
||||||
if not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||||
headers["Access-Control-Allow-Origin"] = origin
|
headers["Access-Control-Allow-Origin"] = origin
|
||||||
|
if "vary" in headers:
|
||||||
|
self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
|
||||||
headers.update(self.simple_headers)
|
headers.update(self.simple_headers)
|
||||||
await send(message)
|
await send(message)
|
||||||
|
|
|
@ -206,3 +206,60 @@ def test_cors_allow_origin_regex():
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.text == "Disallowed CORS origin"
|
assert response.text == "Disallowed CORS origin"
|
||||||
assert "access-control-allow-origin" not in response.headers
|
assert "access-control-allow-origin" not in response.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_cors_credentialed_requests_return_specific_origin():
|
||||||
|
app = Starlette()
|
||||||
|
|
||||||
|
app.add_middleware(CORSMiddleware, allow_origins=["*"])
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def homepage(request):
|
||||||
|
return PlainTextResponse("Homepage", status_code=200)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Test credentialed request
|
||||||
|
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
|
||||||
|
response = client.get("/", headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.text == "Homepage"
|
||||||
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cors_vary_header_defaults_to_origin():
|
||||||
|
app = Starlette()
|
||||||
|
|
||||||
|
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
||||||
|
|
||||||
|
headers = {"Origin": "https://example.org"}
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def homepage(request):
|
||||||
|
return PlainTextResponse("Homepage", status_code=200)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/", headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["vary"] == "Origin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cors_vary_header_is_properly_set():
|
||||||
|
app = Starlette()
|
||||||
|
|
||||||
|
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
||||||
|
|
||||||
|
headers = {"Origin": "https://example.org"}
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def homepage(request):
|
||||||
|
return PlainTextResponse(
|
||||||
|
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/", headers=headers)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["vary"] == "Accept-Encoding, Origin"
|
||||||
|
|
Loading…
Reference in New Issue