mirror of https://github.com/encode/starlette.git
Lazily build middleware stack (#2017)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
ca1711fab7
commit
51c1de1839
|
@ -65,7 +65,7 @@ class Starlette:
|
|||
on_startup is None and on_shutdown is None
|
||||
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
|
||||
|
||||
self._debug = debug
|
||||
self.debug = debug
|
||||
self.state = State()
|
||||
self.router = Router(
|
||||
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
|
||||
|
@ -74,7 +74,7 @@ class Starlette:
|
|||
{} if exception_handlers is None else dict(exception_handlers)
|
||||
)
|
||||
self.user_middleware = [] if middleware is None else list(middleware)
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
self.middleware_stack: typing.Optional[ASGIApp] = None
|
||||
|
||||
def build_middleware_stack(self) -> ASGIApp:
|
||||
debug = self.debug
|
||||
|
@ -108,20 +108,13 @@ class Starlette:
|
|||
def routes(self) -> typing.List[BaseRoute]:
|
||||
return self.router.routes
|
||||
|
||||
@property
|
||||
def debug(self) -> bool:
|
||||
return self._debug
|
||||
|
||||
@debug.setter
|
||||
def debug(self, value: bool) -> None:
|
||||
self._debug = value
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
scope["app"] = self
|
||||
if self.middleware_stack is None:
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
await self.middleware_stack(scope, receive, send)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
|
||||
|
@ -137,11 +130,10 @@ class Starlette:
|
|||
) -> None: # pragma: no cover
|
||||
self.router.host(host, app=app, name=name)
|
||||
|
||||
def add_middleware(
|
||||
self, middleware_class: type, **options: typing.Any
|
||||
) -> None: # pragma: no cover
|
||||
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
|
||||
if self.middleware_stack is not None: # pragma: no cover
|
||||
raise RuntimeError("Cannot add middleware after an application has started")
|
||||
self.user_middleware.insert(0, Middleware(middleware_class, **options))
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def add_exception_handler(
|
||||
self,
|
||||
|
@ -149,7 +141,6 @@ class Starlette:
|
|||
handler: typing.Callable,
|
||||
) -> None: # pragma: no cover
|
||||
self.exception_handlers[exc_class_or_status_code] = handler
|
||||
self.middleware_stack = self.build_middleware_stack()
|
||||
|
||||
def add_event_handler(
|
||||
self, event_type: str, func: typing.Callable
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Callable
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from starlette import status
|
||||
|
@ -13,6 +15,7 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|||
from starlette.responses import JSONResponse, PlainTextResponse
|
||||
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.types import ASGIApp
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
|
@ -486,3 +489,45 @@ def test_decorator_deprecations() -> None:
|
|||
|
||||
app.on_event("startup")(startup)
|
||||
assert len(record) == 1
|
||||
|
||||
|
||||
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
|
||||
class NoOpMiddleware:
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, *args: Any):
|
||||
await self.app(*args)
|
||||
|
||||
class SimpleInitializableMiddleware:
|
||||
counter = 0
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
SimpleInitializableMiddleware.counter += 1
|
||||
|
||||
async def __call__(self, *args: Any):
|
||||
await self.app(*args)
|
||||
|
||||
def get_app() -> ASGIApp:
|
||||
app = Starlette()
|
||||
app.add_middleware(SimpleInitializableMiddleware)
|
||||
app.add_middleware(NoOpMiddleware)
|
||||
return app
|
||||
|
||||
app = get_app()
|
||||
|
||||
with test_client_factory(app):
|
||||
pass
|
||||
|
||||
assert SimpleInitializableMiddleware.counter == 1
|
||||
|
||||
test_client_factory(app).get("/foo")
|
||||
|
||||
assert SimpleInitializableMiddleware.counter == 1
|
||||
|
||||
app = get_app()
|
||||
|
||||
test_client_factory(app).get("/foo")
|
||||
|
||||
assert SimpleInitializableMiddleware.counter == 2
|
||||
|
|
Loading…
Reference in New Issue