mirror of https://github.com/encode/starlette.git
WebSocket support for auth (#347)
This commit is contained in:
parent
88bd2a2ce9
commit
16db18f519
|
@ -4,13 +4,14 @@ import inspect
|
|||
import typing
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request, HTTPConnection
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
def has_required_scope(request: Request, scopes: typing.Sequence[str]) -> bool:
|
||||
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||||
for scope in scopes:
|
||||
if scope not in request.auth.scopes:
|
||||
if scope not in conn.auth.scopes:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -23,17 +24,39 @@ def requires(
|
|||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
type = None
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request":
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
type = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(f'No "request" argument on function "{func}"')
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
|
||||
if type == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
async def websocket_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
websocket = kwargs.get("websocket", args[idx])
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
await websocket.close()
|
||||
else:
|
||||
await func(*args, **kwargs)
|
||||
|
||||
return websocket_wrapper
|
||||
|
||||
elif asyncio.iscoroutinefunction(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> Response:
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
|
@ -43,21 +66,22 @@ def requires(
|
|||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return async_wrapper
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
# Support either `func(request)` or `func(self, request)`
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(url=request.url_for(redirect))
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(url=request.url_for(redirect))
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
|
|
@ -33,12 +33,15 @@ class AuthenticationMiddleware:
|
|||
return self.app(scope)
|
||||
|
||||
async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
|
||||
conn = HTTPConnection(scope, receive=receive)
|
||||
conn = HTTPConnection(scope)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(conn, exc)
|
||||
await response(receive, send)
|
||||
if scope["type"] == "websocket":
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
else:
|
||||
await response(receive, send)
|
||||
return
|
||||
|
||||
if auth_result is None:
|
||||
|
|
|
@ -27,42 +27,6 @@ class WebSocket(HTTPConnection):
|
|||
self.client_state = WebSocketState.CONNECTING
|
||||
self.application_state = WebSocketState.CONNECTING
|
||||
|
||||
# def __getitem__(self, key: str) -> str:
|
||||
# return self._scope[key]
|
||||
#
|
||||
# def __iter__(self) -> typing.Iterator:
|
||||
# return iter(self._scope)
|
||||
#
|
||||
# def __len__(self) -> int:
|
||||
# return len(self._scope)
|
||||
#
|
||||
# @property
|
||||
# def url(self) -> URL:
|
||||
# if not hasattr(self, "_url"):
|
||||
# self._url = URL(scope=self._scope)
|
||||
# return self._url
|
||||
#
|
||||
# @property
|
||||
# def headers(self) -> Headers:
|
||||
# if not hasattr(self, "_headers"):
|
||||
# self._headers = Headers(scope=self._scope)
|
||||
# return self._headers
|
||||
#
|
||||
# @property
|
||||
# def query_params(self) -> QueryParams:
|
||||
# if not hasattr(self, "_query_params"):
|
||||
# self._query_params = QueryParams(scope=self._scope)
|
||||
# return self._query_params
|
||||
#
|
||||
# @property
|
||||
# def path_params(self) -> dict:
|
||||
# return self._scope.get("path_params", {})
|
||||
#
|
||||
# def url_for(self, name: str, **path_params: typing.Any) -> str:
|
||||
# router = self._scope["router"]
|
||||
# url_path = router.url_path_for(name, **path_params)
|
||||
# return url_path.make_absolute_url(base_url=self.url)
|
||||
|
||||
async def receive(self) -> Message:
|
||||
"""
|
||||
Receive ASGI websocket messages, ensuring valid state transitions.
|
||||
|
|
|
@ -16,6 +16,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
|
||||
class BasicAuth(AuthenticationBackend):
|
||||
|
@ -104,6 +105,18 @@ def admin(request):
|
|||
)
|
||||
|
||||
|
||||
@app.websocket_route("/ws")
|
||||
@requires("authenticated")
|
||||
async def websocket_endpoint(websocket):
|
||||
await websocket.accept()
|
||||
await websocket.send_json(
|
||||
{
|
||||
"authenticated": websocket.user.is_authenticated,
|
||||
"user": websocket.user.display_name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_decorator_usage():
|
||||
with pytest.raises(Exception):
|
||||
|
||||
|
@ -151,6 +164,21 @@ def test_authentication_required():
|
|||
assert response.text == "Invalid basic auth credentials"
|
||||
|
||||
|
||||
def test_websocket_authentication_required():
|
||||
with TestClient(app) as client:
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
client.websocket_connect("/ws")
|
||||
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
client.websocket_connect("/ws", headers={"Authorization": "basic foobar"})
|
||||
|
||||
with client.websocket_connect(
|
||||
"/ws", auth=("tomchristie", "example")
|
||||
) as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data == {"authenticated": True, "user": "tomchristie"}
|
||||
|
||||
|
||||
def test_authentication_redirect():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/admin")
|
||||
|
|
Loading…
Reference in New Issue