2018-10-29 14:46:42 +00:00
|
|
|
import pytest
|
|
|
|
|
2018-10-29 13:02:43 +00:00
|
|
|
from starlette.applications import Starlette
|
2019-11-04 15:08:24 +00:00
|
|
|
from starlette.middleware import Middleware
|
2018-10-29 13:02:43 +00:00
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from starlette.responses import PlainTextResponse
|
2019-11-04 15:08:24 +00:00
|
|
|
from starlette.routing import Route
|
2018-10-29 13:02:43 +00:00
|
|
|
from starlette.testclient import TestClient
|
|
|
|
|
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
class CustomMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request, call_next):
|
|
|
|
response = await call_next(request)
|
|
|
|
response.headers["Custom-Header"] = "Example"
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
app = Starlette()
|
|
|
|
app.add_middleware(CustomMiddleware)
|
|
|
|
|
2018-10-29 13:02:43 +00:00
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
@app.route("/")
|
|
|
|
def homepage(request):
|
|
|
|
return PlainTextResponse("Homepage")
|
2018-10-29 13:02:43 +00:00
|
|
|
|
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
@app.route("/exc")
|
|
|
|
def exc(request):
|
|
|
|
raise Exception()
|
2018-10-29 13:02:43 +00:00
|
|
|
|
2019-03-19 16:03:19 +00:00
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
@app.route("/no-response")
|
|
|
|
class NoResponse:
|
|
|
|
def __init__(self, scope, receive, send):
|
|
|
|
pass
|
2018-10-29 13:02:43 +00:00
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
def __await__(self):
|
|
|
|
return self.dispatch().__await__()
|
2018-10-29 13:02:43 +00:00
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
async def dispatch(self):
|
|
|
|
pass
|
2018-10-29 13:02:43 +00:00
|
|
|
|
2019-06-17 14:29:49 +00:00
|
|
|
|
|
|
|
@app.websocket_route("/ws")
|
|
|
|
async def websocket_endpoint(session):
|
|
|
|
await session.accept()
|
|
|
|
await session.send_text("Hello, world!")
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
|
|
def test_custom_middleware():
|
2018-10-29 13:02:43 +00:00
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.headers["Custom-Header"] == "Example"
|
|
|
|
|
|
|
|
with pytest.raises(Exception):
|
|
|
|
response = client.get("/exc")
|
|
|
|
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
response = client.get("/no-response")
|
|
|
|
|
|
|
|
with client.websocket_connect("/ws") as session:
|
|
|
|
text = session.receive_text()
|
|
|
|
assert text == "Hello, world!"
|
2018-11-06 12:19:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_middleware_decorator():
|
|
|
|
app = Starlette()
|
|
|
|
|
|
|
|
@app.route("/homepage")
|
|
|
|
def homepage(request):
|
|
|
|
return PlainTextResponse("Homepage")
|
|
|
|
|
|
|
|
@app.middleware("http")
|
|
|
|
async def plaintext(request, call_next):
|
|
|
|
if request.url.path == "/":
|
|
|
|
return PlainTextResponse("OK")
|
|
|
|
response = await call_next(request)
|
|
|
|
response.headers["Custom"] = "Example"
|
|
|
|
return response
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.text == "OK"
|
|
|
|
|
|
|
|
response = client.get("/homepage")
|
|
|
|
assert response.text == "Homepage"
|
|
|
|
assert response.headers["Custom"] == "Example"
|
2019-06-16 00:22:50 +00:00
|
|
|
|
2019-06-17 12:56:29 +00:00
|
|
|
|
2019-06-16 00:22:50 +00:00
|
|
|
def test_state_data_across_multiple_middlewares():
|
2019-06-17 12:56:29 +00:00
|
|
|
expected_value1 = "foo"
|
|
|
|
expected_value2 = "bar"
|
2019-06-16 00:22:50 +00:00
|
|
|
|
|
|
|
class aMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request, call_next):
|
2019-06-16 16:18:38 +00:00
|
|
|
request.state.foo = expected_value1
|
2019-06-16 00:22:50 +00:00
|
|
|
response = await call_next(request)
|
|
|
|
return response
|
|
|
|
|
|
|
|
class bMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request, call_next):
|
2019-06-16 16:18:38 +00:00
|
|
|
request.state.bar = expected_value2
|
2019-06-16 00:22:50 +00:00
|
|
|
response = await call_next(request)
|
2019-06-16 16:18:38 +00:00
|
|
|
response.headers["X-State-Foo"] = request.state.foo
|
|
|
|
return response
|
|
|
|
|
|
|
|
class cMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request, call_next):
|
|
|
|
response = await call_next(request)
|
|
|
|
response.headers["X-State-Bar"] = request.state.bar
|
2019-06-16 00:22:50 +00:00
|
|
|
return response
|
|
|
|
|
|
|
|
app = Starlette()
|
|
|
|
app.add_middleware(aMiddleware)
|
|
|
|
app.add_middleware(bMiddleware)
|
2019-06-16 16:18:38 +00:00
|
|
|
app.add_middleware(cMiddleware)
|
2019-06-16 00:22:50 +00:00
|
|
|
|
|
|
|
@app.route("/")
|
|
|
|
def homepage(request):
|
|
|
|
return PlainTextResponse("OK")
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.text == "OK"
|
2019-06-16 16:18:38 +00:00
|
|
|
assert response.headers["X-State-Foo"] == expected_value1
|
|
|
|
assert response.headers["X-State-Bar"] == expected_value2
|
2019-11-04 15:08:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_app_middleware_argument():
|
|
|
|
def homepage(request):
|
|
|
|
return PlainTextResponse("Homepage")
|
|
|
|
|
|
|
|
app = Starlette(
|
|
|
|
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
|
|
|
|
)
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.headers["Custom-Header"] == "Example"
|
|
|
|
|
|
|
|
|
|
|
|
def test_middleware_repr():
|
|
|
|
middleware = Middleware(CustomMiddleware)
|
|
|
|
assert repr(middleware) == "Middleware(CustomMiddleware)"
|
2021-06-18 14:48:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_fully_evaluated_response():
|
|
|
|
# Test for https://github.com/encode/starlette/issues/1022
|
|
|
|
class CustomMiddleware(BaseHTTPMiddleware):
|
|
|
|
async def dispatch(self, request, call_next):
|
|
|
|
await call_next(request)
|
|
|
|
return PlainTextResponse("Custom")
|
|
|
|
|
|
|
|
app = Starlette()
|
|
|
|
app.add_middleware(CustomMiddleware)
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/does_not_exist")
|
|
|
|
assert response.text == "Custom"
|