2018-10-29 14:46:42 +00:00
|
|
|
import pytest
|
|
|
|
|
2018-09-04 10:52:29 +00:00
|
|
|
from starlette.exceptions import ExceptionMiddleware, HTTPException
|
2018-09-05 09:29:04 +00:00
|
|
|
from starlette.responses import PlainTextResponse
|
2018-10-29 14:46:42 +00:00
|
|
|
from starlette.routing import Route, Router, WebSocketRoute
|
2018-09-04 10:52:29 +00:00
|
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
|
|
|
2018-10-29 09:22:45 +00:00
|
|
|
def raise_runtime_error(request):
|
|
|
|
raise RuntimeError("Yikes")
|
2018-09-04 10:52:29 +00:00
|
|
|
|
|
|
|
|
2018-10-29 09:22:45 +00:00
|
|
|
def not_acceptable(request):
|
|
|
|
raise HTTPException(status_code=406)
|
2018-09-04 10:52:29 +00:00
|
|
|
|
|
|
|
|
2018-10-29 09:22:45 +00:00
|
|
|
def not_modified(request):
|
|
|
|
raise HTTPException(status_code=304)
|
2018-09-04 10:52:29 +00:00
|
|
|
|
|
|
|
|
2018-10-29 09:22:45 +00:00
|
|
|
class HandledExcAfterResponse:
|
|
|
|
def __init__(self, scope):
|
|
|
|
pass
|
2018-09-04 10:52:29 +00:00
|
|
|
|
2018-10-29 09:22:45 +00:00
|
|
|
async def __call__(self, receive, send):
|
2018-09-04 10:52:29 +00:00
|
|
|
response = PlainTextResponse("OK", status_code=200)
|
|
|
|
await response(receive, send)
|
|
|
|
raise HTTPException(status_code=406)
|
|
|
|
|
|
|
|
|
|
|
|
router = Router(
|
|
|
|
routes=[
|
2018-10-29 09:22:45 +00:00
|
|
|
Route("/runtime_error", endpoint=raise_runtime_error),
|
|
|
|
Route("/not_acceptable", endpoint=not_acceptable),
|
|
|
|
Route("/not_modified", endpoint=not_modified),
|
|
|
|
Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse),
|
|
|
|
WebSocketRoute("/runtime_error", endpoint=raise_runtime_error),
|
2018-09-04 10:52:29 +00:00
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
app = ExceptionMiddleware(router)
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
|
|
def test_not_acceptable():
|
|
|
|
response = client.get("/not_acceptable")
|
|
|
|
assert response.status_code == 406
|
|
|
|
assert response.text == "Not Acceptable"
|
|
|
|
|
|
|
|
|
|
|
|
def test_not_modified():
|
|
|
|
response = client.get("/not_modified")
|
|
|
|
assert response.status_code == 304
|
|
|
|
assert response.text == ""
|
|
|
|
|
|
|
|
|
|
|
|
def test_websockets_should_raise():
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
client.websocket_connect("/runtime_error")
|
|
|
|
|
|
|
|
|
|
|
|
def test_handled_exc_after_response():
|
|
|
|
# A 406 HttpException is raised *after* the response has already been sent.
|
|
|
|
# The exception middleware should raise a RuntimeError.
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
client.get("/handled_exc_after_response")
|
|
|
|
|
|
|
|
# If `raise_server_exceptions=False` then the test client will still allow
|
|
|
|
# us to see the response as it will have been seen by the client.
|
|
|
|
allow_200_client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
response = allow_200_client.get("/handled_exc_after_response")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "OK"
|
|
|
|
|
|
|
|
|
|
|
|
def test_force_500_response():
|
|
|
|
def app(scope):
|
|
|
|
raise RuntimeError()
|
|
|
|
|
|
|
|
force_500_client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
response = force_500_client.get("/")
|
|
|
|
assert response.status_code == 500
|
|
|
|
assert response.text == ""
|