mirror of https://github.com/encode/starlette.git
Update type annotations for BackgroundTask and utils (#1383)
* Update type annotations for BackgroundTask and utils * Add type annotation to handler * Update setup.py * Fix import issue * Fix missed import * Fix coverage Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
f1c5049643
commit
165592fb89
2
setup.py
2
setup.py
|
@ -39,7 +39,7 @@ setup(
|
|||
include_package_data=True,
|
||||
install_requires=[
|
||||
"anyio>=3.0.0,<4",
|
||||
"typing_extensions; python_version < '3.8'",
|
||||
"typing_extensions; python_version < '3.10'",
|
||||
"contextlib2 >= 21.6.0; python_version < '3.7'",
|
||||
],
|
||||
extras_require={
|
||||
|
|
|
@ -1,12 +1,20 @@
|
|||
import asyncio
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
|
@ -25,7 +33,7 @@ class BackgroundTasks(BackgroundTask):
|
|||
self.tasks = list(tasks) if tasks else []
|
||||
|
||||
def add_task(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
task = BackgroundTask(func, *args, **kwargs)
|
||||
self.tasks.append(task)
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
import functools
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, AsyncGenerator, Iterator
|
||||
|
||||
import anyio
|
||||
|
||||
if sys.version_info >= (3, 10): # pragma: no cover
|
||||
from typing import ParamSpec
|
||||
else: # pragma: no cover
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
try:
|
||||
import contextvars # Python 3.7+ only or via contextvars backport.
|
||||
except ImportError: # pragma: no cover
|
||||
|
@ -11,6 +16,7 @@ except ImportError: # pragma: no cover
|
|||
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
|
||||
|
@ -25,14 +31,14 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
|
|||
|
||||
|
||||
async def run_in_threadpool(
|
||||
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
|
||||
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
|
||||
) -> T:
|
||||
if contextvars is not None: # pragma: no cover
|
||||
# Ensure we run in the same context
|
||||
child = functools.partial(func, *args, **kwargs)
|
||||
context = contextvars.copy_context()
|
||||
func = context.run
|
||||
args = (child,)
|
||||
func = context.run # type: ignore[assignment]
|
||||
args = (child,) # type: ignore[assignment]
|
||||
elif kwargs: # pragma: no cover
|
||||
# run_sync doesn't accept 'kwargs', so bind them in here
|
||||
func = functools.partial(func, **kwargs)
|
||||
|
@ -43,7 +49,7 @@ class _StopIteration(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def _next(iterator: Iterator) -> Any:
|
||||
def _next(iterator: typing.Iterator[T]) -> T:
|
||||
# We can't raise `StopIteration` from within the threadpool iterator
|
||||
# and catch it outside that context, so we coerce them into a different
|
||||
# exception type.
|
||||
|
@ -53,7 +59,9 @@ def _next(iterator: Iterator) -> Any:
|
|||
raise _StopIteration
|
||||
|
||||
|
||||
async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
|
||||
async def iterate_in_threadpool(
|
||||
iterator: typing.Iterator[T],
|
||||
) -> typing.AsyncIterator[T]:
|
||||
while True:
|
||||
try:
|
||||
yield await anyio.to_thread.run_sync(_next, iterator)
|
||||
|
|
|
@ -29,7 +29,9 @@ class HTTPEndpoint:
|
|||
else request.method.lower()
|
||||
)
|
||||
|
||||
handler = getattr(self, handler_name, self.method_not_allowed)
|
||||
handler: typing.Callable[[Request], typing.Any] = getattr(
|
||||
self, handler_name, self.method_not_allowed
|
||||
)
|
||||
is_async = asyncio.iscoroutinefunction(handler)
|
||||
if is_async:
|
||||
response = await handler(request)
|
||||
|
|
Loading…
Reference in New Issue