diff --git a/docs/middleware.md b/docs/middleware.md index be4cb23c..a814b7e6 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -97,6 +97,25 @@ The following arguments are supported: If an incoming request does not validate correctly then a 400 response will be sent. +## GZipMiddleware + +Handles GZip responses for any request that includes `"gzip"` in the `Accept-Encoding` header. + +The middleware will handle both standard and streaming responses. + +```python +from starlette.applications import Starlette +from starlette.middleware.trustedhost import TrustedHostMiddleware + + +app = Starlette() +app.add_middleware(GZipMiddleware, minimum_size=1000) +``` + +The following arguments are supported: + +* `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`. + ## Using ASGI middleware without Starlette To wrap ASGI middleware around other ASGI applications, you should use the diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 2705fd31..649ff346 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -285,3 +285,9 @@ class MutableHeaders(Headers): def update(self, other: dict): for key, val in other.items(): self[key] = val + + def add_vary_header(self, vary): + existing = self.get("vary") + if existing is not None: + vary = ", ".join([existing, vary]) + self["vary"] = vary diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 8bc33804..8b0abeae 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -32,8 +32,6 @@ class CORSMiddleware: simple_headers = {} if "*" in allow_origins: simple_headers["Access-Control-Allow-Origin"] = "*" - else: - simple_headers["Vary"] = "Origin" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" if expose_headers: @@ -142,6 +140,7 @@ class CORSMiddleware: await send(message) return + print(message) message.setdefault("headers", []) headers = MutableHeaders(message["headers"]) origin = request_headers["Origin"] @@ -156,7 +155,7 @@ class CORSMiddleware: # the Origin header in the response. elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): headers["Access-Control-Allow-Origin"] = origin - if "vary" in headers: - self.simple_headers["Vary"] = f"{headers.get('vary')}, Origin" + headers.add_vary_header("Origin") + print(headers) headers.update(self.simple_headers) await send(message) diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py new file mode 100644 index 00000000..091c96e3 --- /dev/null +++ b/starlette/middleware/gzip.py @@ -0,0 +1,95 @@ +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, ASGIInstance, Scope, Receive, Send, Message +import gzip +import io +import typing + + +class GZipMiddleware: + def __init__(self, app: ASGIApp, minimum_size: int = 500) -> None: + self.app = app + self.minimum_size = minimum_size + + def __call__(self, scope: Scope) -> ASGIInstance: + if scope["type"] == "http": + headers = Headers(scope["headers"]) + if "gzip" in headers.get("Accept-Encoding", ""): + return GZipResponder(self.app, scope, self.minimum_size) + return self.app(scope) + + +class GZipResponder: + def __init__(self, app: ASGIApp, scope: Scope, minimum_size: int) -> None: + self.inner = app(scope) + self.minimum_size = minimum_size + self.send = unattached_send # type: Send + self.initial_message = {} # type: Message + self.started = False + self.gzip_buffer = io.BytesIO() + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer) + + async def __call__(self, receive: Receive, send: Send) -> None: + self.send = send + await self.inner(receive, self.send_with_gzip) + + async def send_with_gzip(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.start": + # Don't send the initial message until we've determined how to + # modify the ougoging headers correctly. + self.initial_message = message + elif message_type == "http.response.body" and not self.started: + self.started = True + body = message.get("body", b"") + more_body = message.get("more_body", False) + if len(body) < self.minimum_size and not more_body: + # Don't apply GZip to small outgoing responses. + await self.send(self.initial_message) + await self.send(message) + elif not more_body: + # Standard GZip response. + self.gzip_file.write(body) + self.gzip_file.close() + body = self.gzip_buffer.getvalue() + + headers = MutableHeaders(self.initial_message["headers"]) + headers["Content-Encoding"] = "gzip" + headers["Content-Length"] = str(len(body)) + headers.add_vary_header("Accept-Encoding") + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + else: + # Initial body in streaming GZip response. + headers = MutableHeaders(self.initial_message["headers"]) + headers["Content-Encoding"] = "gzip" + headers.add_vary_header("Accept-Encoding") + del headers["Content-Length"] + + self.gzip_file.write(body) + message["body"] = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + await self.send(self.initial_message) + await self.send(message) + + elif message_type == "http.response.body": + # Remaining body in streaming GZip response. + body = message.get("body", b"") + more_body = message.get("more_body", False) + + self.gzip_file.write(body) + if not more_body: + self.gzip_file.close() + + message["body"] = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + await self.send(message) + + +async def unattached_send(message: Message): + raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/tests/test_middleware.py b/tests/middleware/test_cors.py similarity index 86% rename from tests/test_middleware.py rename to tests/middleware/test_cors.py index ab3a0772..ddd00287 100644 --- a/tests/test_middleware.py +++ b/tests/middleware/test_cors.py @@ -1,48 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware -from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware -from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import PlainTextResponse from starlette.testclient import TestClient -def test_trusted_host_middleware(): - app = Starlette() - - app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"]) - - @app.route("/") - def homepage(request): - return PlainTextResponse("OK", status_code=200) - - client = TestClient(app) - response = client.get("/") - assert response.status_code == 200 - - client = TestClient(app, base_url="http://invalidhost") - response = client.get("/") - assert response.status_code == 400 - - -def test_https_redirect_middleware(): - app = Starlette() - - app.add_middleware(HTTPSRedirectMiddleware) - - @app.route("/") - def homepage(request): - return PlainTextResponse("OK", status_code=200) - - client = TestClient(app, base_url="https://testserver") - response = client.get("/") - assert response.status_code == 200 - - client = TestClient(app) - response = client.get("/", allow_redirects=False) - assert response.status_code == 301 - assert response.headers["location"] == "https://testserver/" - - def test_cors_allow_all(): app = Starlette() diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py new file mode 100644 index 00000000..cd989b8c --- /dev/null +++ b/tests/middleware/test_gzip.py @@ -0,0 +1,77 @@ +from starlette.applications import Starlette +from starlette.middleware.gzip import GZipMiddleware +from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.testclient import TestClient + + +def test_gzip_responses(): + app = Starlette() + + app.add_middleware(GZipMiddleware) + + @app.route("/") + def homepage(request): + return PlainTextResponse("x" * 4000, status_code=200) + + client = TestClient(app) + response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == "gzip" + assert int(response.headers["Content-Length"]) < 4000 + + +def test_gzip_not_in_accept_encoding(): + app = Starlette() + + app.add_middleware(GZipMiddleware) + + @app.route("/") + def homepage(request): + return PlainTextResponse("x" * 4000, status_code=200) + + client = TestClient(app) + response = client.get("/", headers={"accept-encoding": "identity"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + +def test_gzip_ignored_for_small_responses(): + app = Starlette() + + app.add_middleware(GZipMiddleware) + + @app.route("/") + def homepage(request): + return PlainTextResponse("OK", status_code=200) + + client = TestClient(app) + response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 + assert response.text == "OK" + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 2 + + +def test_gzip_streaming_response(): + app = Starlette() + + app.add_middleware(GZipMiddleware) + + @app.route("/") + def homepage(request): + async def generator(bytes, count): + for index in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse(streaming, status_code=200) + + client = TestClient(app) + response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == "gzip" + assert "Content-Length" not in response.headers diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py new file mode 100644 index 00000000..c36f24e4 --- /dev/null +++ b/tests/middleware/test_https_redirect.py @@ -0,0 +1,23 @@ +from starlette.applications import Starlette +from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware +from starlette.responses import PlainTextResponse +from starlette.testclient import TestClient + + +def test_https_redirect_middleware(): + app = Starlette() + + app.add_middleware(HTTPSRedirectMiddleware) + + @app.route("/") + def homepage(request): + return PlainTextResponse("OK", status_code=200) + + client = TestClient(app, base_url="https://testserver") + response = client.get("/") + assert response.status_code == 200 + + client = TestClient(app) + response = client.get("/", allow_redirects=False) + assert response.status_code == 301 + assert response.headers["location"] == "https://testserver/" diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py new file mode 100644 index 00000000..4b4e86ee --- /dev/null +++ b/tests/middleware/test_trusted_host.py @@ -0,0 +1,22 @@ +from starlette.applications import Starlette +from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.responses import PlainTextResponse +from starlette.testclient import TestClient + + +def test_trusted_host_middleware(): + app = Starlette() + + app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver"]) + + @app.route("/") + def homepage(request): + return PlainTextResponse("OK", status_code=200) + + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + + client = TestClient(app, base_url="http://invalidhost") + response = client.get("/") + assert response.status_code == 400