Lifespan route instance (#401)

* Add Mount(routes=...)
* Lifespan as a standard routing component
* Linting
This commit is contained in:
Tom Christie 2019-02-19 10:55:45 +00:00 committed by GitHub
parent 933d786627
commit 06adecd6ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 62 deletions

View File

@ -11,3 +11,4 @@ ${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
${PREFIX}autoflake --in-place --recursive starlette tests
${PREFIX}black starlette tests
${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs

View File

@ -23,7 +23,6 @@ class Starlette:
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
if template_directory is not None:
from starlette.templating import Jinja2Templates
@ -53,7 +52,7 @@ class Starlette:
return self.schema_generator.get_schema(self.routes)
def on_event(self, event_type: str) -> typing.Callable:
return self.lifespan_middleware.on_event(event_type)
return self.router.lifespan.on_event(event_type)
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.router.mount(path, app=app, name=name)
@ -79,7 +78,7 @@ class Starlette:
)
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
self.lifespan_middleware.add_event_handler(event_type, func)
self.router.lifespan.add_event_handler(event_type, func)
def add_route(
self,
@ -149,4 +148,4 @@ class Starlette:
def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
return self.lifespan_middleware(scope)
return self.error_middleware(scope)

View File

@ -38,7 +38,7 @@ class LifespanMiddleware:
return LifespanHandler(
self.app, scope, self.startup_handlers, self.shutdown_handlers
)
return self.app(scope)
return self.app(scope) # pragma: no cover
class LifespanHandler:

View File

@ -286,11 +286,10 @@ class Mount(BaseRoute):
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app', or 'routes' must be specified"
), "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if routes is None:
assert app is not None
self.app = app
if app is not None:
self.app = app # type: ASGIApp
else:
self.app = Router(routes=routes)
self.name = name
@ -303,23 +302,24 @@ class Mount(BaseRoute):
return getattr(self.app, "routes", None)
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, **path_params: str) -> URLPath:
@ -375,17 +375,18 @@ class Host(BaseRoute):
return getattr(self.app, "routes", None)
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, **path_params: str) -> URLPath:
@ -426,6 +427,63 @@ class Host(BaseRoute):
)
class Lifespan(BaseRoute):
def __init__(
self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None
):
self.startup_handlers = [] if on_startup is None else [on_startup]
self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown]
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] == "lifespan":
return Match.FULL, {}
return Match.NONE, {}
def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
assert event_type in ("startup", "shutdown")
if event_type == "startup":
self.startup_handlers.append(func)
else:
assert event_type == "shutdown"
self.shutdown_handlers.append(func)
def on_event(self, event_type: str) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.add_event_handler(event_type, func)
return func
return decorator
async def startup(self) -> None:
for handler in self.startup_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()
async def shutdown(self) -> None:
for handler in self.shutdown_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()
def __call__(self, scope: Scope) -> ASGIInstance:
return self.asgi
async def asgi(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await self.startup()
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
await self.shutdown()
await send({"type": "lifespan.shutdown.complete"})
class Router:
def __init__(
self,
@ -436,6 +494,7 @@ class Router:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.lifespan = Lifespan()
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
route = Mount(path, app=app, name=name)
@ -516,9 +575,6 @@ class Router:
def __call__(self, scope: Scope) -> ASGIInstance:
assert scope["type"] in ("http", "websocket", "lifespan")
if scope["type"] == "lifespan":
return LifespanHandler(scope)
if "router" not in scope:
scope["router"] = self
@ -537,31 +593,20 @@ class Router:
scope.update(partial_scope)
return partial(scope)
if self.redirect_slashes and not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"
if scope["type"] == "http" and self.redirect_slashes:
if not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))
if scope["type"] == "lifespan":
return self.lifespan(scope)
return self.default(scope)
def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes
class LifespanHandler:
def __init__(self, scope: Scope) -> None:
pass
async def __call__(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})

View File

@ -2,6 +2,8 @@ import pytest
from starlette.applications import Starlette
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Lifespan, Route, Router
from starlette.testclient import TestClient
@ -98,6 +100,38 @@ def test_raise_on_shutdown():
pass
def test_routed_lifespan():
startup_complete = False
shutdown_complete = False
def hello_world(request):
return PlainTextResponse("hello, world")
def run_startup():
nonlocal startup_complete
startup_complete = True
def run_shutdown():
nonlocal shutdown_complete
shutdown_complete = True
app = Router(
routes=[
Lifespan(on_startup=run_startup, on_shutdown=run_shutdown),
Route("/", hello_world),
]
)
assert not startup_complete
assert not shutdown_complete
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete
def test_app_lifespan():
startup_complete = False
cleanup_complete = False
@ -120,3 +154,27 @@ def test_app_lifespan():
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
def test_app_async_lifespan():
startup_complete = False
cleanup_complete = False
app = Starlette()
@app.on_event("startup")
async def run_startup():
nonlocal startup_complete
startup_complete = True
@app.on_event("shutdown")
async def run_cleanup():
nonlocal cleanup_complete
cleanup_complete = True
assert not startup_complete
assert not cleanup_complete
with TestClient(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete