mirror of https://github.com/encode/starlette.git
use an async context manager factory for lifespan (#1227)
This commit is contained in:
parent
254d0d97e4
commit
537ab6afd1
1
setup.py
1
setup.py
|
@ -40,6 +40,7 @@ setup(
|
|||
install_requires=[
|
||||
"anyio>=3.0.0,<4",
|
||||
"typing_extensions; python_version < '3.8'",
|
||||
"contextlib2 >= 21.6.0; python_version < '3.7'",
|
||||
],
|
||||
extras_require={
|
||||
"full": [
|
||||
|
|
|
@ -46,7 +46,7 @@ class Starlette:
|
|||
] = None,
|
||||
on_startup: typing.Sequence[typing.Callable] = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable] = None,
|
||||
lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
|
||||
lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
|
||||
) -> None:
|
||||
# The lifespan context function is a newer style that replaces
|
||||
# on_startup / on_shutdown handlers. Use one or the other, not both.
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from enum import Enum
|
||||
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
@ -15,6 +19,11 @@ from starlette.responses import PlainTextResponse, RedirectResponse
|
|||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket, WebSocketClose
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
from contextlib import asynccontextmanager # pragma: no cover
|
||||
else:
|
||||
from contextlib2 import asynccontextmanager # pragma: no cover
|
||||
|
||||
|
||||
class NoMatchFound(Exception):
|
||||
"""
|
||||
|
@ -470,6 +479,51 @@ class Host(BaseRoute):
|
|||
)
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
|
||||
def __init__(self, cm: typing.ContextManager[_T]):
|
||||
self._cm = cm
|
||||
|
||||
async def __aenter__(self) -> _T:
|
||||
return self._cm.__enter__()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Optional[typing.Type[BaseException]],
|
||||
exc_value: typing.Optional[BaseException],
|
||||
traceback: typing.Optional[types.TracebackType],
|
||||
) -> typing.Optional[bool]:
|
||||
return self._cm.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def _wrap_gen_lifespan_context(
|
||||
lifespan_context: typing.Callable[[typing.Any], typing.Generator]
|
||||
) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
|
||||
cmgr = contextlib.contextmanager(lifespan_context)
|
||||
|
||||
@functools.wraps(cmgr)
|
||||
def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
|
||||
return _AsyncLiftContextManager(cmgr(app))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _DefaultLifespan:
|
||||
def __init__(self, router: "Router"):
|
||||
self._router = router
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._router.startup()
|
||||
|
||||
async def __aexit__(self, *exc_info: object) -> None:
|
||||
await self._router.shutdown()
|
||||
|
||||
def __call__(self: _T, app: object) -> _T:
|
||||
return self
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -478,7 +532,7 @@ class Router:
|
|||
default: ASGIApp = None,
|
||||
on_startup: typing.Sequence[typing.Callable] = None,
|
||||
on_shutdown: typing.Sequence[typing.Callable] = None,
|
||||
lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
|
||||
lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
|
||||
) -> None:
|
||||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
|
@ -486,12 +540,31 @@ class Router:
|
|||
self.on_startup = [] if on_startup is None else list(on_startup)
|
||||
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
|
||||
|
||||
async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
|
||||
await self.startup()
|
||||
yield
|
||||
await self.shutdown()
|
||||
if lifespan is None:
|
||||
self.lifespan_context: typing.Callable[
|
||||
[typing.Any], typing.AsyncContextManager
|
||||
] = _DefaultLifespan(self)
|
||||
|
||||
self.lifespan_context = default_lifespan if lifespan is None else lifespan
|
||||
elif inspect.isasyncgenfunction(lifespan):
|
||||
warnings.warn(
|
||||
"async generator function lifespans are deprecated, "
|
||||
"use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = asynccontextmanager(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
)
|
||||
elif inspect.isgeneratorfunction(lifespan):
|
||||
warnings.warn(
|
||||
"generator function lifespans are deprecated, "
|
||||
"use an @contextlib.asynccontextmanager function instead",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.lifespan_context = _wrap_gen_lifespan_context(
|
||||
lifespan, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.lifespan_context = lifespan
|
||||
|
||||
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] == "websocket":
|
||||
|
@ -541,25 +614,19 @@ class Router:
|
|||
Handle ASGI lifespan messages, which allows us to manage application
|
||||
startup and shutdown events.
|
||||
"""
|
||||
first = True
|
||||
started = False
|
||||
app = scope.get("app")
|
||||
await receive()
|
||||
try:
|
||||
if inspect.isasyncgenfunction(self.lifespan_context):
|
||||
async for item in self.lifespan_context(app):
|
||||
assert first, "Lifespan context yielded multiple times."
|
||||
first = False
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
await receive()
|
||||
else:
|
||||
for item in self.lifespan_context(app): # type: ignore
|
||||
assert first, "Lifespan context yielded multiple times."
|
||||
first = False
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
await receive()
|
||||
async with self.lifespan_context(app):
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
started = True
|
||||
await receive()
|
||||
except BaseException:
|
||||
if first:
|
||||
exc_text = traceback.format_exc()
|
||||
exc_text = traceback.format_exc()
|
||||
if started:
|
||||
await send({"type": "lifespan.shutdown.failed", "message": exc_text})
|
||||
else:
|
||||
await send({"type": "lifespan.startup.failed", "message": exc_text})
|
||||
raise
|
||||
else:
|
||||
|
|
|
@ -543,22 +543,34 @@ class TestClient(requests.Session):
|
|||
|
||||
async def wait_startup(self) -> None:
|
||||
await self.stream_receive.send({"type": "lifespan.startup"})
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
|
||||
async def receive() -> typing.Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
return message
|
||||
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.startup.complete",
|
||||
"lifespan.startup.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.startup.failed":
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
await receive()
|
||||
|
||||
async def wait_shutdown(self) -> None:
|
||||
async with self.stream_send:
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
async def receive() -> typing.Any:
|
||||
message = await self.stream_send.receive()
|
||||
if message is None:
|
||||
self.task.result()
|
||||
assert message["type"] == "lifespan.shutdown.complete"
|
||||
return message
|
||||
|
||||
async with self.stream_send:
|
||||
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||||
message = await receive()
|
||||
assert message["type"] in (
|
||||
"lifespan.shutdown.complete",
|
||||
"lifespan.shutdown.failed",
|
||||
)
|
||||
if message["type"] == "lifespan.shutdown.failed":
|
||||
await receive()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -10,6 +11,11 @@ from starlette.responses import JSONResponse, PlainTextResponse
|
|||
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
from contextlib import asynccontextmanager # pragma: no cover
|
||||
else:
|
||||
from contextlib2 import asynccontextmanager # pragma: no cover
|
||||
|
||||
app = Starlette()
|
||||
|
||||
|
||||
|
@ -286,7 +292,39 @@ def test_app_add_event_handler(test_client_factory):
|
|||
assert cleanup_complete
|
||||
|
||||
|
||||
def test_app_async_lifespan(test_client_factory):
|
||||
def test_app_async_cm_lifespan(test_client_factory):
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
nonlocal startup_complete, cleanup_complete
|
||||
startup_complete = True
|
||||
yield
|
||||
cleanup_complete = True
|
||||
|
||||
app = Starlette(lifespan=lifespan)
|
||||
|
||||
assert not startup_complete
|
||||
assert not cleanup_complete
|
||||
with test_client_factory(app):
|
||||
assert startup_complete
|
||||
assert not cleanup_complete
|
||||
assert startup_complete
|
||||
assert cleanup_complete
|
||||
|
||||
|
||||
deprecated_lifespan = pytest.mark.filterwarnings(
|
||||
r"ignore"
|
||||
r":(async )?generator function lifespans are deprecated, use an "
|
||||
r"@contextlib\.asynccontextmanager function instead"
|
||||
r":DeprecationWarning"
|
||||
r":starlette.routing"
|
||||
)
|
||||
|
||||
|
||||
@deprecated_lifespan
|
||||
def test_app_async_gen_lifespan(test_client_factory):
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
|
@ -307,7 +345,8 @@ def test_app_async_lifespan(test_client_factory):
|
|||
assert cleanup_complete
|
||||
|
||||
|
||||
def test_app_sync_lifespan(test_client_factory):
|
||||
@deprecated_lifespan
|
||||
def test_app_sync_gen_lifespan(test_client_factory):
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
|
|
|
@ -12,10 +12,12 @@ from starlette.middleware import Middleware
|
|||
from starlette.responses import JSONResponse
|
||||
from starlette.websockets import WebSocket, WebSocketDisconnect
|
||||
|
||||
if sys.version_info >= (3, 7):
|
||||
from asyncio import current_task as asyncio_current_task # pragma: no cover
|
||||
else:
|
||||
asyncio_current_task = asyncio.Task.current_task # pragma: no cover
|
||||
if sys.version_info >= (3, 7): # pragma: no cover
|
||||
from asyncio import current_task as asyncio_current_task
|
||||
from contextlib import asynccontextmanager
|
||||
else: # pragma: no cover
|
||||
asyncio_current_task = asyncio.Task.current_task
|
||||
from contextlib2 import asynccontextmanager
|
||||
|
||||
mock_service = Starlette()
|
||||
|
||||
|
@ -90,6 +92,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam
|
|||
shutdown_task = object()
|
||||
shutdown_loop = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan_context(app):
|
||||
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
|
||||
|
||||
|
|
Loading…
Reference in New Issue