mirror of https://github.com/encode/starlette.git
Domain wildcards for trusted host (#151)
* Support domain wildcards with TrustedHostMiddleware * Support domain wildcards with TrustedHostMiddleware * Include domain wildcards in TrustedHostMiddleware docs
This commit is contained in:
parent
cdb08bc644
commit
f53b172f8a
|
@ -13,9 +13,9 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|||
|
||||
app = Starlette()
|
||||
|
||||
# Ensure that all requests include an 'example.com' host header,
|
||||
# Ensure that all requests include an 'example.com' or '*.example.com' host header,
|
||||
# and strictly enforce https-only access.
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com'])
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'])
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
```
|
||||
|
||||
|
@ -88,12 +88,14 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|||
|
||||
|
||||
app = Starlette()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com'])
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com'])
|
||||
```
|
||||
|
||||
The following arguments are supported:
|
||||
|
||||
* `allowed_hosts` - A list of domain names that should be allowed as hostnames.
|
||||
* `allowed_hosts` - A list of domain names that should be allowed as hostnames. Wildcard
|
||||
domains such as `*.example.com` are supported for matching subdomains. To allow any
|
||||
hostname either use `allowed_hosts=["*"]` or omit the middleware.
|
||||
|
||||
If an incoming request does not validate correctly then a 400 response will be sent.
|
||||
|
||||
|
|
|
@ -4,10 +4,17 @@ from starlette.types import ASGIApp, ASGIInstance, Scope
|
|||
import typing
|
||||
|
||||
|
||||
ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'."
|
||||
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self, app: ASGIApp, allowed_hosts: typing.Sequence[str] = ["*"]
|
||||
) -> None:
|
||||
for pattern in allowed_hosts:
|
||||
assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||
if pattern.startswith("*") and pattern != "*":
|
||||
assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD
|
||||
self.app = app
|
||||
self.allowed_hosts = allowed_hosts
|
||||
self.allow_any = "*" in allowed_hosts
|
||||
|
@ -15,8 +22,15 @@ class TrustedHostMiddleware:
|
|||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] in ("http", "websocket") and not self.allow_any:
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host")
|
||||
if host not in self.allowed_hosts:
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
for pattern in self.allowed_hosts:
|
||||
if (
|
||||
host == pattern
|
||||
or pattern.startswith("*")
|
||||
and host.endswith(pattern[1:])
|
||||
):
|
||||
break
|
||||
else:
|
||||
return PlainTextResponse("Invalid host header", status_code=400)
|
||||
|
||||
return self.app(scope)
|
||||
|
|
|
@ -7,7 +7,9 @@ from starlette.testclient import TestClient
|
|||
def test_trusted_host_middleware():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
|
||||
app.add_middleware(
|
||||
TrustedHostMiddleware, allowed_hosts=["testserver", "*.testserver"]
|
||||
)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
|
@ -17,6 +19,10 @@ def test_trusted_host_middleware():
|
|||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app, base_url="http://subdomain.testserver")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app, base_url="http://invalidhost")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 400
|
||||
|
|
Loading…
Reference in New Issue