Add www_redirect behavior to TrustedHostsMiddleware (#181)

This commit is contained in:
Tom Christie 2018-11-05 11:38:20 +00:00 committed by GitHub
parent 552e0f6f2d
commit 3c7573715d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 8 deletions

View File

@ -1,7 +1,7 @@
import typing import typing
from starlette.datastructures import Headers from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, ASGIInstance, Scope from starlette.types import ASGIApp, ASGIInstance, Scope
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
@ -9,28 +9,38 @@ ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'
class TrustedHostMiddleware: class TrustedHostMiddleware:
def __init__( def __init__(
self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"] self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] = ["*"],
www_redirect: bool = True,
) -> None: ) -> None:
for pattern in allowed_hosts: for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*": if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app self.app = app
self.allowed_hosts = allowed_hosts self.allowed_hosts = list(allowed_hosts)
self.allow_any = "*" in allowed_hosts self.allow_any = "*" in allowed_hosts
self.www_redirect = www_redirect
def __call__(self, scope: Scope) -> ASGIInstance: def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ("http", "websocket") and not self.allow_any: if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope) headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0] host = headers.get("host", "").split(":")[0]
found_www_redirect = False
for pattern in self.allowed_hosts: for pattern in self.allowed_hosts:
if ( if host == pattern or (
host == pattern pattern.startswith("*") and host.endswith(pattern[1:])
or pattern.startswith("*")
and host.endswith(pattern[1:])
): ):
break break
elif "www." + host == pattern:
found_www_redirect = True
else: else:
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc="www." + url.netloc)
print(redirect_url)
return RedirectResponse(url=str(redirect_url))
return PlainTextResponse("Invalid host header", status_code=400) return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope) return self.app(scope)

View File

@ -26,3 +26,18 @@ def test_trusted_host_middleware():
client = TestClient(app, base_url="http://invalidhost") client = TestClient(app, base_url="http://invalidhost")
response = client.get("/") response = client.get("/")
assert response.status_code == 400 assert response.status_code == 400
def test_www_redirect():
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app, base_url="https://example.com")
response = client.get("/")
assert response.status_code == 200
assert response.url == "https://www.example.com/"