mirror of https://github.com/encode/starlette.git
Flexible requires decorator (#314)
* Remove unused imports, with 'autoflake' * Flexible 'requires' decorator * Linting * Exclude coverage on uncalled line * Merge master
This commit is contained in:
parent
6cd0b2b787
commit
a59b5295ef
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
@ -22,35 +23,39 @@ def requires(
|
||||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||||
|
|
||||||
def decorator(func: typing.Callable) -> typing.Callable:
|
def decorator(func: typing.Callable) -> typing.Callable:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
for idx, parameter in enumerate(sig.parameters.values()):
|
||||||
|
if parameter.name == "request":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception('No "request" argument on function "%s"' % func)
|
||||||
|
|
||||||
if asyncio.iscoroutinefunction(func):
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: typing.Any) -> Response:
|
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||||
# Support either `func(request)`` or `func(self, request)``
|
request = kwargs.get("request", args[idx])
|
||||||
assert len(args) in (1, 2)
|
|
||||||
request = args[-1]
|
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
|
|
||||||
if not has_required_scope(request, scopes_list):
|
if not has_required_scope(request, scopes_list):
|
||||||
if redirect is not None:
|
if redirect is not None:
|
||||||
return RedirectResponse(url=request.url_for(redirect))
|
return RedirectResponse(url=request.url_for(redirect))
|
||||||
raise HTTPException(status_code=status_code)
|
raise HTTPException(status_code=status_code)
|
||||||
return await func(*args)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def sync_wrapper(*args: typing.Any) -> Response:
|
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||||
# Support either `func(request)`` or `func(self, request)``
|
# Support either `func(request)` or `func(self, request)`
|
||||||
assert len(args) in (1, 2)
|
request = kwargs.get("request", args[idx])
|
||||||
request = args[-1]
|
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
|
|
||||||
if not has_required_scope(request, scopes_list):
|
if not has_required_scope(request, scopes_list):
|
||||||
if redirect is not None:
|
if redirect is not None:
|
||||||
return RedirectResponse(url=request.url_for(redirect))
|
return RedirectResponse(url=request.url_for(redirect))
|
||||||
raise HTTPException(status_code=status_code)
|
raise HTTPException(status_code=status_code)
|
||||||
return func(*args)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.authentication import (
|
from starlette.authentication import (
|
||||||
AuthCredentials,
|
AuthCredentials,
|
||||||
|
@ -102,6 +104,14 @@ def admin(request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_decorator_usage():
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
|
||||||
|
@requires("authenticated")
|
||||||
|
def foo():
|
||||||
|
pass # pragma: nocover
|
||||||
|
|
||||||
|
|
||||||
def test_user_interface():
|
def test_user_interface():
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
|
|
Loading…
Reference in New Issue