Add www_redirect behavior to TrustedHostsMiddleware (#181)

This commit is contained in:
Tom Christie 2018-11-05 11:38:20 +00:00 committed by GitHub
parent 552e0f6f2d
commit 3c7573715d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 8 deletions

View File

@ -1,7 +1,7 @@
import typing
from starlette.datastructures import Headers
from starlette.responses import PlainTextResponse
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, ASGIInstance, Scope
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
@ -9,28 +9,38 @@ ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'
class TrustedHostMiddleware:
def __init__(
self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"]
self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] = ["*"],
www_redirect: bool = True,
) -> None:
for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = allowed_hosts
self.allowed_hosts = list(allowed_hosts)
self.allow_any = "*" in allowed_hosts
self.www_redirect = www_redirect
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
found_www_redirect = False
for pattern in self.allowed_hosts:
if (
host == pattern
or pattern.startswith("*")
and host.endswith(pattern[1:])
if host == pattern or (
pattern.startswith("*") and host.endswith(pattern[1:])
):
break
elif "www." + host == pattern:
found_www_redirect = True
else:
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc="www." + url.netloc)
print(redirect_url)
return RedirectResponse(url=str(redirect_url))
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)

View File

@ -26,3 +26,18 @@ def test_trusted_host_middleware():
client = TestClient(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400
def test_www_redirect():
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app, base_url="https://example.com")
response = client.get("/")
assert response.status_code == 200
assert response.url == "https://www.example.com/"