Add on_error parameter to AuthenticationMiddleware (#281)

* Add on_error parameter to AuthenticationMiddleware to customise responses when auth fails

* Fine-tuning, type hints
This commit is contained in:
Pierre Vanliefland 2018-12-18 13:32:28 +01:00 committed by Tom Christie
parent 5fcf4945e6
commit 155c8dd625
3 changed files with 63 additions and 4 deletions

View File

@ -130,3 +130,15 @@ async def homepage(request):
async def dashboard(request): async def dashboard(request):
... ...
``` ```
## Custom authentication error responses
You can customise the error response sent when a `AuthenticationError` is
raised by an auth backend:
```python
def on_auth_error(request: Request, exc: Exception):
return JSONResponse({"error": str(exc)}, status_code=401)
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
```

View File

@ -1,4 +1,5 @@
import functools import functools
import typing
from starlette.authentication import ( from starlette.authentication import (
AuthCredentials, AuthCredentials,
@ -7,14 +8,22 @@ from starlette.authentication import (
UnauthenticatedUser, UnauthenticatedUser,
) )
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import PlainTextResponse from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
class AuthenticationMiddleware: class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, backend: AuthenticationBackend) -> None: def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
) -> None:
self.app = app self.app = app
self.backend = backend self.backend = backend
self.on_error = (
on_error if on_error is not None else self.default_on_error
) # type: typing.Callable[[Request, AuthenticationError], Response]
def __call__(self, scope: Scope) -> ASGIInstance: def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ["http", "websockets"]: if scope["type"] in ["http", "websockets"]:
@ -26,7 +35,7 @@ class AuthenticationMiddleware:
try: try:
auth_result = await self.backend.authenticate(request) auth_result = await self.backend.authenticate(request)
except AuthenticationError as exc: except AuthenticationError as exc:
response = PlainTextResponse(str(exc), status_code=400) response = self.on_error(request, exc)
await response(receive, send) await response(receive, send)
return return
@ -35,3 +44,7 @@ class AuthenticationMiddleware:
scope["auth"], scope["user"] = auth_result scope["auth"], scope["user"] = auth_result
inner = self.app(scope) inner = self.app(scope)
await inner(receive, send) await inner(receive, send)
@staticmethod
def default_on_error(request: Request, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=400)

View File

@ -7,10 +7,10 @@ from starlette.authentication import (
AuthenticationBackend, AuthenticationBackend,
AuthenticationError, AuthenticationError,
SimpleUser, SimpleUser,
UnauthenticatedUser,
requires, requires,
) )
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.testclient import TestClient from starlette.testclient import TestClient
@ -138,3 +138,37 @@ def test_authentication_redirect():
response = client.get("/admin/sync", auth=("tomchristie", "example")) response = client.get("/admin/sync", auth=("tomchristie", "example"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"} assert response.json() == {"authenticated": True, "user": "tomchristie"}
def on_auth_error(request: Request, exc: Exception):
return JSONResponse({"error": str(exc)}, status_code=401)
other_app = Starlette()
other_app.add_middleware(
AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
)
@other_app.route("/control-panel")
@requires("authenticated")
def control_panel(request):
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
def test_custom_on_error():
with TestClient(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get(
"/control-panel", headers={"Authorization": "basic foobar"}
)
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}