Fix typing of Lifespan to allow subclasses of Starlette (#2077)

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Adrian Garcia Badaracco 2023-03-13 13:04:15 -05:00 committed by GitHub
parent ada845cc8a
commit f640241b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 9 deletions

View File

@ -11,6 +11,8 @@ from starlette.responses import Response
from starlette.routing import BaseRoute, Router from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
AppType = typing.TypeVar("AppType", bound="Starlette")
class Starlette: class Starlette:
""" """
@ -43,7 +45,7 @@ class Starlette:
""" """
def __init__( def __init__(
self, self: "AppType",
debug: bool = False, debug: bool = False,
routes: typing.Optional[typing.Sequence[BaseRoute]] = None, routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None, middleware: typing.Optional[typing.Sequence[Middleware]] = None,
@ -58,7 +60,7 @@ class Starlette:
] = None, ] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None, lifespan: typing.Optional[Lifespan["AppType"]] = None,
) -> None: ) -> None:
# The lifespan context function is a newer style that replaces # The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both. # on_startup / on_shutdown handlers. Use one or the other, not both.

View File

@ -580,7 +580,9 @@ class Router:
default: typing.Optional[ASGIApp] = None, default: typing.Optional[ASGIApp] = None,
on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
lifespan: typing.Optional[Lifespan] = None, # the generic to Lifespan[AppType] is the type of the top level application
# which the router cannot know statically, so we use typing.Any
lifespan: typing.Optional[Lifespan[typing.Any]] = None,
) -> None: ) -> None:
self.routes = [] if routes is None else list(routes) self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes self.redirect_slashes = redirect_slashes

View File

@ -1,7 +1,6 @@
import typing import typing
if typing.TYPE_CHECKING: AppType = typing.TypeVar("AppType")
from starlette.applications import Starlette
Scope = typing.MutableMapping[str, typing.Any] Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any]
@ -11,8 +10,8 @@ Send = typing.Callable[[Message], typing.Awaitable[None]]
ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
StatelessLifespan = typing.Callable[["Starlette"], typing.AsyncContextManager[None]] StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]]
StatefulLifespan = typing.Callable[ StatefulLifespan = typing.Callable[
["Starlette"], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] [AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
] ]
Lifespan = typing.Union[StatelessLifespan, StatefulLifespan] Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]

View File

@ -1,6 +1,6 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Callable from typing import Any, AsyncIterator, Callable
import anyio import anyio
import httpx import httpx
@ -534,3 +534,17 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
test_client_factory(app).get("/foo") test_client_factory(app).get("/foo")
assert SimpleInitializableMiddleware.counter == 2 assert SimpleInitializableMiddleware.counter == 2
def test_lifespan_app_subclass():
# This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan
class App(Starlette):
pass
@asynccontextmanager
async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover
yield
App(lifespan=lifespan)