mirror of https://github.com/encode/starlette.git
65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
from starlette.background import BackgroundTask, BackgroundTasks
|
|
from starlette.responses import Response
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
def test_async_task():
|
|
TASK_COMPLETE = False
|
|
|
|
async def async_task():
|
|
nonlocal TASK_COMPLETE
|
|
TASK_COMPLETE = True
|
|
|
|
task = BackgroundTask(async_task)
|
|
|
|
async def app(scope, receive, send):
|
|
response = Response("task initiated", media_type="text/plain", background=task)
|
|
await response(scope, receive, send)
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/")
|
|
assert response.text == "task initiated"
|
|
assert TASK_COMPLETE
|
|
|
|
|
|
def test_sync_task():
|
|
TASK_COMPLETE = False
|
|
|
|
def sync_task():
|
|
nonlocal TASK_COMPLETE
|
|
TASK_COMPLETE = True
|
|
|
|
task = BackgroundTask(sync_task)
|
|
|
|
async def app(scope, receive, send):
|
|
response = Response("task initiated", media_type="text/plain", background=task)
|
|
await response(scope, receive, send)
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/")
|
|
assert response.text == "task initiated"
|
|
assert TASK_COMPLETE
|
|
|
|
|
|
def test_multiple_tasks():
|
|
TASK_COUNTER = 0
|
|
|
|
def increment(amount):
|
|
nonlocal TASK_COUNTER
|
|
TASK_COUNTER += amount
|
|
|
|
async def app(scope, receive, send):
|
|
tasks = BackgroundTasks()
|
|
tasks.add_task(increment, amount=1)
|
|
tasks.add_task(increment, amount=2)
|
|
tasks.add_task(increment, amount=3)
|
|
response = Response(
|
|
"tasks initiated", media_type="text/plain", background=tasks
|
|
)
|
|
await response(scope, receive, send)
|
|
|
|
client = TestClient(app)
|
|
response = client.get("/")
|
|
assert response.text == "tasks initiated"
|
|
assert TASK_COUNTER == 1 + 2 + 3
|