mirror of https://github.com/encode/starlette.git
Add www_redirect behavior to TrustedHostsMiddleware (#181)
This commit is contained in:
parent
552e0f6f2d
commit
3c7573715d
|
@ -1,7 +1,7 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import URL, Headers
|
||||||
from starlette.responses import PlainTextResponse
|
from starlette.responses import PlainTextResponse, RedirectResponse
|
||||||
from starlette.types import ASGIApp, ASGIInstance, Scope
|
from starlette.types import ASGIApp, ASGIInstance, Scope
|
||||||
|
|
||||||
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
||||||
|
@ -9,28 +9,38 @@ ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'
|
||||||
|
|
||||||
class TrustedHostMiddleware:
|
class TrustedHostMiddleware:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"]
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
allowed_hosts: typing.Sequence[str] = ["*"],
|
||||||
|
www_redirect: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
for pattern in allowed_hosts:
|
for pattern in allowed_hosts:
|
||||||
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||||
if pattern.startswith("*") and pattern != "*":
|
if pattern.startswith("*") and pattern != "*":
|
||||||
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
||||||
self.app = app
|
self.app = app
|
||||||
self.allowed_hosts = allowed_hosts
|
self.allowed_hosts = list(allowed_hosts)
|
||||||
self.allow_any = "*" in allowed_hosts
|
self.allow_any = "*" in allowed_hosts
|
||||||
|
self.www_redirect = www_redirect
|
||||||
|
|
||||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||||
if scope["type"] in ("http", "websocket") and not self.allow_any:
|
if scope["type"] in ("http", "websocket") and not self.allow_any:
|
||||||
headers = Headers(scope=scope)
|
headers = Headers(scope=scope)
|
||||||
host = headers.get("host", "").split(":")[0]
|
host = headers.get("host", "").split(":")[0]
|
||||||
|
found_www_redirect = False
|
||||||
for pattern in self.allowed_hosts:
|
for pattern in self.allowed_hosts:
|
||||||
if (
|
if host == pattern or (
|
||||||
host == pattern
|
pattern.startswith("*") and host.endswith(pattern[1:])
|
||||||
or pattern.startswith("*")
|
|
||||||
and host.endswith(pattern[1:])
|
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
elif "www." + host == pattern:
|
||||||
|
found_www_redirect = True
|
||||||
else:
|
else:
|
||||||
|
if found_www_redirect and self.www_redirect:
|
||||||
|
url = URL(scope=scope)
|
||||||
|
redirect_url = url.replace(netloc="www." + url.netloc)
|
||||||
|
print(redirect_url)
|
||||||
|
return RedirectResponse(url=str(redirect_url))
|
||||||
return PlainTextResponse("Invalid host header", status_code=400)
|
return PlainTextResponse("Invalid host header", status_code=400)
|
||||||
|
|
||||||
return self.app(scope)
|
return self.app(scope)
|
||||||
|
|
|
@ -26,3 +26,18 @@ def test_trusted_host_middleware():
|
||||||
client = TestClient(app, base_url="http://invalidhost")
|
client = TestClient(app, base_url="http://invalidhost")
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_www_redirect():
|
||||||
|
app = Starlette()
|
||||||
|
|
||||||
|
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def homepage(request):
|
||||||
|
return PlainTextResponse("OK", status_code=200)
|
||||||
|
|
||||||
|
client = TestClient(app, base_url="https://example.com")
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.url == "https://www.example.com/"
|
||||||
|
|
Loading…
Reference in New Issue