mirror of https://github.com/encode/starlette.git
Add on_error parameter to AuthenticationMiddleware (#281)
* Add on_error parameter to AuthenticationMiddleware to customise responses when auth fails * Fine-tuning, type hints
This commit is contained in:
parent
5fcf4945e6
commit
155c8dd625
|
@ -130,3 +130,15 @@ async def homepage(request):
|
|||
async def dashboard(request):
|
||||
...
|
||||
```
|
||||
|
||||
## Custom authentication error responses
|
||||
|
||||
You can customise the error response sent when a `AuthenticationError` is
|
||||
raised by an auth backend:
|
||||
|
||||
```python
|
||||
def on_auth_error(request: Request, exc: Exception):
|
||||
return JSONResponse({"error": str(exc)}, status_code=401)
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
|
||||
```
|
|
@ -1,4 +1,5 @@
|
|||
import functools
|
||||
import typing
|
||||
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
|
@ -7,14 +8,22 @@ from starlette.authentication import (
|
|||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
def __init__(self, app: ASGIApp, backend: AuthenticationBackend) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
) # type: typing.Callable[[Request, AuthenticationError], Response]
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] in ["http", "websockets"]:
|
||||
|
@ -26,7 +35,7 @@ class AuthenticationMiddleware:
|
|||
try:
|
||||
auth_result = await self.backend.authenticate(request)
|
||||
except AuthenticationError as exc:
|
||||
response = PlainTextResponse(str(exc), status_code=400)
|
||||
response = self.on_error(request, exc)
|
||||
await response(receive, send)
|
||||
return
|
||||
|
||||
|
@ -35,3 +44,7 @@ class AuthenticationMiddleware:
|
|||
scope["auth"], scope["user"] = auth_result
|
||||
inner = self.app(scope)
|
||||
await inner(receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(request: Request, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
||||
|
|
|
@ -7,10 +7,10 @@ from starlette.authentication import (
|
|||
AuthenticationBackend,
|
||||
AuthenticationError,
|
||||
SimpleUser,
|
||||
UnauthenticatedUser,
|
||||
requires,
|
||||
)
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
@ -138,3 +138,37 @@ def test_authentication_redirect():
|
|||
response = client.get("/admin/sync", auth=("tomchristie", "example"))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"authenticated": True, "user": "tomchristie"}
|
||||
|
||||
|
||||
def on_auth_error(request: Request, exc: Exception):
|
||||
return JSONResponse({"error": str(exc)}, status_code=401)
|
||||
|
||||
|
||||
other_app = Starlette()
|
||||
other_app.add_middleware(
|
||||
AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
|
||||
)
|
||||
|
||||
|
||||
@other_app.route("/control-panel")
|
||||
@requires("authenticated")
|
||||
def control_panel(request):
|
||||
return JSONResponse(
|
||||
{
|
||||
"authenticated": request.user.is_authenticated,
|
||||
"user": request.user.display_name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_custom_on_error():
|
||||
with TestClient(other_app) as client:
|
||||
response = client.get("/control-panel", auth=("tomchristie", "example"))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"authenticated": True, "user": "tomchristie"}
|
||||
|
||||
response = client.get(
|
||||
"/control-panel", headers={"Authorization": "basic foobar"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"error": "Invalid basic auth credentials"}
|
||||
|
|
Loading…
Reference in New Issue