From 537ab6afd110b79dca95f3c8ecc6980710b1de1c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 18:43:24 +0100 Subject: [PATCH] use an async context manager factory for lifespan (#1227) --- setup.py | 1 + starlette/applications.py | 2 +- starlette/routing.py | 109 ++++++++++++++++++++++++++++++------- starlette/testclient.py | 30 +++++++--- tests/test_applications.py | 43 ++++++++++++++- tests/test_testclient.py | 11 ++-- 6 files changed, 159 insertions(+), 37 deletions(-) diff --git a/setup.py b/setup.py index ac647974..31789fe0 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ setup( install_requires=[ "anyio>=3.0.0,<4", "typing_extensions; python_version < '3.8'", + "contextlib2 >= 21.6.0; python_version < '3.7'", ], extras_require={ "full": [ diff --git a/starlette/applications.py b/starlette/applications.py index 34c3e38b..ea52ee70 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -46,7 +46,7 @@ class Starlette: ] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None, + lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index cef1ef48..9a1a5e12 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,9 +1,13 @@ import asyncio +import contextlib import functools import inspect import re +import sys import traceback +import types import typing +import warnings from enum import Enum from starlette.concurrency import run_in_threadpool @@ -15,6 +19,11 @@ from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover + class NoMatchFound(Exception): """ @@ -470,6 +479,51 @@ class Host(BaseRoute): ) +_T = typing.TypeVar("_T") + + +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> typing.Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + +class _DefaultLifespan: + def __init__(self, router: "Router"): + self._router = router + + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + class Router: def __init__( self, @@ -478,7 +532,7 @@ class Router: default: ASGIApp = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None, + lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -486,12 +540,31 @@ class Router: self.on_startup = [] if on_startup is None else list(on_startup) self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) - async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator: - await self.startup() - yield - await self.shutdown() + if lifespan is None: + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = _DefaultLifespan(self) - self.lifespan_context = default_lifespan if lifespan is None else lifespan + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, # type: ignore[arg-type] + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, # type: ignore[arg-type] + ) + else: + self.lifespan_context = lifespan async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": @@ -541,25 +614,19 @@ class Router: Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ - first = True + started = False app = scope.get("app") await receive() try: - if inspect.isasyncgenfunction(self.lifespan_context): - async for item in self.lifespan_context(app): - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - else: - for item in self.lifespan_context(app): # type: ignore - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() + async with self.lifespan_context(app): + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() except BaseException: - if first: - exc_text = traceback.format_exc() + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: diff --git a/starlette/testclient.py b/starlette/testclient.py index 7aa59fb9..08d03fa5 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -543,22 +543,34 @@ class TestClient(requests.Session): async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) - message = await self.stream_send.receive() - if message is None: - self.task.result() + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - message = await self.stream_send.receive() - if message is None: - self.task.result() + await receive() async def wait_shutdown(self) -> None: - async with self.stream_send: - await self.stream_receive.send({"type": "lifespan.shutdown"}) + async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() - assert message["type"] == "lifespan.shutdown.complete" + return message + + async with self.stream_send: + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/tests/test_applications.py b/tests/test_applications.py index 6cb49069..f5f4c7fb 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,4 +1,5 @@ import os +import sys import pytest @@ -10,6 +11,11 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover + app = Starlette() @@ -286,7 +292,39 @@ def test_app_add_event_handler(test_client_factory): assert cleanup_complete -def test_app_async_lifespan(test_client_factory): +def test_app_async_cm_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + @asynccontextmanager + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +deprecated_lifespan = pytest.mark.filterwarnings( + r"ignore" + r":(async )?generator function lifespans are deprecated, use an " + r"@contextlib\.asynccontextmanager function instead" + r":DeprecationWarning" + r":starlette.routing" +) + + +@deprecated_lifespan +def test_app_async_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -307,7 +345,8 @@ def test_app_async_lifespan(test_client_factory): assert cleanup_complete -def test_app_sync_lifespan(test_client_factory): +@deprecated_lifespan +def test_app_sync_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 57ea1c3d..8c066678 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -12,10 +12,12 @@ from starlette.middleware import Middleware from starlette.responses import JSONResponse from starlette.websockets import WebSocket, WebSocketDisconnect -if sys.version_info >= (3, 7): - from asyncio import current_task as asyncio_current_task # pragma: no cover -else: - asyncio_current_task = asyncio.Task.current_task # pragma: no cover +if sys.version_info >= (3, 7): # pragma: no cover + from asyncio import current_task as asyncio_current_task + from contextlib import asynccontextmanager +else: # pragma: no cover + asyncio_current_task = asyncio.Task.current_task + from contextlib2 import asynccontextmanager mock_service = Starlette() @@ -90,6 +92,7 @@ def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_nam shutdown_task = object() shutdown_loop = None + @asynccontextmanager async def lifespan_context(app): nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop