WebSocket support for auth (#347)

This commit is contained in:
Tom Christie 2019-01-25 14:00:35 +00:00 committed by GitHub
parent 88bd2a2ce9
commit 16db18f519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 58 deletions

View File

@ -4,13 +4,14 @@ import inspect
import typing
from starlette.exceptions import HTTPException
from starlette.requests import Request, HTTPConnection
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
from starlette.websockets import WebSocket
def has_required_scope(request: Request, scopes: typing.Sequence[str]) -> bool:
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
for scope in scopes:
if scope not in request.auth.scopes:
if scope not in conn.auth.scopes:
return False
return True
@ -23,17 +24,39 @@ def requires(
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(func: typing.Callable) -> typing.Callable:
type = None
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request":
if parameter.name == "request" or parameter.name == "websocket":
type = parameter.name
break
else:
raise Exception(f'No "request" argument on function "{func}"')
if asyncio.iscoroutinefunction(func):
raise Exception(
f'No "request" or "websocket" argument on function "{func}"'
)
if type == "websocket":
# Handle websocket functions. (Always async)
@functools.wraps(func)
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
async def websocket_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> None:
websocket = kwargs.get("websocket", args[idx])
assert isinstance(websocket, WebSocket)
if not has_required_scope(websocket, scopes_list):
await websocket.close()
else:
await func(*args, **kwargs)
return websocket_wrapper
elif asyncio.iscoroutinefunction(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
*args: typing.Any, **kwargs: typing.Any
) -> Response:
request = kwargs.get("request", args[idx])
assert isinstance(request, Request)
@ -43,21 +66,22 @@ def requires(
raise HTTPException(status_code=status_code)
return await func(*args, **kwargs)
return wrapper
return async_wrapper
@functools.wraps(func)
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
# Support either `func(request)` or `func(self, request)`
request = kwargs.get("request", args[idx])
assert isinstance(request, Request)
else:
# Handle sync request/response functions.
@functools.wraps(func)
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
request = kwargs.get("request", args[idx])
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
if not has_required_scope(request, scopes_list):
if redirect is not None:
return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code)
return func(*args, **kwargs)
return sync_wrapper
return sync_wrapper
return decorator

View File

@ -33,12 +33,15 @@ class AuthenticationMiddleware:
return self.app(scope)
async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
conn = HTTPConnection(scope, receive=receive)
conn = HTTPConnection(scope)
try:
auth_result = await self.backend.authenticate(conn)
except AuthenticationError as exc:
response = self.on_error(conn, exc)
await response(receive, send)
if scope["type"] == "websocket":
await send({"type": "websocket.close", "code": 1000})
else:
await response(receive, send)
return
if auth_result is None:

View File

@ -27,42 +27,6 @@ class WebSocket(HTTPConnection):
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
# def __getitem__(self, key: str) -> str:
# return self._scope[key]
#
# def __iter__(self) -> typing.Iterator:
# return iter(self._scope)
#
# def __len__(self) -> int:
# return len(self._scope)
#
# @property
# def url(self) -> URL:
# if not hasattr(self, "_url"):
# self._url = URL(scope=self._scope)
# return self._url
#
# @property
# def headers(self) -> Headers:
# if not hasattr(self, "_headers"):
# self._headers = Headers(scope=self._scope)
# return self._headers
#
# @property
# def query_params(self) -> QueryParams:
# if not hasattr(self, "_query_params"):
# self._query_params = QueryParams(scope=self._scope)
# return self._query_params
#
# @property
# def path_params(self) -> dict:
# return self._scope.get("path_params", {})
#
# def url_for(self, name: str, **path_params: typing.Any) -> str:
# router = self._scope["router"]
# url_path = router.url_path_for(name, **path_params)
# return url_path.make_absolute_url(base_url=self.url)
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.

View File

@ -16,6 +16,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
class BasicAuth(AuthenticationBackend):
@ -104,6 +105,18 @@ def admin(request):
)
@app.websocket_route("/ws")
@requires("authenticated")
async def websocket_endpoint(websocket):
await websocket.accept()
await websocket.send_json(
{
"authenticated": websocket.user.is_authenticated,
"user": websocket.user.display_name,
}
)
def test_invalid_decorator_usage():
with pytest.raises(Exception):
@ -151,6 +164,21 @@ def test_authentication_required():
assert response.text == "Invalid basic auth credentials"
def test_websocket_authentication_required():
with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect):
client.websocket_connect("/ws")
with pytest.raises(WebSocketDisconnect):
client.websocket_connect("/ws", headers={"Authorization": "basic foobar"})
with client.websocket_connect(
"/ws", auth=("tomchristie", "example")
) as websocket:
data = websocket.receive_json()
assert data == {"authenticated": True, "user": "tomchristie"}
def test_authentication_redirect():
with TestClient(app) as client:
response = client.get("/admin")