mirror of https://github.com/encode/starlette.git
Add on_error parameter to AuthenticationMiddleware (#281)
* Add on_error parameter to AuthenticationMiddleware to customise responses when auth fails * Fine-tuning, type hints
This commit is contained in:
parent
5fcf4945e6
commit
155c8dd625
|
@ -130,3 +130,15 @@ async def homepage(request):
|
||||||
async def dashboard(request):
|
async def dashboard(request):
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Custom authentication error responses
|
||||||
|
|
||||||
|
You can customise the error response sent when a `AuthenticationError` is
|
||||||
|
raised by an auth backend:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def on_auth_error(request: Request, exc: Exception):
|
||||||
|
return JSONResponse({"error": str(exc)}, status_code=401)
|
||||||
|
|
||||||
|
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
|
||||||
|
```
|
|
@ -1,4 +1,5 @@
|
||||||
import functools
|
import functools
|
||||||
|
import typing
|
||||||
|
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
AuthCredentials,
|
AuthCredentials,
|
||||||
|
@ -7,14 +8,22 @@ from starlette.authentication import (
|
||||||
UnauthenticatedUser,
|
UnauthenticatedUser,
|
||||||
)
|
)
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import PlainTextResponse
|
from starlette.responses import PlainTextResponse, Response
|
||||||
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
|
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
def __init__(self, app: ASGIApp, backend: AuthenticationBackend) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
backend: AuthenticationBackend,
|
||||||
|
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
|
||||||
|
) -> None:
|
||||||
self.app = app
|
self.app = app
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
self.on_error = (
|
||||||
|
on_error if on_error is not None else self.default_on_error
|
||||||
|
) # type: typing.Callable[[Request, AuthenticationError], Response]
|
||||||
|
|
||||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||||
if scope["type"] in ["http", "websockets"]:
|
if scope["type"] in ["http", "websockets"]:
|
||||||
|
@ -26,7 +35,7 @@ class AuthenticationMiddleware:
|
||||||
try:
|
try:
|
||||||
auth_result = await self.backend.authenticate(request)
|
auth_result = await self.backend.authenticate(request)
|
||||||
except AuthenticationError as exc:
|
except AuthenticationError as exc:
|
||||||
response = PlainTextResponse(str(exc), status_code=400)
|
response = self.on_error(request, exc)
|
||||||
await response(receive, send)
|
await response(receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -35,3 +44,7 @@ class AuthenticationMiddleware:
|
||||||
scope["auth"], scope["user"] = auth_result
|
scope["auth"], scope["user"] = auth_result
|
||||||
inner = self.app(scope)
|
inner = self.app(scope)
|
||||||
await inner(receive, send)
|
await inner(receive, send)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def default_on_error(request: Request, exc: Exception) -> Response:
|
||||||
|
return PlainTextResponse(str(exc), status_code=400)
|
||||||
|
|
|
@ -7,10 +7,10 @@ from starlette.authentication import (
|
||||||
AuthenticationBackend,
|
AuthenticationBackend,
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
SimpleUser,
|
SimpleUser,
|
||||||
UnauthenticatedUser,
|
|
||||||
requires,
|
requires,
|
||||||
)
|
)
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
@ -138,3 +138,37 @@ def test_authentication_redirect():
|
||||||
response = client.get("/admin/sync", auth=("tomchristie", "example"))
|
response = client.get("/admin/sync", auth=("tomchristie", "example"))
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"authenticated": True, "user": "tomchristie"}
|
assert response.json() == {"authenticated": True, "user": "tomchristie"}
|
||||||
|
|
||||||
|
|
||||||
|
def on_auth_error(request: Request, exc: Exception):
|
||||||
|
return JSONResponse({"error": str(exc)}, status_code=401)
|
||||||
|
|
||||||
|
|
||||||
|
other_app = Starlette()
|
||||||
|
other_app.add_middleware(
|
||||||
|
AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@other_app.route("/control-panel")
|
||||||
|
@requires("authenticated")
|
||||||
|
def control_panel(request):
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"authenticated": request.user.is_authenticated,
|
||||||
|
"user": request.user.display_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_on_error():
|
||||||
|
with TestClient(other_app) as client:
|
||||||
|
response = client.get("/control-panel", auth=("tomchristie", "example"))
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"authenticated": True, "user": "tomchristie"}
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/control-panel", headers={"Authorization": "basic foobar"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert response.json() == {"error": "Invalid basic auth credentials"}
|
||||||
|
|
Loading…
Reference in New Issue