diff --git a/starlette/authentication.py b/starlette/authentication.py index f7c221e5..9aa4ca5e 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -1,5 +1,6 @@ import asyncio import functools +import inspect import typing from starlette.exceptions import HTTPException @@ -22,35 +23,39 @@ def requires( scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: + sig = inspect.signature(func) + for idx, parameter in enumerate(sig.parameters.values()): + if parameter.name == "request": + break + else: + raise Exception('No "request" argument on function "%s"' % func) + if asyncio.iscoroutinefunction(func): @functools.wraps(func) - async def wrapper(*args: typing.Any) -> Response: - # Support either `func(request)`` or `func(self, request)`` - assert len(args) in (1, 2) - request = args[-1] + async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: + request = kwargs.get("request", args[idx]) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse(url=request.url_for(redirect)) raise HTTPException(status_code=status_code) - return await func(*args) + return await func(*args, **kwargs) return wrapper @functools.wraps(func) - def sync_wrapper(*args: typing.Any) -> Response: - # Support either `func(request)`` or `func(self, request)`` - assert len(args) in (1, 2) - request = args[-1] + def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: + # Support either `func(request)` or `func(self, request)` + request = kwargs.get("request", args[idx]) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: return RedirectResponse(url=request.url_for(redirect)) raise HTTPException(status_code=status_code) - return func(*args) + return func(*args, **kwargs) return sync_wrapper diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 829d9802..038a0978 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,6 +1,8 @@ import base64 import binascii +import pytest + from starlette.applications import Starlette from starlette.authentication import ( AuthCredentials, @@ -102,6 +104,14 @@ def admin(request): ) +def test_invalid_decorator_usage(): + with pytest.raises(Exception): + + @requires("authenticated") + def foo(): + pass # pragma: nocover + + def test_user_interface(): with TestClient(app) as client: response = client.get("/")