2018-10-29 14:46:42 +00:00
import pytest
2018-10-29 13:02:43 +00:00
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["Custom-Header"] = "Example"
return response
app = Starlette()
def homepage(request):
return PlainTextResponse("Homepage")
def exc(request):
raise Exception()
class App:
def __init__(self, scope):
async def __call__(self, receive, send):
async def websocket_endpoint(session):
await session.accept()
await session.send_text("Hello, world!")
await session.close()
def test_custom_middleware():
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
with pytest.raises(Exception):
response = client.get("/exc")
with pytest.raises(RuntimeError):
response = client.get("/no-response")
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"