Lazily build middleware stack (#2017)

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Adrian Garcia Badaracco 2023-02-05 21:35:09 -08:00 committed by GitHub
parent ca1711fab7
commit 51c1de1839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 16 deletions

View File

@ -65,7 +65,7 @@ class Starlette:
on_startup is None and on_shutdown is None
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
self._debug = debug
self.debug = debug
self.state = State()
self.router = Router(
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
@ -74,7 +74,7 @@ class Starlette:
{} if exception_handlers is None else dict(exception_handlers)
)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack = self.build_middleware_stack()
self.middleware_stack: typing.Optional[ASGIApp] = None
def build_middleware_stack(self) -> ASGIApp:
debug = self.debug
@ -108,20 +108,13 @@ class Starlette:
def routes(self) -> typing.List[BaseRoute]:
return self.router.routes
@property
def debug(self) -> bool:
return self._debug
@debug.setter
def debug(self, value: bool) -> None:
self._debug = value
self.middleware_stack = self.build_middleware_stack()
def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
return self.router.url_path_for(name, **path_params)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["app"] = self
if self.middleware_stack is None:
self.middleware_stack = self.build_middleware_stack()
await self.middleware_stack(scope, receive, send)
def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
@ -137,11 +130,10 @@ class Starlette:
) -> None: # pragma: no cover
self.router.host(host, app=app, name=name)
def add_middleware(
self, middleware_class: type, **options: typing.Any
) -> None: # pragma: no cover
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
self.middleware_stack = self.build_middleware_stack()
def add_exception_handler(
self,
@ -149,7 +141,6 @@ class Starlette:
handler: typing.Callable,
) -> None: # pragma: no cover
self.exception_handlers[exc_class_or_status_code] = handler
self.middleware_stack = self.build_middleware_stack()
def add_event_handler(
self, event_type: str, func: typing.Callable

View File

@ -1,7 +1,9 @@
import os
from contextlib import asynccontextmanager
from typing import Any, Callable
import anyio
import httpx
import pytest
from starlette import status
@ -13,6 +15,7 @@ from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.types import ASGIApp
from starlette.websockets import WebSocket
@ -486,3 +489,45 @@ def test_decorator_deprecations() -> None:
app.on_event("startup")(startup)
assert len(record) == 1
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
class NoOpMiddleware:
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, *args: Any):
await self.app(*args)
class SimpleInitializableMiddleware:
counter = 0
def __init__(self, app: ASGIApp):
self.app = app
SimpleInitializableMiddleware.counter += 1
async def __call__(self, *args: Any):
await self.app(*args)
def get_app() -> ASGIApp:
app = Starlette()
app.add_middleware(SimpleInitializableMiddleware)
app.add_middleware(NoOpMiddleware)
return app
app = get_app()
with test_client_factory(app):
pass
assert SimpleInitializableMiddleware.counter == 1
test_client_factory(app).get("/foo")
assert SimpleInitializableMiddleware.counter == 1
app = get_app()
test_client_factory(app).get("/foo")
assert SimpleInitializableMiddleware.counter == 2