mirror of https://github.com/encode/starlette.git
Add `request.is_disconnected()` (#320)
* Add request.is_disconnected() * Add request.is_disconnected
This commit is contained in:
parent
af105b23d5
commit
c220994eb1
|
@ -109,3 +109,7 @@ class App:
|
||||||
If you access `.stream()` then the byte chunks are provided without storing
|
If you access `.stream()` then the byte chunks are provided without storing
|
||||||
the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()`
|
the entire body to memory. Any subsequent calls to `.body()`, `.form()`, or `.json()`
|
||||||
will raise an error.
|
will raise an error.
|
||||||
|
|
||||||
|
In some cases such as long-polling, or streaming responses you might need to
|
||||||
|
determine if the client has dropped the connection. You can determine this
|
||||||
|
state with `disconnected = await request.is_disconnected()`.
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import http.cookies
|
import http.cookies
|
||||||
import json
|
import json
|
||||||
import typing
|
import typing
|
||||||
|
@ -120,6 +121,7 @@ class Request(HTTPConnection):
|
||||||
assert scope["type"] == "http"
|
assert scope["type"] == "http"
|
||||||
self._receive = receive
|
self._receive = receive
|
||||||
self._stream_consumed = False
|
self._stream_consumed = False
|
||||||
|
self._is_disconnected = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def method(self) -> str:
|
def method(self) -> str:
|
||||||
|
@ -148,6 +150,7 @@ class Request(HTTPConnection):
|
||||||
if not message.get("more_body", False):
|
if not message.get("more_body", False):
|
||||||
break
|
break
|
||||||
elif message["type"] == "http.disconnect":
|
elif message["type"] == "http.disconnect":
|
||||||
|
self._is_disconnected = True
|
||||||
raise ClientDisconnect()
|
raise ClientDisconnect()
|
||||||
yield b""
|
yield b""
|
||||||
|
|
||||||
|
@ -187,3 +190,15 @@ class Request(HTTPConnection):
|
||||||
for item in self._form.values():
|
for item in self._form.values():
|
||||||
if hasattr(item, "close"):
|
if hasattr(item, "close"):
|
||||||
await item.close() # type: ignore
|
await item.close() # type: ignore
|
||||||
|
|
||||||
|
async def is_disconnected(self) -> bool:
|
||||||
|
if not self._is_disconnected:
|
||||||
|
try:
|
||||||
|
message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
message = {}
|
||||||
|
|
||||||
|
if message.get("type") == "http.disconnect":
|
||||||
|
self._is_disconnected = True
|
||||||
|
|
||||||
|
return self._is_disconnected
|
||||||
|
|
|
@ -129,6 +129,13 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def receive() -> Message:
|
async def receive() -> Message:
|
||||||
|
nonlocal request_complete, response_complete
|
||||||
|
|
||||||
|
if request_complete:
|
||||||
|
while not response_complete:
|
||||||
|
await asyncio.sleep(0.0001)
|
||||||
|
return {"type": "http.disconnect"}
|
||||||
|
|
||||||
body = request.body
|
body = request.body
|
||||||
if isinstance(body, str):
|
if isinstance(body, str):
|
||||||
body_bytes = body.encode("utf-8") # type: bytes
|
body_bytes = body.encode("utf-8") # type: bytes
|
||||||
|
@ -141,9 +148,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||||
chunk = chunk.encode("utf-8")
|
chunk = chunk.encode("utf-8")
|
||||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
request_complete = True
|
||||||
return {"type": "http.request", "body": b""}
|
return {"type": "http.request", "body": b""}
|
||||||
else:
|
else:
|
||||||
body_bytes = body
|
body_bytes = body
|
||||||
|
|
||||||
|
request_complete = True
|
||||||
return {"type": "http.request", "body": body_bytes}
|
return {"type": "http.request", "body": body_bytes}
|
||||||
|
|
||||||
async def send(message: Message) -> None:
|
async def send(message: Message) -> None:
|
||||||
|
@ -182,6 +192,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||||
template = message["template"]
|
template = message["template"]
|
||||||
context = message["context"]
|
context = message["context"]
|
||||||
|
|
||||||
|
request_complete = False
|
||||||
response_started = False
|
response_started = False
|
||||||
response_complete = False
|
response_complete = False
|
||||||
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
|
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
|
||||||
|
|
|
@ -254,6 +254,32 @@ def test_request_disconnect():
|
||||||
loop.run_until_complete(asgi_callable(receiver, None))
|
loop.run_until_complete(asgi_callable(receiver, None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_is_disconnected():
|
||||||
|
"""
|
||||||
|
If a client disconnect occurs while reading request body
|
||||||
|
then ClientDisconnect should be raised.
|
||||||
|
"""
|
||||||
|
disconnected_after_response = None
|
||||||
|
|
||||||
|
def app(scope):
|
||||||
|
async def asgi(receive, send):
|
||||||
|
nonlocal disconnected_after_response
|
||||||
|
|
||||||
|
request = Request(scope, receive)
|
||||||
|
await request.body()
|
||||||
|
disconnected = await request.is_disconnected()
|
||||||
|
response = JSONResponse({"disconnected": disconnected})
|
||||||
|
await response(receive, send)
|
||||||
|
disconnected_after_response = await request.is_disconnected()
|
||||||
|
|
||||||
|
return asgi
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.json() == {"disconnected": False}
|
||||||
|
assert disconnected_after_response
|
||||||
|
|
||||||
|
|
||||||
def test_request_cookies():
|
def test_request_cookies():
|
||||||
def app(scope):
|
def app(scope):
|
||||||
async def asgi(receive, send):
|
async def asgi(receive, send):
|
||||||
|
|
Loading…
Reference in New Issue