starlette/tests/test_authentication.py

361 lines
12 KiB
Python

from __future__ import annotations
import base64
import binascii
from typing import Any, Awaitable, Callable
from urllib.parse import urlencode
import pytest
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
AuthenticationError,
SimpleUser,
requires,
)
from starlette.endpoints import HTTPEndpoint
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection, Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocket, WebSocketDisconnect
from tests.types import TestClientFactory
AsyncEndpoint = Callable[..., Awaitable[Response]]
SyncEndpoint = Callable[..., Response]
class BasicAuth(AuthenticationBackend):
async def authenticate(
self,
request: HTTPConnection,
) -> tuple[AuthCredentials, SimpleUser] | None:
if "Authorization" not in request.headers:
return None
auth = request.headers["Authorization"]
try:
scheme, credentials = auth.split()
decoded = base64.b64decode(credentials).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid basic auth credentials")
username, _, password = decoded.partition(":")
return AuthCredentials(["authenticated"]), SimpleUser(username)
def homepage(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
async def dashboard(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated", redirect="homepage")
async def admin(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
def dashboard_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
class Dashboard(HTTPEndpoint):
@requires("authenticated")
def get(self, request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated", redirect="homepage")
def admin_sync(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
@requires("authenticated")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
await websocket.send_json(
{
"authenticated": websocket.user.is_authenticated,
"user": websocket.user.display_name,
}
)
def async_inject_decorator(
**kwargs: Any,
) -> Callable[[AsyncEndpoint], Callable[..., Awaitable[Response]]]:
def wrapper(endpoint: AsyncEndpoint) -> Callable[..., Awaitable[Response]]:
async def app(request: Request) -> Response:
return await endpoint(request=request, **kwargs)
return app
return wrapper
@async_inject_decorator(additional="payload")
@requires("authenticated")
async def decorated_async(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
"additional": additional,
}
)
def sync_inject_decorator(
**kwargs: Any,
) -> Callable[[SyncEndpoint], Callable[..., Response]]:
def wrapper(endpoint: SyncEndpoint) -> Callable[..., Response]:
def app(request: Request) -> Response:
return endpoint(request=request, **kwargs)
return app
return wrapper
@sync_inject_decorator(additional="payload")
@requires("authenticated")
def decorated_sync(request: Request, additional: str) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
"additional": additional,
}
)
def ws_inject_decorator(**kwargs: Any) -> Callable[..., AsyncEndpoint]:
def wrapper(endpoint: AsyncEndpoint) -> AsyncEndpoint:
def app(websocket: WebSocket) -> Awaitable[Response]:
return endpoint(websocket=websocket, **kwargs)
return app
return wrapper
@ws_inject_decorator(additional="payload")
@requires("authenticated")
async def websocket_endpoint_decorated(websocket: WebSocket, additional: str) -> None:
await websocket.accept()
await websocket.send_json(
{
"authenticated": websocket.user.is_authenticated,
"user": websocket.user.display_name,
"additional": additional,
}
)
app = Starlette(
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth())],
routes=[
Route("/", endpoint=homepage),
Route("/dashboard", endpoint=dashboard),
Route("/admin", endpoint=admin),
Route("/dashboard/sync", endpoint=dashboard_sync),
Route("/dashboard/class", endpoint=Dashboard),
Route("/admin/sync", endpoint=admin_sync),
Route("/dashboard/decorated", endpoint=decorated_async),
Route("/dashboard/decorated/sync", endpoint=decorated_sync),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
WebSocketRoute("/ws/decorated", endpoint=websocket_endpoint_decorated),
],
)
def test_invalid_decorator_usage() -> None:
with pytest.raises(Exception):
@requires("authenticated")
def foo() -> None:
pass # pragma: no cover
def test_user_interface(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"authenticated": False, "user": ""}
response = client.get("/", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
def test_authentication_required(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/dashboard")
assert response.status_code == 403
response = client.get("/dashboard", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/sync")
assert response.status_code == 403
response = client.get("/dashboard/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/class")
assert response.status_code == 403
response = client.get("/dashboard/class", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/dashboard/decorated", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
response = client.get("/dashboard/decorated")
assert response.status_code == 403
response = client.get("/dashboard/decorated/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
response = client.get("/dashboard/decorated/sync")
assert response.status_code == 403
response = client.get("/dashboard", headers={"Authorization": "basic foobar"})
assert response.status_code == 400
assert response.text == "Invalid basic auth credentials"
def test_websocket_authentication_required(
test_client_factory: TestClientFactory,
) -> None:
with test_client_factory(app) as client:
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws"):
pass # pragma: no cover
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}):
pass # pragma: no cover
with client.websocket_connect("/ws", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {"authenticated": True, "user": "tomchristie"}
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/decorated"):
pass # pragma: no cover
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/decorated", headers={"Authorization": "basic foobar"}):
pass # pragma: no cover
with client.websocket_connect("/ws/decorated", auth=("tomchristie", "example")) as websocket:
data = websocket.receive_json()
assert data == {
"authenticated": True,
"user": "tomchristie",
"additional": "payload",
}
def test_authentication_redirect(test_client_factory: TestClientFactory) -> None:
with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin"}))
assert response.url == url
response = client.get("/admin", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/admin/sync")
assert response.status_code == 200
url = "{}?{}".format("http://testserver/", urlencode({"next": "http://testserver/admin/sync"}))
assert response.url == url
response = client.get("/admin/sync", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
def on_auth_error(request: HTTPConnection, exc: AuthenticationError) -> JSONResponse:
return JSONResponse({"error": str(exc)}, status_code=401)
@requires("authenticated")
def control_panel(request: Request) -> JSONResponse:
return JSONResponse(
{
"authenticated": request.user.is_authenticated,
"user": request.user.display_name,
}
)
other_app = Starlette(
routes=[Route("/control-panel", control_panel)],
middleware=[Middleware(AuthenticationMiddleware, backend=BasicAuth(), on_error=on_auth_error)],
)
def test_custom_on_error(test_client_factory: TestClientFactory) -> None:
with test_client_factory(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
response = client.get("/control-panel", headers={"Authorization": "basic foobar"})
assert response.status_code == 401
assert response.json() == {"error": "Invalid basic auth credentials"}