Add on_error parameter to AuthenticationMiddleware (#281)

* Add on_error parameter to AuthenticationMiddleware to customise responses when auth fails

* Fine-tuning, type hints
This commit is contained in:
Pierre Vanliefland 2018-12-18 13:32:28 +01:00 committed by Tom Christie
parent 5fcf4945e6
commit 155c8dd625
3 changed files with 63 additions and 4 deletions

View File

@ -130,3 +130,15 @@ async def homepage(request):
async def dashboard(request):
...
```
## Custom authentication error responses
You can customise the error response sent when a `AuthenticationError` is
raised by an auth backend:
```python
def on_auth_error(request: Request, exc: Exception):
return JSONResponse({"error": str(exc)}, status_code=401)
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
```

View File

@ -1,4 +1,5 @@
import functools
import typing
from starlette.authentication import (
AuthCredentials,
@ -7,14 +8,22 @@ from starlette.authentication import (
UnauthenticatedUser,
)
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, backend: AuthenticationBackend) -> None:
def __init__(
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
) -> None:
self.app = app
self.backend = backend
self.on_error = (
on_error if on_error is not None else self.default_on_error
) # type: typing.Callable[[Request, AuthenticationError], Response]
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ["http", "websockets"]:
@ -26,7 +35,7 @@ class AuthenticationMiddleware:
try:
auth_result = await self.backend.authenticate(request)
except AuthenticationError as exc:
response = PlainTextResponse(str(exc), status_code=400)
response = self.on_error(request, exc)
await response(receive, send)
return
@ -35,3 +44,7 @@ class AuthenticationMiddleware:
scope["auth"], scope["user"] = auth_result
inner = self.app(scope)
await inner(receive, send)
@staticmethod
def default_on_error(request: Request, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=400)

View File

@ -7,10 +7,10 @@ from starlette.authentication import (
AuthenticationBackend,
AuthenticationError,
SimpleUser,
UnauthenticatedUser,
requires,
)
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
@ -138,3 +138,37 @@ def test_authentication_redirect():
response = client.get("/admin/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
def on_auth_error(request: Request, exc: Exception):
return JSONResponse({"error": str(exc)}, status_code=401)
other_app = Starlette()
other_app.add_middleware(
AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
)
@other_app.route("/control-panel")
@requires("authenticated")
def control_panel(request):
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
def test_custom_on_error():
with TestClient(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get(
"/control-panel", headers={"Authorization": "basic foobar"}
)
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}