starlette/tests/middleware/test_base.py

124 lines
3.4 KiB
Python
Raw Normal View History

import pytest
from starlette.applications import Starlette
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.testclient import TestClient
def test_custom_middleware():
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["Custom-Header"] = "Example"
return response
app = Starlette()
app.add_middleware(CustomMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage")
@app.route("/exc")
def exc(request):
raise Exception()
@app.route("/no-response")
class NoResponse:
def __init__(self, scope, receive, send):
pass
def __await__(self):
return self.dispatch().__await__()
async def dispatch(self):
pass
@app.websocket_route("/ws")
async def websocket_endpoint(session):
await session.accept()
await session.send_text("Hello, world!")
await session.close()
client = TestClient(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
with pytest.raises(Exception):
response = client.get("/exc")
with pytest.raises(RuntimeError):
response = client.get("/no-response")
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
def test_middleware_decorator():
app = Starlette()
@app.route("/homepage")
def homepage(request):
return PlainTextResponse("Homepage")
@app.middleware("http")
async def plaintext(request, call_next):
if request.url.path == "/":
return PlainTextResponse("OK")
response = await call_next(request)
response.headers["Custom"] = "Example"
return response
client = TestClient(app)
response = client.get("/")
assert response.text == "OK"
response = client.get("/homepage")
assert response.text == "Homepage"
assert response.headers["Custom"] == "Example"
def test_state_data_across_multiple_middlewares():
2019-06-16 16:18:38 +00:00
expected_value1 = 'foo'
expected_value2 = 'bar'
class aMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
2019-06-16 16:18:38 +00:00
request.state.foo = expected_value1
response = await call_next(request)
return response
class bMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
2019-06-16 16:18:38 +00:00
request.state.bar = expected_value2
response = await call_next(request)
2019-06-16 16:18:38 +00:00
response.headers["X-State-Foo"] = request.state.foo
return response
class cMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
response.headers["X-State-Bar"] = request.state.bar
return response
app = Starlette()
app.add_middleware(aMiddleware)
app.add_middleware(bMiddleware)
2019-06-16 16:18:38 +00:00
app.add_middleware(cMiddleware)
@app.route("/")
def homepage(request):
return PlainTextResponse("OK")
client = TestClient(app)
response = client.get("/")
assert response.text == "OK"
2019-06-16 16:18:38 +00:00
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2