Flexible requires decorator (#314)

* Remove unused imports, with 'autoflake'

* Flexible 'requires' decorator

* Linting

* Exclude coverage on uncalled line

* Merge master
This commit is contained in:
Tom Christie 2019-01-10 15:37:50 +00:00 committed by GitHub
parent 6cd0b2b787
commit a59b5295ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import functools import functools
import inspect
import typing import typing
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@ -22,35 +23,39 @@ def requires(
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
def decorator(func: typing.Callable) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable:
sig = inspect.signature(func)
for idx, parameter in enumerate(sig.parameters.values()):
if parameter.name == "request":
break
else:
raise Exception('No "request" argument on function "%s"' % func)
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: typing.Any) -> Response: async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
# Support either `func(request)`` or `func(self, request)`` request = kwargs.get("request", args[idx])
assert len(args) in (1, 2)
request = args[-1]
assert isinstance(request, Request) assert isinstance(request, Request)
if not has_required_scope(request, scopes_list): if not has_required_scope(request, scopes_list):
if redirect is not None: if redirect is not None:
return RedirectResponse(url=request.url_for(redirect)) return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code) raise HTTPException(status_code=status_code)
return await func(*args) return await func(*args, **kwargs)
return wrapper return wrapper
@functools.wraps(func) @functools.wraps(func)
def sync_wrapper(*args: typing.Any) -> Response: def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
# Support either `func(request)`` or `func(self, request)`` # Support either `func(request)` or `func(self, request)`
assert len(args) in (1, 2) request = kwargs.get("request", args[idx])
request = args[-1]
assert isinstance(request, Request) assert isinstance(request, Request)
if not has_required_scope(request, scopes_list): if not has_required_scope(request, scopes_list):
if redirect is not None: if redirect is not None:
return RedirectResponse(url=request.url_for(redirect)) return RedirectResponse(url=request.url_for(redirect))
raise HTTPException(status_code=status_code) raise HTTPException(status_code=status_code)
return func(*args) return func(*args, **kwargs)
return sync_wrapper return sync_wrapper

View File

@ -1,6 +1,8 @@
import base64 import base64
import binascii import binascii
import pytest
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.authentication import ( from starlette.authentication import (
AuthCredentials, AuthCredentials,
@ -102,6 +104,14 @@ def admin(request):
) )
def test_invalid_decorator_usage():
with pytest.raises(Exception):
@requires("authenticated")
def foo():
pass # pragma: nocover
def test_user_interface(): def test_user_interface():
with TestClient(app) as client: with TestClient(app) as client:
response = client.get("/") response = client.get("/")