diff --git a/starlette/authentication.py b/starlette/authentication.py index 5a768368..c1726773 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -4,13 +4,14 @@ import inspect import typing from starlette.exceptions import HTTPException -from starlette.requests import Request, HTTPConnection +from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response +from starlette.websockets import WebSocket -def has_required_scope(request: Request, scopes: typing.Sequence[str]) -> bool: +def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: - if scope not in request.auth.scopes: + if scope not in conn.auth.scopes: return False return True @@ -23,17 +24,39 @@ def requires( scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: + type = None sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): - if parameter.name == "request": + if parameter.name == "request" or parameter.name == "websocket": + type = parameter.name break else: - raise Exception(f'No "request" argument on function "{func}"') - - if asyncio.iscoroutinefunction(func): + raise Exception( + f'No "request" or "websocket" argument on function "{func}"' + ) + if type == "websocket": + # Handle websocket functions. (Always async) @functools.wraps(func) - async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: + async def websocket_wrapper( + *args: typing.Any, **kwargs: typing.Any + ) -> None: + websocket = kwargs.get("websocket", args[idx]) + assert isinstance(websocket, WebSocket) + + if not has_required_scope(websocket, scopes_list): + await websocket.close() + else: + await func(*args, **kwargs) + + return websocket_wrapper + + elif asyncio.iscoroutinefunction(func): + # Handle async request/response functions. + @functools.wraps(func) + async def async_wrapper( + *args: typing.Any, **kwargs: typing.Any + ) -> Response: request = kwargs.get("request", args[idx]) assert isinstance(request, Request) @@ -43,21 +66,22 @@ def requires( raise HTTPException(status_code=status_code) return await func(*args, **kwargs) - return wrapper + return async_wrapper - @functools.wraps(func) - def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: - # Support either `func(request)` or `func(self, request)` - request = kwargs.get("request", args[idx]) - assert isinstance(request, Request) + else: + # Handle sync request/response functions. + @functools.wraps(func) + def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: + request = kwargs.get("request", args[idx]) + assert isinstance(request, Request) - if not has_required_scope(request, scopes_list): - if redirect is not None: - return RedirectResponse(url=request.url_for(redirect)) - raise HTTPException(status_code=status_code) - return func(*args, **kwargs) + if not has_required_scope(request, scopes_list): + if redirect is not None: + return RedirectResponse(url=request.url_for(redirect)) + raise HTTPException(status_code=status_code) + return func(*args, **kwargs) - return sync_wrapper + return sync_wrapper return decorator diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 6f8d95cd..b9bceb13 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -33,12 +33,15 @@ class AuthenticationMiddleware: return self.app(scope) async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None: - conn = HTTPConnection(scope, receive=receive) + conn = HTTPConnection(scope) try: auth_result = await self.backend.authenticate(conn) except AuthenticationError as exc: response = self.on_error(conn, exc) - await response(receive, send) + if scope["type"] == "websocket": + await send({"type": "websocket.close", "code": 1000}) + else: + await response(receive, send) return if auth_result is None: diff --git a/starlette/websockets.py b/starlette/websockets.py index 6240f98c..72f22d1f 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -27,42 +27,6 @@ class WebSocket(HTTPConnection): self.client_state = WebSocketState.CONNECTING self.application_state = WebSocketState.CONNECTING - # def __getitem__(self, key: str) -> str: - # return self._scope[key] - # - # def __iter__(self) -> typing.Iterator: - # return iter(self._scope) - # - # def __len__(self) -> int: - # return len(self._scope) - # - # @property - # def url(self) -> URL: - # if not hasattr(self, "_url"): - # self._url = URL(scope=self._scope) - # return self._url - # - # @property - # def headers(self) -> Headers: - # if not hasattr(self, "_headers"): - # self._headers = Headers(scope=self._scope) - # return self._headers - # - # @property - # def query_params(self) -> QueryParams: - # if not hasattr(self, "_query_params"): - # self._query_params = QueryParams(scope=self._scope) - # return self._query_params - # - # @property - # def path_params(self) -> dict: - # return self._scope.get("path_params", {}) - # - # def url_for(self, name: str, **path_params: typing.Any) -> str: - # router = self._scope["router"] - # url_path = router.url_path_for(name, **path_params) - # return url_path.make_absolute_url(base_url=self.url) - async def receive(self) -> Message: """ Receive ASGI websocket messages, ensuring valid state transitions. diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 038a0978..d83e141e 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -16,6 +16,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from starlette.testclient import TestClient +from starlette.websockets import WebSocketDisconnect class BasicAuth(AuthenticationBackend): @@ -104,6 +105,18 @@ def admin(request): ) +@app.websocket_route("/ws") +@requires("authenticated") +async def websocket_endpoint(websocket): + await websocket.accept() + await websocket.send_json( + { + "authenticated": websocket.user.is_authenticated, + "user": websocket.user.display_name, + } + ) + + def test_invalid_decorator_usage(): with pytest.raises(Exception): @@ -151,6 +164,21 @@ def test_authentication_required(): assert response.text == "Invalid basic auth credentials" +def test_websocket_authentication_required(): + with TestClient(app) as client: + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/ws") + + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}) + + with client.websocket_connect( + "/ws", auth=("tomchristie", "example") + ) as websocket: + data = websocket.receive_json() + assert data == {"authenticated": True, "user": "tomchristie"} + + def test_authentication_redirect(): with TestClient(app) as client: response = client.get("/admin")