diff --git a/docs/requests.md b/docs/requests.md index 295fb854..d0cbfe77 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -109,3 +109,7 @@ class App: If you access `.stream()` then the byte chunks are provided without storing the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()` will raise an error. + +In some cases such as long-polling, or streaming responses you might need to +determine if the client has dropped the connection. You can determine this +state with `disconnected = await request.is_disconnected()`. diff --git a/starlette/requests.py b/starlette/requests.py index 1029d977..5b703cea 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,3 +1,4 @@ +import asyncio import http.cookies import json import typing @@ -120,6 +121,7 @@ class Request(HTTPConnection): assert scope["type"] == "http" self._receive = receive self._stream_consumed = False + self._is_disconnected = False @property def method(self) -> str: @@ -148,6 +150,7 @@ class Request(HTTPConnection): if not message.get("more_body", False): break elif message["type"] == "http.disconnect": + self._is_disconnected = True raise ClientDisconnect() yield b"" @@ -187,3 +190,15 @@ class Request(HTTPConnection): for item in self._form.values(): if hasattr(item, "close"): await item.close() # type: ignore + + async def is_disconnected(self) -> bool: + if not self._is_disconnected: + try: + message = await asyncio.wait_for(self._receive(), timeout=0.0000001) + except asyncio.TimeoutError as exc: + message = {} + + if message.get("type") == "http.disconnect": + self._is_disconnected = True + + return self._is_disconnected diff --git a/starlette/testclient.py b/starlette/testclient.py index 01b843aa..dd3f2976 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -129,6 +129,13 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): } async def receive() -> Message: + nonlocal request_complete, response_complete + + if request_complete: + while not response_complete: + await asyncio.sleep(0.0001) + return {"type": "http.disconnect"} + body = request.body if isinstance(body, str): body_bytes = body.encode("utf-8") # type: bytes @@ -141,9 +148,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): chunk = chunk.encode("utf-8") return {"type": "http.request", "body": chunk, "more_body": True} except StopIteration: + request_complete = True return {"type": "http.request", "body": b""} else: body_bytes = body + + request_complete = True return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: @@ -182,6 +192,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): template = message["template"] context = message["context"] + request_complete = False response_started = False response_complete = False raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] diff --git a/tests/test_requests.py b/tests/test_requests.py index 9e02c9ff..32d0ab6b 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -254,6 +254,32 @@ def test_request_disconnect(): loop.run_until_complete(asgi_callable(receiver, None)) +def test_request_is_disconnected(): + """ + If a client disconnect occurs while reading request body + then ClientDisconnect should be raised. + """ + disconnected_after_response = None + + def app(scope): + async def asgi(receive, send): + nonlocal disconnected_after_response + + request = Request(scope, receive) + await request.body() + disconnected = await request.is_disconnected() + response = JSONResponse({"disconnected": disconnected}) + await response(receive, send) + disconnected_after_response = await request.is_disconnected() + + return asgi + + client = TestClient(app) + response = client.get("/") + assert response.json() == {"disconnected": False} + assert disconnected_after_response + + def test_request_cookies(): def app(scope): async def asgi(receive, send):