mirror of https://github.com/encode/starlette.git
Add `request.is_disconnected()` (#320)
* Add request.is_disconnected() * Add request.is_disconnected
This commit is contained in:
parent
af105b23d5
commit
c220994eb1
|
@ -109,3 +109,7 @@ class App:
|
|||
If you access `.stream()` then the byte chunks are provided without storing
|
||||
the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()`
|
||||
will raise an error.
|
||||
|
||||
In some cases such as long-polling, or streaming responses you might need to
|
||||
determine if the client has dropped the connection. You can determine this
|
||||
state with `disconnected = await request.is_disconnected()`.
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import http.cookies
|
||||
import json
|
||||
import typing
|
||||
|
@ -120,6 +121,7 @@ class Request(HTTPConnection):
|
|||
assert scope["type"] == "http"
|
||||
self._receive = receive
|
||||
self._stream_consumed = False
|
||||
self._is_disconnected = False
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
|
@ -148,6 +150,7 @@ class Request(HTTPConnection):
|
|||
if not message.get("more_body", False):
|
||||
break
|
||||
elif message["type"] == "http.disconnect":
|
||||
self._is_disconnected = True
|
||||
raise ClientDisconnect()
|
||||
yield b""
|
||||
|
||||
|
@ -187,3 +190,15 @@ class Request(HTTPConnection):
|
|||
for item in self._form.values():
|
||||
if hasattr(item, "close"):
|
||||
await item.close() # type: ignore
|
||||
|
||||
async def is_disconnected(self) -> bool:
|
||||
if not self._is_disconnected:
|
||||
try:
|
||||
message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
|
||||
except asyncio.TimeoutError as exc:
|
||||
message = {}
|
||||
|
||||
if message.get("type") == "http.disconnect":
|
||||
self._is_disconnected = True
|
||||
|
||||
return self._is_disconnected
|
||||
|
|
|
@ -129,6 +129,13 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
}
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal request_complete, response_complete
|
||||
|
||||
if request_complete:
|
||||
while not response_complete:
|
||||
await asyncio.sleep(0.0001)
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
body = request.body
|
||||
if isinstance(body, str):
|
||||
body_bytes = body.encode("utf-8") # type: bytes
|
||||
|
@ -141,9 +148,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b""}
|
||||
else:
|
||||
body_bytes = body
|
||||
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
|
||||
async def send(message: Message) -> None:
|
||||
|
@ -182,6 +192,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
template = message["template"]
|
||||
context = message["context"]
|
||||
|
||||
request_complete = False
|
||||
response_started = False
|
||||
response_complete = False
|
||||
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
|
||||
|
|
|
@ -254,6 +254,32 @@ def test_request_disconnect():
|
|||
loop.run_until_complete(asgi_callable(receiver, None))
|
||||
|
||||
|
||||
def test_request_is_disconnected():
|
||||
"""
|
||||
If a client disconnect occurs while reading request body
|
||||
then ClientDisconnect should be raised.
|
||||
"""
|
||||
disconnected_after_response = None
|
||||
|
||||
def app(scope):
|
||||
async def asgi(receive, send):
|
||||
nonlocal disconnected_after_response
|
||||
|
||||
request = Request(scope, receive)
|
||||
await request.body()
|
||||
disconnected = await request.is_disconnected()
|
||||
response = JSONResponse({"disconnected": disconnected})
|
||||
await response(receive, send)
|
||||
disconnected_after_response = await request.is_disconnected()
|
||||
|
||||
return asgi
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.json() == {"disconnected": False}
|
||||
assert disconnected_after_response
|
||||
|
||||
|
||||
def test_request_cookies():
|
||||
def app(scope):
|
||||
async def asgi(receive, send):
|
||||
|
|
Loading…
Reference in New Issue