Domain wildcards for trusted host (#151)

* Support domain wildcards with TrustedHostMiddleware

* Support domain wildcards with TrustedHostMiddleware

* Include domain wildcards in TrustedHostMiddleware docs
This commit is contained in:
Tom Christie 2018-10-29 09:22:13 +00:00 committed by GitHub
parent cdb08bc644
commit f53b172f8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 7 deletions

View File

@ -13,9 +13,9 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
app = Starlette()
# Ensure that all requests include an 'example.com' host header,
# Ensure that all requests include an 'example.com' or '*.example.com' host header,
# and strictly enforce https-only access.
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com'])
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'])
app.add_middleware(HTTPSRedirectMiddleware)
```
@ -88,12 +88,14 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com'])
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'])
```
The following arguments are supported:
* `allowed_hosts` - A list of domain names that should be allowed as hostnames.
* `allowed_hosts` - A list of domain names that should be allowed as hostnames. Wildcard
domains such as `*.example.com` are supported for matching subdomains. To allow any
hostname either use `allowed_hosts=["*"]` or omit the middleware.
If an incoming request does not validate correctly then a 400 response will be sent.

View File

@ -4,10 +4,17 @@ from starlette.types import ASGIApp, ASGIInstance, Scope
import typing
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
class TrustedHostMiddleware:
def __init__(
self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"]
) -> None:
for pattern in allowed_hosts:
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith("*") and pattern != "*":
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = allowed_hosts
self.allow_any = "*" in allowed_hosts
@ -15,8 +22,15 @@ class TrustedHostMiddleware:
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ("http", "websocket") and not self.allow_any:
headers = Headers(scope=scope)
host = headers.get("host")
if host not in self.allowed_hosts:
host = headers.get("host", "").split(":")[0]
for pattern in self.allowed_hosts:
if (
host == pattern
or pattern.startswith("*")
and host.endswith(pattern[1:])
):
break
else:
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)

View File

@ -7,7 +7,9 @@ from starlette.testclient import TestClient
def test_trusted_host_middleware():
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
app.add_middleware(
TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"]
)
@app.route("/")
def homepage(request):
@ -17,6 +19,10 @@ def test_trusted_host_middleware():
response = client.get("/")
assert response.status_code == 200
client = TestClient(app, base_url="http://subdomain.testserver")
response = client.get("/")
assert response.status_code == 200
client = TestClient(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400