mirror of https://github.com/encode/starlette.git
227 lines
7.0 KiB
Python
227 lines
7.0 KiB
Python
from starlette.applications import Starlette
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.responses import PlainTextResponse
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
def test_cors_allow_all():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_headers=["*"],
|
|
allow_methods=["*"],
|
|
expose_headers=["X-Status"],
|
|
allow_credentials=True,
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://example.org",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "*"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "*"
|
|
assert response.headers["access-control-expose-headers"] == "X-Status"
|
|
|
|
# Test non-CORS response
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
|
|
def test_cors_allow_specific_origin():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://example.org"],
|
|
allow_headers=["X-Example"],
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://example.org",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
|
|
# Test non-CORS response
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
|
|
def test_cors_disallowed_preflight():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://example.org"],
|
|
allow_headers=["X-Example"],
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
pass # pragma: no cover
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://another.org",
|
|
"Access-Control-Request-Method": "POST",
|
|
"Access-Control-Request-Headers": "X-Nope",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 400
|
|
assert response.text == "Disallowed CORS origin, method, headers"
|
|
|
|
|
|
def test_cors_allow_origin_regex():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware, allow_headers=["X-Example"], allow_origin_regex="https://*"
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
|
|
# Test diallowed standard response
|
|
# Note that enforcement is a browser concern. The disallowed-ness is reflected
|
|
# in the lack of an "access-control-allow-origin" header in the response.
|
|
headers = {"Origin": "http://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://another.com",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "https://another.com"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test disallowed pre-flight response
|
|
headers = {
|
|
"Origin": "http://another.com",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 400
|
|
assert response.text == "Disallowed CORS origin"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
|
|
def test_cors_credentialed_requests_return_specific_origin():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"])
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test credentialed request
|
|
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
|
|
|
|
def test_cors_vary_header_defaults_to_origin():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
|
|
|
headers = {"Origin": "https://example.org"}
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.headers["vary"] == "Origin"
|
|
|
|
|
|
def test_cors_vary_header_is_properly_set():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
|
|
|
|
headers = {"Origin": "https://example.org"}
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse(
|
|
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
|
|
)
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.headers["vary"] == "Accept-Encoding, Origin"
|