2022-04-24 06:17:07 +00:00
|
|
|
import pytest
|
|
|
|
|
2018-12-14 14:56:31 +00:00
|
|
|
from starlette.background import BackgroundTask, BackgroundTasks
|
2018-10-29 14:46:42 +00:00
|
|
|
from starlette.responses import Response
|
2024-02-04 16:48:26 +00:00
|
|
|
from starlette.types import Receive, Scope, Send
|
2024-07-27 09:31:16 +00:00
|
|
|
from tests.types import TestClientFactory
|
2018-10-02 10:40:08 +00:00
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
|
|
|
|
def test_async_task(test_client_factory: TestClientFactory) -> None:
|
2018-10-02 11:29:44 +00:00
|
|
|
TASK_COMPLETE = False
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
async def async_task() -> None:
|
2018-10-02 11:29:44 +00:00
|
|
|
nonlocal TASK_COMPLETE
|
|
|
|
TASK_COMPLETE = True
|
2018-10-02 10:40:08 +00:00
|
|
|
|
|
|
|
task = BackgroundTask(async_task)
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
response = Response("task initiated", media_type="text/plain", background=task)
|
|
|
|
await response(scope, receive, send)
|
2018-10-02 10:40:08 +00:00
|
|
|
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app)
|
2018-10-02 10:40:08 +00:00
|
|
|
response = client.get("/")
|
|
|
|
assert response.text == "task initiated"
|
2018-10-02 11:29:44 +00:00
|
|
|
assert TASK_COMPLETE
|
2018-10-02 10:40:08 +00:00
|
|
|
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
def test_sync_task(test_client_factory: TestClientFactory) -> None:
|
2018-10-02 11:29:44 +00:00
|
|
|
TASK_COMPLETE = False
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
def sync_task() -> None:
|
2018-10-02 11:29:44 +00:00
|
|
|
nonlocal TASK_COMPLETE
|
|
|
|
TASK_COMPLETE = True
|
2018-10-02 10:40:08 +00:00
|
|
|
|
|
|
|
task = BackgroundTask(sync_task)
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
response = Response("task initiated", media_type="text/plain", background=task)
|
|
|
|
await response(scope, receive, send)
|
2018-10-02 10:40:08 +00:00
|
|
|
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app)
|
2018-10-02 10:40:08 +00:00
|
|
|
response = client.get("/")
|
|
|
|
assert response.text == "task initiated"
|
2018-10-02 11:29:44 +00:00
|
|
|
assert TASK_COMPLETE
|
2018-12-14 14:56:31 +00:00
|
|
|
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
def test_multiple_tasks(test_client_factory: TestClientFactory) -> None:
|
2018-12-14 14:56:31 +00:00
|
|
|
TASK_COUNTER = 0
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
def increment(amount: int) -> None:
|
2018-12-14 14:56:31 +00:00
|
|
|
nonlocal TASK_COUNTER
|
|
|
|
TASK_COUNTER += amount
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2019-03-19 16:03:19 +00:00
|
|
|
tasks = BackgroundTasks()
|
|
|
|
tasks.add_task(increment, amount=1)
|
|
|
|
tasks.add_task(increment, amount=2)
|
|
|
|
tasks.add_task(increment, amount=3)
|
2024-09-01 13:11:01 +00:00
|
|
|
response = Response("tasks initiated", media_type="text/plain", background=tasks)
|
2019-03-19 16:03:19 +00:00
|
|
|
await response(scope, receive, send)
|
2018-12-14 14:56:31 +00:00
|
|
|
|
2021-06-28 20:36:13 +00:00
|
|
|
client = test_client_factory(app)
|
2018-12-14 14:56:31 +00:00
|
|
|
response = client.get("/")
|
|
|
|
assert response.text == "tasks initiated"
|
|
|
|
assert TASK_COUNTER == 1 + 2 + 3
|
2022-04-24 06:17:07 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_multi_tasks_failure_avoids_next_execution(
|
2024-02-04 16:48:26 +00:00
|
|
|
test_client_factory: TestClientFactory,
|
2022-04-24 06:17:07 +00:00
|
|
|
) -> None:
|
|
|
|
TASK_COUNTER = 0
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
def increment() -> None:
|
2022-04-24 06:17:07 +00:00
|
|
|
nonlocal TASK_COUNTER
|
|
|
|
TASK_COUNTER += 1
|
|
|
|
if TASK_COUNTER == 1:
|
|
|
|
raise Exception("task failed")
|
|
|
|
|
2024-02-04 16:48:26 +00:00
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
2022-04-24 06:17:07 +00:00
|
|
|
tasks = BackgroundTasks()
|
|
|
|
tasks.add_task(increment)
|
|
|
|
tasks.add_task(increment)
|
2024-09-01 13:11:01 +00:00
|
|
|
response = Response("tasks initiated", media_type="text/plain", background=tasks)
|
2022-04-24 06:17:07 +00:00
|
|
|
await response(scope, receive, send)
|
|
|
|
|
|
|
|
client = test_client_factory(app)
|
|
|
|
with pytest.raises(Exception):
|
|
|
|
client.get("/")
|
|
|
|
assert TASK_COUNTER == 1
|