mirror of https://github.com/encode/starlette.git
Add GZip middleware (#111)
This commit is contained in:
parent
d05ef2328d
commit
097152be5a
|
@ -97,6 +97,25 @@ The following arguments are supported:
|
|||
|
||||
If an incoming request does not validate correctly then a 400 response will be sent.
|
||||
|
||||
## GZipMiddleware
|
||||
|
||||
Handles GZip responses for any request that includes `"gzip"` in the `Accept-Encoding` header.
|
||||
|
||||
The middleware will handle both standard and streaming responses.
|
||||
|
||||
```python
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
|
||||
app = Starlette()
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
```
|
||||
|
||||
The following arguments are supported:
|
||||
|
||||
* `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`.
|
||||
|
||||
## Using ASGI middleware without Starlette
|
||||
|
||||
To wrap ASGI middleware around other ASGI applications, you should use the
|
||||
|
|
|
@ -285,3 +285,9 @@ class MutableHeaders(Headers):
|
|||
def update(self, other: dict):
|
||||
for key, val in other.items():
|
||||
self[key] = val
|
||||
|
||||
def add_vary_header(self, vary):
|
||||
existing = self.get("vary")
|
||||
if existing is not None:
|
||||
vary = ", ".join([existing, vary])
|
||||
self["vary"] = vary
|
||||
|
|
|
@ -32,8 +32,6 @@ class CORSMiddleware:
|
|||
simple_headers = {}
|
||||
if "*" in allow_origins:
|
||||
simple_headers["Access-Control-Allow-Origin"] = "*"
|
||||
else:
|
||||
simple_headers["Vary"] = "Origin"
|
||||
if allow_credentials:
|
||||
simple_headers["Access-Control-Allow-Credentials"] = "true"
|
||||
if expose_headers:
|
||||
|
@ -142,6 +140,7 @@ class CORSMiddleware:
|
|||
await send(message)
|
||||
return
|
||||
|
||||
print(message)
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(message["headers"])
|
||||
origin = request_headers["Origin"]
|
||||
|
@ -156,7 +155,7 @@ class CORSMiddleware:
|
|||
# the Origin header in the response.
|
||||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
if "vary" in headers:
|
||||
self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin"
|
||||
headers.add_vary_header("Origin")
|
||||
print(headers)
|
||||
headers.update(self.simple_headers)
|
||||
await send(message)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
from starlette.datastructures import Headers, MutableHeaders
|
||||
from starlette.types import ASGIApp, ASGIInstance, Scope, Receive, Send, Message
|
||||
import gzip
|
||||
import io
|
||||
import typing
|
||||
|
||||
|
||||
class GZipMiddleware:
|
||||
def __init__(self, app: ASGIApp, minimum_size: int = 500) -> None:
|
||||
self.app = app
|
||||
self.minimum_size = minimum_size
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] == "http":
|
||||
headers = Headers(scope["headers"])
|
||||
if "gzip" in headers.get("Accept-Encoding", ""):
|
||||
return GZipResponder(self.app, scope, self.minimum_size)
|
||||
return self.app(scope)
|
||||
|
||||
|
||||
class GZipResponder:
|
||||
def __init__(self, app: ASGIApp, scope: Scope, minimum_size: int) -> None:
|
||||
self.inner = app(scope)
|
||||
self.minimum_size = minimum_size
|
||||
self.send = unattached_send # type: Send
|
||||
self.initial_message = {} # type: Message
|
||||
self.started = False
|
||||
self.gzip_buffer = io.BytesIO()
|
||||
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer)
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
self.send = send
|
||||
await self.inner(receive, self.send_with_gzip)
|
||||
|
||||
async def send_with_gzip(self, message: Message) -> None:
|
||||
message_type = message["type"]
|
||||
if message_type == "http.response.start":
|
||||
# Don't send the initial message until we've determined how to
|
||||
# modify the ougoging headers correctly.
|
||||
self.initial_message = message
|
||||
elif message_type == "http.response.body" and not self.started:
|
||||
self.started = True
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
if len(body) < self.minimum_size and not more_body:
|
||||
# Don't apply GZip to small outgoing responses.
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
elif not more_body:
|
||||
# Standard GZip response.
|
||||
self.gzip_file.write(body)
|
||||
self.gzip_file.close()
|
||||
body = self.gzip_buffer.getvalue()
|
||||
|
||||
headers = MutableHeaders(self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers["Content-Length"] = str(len(body))
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
message["body"] = body
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
else:
|
||||
# Initial body in streaming GZip response.
|
||||
headers = MutableHeaders(self.initial_message["headers"])
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
headers.add_vary_header("Accept-Encoding")
|
||||
del headers["Content-Length"]
|
||||
|
||||
self.gzip_file.write(body)
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(self.initial_message)
|
||||
await self.send(message)
|
||||
|
||||
elif message_type == "http.response.body":
|
||||
# Remaining body in streaming GZip response.
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
self.gzip_file.write(body)
|
||||
if not more_body:
|
||||
self.gzip_file.close()
|
||||
|
||||
message["body"] = self.gzip_buffer.getvalue()
|
||||
self.gzip_buffer.seek(0)
|
||||
self.gzip_buffer.truncate()
|
||||
|
||||
await self.send(message)
|
||||
|
||||
|
||||
async def unattached_send(message: Message):
|
||||
raise RuntimeError("send awaitable not set") # pragma: no cover
|
|
@ -1,48 +1,9 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def test_trusted_host_middleware():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app, base_url="http://invalidhost")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
def test_https_redirect_middleware():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
client = TestClient(app, base_url="https://testserver")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", allow_redirects=False)
|
||||
assert response.status_code == 301
|
||||
assert response.headers["location"] == "https://testserver/"
|
||||
|
||||
|
||||
def test_cors_allow_all():
|
||||
app = Starlette()
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.responses import PlainTextResponse, StreamingResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def test_gzip_responses():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("x" * 4000, status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", headers={"accept-encoding": "gzip"})
|
||||
assert response.status_code == 200
|
||||
assert response.text == "x" * 4000
|
||||
assert response.headers["Content-Encoding"] == "gzip"
|
||||
assert int(response.headers["Content-Length"]) < 4000
|
||||
|
||||
|
||||
def test_gzip_not_in_accept_encoding():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("x" * 4000, status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", headers={"accept-encoding": "identity"})
|
||||
assert response.status_code == 200
|
||||
assert response.text == "x" * 4000
|
||||
assert "Content-Encoding" not in response.headers
|
||||
assert int(response.headers["Content-Length"]) == 4000
|
||||
|
||||
|
||||
def test_gzip_ignored_for_small_responses():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", headers={"accept-encoding": "gzip"})
|
||||
assert response.status_code == 200
|
||||
assert response.text == "OK"
|
||||
assert "Content-Encoding" not in response.headers
|
||||
assert int(response.headers["Content-Length"]) == 2
|
||||
|
||||
|
||||
def test_gzip_streaming_response():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
async def generator(bytes, count):
|
||||
for index in range(count):
|
||||
yield bytes
|
||||
|
||||
streaming = generator(bytes=b"x" * 400, count=10)
|
||||
return StreamingResponse(streaming, status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", headers={"accept-encoding": "gzip"})
|
||||
assert response.status_code == 200
|
||||
assert response.text == "x" * 4000
|
||||
assert response.headers["Content-Encoding"] == "gzip"
|
||||
assert "Content-Length" not in response.headers
|
|
@ -0,0 +1,23 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def test_https_redirect_middleware():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(HTTPSRedirectMiddleware)
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
client = TestClient(app, base_url="https://testserver")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/", allow_redirects=False)
|
||||
assert response.status_code == 301
|
||||
assert response.headers["location"] == "https://testserver/"
|
|
@ -0,0 +1,22 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
def test_trusted_host_middleware():
|
||||
app = Starlette()
|
||||
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"])
|
||||
|
||||
@app.route("/")
|
||||
def homepage(request):
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
client = TestClient(app, base_url="http://invalidhost")
|
||||
response = client.get("/")
|
||||
assert response.status_code == 400
|
Loading…
Reference in New Issue