diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 61fe5ff7..aba3ceb1 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,4 +1,5 @@ from contextvars import ContextVar +from typing import Callable, Iterator import anyio import pytest @@ -8,17 +9,20 @@ from starlette.concurrency import iterate_in_threadpool, run_until_first_complet from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route +from starlette.testclient import TestClient + +TestClientFactory = Callable[..., TestClient] @pytest.mark.anyio -async def test_run_until_first_complete(): +async def test_run_until_first_complete() -> None: task1_finished = anyio.Event() task2_finished = anyio.Event() - async def task1(): + async def task1() -> None: task1_finished.set() - async def task2(): + async def task2() -> None: await task1_finished.wait() await anyio.sleep(0) # pragma: nocover task2_finished.set() # pragma: nocover @@ -28,7 +32,9 @@ async def test_run_until_first_complete(): assert not task2_finished.is_set() -def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None: +def test_accessing_context_from_threaded_sync_endpoint( + test_client_factory: TestClientFactory, +) -> None: ctxvar: ContextVar[bytes] = ContextVar("ctxvar") ctxvar.set(b"data") @@ -45,7 +51,7 @@ def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> N @pytest.mark.anyio async def test_iterate_in_threadpool() -> None: class CustomIterable: - def __iter__(self): + def __iter__(self) -> Iterator[int]: yield from range(3) assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]