diff --git a/docs/authentication.md b/docs/authentication.md index c1b035c8..7d3f293a 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -130,3 +130,15 @@ async def homepage(request): async def dashboard(request): ... ``` + +## Custom authentication error responses + +You can customise the error response sent when a `AuthenticationError` is +raised by an auth backend: + +```python +def on_auth_error(request: Request, exc: Exception): + return JSONResponse({"error": str(exc)}, status_code=401) + +app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error) +``` \ No newline at end of file diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index b9d90e8b..62b09ec6 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -1,4 +1,5 @@ import functools +import typing from starlette.authentication import ( AuthCredentials, @@ -7,14 +8,22 @@ from starlette.authentication import ( UnauthenticatedUser, ) from starlette.requests import Request -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send class AuthenticationMiddleware: - def __init__(self, app: ASGIApp, backend: AuthenticationBackend) -> None: + def __init__( + self, + app: ASGIApp, + backend: AuthenticationBackend, + on_error: typing.Callable[[Request, AuthenticationError], Response] = None, + ) -> None: self.app = app self.backend = backend + self.on_error = ( + on_error if on_error is not None else self.default_on_error + ) # type: typing.Callable[[Request, AuthenticationError], Response] def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] in ["http", "websockets"]: @@ -26,7 +35,7 @@ class AuthenticationMiddleware: try: auth_result = await self.backend.authenticate(request) except AuthenticationError as exc: - response = PlainTextResponse(str(exc), status_code=400) + response = self.on_error(request, exc) await response(receive, send) return @@ -35,3 +44,7 @@ class AuthenticationMiddleware: scope["auth"], scope["user"] = auth_result inner = self.app(scope) await inner(receive, send) + + @staticmethod + def default_on_error(request: Request, exc: Exception) -> Response: + return PlainTextResponse(str(exc), status_code=400) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9341c71f..da22fc88 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -7,10 +7,10 @@ from starlette.authentication import ( AuthenticationBackend, AuthenticationError, SimpleUser, - UnauthenticatedUser, requires, ) from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request from starlette.responses import JSONResponse from starlette.testclient import TestClient @@ -138,3 +138,37 @@ def test_authentication_redirect(): response = client.get("/admin/sync", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} + + +def on_auth_error(request: Request, exc: Exception): + return JSONResponse({"error": str(exc)}, status_code=401) + + +other_app = Starlette() +other_app.add_middleware( + AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error +) + + +@other_app.route("/control-panel") +@requires("authenticated") +def control_panel(request): + return JSONResponse( + { + "authenticated": request.user.is_authenticated, + "user": request.user.display_name, + } + ) + + +def test_custom_on_error(): + with TestClient(other_app) as client: + response = client.get("/control-panel", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"} + + response = client.get( + "/control-panel", headers={"Authorization": "basic foobar"} + ) + assert response.status_code == 401 + assert response.json() == {"error": "Invalid basic auth credentials"}