Add GZip middleware (#111)

This commit is contained in:
Tom Christie 2018-10-15 12:08:10 +01:00 committed by GitHub
parent d05ef2328d
commit 097152be5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 245 additions and 43 deletions

View File

@ -97,6 +97,25 @@ The following arguments are supported:
If an incoming request does not validate correctly then a 400 response will be sent.
## GZipMiddleware
Handles GZip responses for any request that includes `"gzip"` in the `Accept-Encoding` header.
The middleware will handle both standard and streaming responses.
```python
from starlette.applications import Starlette
from starlette.middleware.trustedhost import TrustedHostMiddleware
app = Starlette()
app.add_middleware(GZipMiddleware, minimum_size=1000)
```
The following arguments are supported:
* `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`.
## Using ASGI middleware without Starlette
To wrap ASGI middleware around other ASGI applications, you should use the

View File

@ -285,3 +285,9 @@ class MutableHeaders(Headers):
def update(self, other: dict):
for key, val in other.items():
self[key] = val
def add_vary_header(self, vary):
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])
self["vary"] = vary

View File

@ -32,8 +32,6 @@ class CORSMiddleware:
simple_headers = {}
if "*" in allow_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
else:
simple_headers["Vary"] = "Origin"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
@ -142,6 +140,7 @@ class CORSMiddleware:
await send(message)
return
print(message)
message.setdefault("headers", [])
headers = MutableHeaders(message["headers"])
origin = request_headers["Origin"]
@ -156,7 +155,7 @@ class CORSMiddleware:
# the Origin header in the response.
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
headers["Access-Control-Allow-Origin"] = origin
if "vary" in headers:
self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
headers.add_vary_header("Origin")
print(headers)
headers.update(self.simple_headers)
await send(message)

View File

@ -0,0 +1,95 @@
from starlette.datastructures import Headers, MutableHeaders
from starlette.types import ASGIApp, ASGIInstance, Scope, Receive, Send, Message
import gzip
import io
import typing
class GZipMiddleware:
def __init__(self, app: ASGIApp, minimum_size: int = 500) -> None:
self.app = app
self.minimum_size = minimum_size
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] == "http":
headers = Headers(scope["headers"])
if "gzip" in headers.get("Accept-Encoding", ""):
return GZipResponder(self.app, scope, self.minimum_size)
return self.app(scope)
class GZipResponder:
def __init__(self, app: ASGIApp, scope: Scope, minimum_size: int) -> None:
self.inner = app(scope)
self.minimum_size = minimum_size
self.send = unattached_send # type: Send
self.initial_message = {} # type: Message
self.started = False
self.gzip_buffer = io.BytesIO()
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer)
async def __call__(self, receive: Receive, send: Send) -> None:
self.send = send
await self.inner(receive, self.send_with_gzip)
async def send_with_gzip(self, message: Message) -> None:
message_type = message["type"]
if message_type == "http.response.start":
# Don't send the initial message until we've determined how to
# modify the ougoging headers correctly.
self.initial_message = message
elif message_type == "http.response.body" and not self.started:
self.started = True
body = message.get("body", b"")
more_body = message.get("more_body", False)
if len(body) < self.minimum_size and not more_body:
# Don't apply GZip to small outgoing responses.
await self.send(self.initial_message)
await self.send(message)
elif not more_body:
# Standard GZip response.
self.gzip_file.write(body)
self.gzip_file.close()
body = self.gzip_buffer.getvalue()
headers = MutableHeaders(self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers["Content-Length"] = str(len(body))
headers.add_vary_header("Accept-Encoding")
message["body"] = body
await self.send(self.initial_message)
await self.send(message)
else:
# Initial body in streaming GZip response.
headers = MutableHeaders(self.initial_message["headers"])
headers["Content-Encoding"] = "gzip"
headers.add_vary_header("Accept-Encoding")
del headers["Content-Length"]
self.gzip_file.write(body)
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body":
# Remaining body in streaming GZip response.
body = message.get("body", b"")
more_body = message.get("more_body", False)
self.gzip_file.write(body)
if not more_body:
self.gzip_file.close()
message["body"] = self.gzip_buffer.getvalue()
self.gzip_buffer.seek(0)
self.gzip_buffer.truncate()
await self.send(message)
async def unattached_send(message: Message):
raise RuntimeError("send awaitable not set") # pragma: no cover

View File

@ -1,48 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
def test_trusted_host_middleware():
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
client = TestClient(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400
def test_https_redirect_middleware():
app = Starlette()
app.add_middleware(HTTPSRedirectMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app, base_url="https://testserver")
response = client.get("/")
assert response.status_code == 200
client = TestClient(app)
response = client.get("/", allow_redirects=False)
assert response.status_code == 301
assert response.headers["location"] == "https://testserver/"
def test_cors_allow_all():
app = Starlette()

View File

@ -0,0 +1,77 @@
from starlette.applications import Starlette
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.testclient import TestClient
def test_gzip_responses():
app = Starlette()
app.add_middleware(GZipMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)
client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert int(response.headers["Content-Length"]) < 4000
def test_gzip_not_in_accept_encoding():
app = Starlette()
app.add_middleware(GZipMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)
client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 4000
def test_gzip_ignored_for_small_responses():
app = Starlette()
app.add_middleware(GZipMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "OK"
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 2
def test_gzip_streaming_response():
app = Starlette()
app.add_middleware(GZipMiddleware)
@app.route("/")
def homepage(request):
async def generator(bytes, count):
for index in range(count):
yield bytes
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)
client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert response.headers["Content-Encoding"] == "gzip"
assert "Content-Length" not in response.headers

View File

@ -0,0 +1,23 @@
from starlette.applications import Starlette
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
def test_https_redirect_middleware():
app = Starlette()
app.add_middleware(HTTPSRedirectMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app, base_url="https://testserver")
response = client.get("/")
assert response.status_code == 200
client = TestClient(app)
response = client.get("/", allow_redirects=False)
assert response.status_code == 301
assert response.headers["location"] == "https://testserver/"

View File

@ -0,0 +1,22 @@
from starlette.applications import Starlette
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
def test_trusted_host_middleware():
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
client = TestClient(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400