Flexible requires decorator (#314)

* Remove unused imports, with 'autoflake'

* Flexible 'requires' decorator

* Linting

* Exclude coverage on uncalled line

* Merge master
This commit is contained in:
Tom Christie 2019-01-10 15:37:50 +00:00 committed by GitHub
parent 6cd0b2b787
commit a59b5295ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import asyncio
import functools
import inspect
import typing
from starlette.exceptions import HTTPException
@ -22,35 +23,39 @@ def requires(
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(func: typing.Callable) -> typing.Callable:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request":
break
else:
raise Exception('No "request" argument on function "%s"' % func)
if asyncio.iscoroutinefunction(func):
@functools.wraps(func)
async def wrapper(*args: typing.Any) -> Response:
# Support either `func(request)`` or `func(self, request)``
assert len(args) in (1, 2)
request = args[-1]
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
request = kwargs.get("request", args[idx])
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code)
return await func(*args)
return await func(*args, **kwargs)
return wrapper
@functools.wraps(func)
def sync_wrapper(*args: typing.Any) -> Response:
# Support either `func(request)`` or `func(self, request)``
assert len(args) in (1, 2)
request = args[-1]
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
# Support either `func(request)` or `func(self, request)`
request = kwargs.get("request", args[idx])
assert isinstance(request, Request)
if not has_required_scope(request, scopes_list):
if redirect is not None:
return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code)
return func(*args)
return func(*args, **kwargs)
return sync_wrapper

View File

@ -1,6 +1,8 @@
import base64
import binascii
import pytest
from starlette.applications import Starlette
from starlette.authentication import (
AuthCredentials,
@ -102,6 +104,14 @@ def admin(request):
)
def test_invalid_decorator_usage():
with pytest.raises(Exception):
@requires("authenticated")
def foo():
pass # pragma: nocover
def test_user_interface():
with TestClient(app) as client:
response = client.get("/")