diff --git a/scripts/lint b/scripts/lint index eceead71..00a21868 100755 --- a/scripts/lint +++ b/scripts/lint @@ -11,3 +11,4 @@ ${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs ${PREFIX}autoflake --in-place --recursive starlette tests ${PREFIX}black starlette tests ${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests +${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs diff --git a/starlette/applications.py b/starlette/applications.py index 7d558705..a6175d99 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -23,7 +23,6 @@ class Starlette: self.error_middleware = ServerErrorMiddleware( self.exception_middleware, debug=debug ) - self.lifespan_middleware = LifespanMiddleware(self.error_middleware) self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] if template_directory is not None: from starlette.templating import Jinja2Templates @@ -53,7 +52,7 @@ class Starlette: return self.schema_generator.get_schema(self.routes) def on_event(self, event_type: str) -> typing.Callable: - return self.lifespan_middleware.on_event(event_type) + return self.router.lifespan.on_event(event_type) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) @@ -79,7 +78,7 @@ class Starlette: ) def add_event_handler(self, event_type: str, func: typing.Callable) -> None: - self.lifespan_middleware.add_event_handler(event_type, func) + self.router.lifespan.add_event_handler(event_type, func) def add_route( self, @@ -149,4 +148,4 @@ class Starlette: def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self - return self.lifespan_middleware(scope) + return self.error_middleware(scope) diff --git a/starlette/middleware/lifespan.py b/starlette/middleware/lifespan.py index 3c7faae9..631588ad 100644 --- a/starlette/middleware/lifespan.py +++ b/starlette/middleware/lifespan.py @@ -38,7 +38,7 @@ class LifespanMiddleware: return LifespanHandler( self.app, scope, self.startup_handlers, self.shutdown_handlers ) - return self.app(scope) + return self.app(scope) # pragma: no cover class LifespanHandler: diff --git a/starlette/routing.py b/starlette/routing.py index 9fa29794..a05b206c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -286,11 +286,10 @@ class Mount(BaseRoute): assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( app is not None or routes is not None - ), "Either 'app', or 'routes' must be specified" + ), "Either 'app=...', or 'routes=' must be specified" self.path = path.rstrip("/") - if routes is None: - assert app is not None - self.app = app + if app is not None: + self.app = app # type: ASGIApp else: self.app = Router(routes=routes) self.name = name @@ -303,23 +302,24 @@ class Mount(BaseRoute): return getattr(self.app, "routes", None) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: - path = scope["path"] - match = self.path_regex.match(path) - if match: - matched_params = match.groupdict() - for key, value in matched_params.items(): - matched_params[key] = self.param_convertors[key].convert(value) - remaining_path = "/" + matched_params.pop("path") - matched_path = path[: -len(remaining_path)] - path_params = dict(scope.get("path_params", {})) - path_params.update(matched_params) - child_scope = { - "path_params": path_params, - "root_path": scope.get("root_path", "") + matched_path, - "path": remaining_path, - "endpoint": self.app, - } - return Match.FULL, child_scope + if scope["type"] in ("http", "websocket"): + path = scope["path"] + match = self.path_regex.match(path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + remaining_path = "/" + matched_params.pop("path") + matched_path = path[: -len(remaining_path)] + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = { + "path_params": path_params, + "root_path": scope.get("root_path", "") + matched_path, + "path": remaining_path, + "endpoint": self.app, + } + return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: str) -> URLPath: @@ -375,17 +375,18 @@ class Host(BaseRoute): return getattr(self.app, "routes", None) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: - headers = Headers(scope=scope) - host = headers.get("host", "").split(":")[0] - match = self.host_regex.match(host) - if match: - matched_params = match.groupdict() - for key, value in matched_params.items(): - matched_params[key] = self.param_convertors[key].convert(value) - path_params = dict(scope.get("path_params", {})) - path_params.update(matched_params) - child_scope = {"path_params": path_params, "endpoint": self.app} - return Match.FULL, child_scope + if scope["type"] in ("http", "websocket"): + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + match = self.host_regex.match(host) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"path_params": path_params, "endpoint": self.app} + return Match.FULL, child_scope return Match.NONE, {} def url_path_for(self, name: str, **path_params: str) -> URLPath: @@ -426,6 +427,63 @@ class Host(BaseRoute): ) +class Lifespan(BaseRoute): + def __init__( + self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None + ): + self.startup_handlers = [] if on_startup is None else [on_startup] + self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown] + + def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + if scope["type"] == "lifespan": + return Match.FULL, {} + return Match.NONE, {} + + def add_event_handler(self, event_type: str, func: typing.Callable) -> None: + assert event_type in ("startup", "shutdown") + + if event_type == "startup": + self.startup_handlers.append(func) + else: + assert event_type == "shutdown" + self.shutdown_handlers.append(func) + + def on_event(self, event_type: str) -> typing.Callable: + def decorator(func: typing.Callable) -> typing.Callable: + self.add_event_handler(event_type, func) + return func + + return decorator + + async def startup(self) -> None: + for handler in self.startup_handlers: + if asyncio.iscoroutinefunction(handler): + await handler() + else: + handler() + + async def shutdown(self) -> None: + for handler in self.shutdown_handlers: + if asyncio.iscoroutinefunction(handler): + await handler() + else: + handler() + + def __call__(self, scope: Scope) -> ASGIInstance: + return self.asgi + + async def asgi(self, receive: Receive, send: Send) -> None: + message = await receive() + assert message["type"] == "lifespan.startup" + await self.startup() + await send({"type": "lifespan.startup.complete"}) + + message = await receive() + assert message["type"] == "lifespan.shutdown" + await self.shutdown() + await send({"type": "lifespan.shutdown.complete"}) + + class Router: def __init__( self, @@ -436,6 +494,7 @@ class Router: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes self.default = self.not_found if default is None else default + self.lifespan = Lifespan() def mount(self, path: str, app: ASGIApp, name: str = None) -> None: route = Mount(path, app=app, name=name) @@ -516,9 +575,6 @@ class Router: def __call__(self, scope: Scope) -> ASGIInstance: assert scope["type"] in ("http", "websocket", "lifespan") - if scope["type"] == "lifespan": - return LifespanHandler(scope) - if "router" not in scope: scope["router"] = self @@ -537,31 +593,20 @@ class Router: scope.update(partial_scope) return partial(scope) - if self.redirect_slashes and not scope["path"].endswith("/"): - redirect_scope = dict(scope) - redirect_scope["path"] += "/" + if scope["type"] == "http" and self.redirect_slashes: + if not scope["path"].endswith("/"): + redirect_scope = dict(scope) + redirect_scope["path"] += "/" - for route in self.routes: - match, child_scope = route.matches(redirect_scope) - if match != Match.NONE: - redirect_url = URL(scope=redirect_scope) - return RedirectResponse(url=str(redirect_url)) + for route in self.routes: + match, child_scope = route.matches(redirect_scope) + if match != Match.NONE: + redirect_url = URL(scope=redirect_scope) + return RedirectResponse(url=str(redirect_url)) + if scope["type"] == "lifespan": + return self.lifespan(scope) return self.default(scope) def __eq__(self, other: typing.Any) -> bool: return isinstance(other, Router) and self.routes == other.routes - - -class LifespanHandler: - def __init__(self, scope: Scope) -> None: - pass - - async def __call__(self, receive: Receive, send: Send) -> None: - message = await receive() - assert message["type"] == "lifespan.startup" - await send({"type": "lifespan.startup.complete"}) - - message = await receive() - assert message["type"] == "lifespan.shutdown" - await send({"type": "lifespan.shutdown.complete"}) diff --git a/tests/middleware/test_lifespan.py b/tests/middleware/test_lifespan.py index 475c1f5b..0c37af9f 100644 --- a/tests/middleware/test_lifespan.py +++ b/tests/middleware/test_lifespan.py @@ -2,6 +2,8 @@ import pytest from starlette.applications import Starlette from starlette.middleware.lifespan import LifespanMiddleware +from starlette.responses import PlainTextResponse +from starlette.routing import Lifespan, Route, Router from starlette.testclient import TestClient @@ -98,6 +100,38 @@ def test_raise_on_shutdown(): pass +def test_routed_lifespan(): + startup_complete = False + shutdown_complete = False + + def hello_world(request): + return PlainTextResponse("hello, world") + + def run_startup(): + nonlocal startup_complete + startup_complete = True + + def run_shutdown(): + nonlocal shutdown_complete + shutdown_complete = True + + app = Router( + routes=[ + Lifespan(on_startup=run_startup, on_shutdown=run_shutdown), + Route("/", hello_world), + ] + ) + + assert not startup_complete + assert not shutdown_complete + with TestClient(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_app_lifespan(): startup_complete = False cleanup_complete = False @@ -120,3 +154,27 @@ def test_app_lifespan(): assert not cleanup_complete assert startup_complete assert cleanup_complete + + +def test_app_async_lifespan(): + startup_complete = False + cleanup_complete = False + app = Starlette() + + @app.on_event("startup") + async def run_startup(): + nonlocal startup_complete + startup_complete = True + + @app.on_event("shutdown") + async def run_cleanup(): + nonlocal cleanup_complete + cleanup_complete = True + + assert not startup_complete + assert not cleanup_complete + with TestClient(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete