Update type annotations for BackgroundTask and utils (#1383)

* Update type annotations for BackgroundTask and utils

* Add type annotation to handler

* Update setup.py

* Fix import issue

* Fix missed import

* Fix coverage

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Yurii Karabas 2022-01-08 13:12:49 +02:00 committed by GitHub
parent f1c5049643
commit 165592fb89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 10 deletions

View File

@ -39,7 +39,7 @@ setup(
include_package_data=True,
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"typing_extensions; python_version < '3.10'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
],
extras_require={

View File

@ -1,12 +1,20 @@
import asyncio
import sys
import typing
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
from starlette.concurrency import run_in_threadpool
P = ParamSpec("P")
class BackgroundTask:
def __init__(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
self.func = func
self.args = args
@ -25,7 +33,7 @@ class BackgroundTasks(BackgroundTask):
self.tasks = list(tasks) if tasks else []
def add_task(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
task = BackgroundTask(func, *args, **kwargs)
self.tasks.append(task)

View File

@ -1,9 +1,14 @@
import functools
import sys
import typing
from typing import Any, AsyncGenerator, Iterator
import anyio
if sys.version_info >= (3, 10): # pragma: no cover
from typing import ParamSpec
else: # pragma: no cover
from typing_extensions import ParamSpec
try:
import contextvars # Python 3.7+ only or via contextvars backport.
except ImportError: # pragma: no cover
@ -11,6 +16,7 @@ except ImportError: # pragma: no cover
T = typing.TypeVar("T")
P = ParamSpec("P")
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
@ -25,14 +31,14 @@ async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -
async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
if contextvars is not None: # pragma: no cover
# Ensure we run in the same context
child = functools.partial(func, *args, **kwargs)
context = contextvars.copy_context()
func = context.run
args = (child,)
func = context.run # type: ignore[assignment]
args = (child,) # type: ignore[assignment]
elif kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
@ -43,7 +49,7 @@ class _StopIteration(Exception):
pass
def _next(iterator: Iterator) -> Any:
def _next(iterator: typing.Iterator[T]) -> T:
# We can't raise `StopIteration` from within the threadpool iterator
# and catch it outside that context, so we coerce them into a different
# exception type.
@ -53,7 +59,9 @@ def _next(iterator: Iterator) -> Any:
raise _StopIteration
async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
async def iterate_in_threadpool(
iterator: typing.Iterator[T],
) -> typing.AsyncIterator[T]:
while True:
try:
yield await anyio.to_thread.run_sync(_next, iterator)

View File

@ -29,7 +29,9 @@ class HTTPEndpoint:
else request.method.lower()
)
handler = getattr(self, handler_name, self.method_not_allowed)
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
is_async = asyncio.iscoroutinefunction(handler)
if is_async:
response = await handler(request)