mirror of https://github.com/encode/starlette.git
Fix typo in scope check for 'websocket' (#335)
* Fix typo in scope check for websocket * Allow for websocket requests * Let Requests be http specific * Use HTTPConnection in place of Requests * Update signature to work with websocket requests * Formatter per black
This commit is contained in:
parent
25cbba15c0
commit
88bd2a2ce9
|
@ -4,7 +4,7 @@ import inspect
|
|||
import typing
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.requests import Request, HTTPConnection
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
|
||||
|
@ -68,7 +68,7 @@ class AuthenticationError(Exception):
|
|||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(
|
||||
self, request: Request
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
raise NotImplemented() # pragma: no cover
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from starlette.authentication import (
|
|||
AuthenticationError,
|
||||
UnauthenticatedUser,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send
|
||||
|
||||
|
@ -17,25 +17,27 @@ class AuthenticationMiddleware:
|
|||
self,
|
||||
app: ASGIApp,
|
||||
backend: AuthenticationBackend,
|
||||
on_error: typing.Callable[[Request, AuthenticationError], Response] = None,
|
||||
on_error: typing.Callable[
|
||||
[HTTPConnection, AuthenticationError], Response
|
||||
] = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.backend = backend
|
||||
self.on_error = (
|
||||
on_error if on_error is not None else self.default_on_error
|
||||
) # type: typing.Callable[[Request, AuthenticationError], Response]
|
||||
) # type: typing.Callable[[HTTPConnection, AuthenticationError], Response]
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] in ["http", "websockets"]:
|
||||
if scope["type"] in ["http", "websocket"]:
|
||||
return functools.partial(self.asgi, scope=scope)
|
||||
return self.app(scope)
|
||||
|
||||
async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None:
|
||||
request = Request(scope, receive=receive)
|
||||
conn = HTTPConnection(scope, receive=receive)
|
||||
try:
|
||||
auth_result = await self.backend.authenticate(request)
|
||||
auth_result = await self.backend.authenticate(conn)
|
||||
except AuthenticationError as exc:
|
||||
response = self.on_error(request, exc)
|
||||
response = self.on_error(conn, exc)
|
||||
await response(receive, send)
|
||||
return
|
||||
|
||||
|
@ -46,5 +48,5 @@ class AuthenticationMiddleware:
|
|||
await inner(receive, send)
|
||||
|
||||
@staticmethod
|
||||
def default_on_error(request: Request, exc: Exception) -> Response:
|
||||
def default_on_error(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
return PlainTextResponse(str(exc), status_code=400)
|
||||
|
|
Loading…
Reference in New Issue