use an async context manager factory for lifespan (#1227)

This commit is contained in:
Thomas Grainger 2021-07-03 18:43:24 +01:00 committed by GitHub
parent 254d0d97e4
commit 537ab6afd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 37 deletions

View File

@ -40,6 +40,7 @@ setup(
install_requires=[
"anyio>=3.0.0,<4",
"typing_extensions; python_version < '3.8'",
"contextlib2 >= 21.6.0; python_version < '3.7'",
],
extras_require={
"full": [

View File

@ -46,7 +46,7 @@ class Starlette:
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.

View File

@ -1,9 +1,13 @@
import asyncio
import contextlib
import functools
import inspect
import re
import sys
import traceback
import types
import typing
import warnings
from enum import Enum
from starlette.concurrency import run_in_threadpool
@ -15,6 +19,11 @@ from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover
class NoMatchFound(Exception):
"""
@ -470,6 +479,51 @@ class Host(BaseRoute):
)
_T = typing.TypeVar("_T")
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
def __init__(self, cm: typing.ContextManager[_T]):
self._cm = cm
async def __aenter__(self) -> _T:
return self._cm.__enter__()
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> typing.Optional[bool]:
return self._cm.__exit__(exc_type, exc_value, traceback)
def _wrap_gen_lifespan_context(
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
cmgr = contextlib.contextmanager(lifespan_context)
@functools.wraps(cmgr)
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return _AsyncLiftContextManager(cmgr(app))
return wrapper
class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router
async def __aenter__(self) -> None:
await self._router.startup()
async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown()
def __call__(self: _T, app: object) -> _T:
return self
class Router:
def __init__(
self,
@ -478,7 +532,7 @@ class Router:
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
@ -486,12 +540,31 @@ class Router:
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
await self.startup()
yield
await self.shutdown()
if lifespan is None:
self.lifespan_context: typing.Callable[
[typing.Any], typing.AsyncContextManager
] = _DefaultLifespan(self)
self.lifespan_context = default_lifespan if lifespan is None else lifespan
elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = asynccontextmanager(
lifespan, # type: ignore[arg-type]
)
elif inspect.isgeneratorfunction(lifespan):
warnings.warn(
"generator function lifespans are deprecated, "
"use an @contextlib.asynccontextmanager function instead",
DeprecationWarning,
)
self.lifespan_context = _wrap_gen_lifespan_context(
lifespan, # type: ignore[arg-type]
)
else:
self.lifespan_context = lifespan
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
@ -541,25 +614,19 @@ class Router:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
first = True
started = False
app = scope.get("app")
await receive()
try:
if inspect.isasyncgenfunction(self.lifespan_context):
async for item in self.lifespan_context(app):
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
else:
for item in self.lifespan_context(app): # type: ignore
assert first, "Lifespan context yielded multiple times."
first = False
await send({"type": "lifespan.startup.complete"})
await receive()
async with self.lifespan_context(app):
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
except BaseException:
if first:
exc_text = traceback.format_exc()
exc_text = traceback.format_exc()
if started:
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:

View File

@ -543,22 +543,34 @@ class TestClient(requests.Session):
async def wait_startup(self) -> None:
await self.stream_receive.send({"type": "lifespan.startup"})
message = await self.stream_send.receive()
if message is None:
self.task.result()
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
return message
message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
message = await self.stream_send.receive()
if message is None:
self.task.result()
await receive()
async def wait_shutdown(self) -> None:
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
async def receive() -> typing.Any:
message = await self.stream_send.receive()
if message is None:
self.task.result()
assert message["type"] == "lifespan.shutdown.complete"
return message
async with self.stream_send:
await self.stream_receive.send({"type": "lifespan.shutdown"})
message = await receive()
assert message["type"] in (
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
if message["type"] == "lifespan.shutdown.failed":
await receive()

View File

@ -1,4 +1,5 @@
import os
import sys
import pytest
@ -10,6 +11,11 @@ from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager # pragma: no cover
else:
from contextlib2 import asynccontextmanager # pragma: no cover
app = Starlette()
@ -286,7 +292,39 @@ def test_app_add_event_handler(test_client_factory):
assert cleanup_complete
def test_app_async_lifespan(test_client_factory):
def test_app_async_cm_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
@asynccontextmanager
async def lifespan(app):
nonlocal startup_complete, cleanup_complete
startup_complete = True
yield
cleanup_complete = True
app = Starlette(lifespan=lifespan)
assert not startup_complete
assert not cleanup_complete
with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
deprecated_lifespan = pytest.mark.filterwarnings(
r"ignore"
r":(async )?generator function lifespans are deprecated, use an "
r"@contextlib\.asynccontextmanager function instead"
r":DeprecationWarning"
r":starlette.routing"
)
@deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
@ -307,7 +345,8 @@ def test_app_async_lifespan(test_client_factory):
assert cleanup_complete
def test_app_sync_lifespan(test_client_factory):
@deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False

View File

@ -12,10 +12,12 @@ from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect
if sys.version_info >= (3, 7):
from asyncio import current_task as asyncio_current_task # pragma: no cover
else:
asyncio_current_task = asyncio.Task.current_task # pragma: no cover
if sys.version_info >= (3, 7): # pragma: no cover
from asyncio import current_task as asyncio_current_task
from contextlib import asynccontextmanager
else: # pragma: no cover
asyncio_current_task = asyncio.Task.current_task
from contextlib2 import asynccontextmanager
mock_service = Starlette()
@ -90,6 +92,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam
shutdown_task = object()
shutdown_loop = None
@asynccontextmanager
async def lifespan_context(app):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop