import pytest from starlette.applications import Starlette from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.testclient import TestClient class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) response.headers["Custom-Header"] = "Example" return response app = Starlette() app.add_middleware(CustomMiddleware) @app.route("/") def homepage(request): return PlainTextResponse("Homepage") @app.route("/exc") def exc(request): raise Exception() @app.route("/no-response") class App: def __init__(self, scope): pass async def __call__(self, receive, send): pass @app.websocket_route("/ws") async def websocket_endpoint(session): await session.accept() await session.send_text("Hello, world!") await session.close() def test_custom_middleware(): client = TestClient(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" with pytest.raises(Exception): response = client.get("/exc") with pytest.raises(RuntimeError): response = client.get("/no-response") with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" def test_middleware_decorator(): app = Starlette() @app.route("/homepage") def homepage(request): return PlainTextResponse("Homepage") @app.middleware("http") async def plaintext(request, call_next): if request.url.path == "/": return PlainTextResponse("OK") response = await call_next(request) response.headers["Custom"] = "Example" return response client = TestClient(app) response = client.get("/") assert response.text == "OK" response = client.get("/homepage") assert response.text == "Homepage" assert response.headers["Custom"] == "Example"