starlette/tests/middleware/test_errors.py

106 lines
3.7 KiB
Python

from typing import Any
import pytest
from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.types import Receive, Scope, Send
from tests.types import TestClientFactory
def test_handler(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
def error_500(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)
app = ServerErrorMiddleware(app, handler=error_500)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
def test_debug_text(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/plain")
assert "RuntimeError: Something went wrong" in response.text
def test_debug_html(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/", headers={"Accept": "text/html, */*"})
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/html")
assert "RuntimeError" in response.text
def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"", status_code=204)
await response(scope, receive, send)
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
client.get("/")
def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app)
with pytest.raises(RuntimeError):
client = test_client_factory(app)
with client.websocket_connect("/"):
pass # pragma: no cover
def test_background_task(test_client_factory: TestClientFactory) -> None:
accessed_error_handler = False
def error_handler(request: Request, exc: Exception) -> Any:
nonlocal accessed_error_handler
accessed_error_handler = True
def raise_exception() -> None:
raise Exception("Something went wrong")
async def endpoint(request: Request) -> Response:
task = BackgroundTask(raise_exception)
return Response(status_code=204, background=task)
app = Starlette(
routes=[Route("/", endpoint=endpoint)],
exception_handlers={Exception: error_handler},
)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 204
assert accessed_error_handler