diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 32468dcd..2eb802be 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -4,7 +4,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.routing import Route +from starlette.routing import Route, WebSocketRoute class CustomMiddleware(BaseHTTPMiddleware): @@ -14,21 +14,14 @@ class CustomMiddleware(BaseHTTPMiddleware): return response -app = Starlette() -app.add_middleware(CustomMiddleware) - - -@app.route("/") def homepage(request): return PlainTextResponse("Homepage") -@app.route("/exc") def exc(request): raise Exception("Exc") -@app.route("/exc-stream") def exc_stream(request): return StreamingResponse(_generate_faulty_stream()) @@ -38,7 +31,6 @@ def _generate_faulty_stream(): raise Exception("Faulty Stream") -@app.route("/no-response") class NoResponse: def __init__(self, scope, receive, send): pass @@ -50,13 +42,24 @@ class NoResponse: pass -@app.websocket_route("/ws") async def websocket_endpoint(session): await session.accept() await session.send_text("Hello, world!") await session.close() +app = Starlette( + routes=[ + Route("/", endpoint=homepage), + Route("/exc", endpoint=exc), + Route("/exc-stream", endpoint=exc_stream), + Route("/no-response", endpoint=NoResponse), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ], + middleware=[Middleware(CustomMiddleware)], +) + + def test_custom_middleware(test_client_factory): client = test_client_factory(app) response = client.get("/")