diff --git a/starlette/applications.py b/starlette/applications.py index 4542f73c..013364be 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -11,6 +11,8 @@ from starlette.responses import Response from starlette.routing import BaseRoute, Router from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +AppType = typing.TypeVar("AppType", bound="Starlette") + class Starlette: """ @@ -43,7 +45,7 @@ class Starlette: """ def __init__( - self, + self: "AppType", debug: bool = False, routes: typing.Optional[typing.Sequence[BaseRoute]] = None, middleware: typing.Optional[typing.Sequence[Middleware]] = None, @@ -58,7 +60,7 @@ class Starlette: ] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[Lifespan] = None, + lifespan: typing.Optional[Lifespan["AppType"]] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index 024c771b..52cf174e 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -580,7 +580,9 @@ class Router: default: typing.Optional[ASGIApp] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[Lifespan] = None, + # the generic to Lifespan[AppType] is the type of the top level application + # which the router cannot know statically, so we use typing.Any + lifespan: typing.Optional[Lifespan[typing.Any]] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes diff --git a/starlette/types.py b/starlette/types.py index 05f5446e..713d18a8 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,7 +1,6 @@ import typing -if typing.TYPE_CHECKING: - from starlette.applications import Starlette +AppType = typing.TypeVar("AppType") Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] @@ -11,8 +10,8 @@ Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] -StatelessLifespan = typing.Callable[["Starlette"], typing.AsyncContextManager[None]] +StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]] StatefulLifespan = typing.Callable[ - ["Starlette"], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] + [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] ] -Lifespan = typing.Union[StatelessLifespan, StatefulLifespan] +Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] diff --git a/tests/test_applications.py b/tests/test_applications.py index ef3f7900..e30ec929 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,6 +1,6 @@ import os from contextlib import asynccontextmanager -from typing import Any, Callable +from typing import Any, AsyncIterator, Callable import anyio import httpx @@ -534,3 +534,17 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl test_client_factory(app).get("/foo") assert SimpleInitializableMiddleware.counter == 2 + + +def test_lifespan_app_subclass(): + # This test exists to make sure that subclasses of Starlette + # (like FastAPI) are compatible with the types hints for Lifespan + + class App(Starlette): + pass + + @asynccontextmanager + async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover + yield + + App(lifespan=lifespan)