2024-07-27 09:31:16 +00:00
|
|
|
from typing import Any
|
2024-02-04 20:54:16 +00:00
|
|
|
|
2018-10-29 14:46:42 +00:00
|
|
|
import pytest
|
|
|
|
|
2022-02-01 14:08:24 +00:00
|
|
|
from starlette.applications import Starlette
|
|
|
|
from starlette.background import BackgroundTask
|
2018-11-08 11:59:15 +00:00
|
|
|
from starlette.middleware.errors import ServerErrorMiddleware
|
2024-02-04 20:54:16 +00:00
|
|
|
from starlette.requests import Request
|
2018-11-08 11:59:15 +00:00
|
|
|
from starlette.responses import JSONResponse, Response
|
2022-02-01 14:08:24 +00:00
|
|
|
from starlette.routing import Route
|
2024-02-04 20:54:16 +00:00
|
|
|
from starlette.types import Receive, Scope, Send
|
2024-07-27 09:31:16 +00:00
|
|
|
from tests.types import TestClientFactory
|
2018-07-18 12:04:14 +00:00
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_handler(
|
|
|
|
test_client_factory: TestClientFactory,
|
|
|
|
) -> None:
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
raise RuntimeError("Something went wrong")
|
2018-11-08 11:59:15 +00:00
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def error_500(request: Request, exc: Exception) -> JSONResponse:
|
2018-11-08 11:59:15 +00:00
|
|
|
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
|
|
|
|
|
|
|
app = ServerErrorMiddleware(app, handler=error_500)
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
2018-11-08 11:59:15 +00:00
|
|
|
response = client.get("/")
|
|
|
|
assert response.status_code == 500
|
|
|
|
assert response.json() == {"detail": "Server Error"}
|
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_debug_text(test_client_factory: TestClientFactory) -> None:
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
raise RuntimeError("Something went wrong")
|
2018-07-18 12:04:14 +00:00
|
|
|
|
2018-11-08 11:59:15 +00:00
|
|
|
app = ServerErrorMiddleware(app, debug=True)
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
2018-07-18 12:04:14 +00:00
|
|
|
response = client.get("/")
|
|
|
|
assert response.status_code == 500
|
2018-07-18 12:08:06 +00:00
|
|
|
assert response.headers["content-type"].startswith("text/plain")
|
2020-09-11 11:15:50 +00:00
|
|
|
assert "RuntimeError: Something went wrong" in response.text
|
2018-07-18 12:04:14 +00:00
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_debug_html(test_client_factory: TestClientFactory) -> None:
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
raise RuntimeError("Something went wrong")
|
2018-07-18 12:04:14 +00:00
|
|
|
|
2018-11-08 11:59:15 +00:00
|
|
|
app = ServerErrorMiddleware(app, debug=True)
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
2018-07-18 12:08:06 +00:00
|
|
|
response = client.get("/", headers={"Accept": "text/html, */*"})
|
2018-07-18 12:04:14 +00:00
|
|
|
assert response.status_code == 500
|
2018-07-18 12:08:06 +00:00
|
|
|
assert response.headers["content-type"].startswith("text/html")
|
|
|
|
assert "RuntimeError" in response.text
|
2018-07-18 12:04:14 +00:00
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
response = Response(b"", status_code=204)
|
|
|
|
await response(scope, receive, send)
|
|
|
|
raise RuntimeError("Something went wrong")
|
2018-07-18 12:04:14 +00:00
|
|
|
|
2018-11-08 11:59:15 +00:00
|
|
|
app = ServerErrorMiddleware(app, debug=True)
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app)
|
2018-07-18 12:04:14 +00:00
|
|
|
with pytest.raises(RuntimeError):
|
2018-09-28 09:41:22 +00:00
|
|
|
client.get("/")
|
2018-07-18 14:34:13 +00:00
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
|
2018-09-04 10:52:29 +00:00
|
|
|
"""
|
|
|
|
DebugMiddleware should just pass through any non-http messages as-is.
|
|
|
|
"""
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2018-07-18 14:34:13 +00:00
|
|
|
raise RuntimeError("Something went wrong")
|
|
|
|
|
2018-11-08 11:59:15 +00:00
|
|
|
app = ServerErrorMiddleware(app)
|
2018-07-18 14:34:13 +00:00
|
|
|
|
|
|
|
with pytest.raises(RuntimeError):
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app)
|
2021-06-18 14:48:43 +00:00
|
|
|
with client.websocket_connect("/"):
|
2024-09-29 08:28:34 +00:00
|
|
|
pass # pragma: no cover
|
2022-02-01 14:08:24 +00:00
|
|
|
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def test_background_task(test_client_factory: TestClientFactory) -> None:
|
2022-02-01 14:08:24 +00:00
|
|
|
accessed_error_handler = False
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def error_handler(request: Request, exc: Exception) -> Any:
|
2022-02-01 14:08:24 +00:00
|
|
|
nonlocal accessed_error_handler
|
|
|
|
accessed_error_handler = True
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
def raise_exception() -> None:
|
2022-02-01 14:08:24 +00:00
|
|
|
raise Exception("Something went wrong")
|
|
|
|
|
2024-02-04 20:54:16 +00:00
|
|
|
async def endpoint(request: Request) -> Response:
|
2022-02-01 14:08:24 +00:00
|
|
|
task = BackgroundTask(raise_exception)
|
|
|
|
return Response(status_code=204, background=task)
|
|
|
|
|
|
|
|
app = Starlette(
|
|
|
|
routes=[Route("/", endpoint=endpoint)],
|
|
|
|
exception_handlers={Exception: error_handler},
|
|
|
|
)
|
|
|
|
|
|
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.status_code == 204
|
|
|
|
assert accessed_error_handler
|