mirror of https://github.com/encode/starlette.git
1293 lines
41 KiB
Python
1293 lines
41 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import functools
|
|
import json
|
|
import typing
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.middleware import Middleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, PlainTextResponse, Response
|
|
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
|
|
from starlette.testclient import TestClient
|
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
from tests.types import TestClientFactory
|
|
|
|
|
|
def homepage(request: Request) -> Response:
|
|
return Response("Hello, world", media_type="text/plain")
|
|
|
|
|
|
def users(request: Request) -> Response:
|
|
return Response("All users", media_type="text/plain")
|
|
|
|
|
|
def user(request: Request) -> Response:
|
|
content = "User " + request.path_params["username"]
|
|
return Response(content, media_type="text/plain")
|
|
|
|
|
|
def user_me(request: Request) -> Response:
|
|
content = "User fixed me"
|
|
return Response(content, media_type="text/plain")
|
|
|
|
|
|
def disable_user(request: Request) -> Response:
|
|
content = "User " + request.path_params["username"] + " disabled"
|
|
return Response(content, media_type="text/plain")
|
|
|
|
|
|
def user_no_match(request: Request) -> Response: # pragma: no cover
|
|
content = "User fixed no match"
|
|
return Response(content, media_type="text/plain")
|
|
|
|
|
|
async def partial_endpoint(arg: str, request: Request) -> JSONResponse:
|
|
return JSONResponse({"arg": arg})
|
|
|
|
|
|
async def partial_ws_endpoint(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
await websocket.send_json({"url": str(websocket.url)})
|
|
await websocket.close()
|
|
|
|
|
|
class PartialRoutes:
|
|
@classmethod
|
|
async def async_endpoint(cls, arg: str, request: Request) -> JSONResponse:
|
|
return JSONResponse({"arg": arg})
|
|
|
|
@classmethod
|
|
async def async_ws_endpoint(cls, websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
await websocket.send_json({"url": str(websocket.url)})
|
|
await websocket.close()
|
|
|
|
|
|
def func_homepage(request: Request) -> Response:
|
|
return Response("Hello, world!", media_type="text/plain")
|
|
|
|
|
|
def contact(request: Request) -> Response:
|
|
return Response("Hello, POST!", media_type="text/plain")
|
|
|
|
|
|
def int_convertor(request: Request) -> JSONResponse:
|
|
number = request.path_params["param"]
|
|
return JSONResponse({"int": number})
|
|
|
|
|
|
def float_convertor(request: Request) -> JSONResponse:
|
|
num = request.path_params["param"]
|
|
return JSONResponse({"float": num})
|
|
|
|
|
|
def path_convertor(request: Request) -> JSONResponse:
|
|
path = request.path_params["param"]
|
|
return JSONResponse({"path": path})
|
|
|
|
|
|
def uuid_converter(request: Request) -> JSONResponse:
|
|
uuid_param = request.path_params["param"]
|
|
return JSONResponse({"uuid": str(uuid_param)})
|
|
|
|
|
|
def path_with_parentheses(request: Request) -> JSONResponse:
|
|
number = request.path_params["param"]
|
|
return JSONResponse({"int": number})
|
|
|
|
|
|
async def websocket_endpoint(session: WebSocket) -> None:
|
|
await session.accept()
|
|
await session.send_text("Hello, world!")
|
|
await session.close()
|
|
|
|
|
|
async def websocket_params(session: WebSocket) -> None:
|
|
await session.accept()
|
|
await session.send_text(f"Hello, {session.path_params['room']}!")
|
|
await session.close()
|
|
|
|
|
|
app = Router(
|
|
[
|
|
Route("/", endpoint=homepage, methods=["GET"]),
|
|
Mount(
|
|
"/users",
|
|
routes=[
|
|
Route("/", endpoint=users),
|
|
Route("/me", endpoint=user_me),
|
|
Route("/{username}", endpoint=user),
|
|
Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]),
|
|
Route("/nomatch", endpoint=user_no_match),
|
|
],
|
|
),
|
|
Mount(
|
|
"/partial",
|
|
routes=[
|
|
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
|
|
Route(
|
|
"/cls",
|
|
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
|
|
),
|
|
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
|
|
WebSocketRoute(
|
|
"/ws/cls",
|
|
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
|
|
),
|
|
],
|
|
),
|
|
Mount("/static", app=Response("xxxxx", media_type="image/png")),
|
|
Route("/func", endpoint=func_homepage, methods=["GET"]),
|
|
Route("/func", endpoint=contact, methods=["POST"]),
|
|
Route("/int/{param:int}", endpoint=int_convertor, name="int-convertor"),
|
|
Route("/float/{param:float}", endpoint=float_convertor, name="float-convertor"),
|
|
Route("/path/{param:path}", endpoint=path_convertor, name="path-convertor"),
|
|
Route("/uuid/{param:uuid}", endpoint=uuid_converter, name="uuid-convertor"),
|
|
# Route with chars that conflict with regex meta chars
|
|
Route(
|
|
"/path-with-parentheses({param:int})",
|
|
endpoint=path_with_parentheses,
|
|
name="path-with-parentheses",
|
|
),
|
|
WebSocketRoute("/ws", endpoint=websocket_endpoint),
|
|
WebSocketRoute("/ws/{room}", endpoint=websocket_params),
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(
|
|
test_client_factory: TestClientFactory,
|
|
) -> typing.Generator[TestClient, None, None]:
|
|
with test_client_factory(app) as client:
|
|
yield client
|
|
|
|
|
|
@pytest.mark.filterwarnings(
|
|
r"ignore"
|
|
r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\."
|
|
r":UserWarning"
|
|
r":charset_normalizer.api"
|
|
)
|
|
def test_router(client: TestClient) -> None:
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, world"
|
|
|
|
response = client.post("/")
|
|
assert response.status_code == 405
|
|
assert response.text == "Method Not Allowed"
|
|
assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"}
|
|
|
|
response = client.get("/foo")
|
|
assert response.status_code == 404
|
|
assert response.text == "Not Found"
|
|
|
|
response = client.get("/users")
|
|
assert response.status_code == 200
|
|
assert response.text == "All users"
|
|
|
|
response = client.get("/users/tomchristie")
|
|
assert response.status_code == 200
|
|
assert response.text == "User tomchristie"
|
|
|
|
response = client.get("/users/me")
|
|
assert response.status_code == 200
|
|
assert response.text == "User fixed me"
|
|
|
|
response = client.get("/users/tomchristie/")
|
|
assert response.status_code == 200
|
|
assert response.url == "http://testserver/users/tomchristie"
|
|
assert response.text == "User tomchristie"
|
|
|
|
response = client.put("/users/tomchristie:disable")
|
|
assert response.status_code == 200
|
|
assert response.url == "http://testserver/users/tomchristie:disable"
|
|
assert response.text == "User tomchristie disabled"
|
|
|
|
response = client.get("/users/nomatch")
|
|
assert response.status_code == 200
|
|
assert response.text == "User nomatch"
|
|
|
|
response = client.get("/static/123")
|
|
assert response.status_code == 200
|
|
assert response.text == "xxxxx"
|
|
|
|
|
|
def test_route_converters(client: TestClient) -> None:
|
|
# Test integer conversion
|
|
response = client.get("/int/5")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"int": 5}
|
|
assert app.url_path_for("int-convertor", param=5) == "/int/5"
|
|
|
|
# Test path with parentheses
|
|
response = client.get("/path-with-parentheses(7)")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"int": 7}
|
|
assert app.url_path_for("path-with-parentheses", param=7) == "/path-with-parentheses(7)"
|
|
|
|
# Test float conversion
|
|
response = client.get("/float/25.5")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"float": 25.5}
|
|
assert app.url_path_for("float-convertor", param=25.5) == "/float/25.5"
|
|
|
|
# Test path conversion
|
|
response = client.get("/path/some/example")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"path": "some/example"}
|
|
assert app.url_path_for("path-convertor", param="some/example") == "/path/some/example"
|
|
|
|
# Test UUID conversion
|
|
response = client.get("/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"uuid": "ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"}
|
|
assert (
|
|
app.url_path_for("uuid-convertor", param=uuid.UUID("ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"))
|
|
== "/uuid/ec38df32-ceda-4cfa-9b4a-1aeb94ad551a"
|
|
)
|
|
|
|
|
|
def test_url_path_for() -> None:
|
|
assert app.url_path_for("homepage") == "/"
|
|
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
|
|
assert app.url_path_for("websocket_endpoint") == "/ws"
|
|
with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "".'):
|
|
assert app.url_path_for("broken")
|
|
with pytest.raises(NoMatchFound, match='No route exists for name "broken" and params "key, key2".'):
|
|
assert app.url_path_for("broken", key="value", key2="value2")
|
|
with pytest.raises(AssertionError):
|
|
app.url_path_for("user", username="tom/christie")
|
|
with pytest.raises(AssertionError):
|
|
app.url_path_for("user", username="")
|
|
|
|
|
|
def test_url_for() -> None:
|
|
assert app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") == "https://example.org/"
|
|
assert (
|
|
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org/root_path/")
|
|
== "https://example.org/root_path/"
|
|
)
|
|
assert (
|
|
app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org")
|
|
== "https://example.org/users/tomchristie"
|
|
)
|
|
assert (
|
|
app.url_path_for("user", username="tomchristie").make_absolute_url(base_url="https://example.org/root_path/")
|
|
== "https://example.org/root_path/users/tomchristie"
|
|
)
|
|
assert (
|
|
app.url_path_for("websocket_endpoint").make_absolute_url(base_url="https://example.org")
|
|
== "wss://example.org/ws"
|
|
)
|
|
|
|
|
|
def test_router_add_route(client: TestClient) -> None:
|
|
response = client.get("/func")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
def test_router_duplicate_path(client: TestClient) -> None:
|
|
response = client.post("/func")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, POST!"
|
|
|
|
|
|
def test_router_add_websocket_route(client: TestClient) -> None:
|
|
with client.websocket_connect("/ws") as session:
|
|
text = session.receive_text()
|
|
assert text == "Hello, world!"
|
|
|
|
with client.websocket_connect("/ws/test") as session:
|
|
text = session.receive_text()
|
|
assert text == "Hello, test!"
|
|
|
|
|
|
def test_router_middleware(test_client_factory: TestClientFactory) -> None:
|
|
class CustomMiddleware:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
response = PlainTextResponse("OK")
|
|
await response(scope, receive, send)
|
|
|
|
app = Router(
|
|
routes=[Route("/", homepage)],
|
|
middleware=[Middleware(CustomMiddleware)],
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "OK"
|
|
|
|
|
|
def http_endpoint(request: Request) -> Response:
|
|
url = request.url_for("http_endpoint")
|
|
return Response(f"URL: {url}", media_type="text/plain")
|
|
|
|
|
|
class WebSocketEndpoint:
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope=scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.send_json({"URL": str(websocket.url_for("websocket_endpoint"))})
|
|
await websocket.close()
|
|
|
|
|
|
mixed_protocol_app = Router(
|
|
routes=[
|
|
Route("/", endpoint=http_endpoint),
|
|
WebSocketRoute("/", endpoint=WebSocketEndpoint(), name="websocket_endpoint"),
|
|
]
|
|
)
|
|
|
|
|
|
def test_protocol_switch(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(mixed_protocol_app)
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "URL: http://testserver/"
|
|
|
|
with client.websocket_connect("/") as session:
|
|
assert session.receive_json() == {"URL": "ws://testserver/"}
|
|
|
|
with pytest.raises(WebSocketDisconnect):
|
|
with client.websocket_connect("/404"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
ok = PlainTextResponse("OK")
|
|
|
|
|
|
def test_mount_urls(test_client_factory: TestClientFactory) -> None:
|
|
mounted = Router([Mount("/users", ok, name="users")])
|
|
client = test_client_factory(mounted)
|
|
assert client.get("/users").status_code == 200
|
|
assert client.get("/users").url == "http://testserver/users/"
|
|
assert client.get("/users/").status_code == 200
|
|
assert client.get("/users/a").status_code == 200
|
|
assert client.get("/usersa").status_code == 404
|
|
|
|
|
|
def test_reverse_mount_urls() -> None:
|
|
mounted = Router([Mount("/users", ok, name="users")])
|
|
assert mounted.url_path_for("users", path="/a") == "/users/a"
|
|
|
|
users = Router([Route("/{username}", ok, name="user")])
|
|
mounted = Router([Mount("/{subpath}/users", users, name="users")])
|
|
assert mounted.url_path_for("users:user", subpath="test", username="tom") == "/test/users/tom"
|
|
assert mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
|
|
|
|
mounted = Router([Mount("/users", ok, name="users")])
|
|
with pytest.raises(NoMatchFound):
|
|
mounted.url_path_for("users", path="/a", foo="bar")
|
|
|
|
mounted = Router([Mount("/users", ok, name="users")])
|
|
with pytest.raises(NoMatchFound):
|
|
mounted.url_path_for("users")
|
|
|
|
|
|
def test_mount_at_root(test_client_factory: TestClientFactory) -> None:
|
|
mounted = Router([Mount("/", ok, name="users")])
|
|
client = test_client_factory(mounted)
|
|
assert client.get("/").status_code == 200
|
|
|
|
|
|
def users_api(request: Request) -> JSONResponse:
|
|
return JSONResponse({"users": [{"username": "tom"}]})
|
|
|
|
|
|
mixed_hosts_app = Router(
|
|
routes=[
|
|
Host(
|
|
"www.example.org",
|
|
app=Router(
|
|
[
|
|
Route("/", homepage, name="homepage"),
|
|
Route("/users", users, name="users"),
|
|
]
|
|
),
|
|
),
|
|
Host(
|
|
"api.example.org",
|
|
name="api",
|
|
app=Router([Route("/users", users_api, name="users")]),
|
|
),
|
|
Host(
|
|
"port.example.org:3600",
|
|
name="port",
|
|
app=Router([Route("/", homepage, name="homepage")]),
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
def test_host_routing(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/")
|
|
|
|
response = client.get("/users")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"users": [{"username": "tom"}]}
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 404
|
|
|
|
client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/")
|
|
|
|
response = client.get("/users")
|
|
assert response.status_code == 200
|
|
assert response.text == "All users"
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:3600/")
|
|
|
|
response = client.get("/users")
|
|
assert response.status_code == 404
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
# Port in requested Host is irrelevant.
|
|
|
|
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org:5600/")
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
|
|
|
|
def test_host_reverse_urls() -> None:
|
|
assert mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") == "https://www.example.org/"
|
|
assert (
|
|
mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") == "https://www.example.org/users"
|
|
)
|
|
assert (
|
|
mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
|
|
== "https://api.example.org/users"
|
|
)
|
|
assert (
|
|
mixed_hosts_app.url_path_for("port:homepage").make_absolute_url("https://whatever")
|
|
== "https://port.example.org:3600/"
|
|
)
|
|
with pytest.raises(NoMatchFound):
|
|
mixed_hosts_app.url_path_for("api", path="whatever", foo="bar")
|
|
|
|
|
|
async def subdomain_app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
response = JSONResponse({"subdomain": scope["path_params"]["subdomain"]})
|
|
await response(scope, receive, send)
|
|
|
|
|
|
subdomain_router = Router(routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")])
|
|
|
|
|
|
def test_subdomain_routing(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(subdomain_router, base_url="https://foo.example.org/")
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"subdomain": "foo"}
|
|
|
|
|
|
def test_subdomain_reverse_urls() -> None:
|
|
assert (
|
|
subdomain_router.url_path_for("subdomains", subdomain="foo", path="/homepage").make_absolute_url(
|
|
"https://whatever"
|
|
)
|
|
== "https://foo.example.org/homepage"
|
|
)
|
|
|
|
|
|
async def echo_urls(request: Request) -> JSONResponse:
|
|
return JSONResponse(
|
|
{
|
|
"index": str(request.url_for("index")),
|
|
"submount": str(request.url_for("mount:submount")),
|
|
}
|
|
)
|
|
|
|
|
|
echo_url_routes = [
|
|
Route("/", echo_urls, name="index", methods=["GET"]),
|
|
Mount(
|
|
"/submount",
|
|
name="mount",
|
|
routes=[Route("/", echo_urls, name="submount", methods=["GET"])],
|
|
),
|
|
]
|
|
|
|
|
|
def test_url_for_with_root_path(test_client_factory: TestClientFactory) -> None:
|
|
app = Starlette(routes=echo_url_routes)
|
|
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/sub_path")
|
|
response = client.get("/sub_path/")
|
|
assert response.json() == {
|
|
"index": "https://www.example.org/sub_path/",
|
|
"submount": "https://www.example.org/sub_path/submount/",
|
|
}
|
|
response = client.get("/sub_path/submount/")
|
|
assert response.json() == {
|
|
"index": "https://www.example.org/sub_path/",
|
|
"submount": "https://www.example.org/sub_path/submount/",
|
|
}
|
|
|
|
|
|
async def stub_app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
pass # pragma: no cover
|
|
|
|
|
|
double_mount_routes = [
|
|
Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]),
|
|
]
|
|
|
|
|
|
def test_url_for_with_double_mount() -> None:
|
|
app = Starlette(routes=double_mount_routes)
|
|
url = app.url_path_for("mount:static", path="123")
|
|
assert url == "/mount/static/123"
|
|
|
|
|
|
def test_standalone_route_matches(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
app = Route("/", PlainTextResponse("Hello, World!"))
|
|
client = test_client_factory(app)
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, World!"
|
|
|
|
|
|
def test_standalone_route_does_not_match(
|
|
test_client_factory: typing.Callable[..., TestClient],
|
|
) -> None:
|
|
app = Route("/", PlainTextResponse("Hello, World!"))
|
|
client = test_client_factory(app)
|
|
response = client.get("/invalid")
|
|
assert response.status_code == 404
|
|
assert response.text == "Not Found"
|
|
|
|
|
|
async def ws_helloworld(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, world!")
|
|
await websocket.close()
|
|
|
|
|
|
def test_standalone_ws_route_matches(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
app = WebSocketRoute("/", ws_helloworld)
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
text = websocket.receive_text()
|
|
assert text == "Hello, world!"
|
|
|
|
|
|
def test_standalone_ws_route_does_not_match(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
app = WebSocketRoute("/", ws_helloworld)
|
|
client = test_client_factory(app)
|
|
with pytest.raises(WebSocketDisconnect):
|
|
with client.websocket_connect("/invalid"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_lifespan_async(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
shutdown_complete = False
|
|
|
|
async def hello_world(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("hello, world")
|
|
|
|
async def run_startup() -> None:
|
|
nonlocal startup_complete
|
|
startup_complete = True
|
|
|
|
async def run_shutdown() -> None:
|
|
nonlocal shutdown_complete
|
|
shutdown_complete = True
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
app = Router(
|
|
on_startup=[run_startup],
|
|
on_shutdown=[run_shutdown],
|
|
routes=[Route("/", hello_world)],
|
|
)
|
|
|
|
assert not startup_complete
|
|
assert not shutdown_complete
|
|
with test_client_factory(app) as client:
|
|
assert startup_complete
|
|
assert not shutdown_complete
|
|
client.get("/")
|
|
assert startup_complete
|
|
assert shutdown_complete
|
|
|
|
|
|
def test_lifespan_with_on_events(test_client_factory: TestClientFactory) -> None:
|
|
lifespan_called = False
|
|
startup_called = False
|
|
shutdown_called = False
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app: Starlette) -> typing.AsyncGenerator[None, None]:
|
|
nonlocal lifespan_called
|
|
lifespan_called = True
|
|
yield
|
|
|
|
# We do not expected, neither of run_startup nor run_shutdown to be called
|
|
# we thus mark them as #pragma: no cover, to fulfill test coverage
|
|
def run_startup() -> None: # pragma: no cover
|
|
nonlocal startup_called
|
|
startup_called = True
|
|
|
|
def run_shutdown() -> None: # pragma: no cover
|
|
nonlocal shutdown_called
|
|
shutdown_called = True
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
with pytest.warns(
|
|
UserWarning, match="The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`."
|
|
):
|
|
app = Router(on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan)
|
|
|
|
assert not lifespan_called
|
|
assert not startup_called
|
|
assert not shutdown_called
|
|
|
|
# Triggers the lifespan events
|
|
with test_client_factory(app):
|
|
...
|
|
|
|
assert lifespan_called
|
|
assert not startup_called
|
|
assert not shutdown_called
|
|
|
|
|
|
def test_lifespan_sync(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
shutdown_complete = False
|
|
|
|
def hello_world(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("hello, world")
|
|
|
|
def run_startup() -> None:
|
|
nonlocal startup_complete
|
|
startup_complete = True
|
|
|
|
def run_shutdown() -> None:
|
|
nonlocal shutdown_complete
|
|
shutdown_complete = True
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
app = Router(
|
|
on_startup=[run_startup],
|
|
on_shutdown=[run_shutdown],
|
|
routes=[Route("/", hello_world)],
|
|
)
|
|
|
|
assert not startup_complete
|
|
assert not shutdown_complete
|
|
with test_client_factory(app) as client:
|
|
assert startup_complete
|
|
assert not shutdown_complete
|
|
client.get("/")
|
|
assert startup_complete
|
|
assert shutdown_complete
|
|
|
|
|
|
def test_lifespan_state_unsupported(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(
|
|
app: ASGIApp,
|
|
) -> typing.AsyncGenerator[dict[str, str], None]:
|
|
yield {"foo": "bar"}
|
|
|
|
app = Router(
|
|
lifespan=lifespan,
|
|
routes=[Mount("/", PlainTextResponse("hello, world"))],
|
|
)
|
|
|
|
async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
|
|
del scope["state"]
|
|
await app(scope, receive, send)
|
|
|
|
with pytest.raises(RuntimeError, match='The server does not support "state" in the lifespan scope'):
|
|
with test_client_factory(no_state_wrapper):
|
|
raise AssertionError("Should not be called") # pragma: no cover
|
|
|
|
|
|
def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
shutdown_complete = False
|
|
|
|
class State(typing.TypedDict):
|
|
count: int
|
|
items: list[int]
|
|
|
|
async def hello_world(request: Request) -> Response:
|
|
# modifications to the state should not leak across requests
|
|
assert request.state.count == 0
|
|
# modify the state, this should not leak to the lifespan or other requests
|
|
request.state.count += 1
|
|
# since state.items is a mutable object this modification _will_ leak across
|
|
# requests and to the lifespan
|
|
request.state.items.append(1)
|
|
return PlainTextResponse("hello, world")
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app: Starlette) -> typing.AsyncIterator[State]:
|
|
nonlocal startup_complete, shutdown_complete
|
|
startup_complete = True
|
|
state = State(count=0, items=[])
|
|
yield state
|
|
shutdown_complete = True
|
|
# modifications made to the state from a request do not leak to the lifespan
|
|
assert state["count"] == 0
|
|
# unless of course the request mutates a mutable object that is referenced
|
|
# via state
|
|
assert state["items"] == [1, 1]
|
|
|
|
app = Router(
|
|
lifespan=lifespan,
|
|
routes=[Route("/", hello_world)],
|
|
)
|
|
|
|
assert not startup_complete
|
|
assert not shutdown_complete
|
|
with test_client_factory(app) as client:
|
|
assert startup_complete
|
|
assert not shutdown_complete
|
|
client.get("/")
|
|
# Calling it a second time to ensure that the state is preserved.
|
|
client.get("/")
|
|
assert startup_complete
|
|
assert shutdown_complete
|
|
|
|
|
|
def test_raise_on_startup(test_client_factory: TestClientFactory) -> None:
|
|
def run_startup() -> None:
|
|
raise RuntimeError()
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
router = Router(on_startup=[run_startup])
|
|
startup_failed = False
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
async def _send(message: Message) -> None:
|
|
nonlocal startup_failed
|
|
if message["type"] == "lifespan.startup.failed":
|
|
startup_failed = True
|
|
return await send(message)
|
|
|
|
await router(scope, receive, _send)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
with test_client_factory(app):
|
|
pass # pragma: no cover
|
|
assert startup_failed
|
|
|
|
|
|
def test_raise_on_shutdown(test_client_factory: TestClientFactory) -> None:
|
|
def run_shutdown() -> None:
|
|
raise RuntimeError()
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
app = Router(on_shutdown=[run_shutdown])
|
|
|
|
with pytest.raises(RuntimeError):
|
|
with test_client_factory(app):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None:
|
|
test_client = test_client_factory(app)
|
|
response = test_client.get("/partial")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"arg": "foo"}
|
|
|
|
cls_method_response = test_client.get("/partial/cls")
|
|
assert cls_method_response.status_code == 200
|
|
assert cls_method_response.json() == {"arg": "foo"}
|
|
|
|
|
|
def test_partial_async_ws_endpoint(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
test_client = test_client_factory(app)
|
|
with test_client.websocket_connect("/partial/ws") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data == {"url": "ws://testserver/partial/ws"}
|
|
|
|
with test_client.websocket_connect("/partial/ws/cls") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data == {"url": "ws://testserver/partial/ws/cls"}
|
|
|
|
|
|
def test_duplicated_param_names() -> None:
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Duplicated param name id at path /{id}/{id}",
|
|
):
|
|
Route("/{id}/{id}", user)
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}",
|
|
):
|
|
Route("/{id}/{name}/{id}/{name}", user)
|
|
|
|
|
|
class Endpoint:
|
|
async def my_method(self, request: Request) -> None: ... # pragma: no cover
|
|
|
|
@classmethod
|
|
async def my_classmethod(cls, request: Request) -> None: ... # pragma: no cover
|
|
|
|
@staticmethod
|
|
async def my_staticmethod(request: Request) -> None: ... # pragma: no cover
|
|
|
|
def __call__(self, request: Request) -> None: ... # pragma: no cover
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"endpoint, expected_name",
|
|
[
|
|
pytest.param(func_homepage, "func_homepage", id="function"),
|
|
pytest.param(Endpoint().my_method, "my_method", id="method"),
|
|
pytest.param(Endpoint.my_classmethod, "my_classmethod", id="classmethod"),
|
|
pytest.param(
|
|
Endpoint.my_staticmethod,
|
|
"my_staticmethod",
|
|
id="staticmethod",
|
|
),
|
|
pytest.param(Endpoint(), "Endpoint", id="object"),
|
|
pytest.param(lambda request: ..., "<lambda>", id="lambda"),
|
|
],
|
|
)
|
|
def test_route_name(endpoint: typing.Callable[..., Response], expected_name: str) -> None:
|
|
assert Route(path="/", endpoint=endpoint).name == expected_name
|
|
|
|
|
|
class AddHeadersMiddleware:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
scope["add_headers_middleware"] = True
|
|
|
|
async def modified_send(msg: Message) -> None:
|
|
if msg["type"] == "http.response.start":
|
|
msg["headers"].append((b"X-Test", b"Set by middleware"))
|
|
await send(msg)
|
|
|
|
await self.app(scope, receive, modified_send)
|
|
|
|
|
|
def assert_middleware_header_route(request: Request) -> Response:
|
|
assert request.scope["add_headers_middleware"] is True
|
|
return Response()
|
|
|
|
|
|
route_with_middleware = Starlette(
|
|
routes=[
|
|
Route(
|
|
"/http",
|
|
endpoint=assert_middleware_header_route,
|
|
methods=["GET"],
|
|
middleware=[Middleware(AddHeadersMiddleware)],
|
|
),
|
|
Route("/home", homepage),
|
|
]
|
|
)
|
|
|
|
mounted_routes_with_middleware = Starlette(
|
|
routes=[
|
|
Mount(
|
|
"/http",
|
|
routes=[
|
|
Route(
|
|
"/",
|
|
endpoint=assert_middleware_header_route,
|
|
methods=["GET"],
|
|
name="route",
|
|
),
|
|
],
|
|
middleware=[Middleware(AddHeadersMiddleware)],
|
|
),
|
|
Route("/home", homepage),
|
|
]
|
|
)
|
|
|
|
|
|
mounted_app_with_middleware = Starlette(
|
|
routes=[
|
|
Mount(
|
|
"/http",
|
|
app=Route(
|
|
"/",
|
|
endpoint=assert_middleware_header_route,
|
|
methods=["GET"],
|
|
name="route",
|
|
),
|
|
middleware=[Middleware(AddHeadersMiddleware)],
|
|
),
|
|
Route("/home", homepage),
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"app",
|
|
[
|
|
mounted_routes_with_middleware,
|
|
mounted_app_with_middleware,
|
|
route_with_middleware,
|
|
],
|
|
)
|
|
def test_base_route_middleware(
|
|
test_client_factory: TestClientFactory,
|
|
app: Starlette,
|
|
) -> None:
|
|
test_client = test_client_factory(app)
|
|
|
|
response = test_client.get("/home")
|
|
assert response.status_code == 200
|
|
assert "X-Test" not in response.headers
|
|
|
|
response = test_client.get("/http")
|
|
assert response.status_code == 200
|
|
assert response.headers["X-Test"] == "Set by middleware"
|
|
|
|
|
|
def test_mount_routes_with_middleware_url_path_for() -> None:
|
|
"""Checks that url_path_for still works with mounted routes with Middleware"""
|
|
assert mounted_routes_with_middleware.url_path_for("route") == "/http/"
|
|
|
|
|
|
def test_mount_asgi_app_with_middleware_url_path_for() -> None:
|
|
"""Mounted ASGI apps do not work with url path for,
|
|
middleware does not change this
|
|
"""
|
|
with pytest.raises(NoMatchFound):
|
|
mounted_app_with_middleware.url_path_for("route")
|
|
|
|
|
|
def test_add_route_to_app_after_mount(
|
|
test_client_factory: typing.Callable[..., TestClient],
|
|
) -> None:
|
|
"""Checks that Mount will pick up routes
|
|
added to the underlying app after it is mounted
|
|
"""
|
|
inner_app = Router()
|
|
app = Mount("/http", app=inner_app)
|
|
inner_app.add_route(
|
|
"/inner",
|
|
endpoint=homepage,
|
|
methods=["GET"],
|
|
)
|
|
client = test_client_factory(app)
|
|
response = client.get("/http/inner")
|
|
assert response.status_code == 200
|
|
|
|
|
|
def test_exception_on_mounted_apps(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
def exc(request: Request) -> None:
|
|
raise Exception("Exc")
|
|
|
|
sub_app = Starlette(routes=[Route("/", exc)])
|
|
app = Starlette(routes=[Mount("/sub", app=sub_app)])
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(Exception) as ctx:
|
|
client.get("/sub/")
|
|
assert str(ctx.value) == "Exc"
|
|
|
|
|
|
def test_mounted_middleware_does_not_catch_exception(
|
|
test_client_factory: typing.Callable[..., TestClient],
|
|
) -> None:
|
|
# https://github.com/encode/starlette/pull/1649#discussion_r960236107
|
|
def exc(request: Request) -> Response:
|
|
raise HTTPException(status_code=403, detail="auth")
|
|
|
|
class NamedMiddleware:
|
|
def __init__(self, app: ASGIApp, name: str) -> None:
|
|
self.app = app
|
|
self.name = name
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
async def modified_send(msg: Message) -> None:
|
|
if msg["type"] == "http.response.start":
|
|
msg["headers"].append((f"X-{self.name}".encode(), b"true"))
|
|
await send(msg)
|
|
|
|
await self.app(scope, receive, modified_send)
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Mount(
|
|
"/mount",
|
|
routes=[
|
|
Route("/err", exc),
|
|
Route("/home", homepage),
|
|
],
|
|
middleware=[Middleware(NamedMiddleware, name="Mounted")],
|
|
),
|
|
Route("/err", exc),
|
|
Route("/home", homepage),
|
|
],
|
|
middleware=[Middleware(NamedMiddleware, name="Outer")],
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
|
|
resp = client.get("/home")
|
|
assert resp.status_code == 200, resp.content
|
|
assert "X-Outer" in resp.headers
|
|
|
|
resp = client.get("/err")
|
|
assert resp.status_code == 403, resp.content
|
|
assert "X-Outer" in resp.headers
|
|
|
|
resp = client.get("/mount/home")
|
|
assert resp.status_code == 200, resp.content
|
|
assert "X-Mounted" in resp.headers
|
|
|
|
resp = client.get("/mount/err")
|
|
assert resp.status_code == 403, resp.content
|
|
assert "X-Mounted" in resp.headers
|
|
|
|
|
|
def test_websocket_route_middleware(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
async def websocket_endpoint(session: WebSocket) -> None:
|
|
await session.accept()
|
|
await session.send_text("Hello, world!")
|
|
await session.close()
|
|
|
|
class WebsocketMiddleware:
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
async def modified_send(msg: Message) -> None:
|
|
if msg["type"] == "websocket.accept":
|
|
msg["headers"].append((b"X-Test", b"Set by middleware"))
|
|
await send(msg)
|
|
|
|
await self.app(scope, receive, modified_send)
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
WebSocketRoute(
|
|
"/ws",
|
|
endpoint=websocket_endpoint,
|
|
middleware=[Middleware(WebsocketMiddleware)],
|
|
)
|
|
]
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
|
|
with client.websocket_connect("/ws") as websocket:
|
|
text = websocket.receive_text()
|
|
assert text == "Hello, world!"
|
|
assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]
|
|
|
|
|
|
def test_route_repr() -> None:
|
|
route = Route("/welcome", endpoint=homepage)
|
|
assert repr(route) == "Route(path='/welcome', name='homepage', methods=['GET', 'HEAD'])"
|
|
|
|
|
|
def test_route_repr_without_methods() -> None:
|
|
route = Route("/welcome", endpoint=Endpoint, methods=None)
|
|
assert repr(route) == "Route(path='/welcome', name='Endpoint', methods=[])"
|
|
|
|
|
|
def test_websocket_route_repr() -> None:
|
|
route = WebSocketRoute("/ws", endpoint=websocket_endpoint)
|
|
assert repr(route) == "WebSocketRoute(path='/ws', name='websocket_endpoint')"
|
|
|
|
|
|
def test_mount_repr() -> None:
|
|
route = Mount(
|
|
"/app",
|
|
routes=[
|
|
Route("/", endpoint=homepage),
|
|
],
|
|
)
|
|
# test for substring because repr(Router) returns unique object ID
|
|
assert repr(route).startswith("Mount(path='/app', name='', app=")
|
|
|
|
|
|
def test_mount_named_repr() -> None:
|
|
route = Mount(
|
|
"/app",
|
|
name="app",
|
|
routes=[
|
|
Route("/", endpoint=homepage),
|
|
],
|
|
)
|
|
# test for substring because repr(Router) returns unique object ID
|
|
assert repr(route).startswith("Mount(path='/app', name='app', app=")
|
|
|
|
|
|
def test_host_repr() -> None:
|
|
route = Host(
|
|
"example.com",
|
|
app=Router(
|
|
[
|
|
Route("/", endpoint=homepage),
|
|
]
|
|
),
|
|
)
|
|
# test for substring because repr(Router) returns unique object ID
|
|
assert repr(route).startswith("Host(host='example.com', name='', app=")
|
|
|
|
|
|
def test_host_named_repr() -> None:
|
|
route = Host(
|
|
"example.com",
|
|
name="app",
|
|
app=Router(
|
|
[
|
|
Route("/", endpoint=homepage),
|
|
]
|
|
),
|
|
)
|
|
# test for substring because repr(Router) returns unique object ID
|
|
assert repr(route).startswith("Host(host='example.com', name='app', app=")
|
|
|
|
|
|
def test_decorator_deprecations() -> None:
|
|
router = Router()
|
|
|
|
with pytest.deprecated_call():
|
|
router.route("/")(homepage)
|
|
|
|
with pytest.deprecated_call():
|
|
router.websocket_route("/ws")(websocket_endpoint)
|
|
|
|
with pytest.deprecated_call():
|
|
|
|
async def startup() -> None: ... # pragma: no cover
|
|
|
|
router.on_event("startup")(startup)
|
|
|
|
|
|
async def echo_paths(request: Request, name: str) -> JSONResponse:
|
|
return JSONResponse(
|
|
{
|
|
"name": name,
|
|
"path": request.scope["path"],
|
|
"root_path": request.scope["root_path"],
|
|
}
|
|
)
|
|
|
|
|
|
async def pure_asgi_echo_paths(scope: Scope, receive: Receive, send: Send, name: str) -> None:
|
|
data = {"name": name, "path": scope["path"], "root_path": scope["root_path"]}
|
|
content = json.dumps(data).encode("utf-8")
|
|
await send(
|
|
{
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [(b"content-type", b"application/json")],
|
|
}
|
|
)
|
|
await send({"type": "http.response.body", "body": content})
|
|
|
|
|
|
echo_paths_routes = [
|
|
Route(
|
|
"/path",
|
|
functools.partial(echo_paths, name="path"),
|
|
name="path",
|
|
methods=["GET"],
|
|
),
|
|
Route(
|
|
"/root-queue/path",
|
|
functools.partial(echo_paths, name="queue_path"),
|
|
name="queue_path",
|
|
methods=["POST"],
|
|
),
|
|
Mount("/asgipath", app=functools.partial(pure_asgi_echo_paths, name="asgipath")),
|
|
Mount(
|
|
"/sub",
|
|
name="mount",
|
|
routes=[
|
|
Route(
|
|
"/path",
|
|
functools.partial(echo_paths, name="subpath"),
|
|
name="subpath",
|
|
methods=["GET"],
|
|
),
|
|
],
|
|
),
|
|
]
|
|
|
|
|
|
def test_paths_with_root_path(test_client_factory: TestClientFactory) -> None:
|
|
app = Starlette(routes=echo_paths_routes)
|
|
client = test_client_factory(app, base_url="https://www.example.org/", root_path="/root")
|
|
response = client.get("/root/path")
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"name": "path",
|
|
"path": "/root/path",
|
|
"root_path": "/root",
|
|
}
|
|
response = client.get("/root/asgipath/")
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"name": "asgipath",
|
|
"path": "/root/asgipath/",
|
|
# Things that mount other ASGI apps, like WSGIMiddleware, would not be aware
|
|
# of the prefixed path, and would have their own notion of their own paths,
|
|
# so they need to be able to rely on the root_path to know the location they
|
|
# are mounted on
|
|
"root_path": "/root/asgipath",
|
|
}
|
|
|
|
response = client.get("/root/sub/path")
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"name": "subpath",
|
|
"path": "/root/sub/path",
|
|
"root_path": "/root/sub",
|
|
}
|
|
|
|
response = client.post("/root/root-queue/path")
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"name": "queue_path",
|
|
"path": "/root/root-queue/path",
|
|
"root_path": "/root",
|
|
}
|