mirror of https://github.com/encode/starlette.git
Add type hints to `test_concurency.py` (#2474)
Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
This commit is contained in:
parent
88331bd5f8
commit
801e73e4d1
|
@ -1,4 +1,5 @@
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -8,17 +9,20 @@ from starlette.concurrency import iterate_in_threadpool, run_until_first_complet
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
TestClientFactory = Callable[..., TestClient]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_run_until_first_complete():
|
async def test_run_until_first_complete() -> None:
|
||||||
task1_finished = anyio.Event()
|
task1_finished = anyio.Event()
|
||||||
task2_finished = anyio.Event()
|
task2_finished = anyio.Event()
|
||||||
|
|
||||||
async def task1():
|
async def task1() -> None:
|
||||||
task1_finished.set()
|
task1_finished.set()
|
||||||
|
|
||||||
async def task2():
|
async def task2() -> None:
|
||||||
await task1_finished.wait()
|
await task1_finished.wait()
|
||||||
await anyio.sleep(0) # pragma: nocover
|
await anyio.sleep(0) # pragma: nocover
|
||||||
task2_finished.set() # pragma: nocover
|
task2_finished.set() # pragma: nocover
|
||||||
|
@ -28,7 +32,9 @@ async def test_run_until_first_complete():
|
||||||
assert not task2_finished.is_set()
|
assert not task2_finished.is_set()
|
||||||
|
|
||||||
|
|
||||||
def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None:
|
def test_accessing_context_from_threaded_sync_endpoint(
|
||||||
|
test_client_factory: TestClientFactory,
|
||||||
|
) -> None:
|
||||||
ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
|
ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
|
||||||
ctxvar.set(b"data")
|
ctxvar.set(b"data")
|
||||||
|
|
||||||
|
@ -45,7 +51,7 @@ def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> N
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_iterate_in_threadpool() -> None:
|
async def test_iterate_in_threadpool() -> None:
|
||||||
class CustomIterable:
|
class CustomIterable:
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[int]:
|
||||||
yield from range(3)
|
yield from range(3)
|
||||||
|
|
||||||
assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]
|
assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]
|
||||||
|
|
Loading…
Reference in New Issue