mirror of https://github.com/encode/starlette.git
Flexible requires decorator (#314)
* Remove unused imports, with 'autoflake' * Flexible 'requires' decorator * Linting * Exclude coverage on uncalled line * Merge master
This commit is contained in:
parent
6cd0b2b787
commit
a59b5295ef
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
|
@ -22,35 +23,39 @@ def requires(
|
|||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request":
|
||||
break
|
||||
else:
|
||||
raise Exception('No "request" argument on function "%s"' % func)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: typing.Any) -> Response:
|
||||
# Support either `func(request)`` or `func(self, request)``
|
||||
assert len(args) in (1, 2)
|
||||
request = args[-1]
|
||||
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(url=request.url_for(redirect))
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any) -> Response:
|
||||
# Support either `func(request)`` or `func(self, request)``
|
||||
assert len(args) in (1, 2)
|
||||
request = args[-1]
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
# Support either `func(request)` or `func(self, request)`
|
||||
request = kwargs.get("request", args[idx])
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
return RedirectResponse(url=request.url_for(redirect))
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import base64
|
||||
import binascii
|
||||
|
||||
import pytest
|
||||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.authentication import (
|
||||
AuthCredentials,
|
||||
|
@ -102,6 +104,14 @@ def admin(request):
|
|||
)
|
||||
|
||||
|
||||
def test_invalid_decorator_usage():
|
||||
with pytest.raises(Exception):
|
||||
|
||||
@requires("authenticated")
|
||||
def foo():
|
||||
pass # pragma: nocover
|
||||
|
||||
|
||||
def test_user_interface():
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/")
|
||||
|
|
Loading…
Reference in New Issue