mirror of https://github.com/encode/starlette.git
Lifespan route instance (#401)
* Add Mount(routes=...) * Lifespan as a standard routing component * Linting
This commit is contained in:
parent
933d786627
commit
06adecd6ed
|
@ -11,3 +11,4 @@ ${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
|
|||
${PREFIX}autoflake --in-place --recursive starlette tests
|
||||
${PREFIX}black starlette tests
|
||||
${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests
|
||||
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
|
||||
|
|
|
@ -23,7 +23,6 @@ class Starlette:
|
|||
self.error_middleware = ServerErrorMiddleware(
|
||||
self.exception_middleware, debug=debug
|
||||
)
|
||||
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
|
||||
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
|
||||
if template_directory is not None:
|
||||
from starlette.templating import Jinja2Templates
|
||||
|
@ -53,7 +52,7 @@ class Starlette:
|
|||
return self.schema_generator.get_schema(self.routes)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
return self.lifespan_middleware.on_event(event_type)
|
||||
return self.router.lifespan.on_event(event_type)
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
|
||||
self.router.mount(path, app=app, name=name)
|
||||
|
@ -79,7 +78,7 @@ class Starlette:
|
|||
)
|
||||
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
|
||||
self.lifespan_middleware.add_event_handler(event_type, func)
|
||||
self.router.lifespan.add_event_handler(event_type, func)
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
|
@ -149,4 +148,4 @@ class Starlette:
|
|||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
scope["app"] = self
|
||||
return self.lifespan_middleware(scope)
|
||||
return self.error_middleware(scope)
|
||||
|
|
|
@ -38,7 +38,7 @@ class LifespanMiddleware:
|
|||
return LifespanHandler(
|
||||
self.app, scope, self.startup_handlers, self.shutdown_handlers
|
||||
)
|
||||
return self.app(scope)
|
||||
return self.app(scope) # pragma: no cover
|
||||
|
||||
|
||||
class LifespanHandler:
|
||||
|
|
|
@ -286,11 +286,10 @@ class Mount(BaseRoute):
|
|||
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
|
||||
assert (
|
||||
app is not None or routes is not None
|
||||
), "Either 'app', or 'routes' must be specified"
|
||||
), "Either 'app=...', or 'routes=' must be specified"
|
||||
self.path = path.rstrip("/")
|
||||
if routes is None:
|
||||
assert app is not None
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.app = app # type: ASGIApp
|
||||
else:
|
||||
self.app = Router(routes=routes)
|
||||
self.name = name
|
||||
|
@ -303,23 +302,24 @@ class Mount(BaseRoute):
|
|||
return getattr(self.app, "routes", None)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
path = scope["path"]
|
||||
match = self.path_regex.match(path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
"root_path": scope.get("root_path", "") + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
path = scope["path"]
|
||||
match = self.path_regex.match(path)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
remaining_path = "/" + matched_params.pop("path")
|
||||
matched_path = path[: -len(remaining_path)]
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {
|
||||
"path_params": path_params,
|
||||
"root_path": scope.get("root_path", "") + matched_path,
|
||||
"path": remaining_path,
|
||||
"endpoint": self.app,
|
||||
}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
|
@ -375,17 +375,18 @@ class Host(BaseRoute):
|
|||
return getattr(self.app, "routes", None)
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"path_params": path_params, "endpoint": self.app}
|
||||
return Match.FULL, child_scope
|
||||
if scope["type"] in ("http", "websocket"):
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get("host", "").split(":")[0]
|
||||
match = self.host_regex.match(host)
|
||||
if match:
|
||||
matched_params = match.groupdict()
|
||||
for key, value in matched_params.items():
|
||||
matched_params[key] = self.param_convertors[key].convert(value)
|
||||
path_params = dict(scope.get("path_params", {}))
|
||||
path_params.update(matched_params)
|
||||
child_scope = {"path_params": path_params, "endpoint": self.app}
|
||||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
|
@ -426,6 +427,63 @@ class Host(BaseRoute):
|
|||
)
|
||||
|
||||
|
||||
class Lifespan(BaseRoute):
|
||||
def __init__(
|
||||
self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None
|
||||
):
|
||||
self.startup_handlers = [] if on_startup is None else [on_startup]
|
||||
self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown]
|
||||
|
||||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
if scope["type"] == "lifespan":
|
||||
return Match.FULL, {}
|
||||
return Match.NONE, {}
|
||||
|
||||
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
|
||||
assert event_type in ("startup", "shutdown")
|
||||
|
||||
if event_type == "startup":
|
||||
self.startup_handlers.append(func)
|
||||
else:
|
||||
assert event_type == "shutdown"
|
||||
self.shutdown_handlers.append(func)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
self.add_event_handler(event_type, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def startup(self) -> None:
|
||||
for handler in self.startup_handlers:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for handler in self.shutdown_handlers:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler()
|
||||
else:
|
||||
handler()
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
return self.asgi
|
||||
|
||||
async def asgi(self, receive: Receive, send: Send) -> None:
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
await self.startup()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.shutdown"
|
||||
await self.shutdown()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -436,6 +494,7 @@ class Router:
|
|||
self.routes = [] if routes is None else list(routes)
|
||||
self.redirect_slashes = redirect_slashes
|
||||
self.default = self.not_found if default is None else default
|
||||
self.lifespan = Lifespan()
|
||||
|
||||
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
|
||||
route = Mount(path, app=app, name=name)
|
||||
|
@ -516,9 +575,6 @@ class Router:
|
|||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
assert scope["type"] in ("http", "websocket", "lifespan")
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
return LifespanHandler(scope)
|
||||
|
||||
if "router" not in scope:
|
||||
scope["router"] = self
|
||||
|
||||
|
@ -537,31 +593,20 @@ class Router:
|
|||
scope.update(partial_scope)
|
||||
return partial(scope)
|
||||
|
||||
if self.redirect_slashes and not scope["path"].endswith("/"):
|
||||
redirect_scope = dict(scope)
|
||||
redirect_scope["path"] += "/"
|
||||
if scope["type"] == "http" and self.redirect_slashes:
|
||||
if not scope["path"].endswith("/"):
|
||||
redirect_scope = dict(scope)
|
||||
redirect_scope["path"] += "/"
|
||||
|
||||
for route in self.routes:
|
||||
match, child_scope = route.matches(redirect_scope)
|
||||
if match != Match.NONE:
|
||||
redirect_url = URL(scope=redirect_scope)
|
||||
return RedirectResponse(url=str(redirect_url))
|
||||
for route in self.routes:
|
||||
match, child_scope = route.matches(redirect_scope)
|
||||
if match != Match.NONE:
|
||||
redirect_url = URL(scope=redirect_scope)
|
||||
return RedirectResponse(url=str(redirect_url))
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
return self.lifespan(scope)
|
||||
return self.default(scope)
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return isinstance(other, Router) and self.routes == other.routes
|
||||
|
||||
|
||||
class LifespanHandler:
|
||||
def __init__(self, scope: Scope) -> None:
|
||||
pass
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.startup"
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
|
||||
message = await receive()
|
||||
assert message["type"] == "lifespan.shutdown"
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
|
|
@ -2,6 +2,8 @@ import pytest
|
|||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware.lifespan import LifespanMiddleware
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.routing import Lifespan, Route, Router
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
|
||||
|
@ -98,6 +100,38 @@ def test_raise_on_shutdown():
|
|||
pass
|
||||
|
||||
|
||||
def test_routed_lifespan():
|
||||
startup_complete = False
|
||||
shutdown_complete = False
|
||||
|
||||
def hello_world(request):
|
||||
return PlainTextResponse("hello, world")
|
||||
|
||||
def run_startup():
|
||||
nonlocal startup_complete
|
||||
startup_complete = True
|
||||
|
||||
def run_shutdown():
|
||||
nonlocal shutdown_complete
|
||||
shutdown_complete = True
|
||||
|
||||
app = Router(
|
||||
routes=[
|
||||
Lifespan(on_startup=run_startup, on_shutdown=run_shutdown),
|
||||
Route("/", hello_world),
|
||||
]
|
||||
)
|
||||
|
||||
assert not startup_complete
|
||||
assert not shutdown_complete
|
||||
with TestClient(app) as client:
|
||||
assert startup_complete
|
||||
assert not shutdown_complete
|
||||
client.get("/")
|
||||
assert startup_complete
|
||||
assert shutdown_complete
|
||||
|
||||
|
||||
def test_app_lifespan():
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
@ -120,3 +154,27 @@ def test_app_lifespan():
|
|||
assert not cleanup_complete
|
||||
assert startup_complete
|
||||
assert cleanup_complete
|
||||
|
||||
|
||||
def test_app_async_lifespan():
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
app = Starlette()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def run_startup():
|
||||
nonlocal startup_complete
|
||||
startup_complete = True
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def run_cleanup():
|
||||
nonlocal cleanup_complete
|
||||
cleanup_complete = True
|
||||
|
||||
assert not startup_complete
|
||||
assert not cleanup_complete
|
||||
with TestClient(app):
|
||||
assert startup_complete
|
||||
assert not cleanup_complete
|
||||
assert startup_complete
|
||||
assert cleanup_complete
|
||||
|
|
Loading…
Reference in New Issue