Add `request.is_disconnected()` (#320)

* Add request.is_disconnected()

* Add request.is_disconnected
This commit is contained in:
Tom Christie 2019-01-15 09:59:59 +00:00 committed by GitHub
parent af105b23d5
commit c220994eb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 0 deletions

View File

@ -109,3 +109,7 @@ class App:
If you access `.stream()` then the byte chunks are provided without storing
the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()`
will raise an error.
In some cases such as long-polling, or streaming responses you might need to
determine if the client has dropped the connection. You can determine this
state with `disconnected = await request.is_disconnected()`.

View File

@ -1,3 +1,4 @@
import asyncio
import http.cookies
import json
import typing
@ -120,6 +121,7 @@ class Request(HTTPConnection):
assert scope["type"] == "http"
self._receive = receive
self._stream_consumed = False
self._is_disconnected = False
@property
def method(self) -> str:
@ -148,6 +150,7 @@ class Request(HTTPConnection):
if not message.get("more_body", False):
break
elif message["type"] == "http.disconnect":
self._is_disconnected = True
raise ClientDisconnect()
yield b""
@ -187,3 +190,15 @@ class Request(HTTPConnection):
for item in self._form.values():
if hasattr(item, "close"):
await item.close() # type: ignore
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
try:
message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
except asyncio.TimeoutError as exc:
message = {}
if message.get("type") == "http.disconnect":
self._is_disconnected = True
return self._is_disconnected

View File

@ -129,6 +129,13 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
}
async def receive() -> Message:
nonlocal request_complete, response_complete
if request_complete:
while not response_complete:
await asyncio.sleep(0.0001)
return {"type": "http.disconnect"}
body = request.body
if isinstance(body, str):
body_bytes = body.encode("utf-8") # type: bytes
@ -141,9 +148,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration:
request_complete = True
return {"type": "http.request", "body": b""}
else:
body_bytes = body
request_complete = True
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
@ -182,6 +192,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
template = message["template"]
context = message["context"]
request_complete = False
response_started = False
response_complete = False
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]

View File

@ -254,6 +254,32 @@ def test_request_disconnect():
loop.run_until_complete(asgi_callable(receiver, None))
def test_request_is_disconnected():
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""
disconnected_after_response = None
def app(scope):
async def asgi(receive, send):
nonlocal disconnected_after_response
request = Request(scope, receive)
await request.body()
disconnected = await request.is_disconnected()
response = JSONResponse({"disconnected": disconnected})
await response(receive, send)
disconnected_after_response = await request.is_disconnected()
return asgi
client = TestClient(app)
response = client.get("/")
assert response.json() == {"disconnected": False}
assert disconnected_after_response
def test_request_cookies():
def app(scope):
async def asgi(receive, send):