From 88bd2a2ce94d3dbbd15712fe7ee792d976aa6f3e Mon Sep 17 00:00:00 2001 From: tchan09 <46059271+tchan09@users.noreply.github.com> Date: Fri, 25 Jan 2019 08:06:09 -0500 Subject: [PATCH] Fix typo in scope check for 'websocket' (#335) * Fix typo in scope check for websocket * Allow for websocket requests * Let Requests be http specific * Use HTTPConnection in place of Requests * Update signature to work with websocket requests * Formatter per black --- starlette/authentication.py | 4 ++-- starlette/middleware/authentication.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/starlette/authentication.py b/starlette/authentication.py index 95a65473..5a768368 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -4,7 +4,7 @@ import inspect import typing from starlette.exceptions import HTTPException -from starlette.requests import Request +from starlette.requests import Request, HTTPConnection from starlette.responses import RedirectResponse, Response @@ -68,7 +68,7 @@ class AuthenticationError(Exception): class AuthenticationBackend: async def authenticate( - self, request: Request + self, conn: HTTPConnection ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: raise NotImplemented() # pragma: no cover diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 62b09ec6..6f8d95cd 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -7,7 +7,7 @@ from starlette.authentication import ( AuthenticationError, UnauthenticatedUser, ) -from starlette.requests import Request +from starlette.requests import HTTPConnection from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send @@ -17,25 +17,27 @@ class AuthenticationMiddleware: self, app: ASGIApp, backend: AuthenticationBackend, - on_error: typing.Callable[[Request, AuthenticationError], Response] = None, + on_error: typing.Callable[ + [HTTPConnection, AuthenticationError], Response + ] = None, ) -> None: self.app = app self.backend = backend self.on_error = ( on_error if on_error is not None else self.default_on_error - ) # type: typing.Callable[[Request, AuthenticationError], Response] + ) # type: typing.Callable[[HTTPConnection, AuthenticationError], Response] def __call__(self, scope: Scope) -> ASGIInstance: - if scope["type"] in ["http", "websockets"]: + if scope["type"] in ["http", "websocket"]: return functools.partial(self.asgi, scope=scope) return self.app(scope) async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None: - request = Request(scope, receive=receive) + conn = HTTPConnection(scope, receive=receive) try: - auth_result = await self.backend.authenticate(request) + auth_result = await self.backend.authenticate(conn) except AuthenticationError as exc: - response = self.on_error(request, exc) + response = self.on_error(conn, exc) await response(receive, send) return @@ -46,5 +48,5 @@ class AuthenticationMiddleware: await inner(receive, send) @staticmethod - def default_on_error(request: Request, exc: Exception) -> Response: + def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: return PlainTextResponse(str(exc), status_code=400)