From 93a124805fc8724e312cf6716ea526428962cddf Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 14 Jan 2019 11:07:23 +0000 Subject: [PATCH] Drop body from responses on HEAD requests (#317) * Drop body from responses on HEAD requests * Linting * Endpoints supporting HEAD should automatically support GET --- starlette/routing.py | 8 +++++++- starlette/testclient.py | 3 ++- tests/test_applications.py | 4 ++++ tests/test_responses.py | 13 +++++++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index bf55e239..9ea092b6 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -162,7 +162,13 @@ class Route(BaseRoute): # Endpoint is a class. Treat it as ASGI. self.app = endpoint - self.methods = methods + if methods is None: + self.methods = None + else: + self.methods = set([method.upper() for method in methods]) + if "GET" in self.methods: + self.methods |= set(["HEAD"]) + self.path_regex, self.path_format, self.param_convertors = self.compile_path( path ) diff --git a/starlette/testclient.py b/starlette/testclient.py index 762250e9..01b843aa 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -173,7 +173,8 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) - raw_kwargs["body"].write(body) + if request.method != "HEAD": + raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) response_complete = True diff --git a/tests/test_applications.py b/tests/test_applications.py index 9a79b3ae..925bbd4f 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -100,6 +100,10 @@ def test_func_route(): assert response.status_code == 200 assert response.text == "Hello, world!" + response = client.head("/func") + assert response.status_code == 200 + assert response.text == "" + def test_async_route(): response = client.get("/async") diff --git a/tests/test_responses.py b/tests/test_responses.py index d4bafaf7..2c056ed5 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -287,3 +287,16 @@ def test_populate_headers(): assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" + + +def test_head_method(): + def app(scope): + async def asgi(receive, send): + response = Response("hello, world", media_type="text/plain") + await response(receive, send) + + return asgi + + client = TestClient(app) + response = client.head("/") + assert response.text == ""