From 3c7573715d4fa81a5e5c921527a466d8f78530c5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 5 Nov 2018 11:38:20 +0000 Subject: [PATCH] Add www_redirect behavior to TrustedHostsMiddleware (#181) --- starlette/middleware/trustedhost.py | 26 ++++++++++++++++++-------- tests/middleware/test_trusted_host.py | 15 +++++++++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index 5f733dc1..536a1968 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,7 +1,7 @@ import typing -from starlette.datastructures import Headers -from starlette.responses import PlainTextResponse +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, ASGIInstance, Scope ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." @@ -9,28 +9,38 @@ ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com' class TrustedHostMiddleware: def __init__( - self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"] + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] = ["*"], + www_redirect: bool = True, ) -> None: for pattern in allowed_hosts: assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD if pattern.startswith("*") and pattern != "*": assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD self.app = app - self.allowed_hosts = allowed_hosts + self.allowed_hosts = list(allowed_hosts) self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] in ("http", "websocket") and not self.allow_any: headers = Headers(scope=scope) host = headers.get("host", "").split(":")[0] + found_www_redirect = False for pattern in self.allowed_hosts: - if ( - host == pattern - or pattern.startswith("*") - and host.endswith(pattern[1:]) + if host == pattern or ( + pattern.startswith("*") and host.endswith(pattern[1:]) ): break + elif "www." + host == pattern: + found_www_redirect = True else: + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + print(redirect_url) + return RedirectResponse(url=str(redirect_url)) return PlainTextResponse("Invalid host header", status_code=400) return self.app(scope) diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index d8127d3e..54b2a0c1 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -26,3 +26,18 @@ def test_trusted_host_middleware(): client = TestClient(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 + + +def test_www_redirect(): + app = Starlette() + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"]) + + @app.route("/") + def homepage(request): + return PlainTextResponse("OK", status_code=200) + + client = TestClient(app, base_url="https://example.com") + response = client.get("/") + assert response.status_code == 200 + assert response.url == "https://www.example.com/"