mirror of https://github.com/encode/starlette.git
51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import PlainTextResponse
|
|
from starlette.routing import Route
|
|
from tests.types import TestClientFactory
|
|
|
|
|
|
def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None:
|
|
def homepage(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("OK", status_code=200)
|
|
|
|
app = Starlette(
|
|
routes=[Route("/", endpoint=homepage)],
|
|
middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"])],
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = test_client_factory(app, base_url="http://subdomain.testserver")
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = test_client_factory(app, base_url="http://invalidhost")
|
|
response = client.get("/")
|
|
assert response.status_code == 400
|
|
|
|
|
|
def test_default_allowed_hosts() -> None:
|
|
app = Starlette()
|
|
middleware = TrustedHostMiddleware(app)
|
|
assert middleware.allowed_hosts == ["*"]
|
|
|
|
|
|
def test_www_redirect(test_client_factory: TestClientFactory) -> None:
|
|
def homepage(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("OK", status_code=200)
|
|
|
|
app = Starlette(
|
|
routes=[Route("/", endpoint=homepage)],
|
|
middleware=[Middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])],
|
|
)
|
|
|
|
client = test_client_factory(app, base_url="https://example.com")
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.url == "https://www.example.com/"
|