mirror of https://github.com/encode/starlette.git
209 lines
6.5 KiB
Python
209 lines
6.5 KiB
Python
from starlette.applications import Starlette
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
|
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|
from starlette.responses import PlainTextResponse
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
def test_trusted_host_middleware():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("OK", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = TestClient(app, base_url="http://invalidhost")
|
|
response = client.get("/")
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_https_redirect_middleware():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(HTTPSRedirectMiddleware)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("OK", status_code=200)
|
|
|
|
client = TestClient(app, base_url="https://testserver")
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/", allow_redirects=False)
|
|
assert response.status_code == 301
|
|
assert response.headers["location"] == "https://testserver/"
|
|
|
|
|
|
def test_cors_allow_all():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_headers=["*"],
|
|
allow_methods=["*"],
|
|
expose_headers=["X-Status"],
|
|
allow_credentials=True,
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://example.org",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "*"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "*"
|
|
assert response.headers["access-control-expose-headers"] == "X-Status"
|
|
|
|
# Test non-CORS response
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
|
|
def test_cors_allow_specific_origin():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://example.org"],
|
|
allow_headers=["X-Example"],
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://example.org",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
|
|
# Test non-CORS response
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
|
|
def test_cors_disallowed_preflight():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["https://example.org"],
|
|
allow_headers=["X-Example"],
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
pass # pragma: no cover
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://another.org",
|
|
"Access-Control-Request-Method": "POST",
|
|
"Access-Control-Request-Headers": "X-Nope",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 400
|
|
assert response.text == "Disallowed CORS origin, method, headers"
|
|
|
|
|
|
def test_cors_allow_origin_regex():
|
|
app = Starlette()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware, allow_headers=["X-Example"], allow_origin_regex="https://*"
|
|
)
|
|
|
|
@app.route("/")
|
|
def homepage(request):
|
|
return PlainTextResponse("Homepage", status_code=200)
|
|
|
|
client = TestClient(app)
|
|
|
|
# Test standard response
|
|
headers = {"Origin": "https://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert response.headers["access-control-allow-origin"] == "https://example.org"
|
|
|
|
# Test diallowed standard response
|
|
# Note that enforcement is a browser concern. The disallowed-ness is reflected
|
|
# in the lack of an "access-control-allow-origin" header in the response.
|
|
headers = {"Origin": "http://example.org"}
|
|
response = client.get("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "Homepage"
|
|
assert "access-control-allow-origin" not in response.headers
|
|
|
|
# Test pre-flight response
|
|
headers = {
|
|
"Origin": "https://another.com",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
assert response.headers["access-control-allow-origin"] == "https://another.com"
|
|
assert response.headers["access-control-allow-headers"] == "X-Example"
|
|
|
|
# Test disallowed pre-flight response
|
|
headers = {
|
|
"Origin": "http://another.com",
|
|
"Access-Control-Request-Method": "GET",
|
|
"Access-Control-Request-Headers": "X-Example",
|
|
}
|
|
response = client.options("/", headers=headers)
|
|
assert response.status_code == 400
|
|
assert response.text == "Disallowed CORS origin"
|
|
assert "access-control-allow-origin" not in response.headers
|