Fix typo in scope check for 'websocket' (#335)

* Fix typo in scope check for websocket

* Allow for websocket requests

* Let Requests be http specific

* Use HTTPConnection in place of Requests

* Update signature to work with websocket requests

* Formatter per black
This commit is contained in:
tchan09 2019-01-25 08:06:09 -05:00 committed by Tom Christie
parent 25cbba15c0
commit 88bd2a2ce9
2 changed files with 12 additions and 10 deletions

View File

@ -4,7 +4,7 @@ import inspect
import typing
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import Request, HTTPConnection
from starlette.responses import RedirectResponse, Response
@ -68,7 +68,7 @@ class AuthenticationError(Exception):
class AuthenticationBackend:
async def authenticate(
self, request: Request
self, conn: HTTPConnection
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
raise NotImplemented() # pragma: no cover

View File

@ -7,7 +7,7 @@ from starlette.authentication import (
AuthenticationError,
UnauthenticatedUser,
)
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.responses import PlainTextResponse, Response
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
@ -17,25 +17,27 @@ class AuthenticationMiddleware:
self,
app: ASGIApp,
backend: AuthenticationBackend,
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
on_error: typing.Callable[
[HTTPConnection, AuthenticationError], Response
] = None,
) -> None:
self.app = app
self.backend = backend
self.on_error = (
on_error if on_error is not None else self.default_on_error
) # type: typing.Callable[[Request, AuthenticationError], Response]
) # type: typing.Callable[[HTTPConnection, AuthenticationError], Response]
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] in ["http", "websockets"]:
if scope["type"] in ["http", "websocket"]:
return functools.partial(self.asgi, scope=scope)
return self.app(scope)
async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
request = Request(scope, receive=receive)
conn = HTTPConnection(scope, receive=receive)
try:
auth_result = await self.backend.authenticate(request)
auth_result = await self.backend.authenticate(conn)
except AuthenticationError as exc:
response = self.on_error(request, exc)
response = self.on_error(conn, exc)
await response(receive, send)
return
@ -46,5 +48,5 @@ class AuthenticationMiddleware:
await inner(receive, send)
@staticmethod
def default_on_error(request: Request, exc: Exception) -> Response:
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
return PlainTextResponse(str(exc), status_code=400)