mirror of https://github.com/encode/starlette.git
595 lines
18 KiB
Python
595 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator, AsyncIterator, Callable, Generator
|
|
|
|
import anyio.from_thread
|
|
import pytest
|
|
|
|
from starlette import status
|
|
from starlette.applications import Starlette
|
|
from starlette.endpoints import HTTPEndpoint
|
|
from starlette.exceptions import HTTPException, WebSocketException
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.base import RequestResponseEndpoint
|
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, PlainTextResponse
|
|
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
|
from starlette.staticfiles import StaticFiles
|
|
from starlette.testclient import TestClient, WebSocketDenialResponse
|
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
from starlette.websockets import WebSocket
|
|
from tests.types import TestClientFactory
|
|
|
|
|
|
async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
|
|
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
|
|
|
|
|
async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
|
|
return JSONResponse({"detail": "Custom message"}, status_code=405)
|
|
|
|
|
|
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
|
|
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
|
|
|
|
|
def func_homepage(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
async def async_homepage(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
class Homepage(HTTPEndpoint):
|
|
def get(self, request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
def all_users_page(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Hello, everyone!")
|
|
|
|
|
|
def user_page(request: Request) -> PlainTextResponse:
|
|
username = request.path_params["username"]
|
|
return PlainTextResponse(f"Hello, {username}!")
|
|
|
|
|
|
def custom_subdomain(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
|
|
|
|
|
|
def runtime_error(request: Request) -> None:
|
|
raise RuntimeError()
|
|
|
|
|
|
async def websocket_endpoint(session: WebSocket) -> None:
|
|
await session.accept()
|
|
await session.send_text("Hello, world!")
|
|
await session.close()
|
|
|
|
|
|
async def websocket_raise_websocket_exception(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
|
|
|
|
|
|
async def websocket_raise_http_exception(websocket: WebSocket) -> None:
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
|
|
class CustomWSException(Exception):
|
|
pass
|
|
|
|
|
|
async def websocket_raise_custom(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
raise CustomWSException()
|
|
|
|
|
|
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
|
|
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
|
|
|
|
|
|
users = Router(
|
|
routes=[
|
|
Route("/", endpoint=all_users_page),
|
|
Route("/{username}", endpoint=user_page),
|
|
]
|
|
)
|
|
|
|
subdomain = Router(
|
|
routes=[
|
|
Route("/", custom_subdomain),
|
|
]
|
|
)
|
|
|
|
exception_handlers = {
|
|
500: error_500,
|
|
405: method_not_allowed,
|
|
HTTPException: http_exception,
|
|
CustomWSException: custom_ws_exception_handler,
|
|
}
|
|
|
|
middleware = [Middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])]
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/func", endpoint=func_homepage),
|
|
Route("/async", endpoint=async_homepage),
|
|
Route("/class", endpoint=Homepage),
|
|
Route("/500", endpoint=runtime_error),
|
|
WebSocketRoute("/ws", endpoint=websocket_endpoint),
|
|
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
|
|
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
|
|
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
|
|
Mount("/users", app=users),
|
|
Host("{subdomain}.example.org", app=subdomain),
|
|
],
|
|
exception_handlers=exception_handlers, # type: ignore
|
|
middleware=middleware,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
|
|
with test_client_factory(app) as client:
|
|
yield client
|
|
|
|
|
|
def test_url_path_for() -> None:
|
|
assert app.url_path_for("func_homepage") == "/func"
|
|
|
|
|
|
def test_func_route(client: TestClient) -> None:
|
|
response = client.get("/func")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, world!"
|
|
|
|
response = client.head("/func")
|
|
assert response.status_code == 200
|
|
assert response.text == ""
|
|
|
|
|
|
def test_async_route(client: TestClient) -> None:
|
|
response = client.get("/async")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
def test_class_route(client: TestClient) -> None:
|
|
response = client.get("/class")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
def test_mounted_route(client: TestClient) -> None:
|
|
response = client.get("/users/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, everyone!"
|
|
|
|
|
|
def test_mounted_route_path_params(client: TestClient) -> None:
|
|
response = client.get("/users/tomchristie")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, tomchristie!"
|
|
|
|
|
|
def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(app, base_url="https://foo.example.org/")
|
|
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Subdomain: foo"
|
|
|
|
|
|
def test_websocket_route(client: TestClient) -> None:
|
|
with client.websocket_connect("/ws") as session:
|
|
text = session.receive_text()
|
|
assert text == "Hello, world!"
|
|
|
|
|
|
def test_400(client: TestClient) -> None:
|
|
response = client.get("/404")
|
|
assert response.status_code == 404
|
|
assert response.json() == {"detail": "Not Found"}
|
|
|
|
|
|
def test_405(client: TestClient) -> None:
|
|
response = client.post("/func")
|
|
assert response.status_code == 405
|
|
assert response.json() == {"detail": "Custom message"}
|
|
|
|
response = client.post("/class")
|
|
assert response.status_code == 405
|
|
assert response.json() == {"detail": "Custom message"}
|
|
|
|
|
|
def test_500(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
|
response = client.get("/500")
|
|
assert response.status_code == 500
|
|
assert response.json() == {"detail": "Server Error"}
|
|
|
|
|
|
def test_websocket_raise_websocket_exception(client: TestClient) -> None:
|
|
with client.websocket_connect("/ws-raise-websocket") as session:
|
|
response = session.receive()
|
|
assert response == {
|
|
"type": "websocket.close",
|
|
"code": status.WS_1003_UNSUPPORTED_DATA,
|
|
"reason": "",
|
|
}
|
|
|
|
|
|
def test_websocket_raise_http_exception(client: TestClient) -> None:
|
|
with pytest.raises(WebSocketDenialResponse) as exc:
|
|
with client.websocket_connect("/ws-raise-http"):
|
|
pass # pragma: no cover
|
|
assert exc.value.status_code == 401
|
|
assert exc.value.content == b'{"detail":"Unauthorized"}'
|
|
|
|
|
|
def test_websocket_raise_custom_exception(client: TestClient) -> None:
|
|
with client.websocket_connect("/ws-raise-custom") as session:
|
|
response = session.receive()
|
|
assert response == {
|
|
"type": "websocket.close",
|
|
"code": status.WS_1013_TRY_AGAIN_LATER,
|
|
"reason": "",
|
|
}
|
|
|
|
|
|
def test_middleware(test_client_factory: TestClientFactory) -> None:
|
|
client = test_client_factory(app, base_url="http://incorrecthost")
|
|
response = client.get("/func")
|
|
assert response.status_code == 400
|
|
assert response.text == "Invalid host header"
|
|
|
|
|
|
def test_routes() -> None:
|
|
assert app.routes == [
|
|
Route("/func", endpoint=func_homepage, methods=["GET"]),
|
|
Route("/async", endpoint=async_homepage, methods=["GET"]),
|
|
Route("/class", endpoint=Homepage),
|
|
Route("/500", endpoint=runtime_error, methods=["GET"]),
|
|
WebSocketRoute("/ws", endpoint=websocket_endpoint),
|
|
WebSocketRoute("/ws-raise-websocket", endpoint=websocket_raise_websocket_exception),
|
|
WebSocketRoute("/ws-raise-http", endpoint=websocket_raise_http_exception),
|
|
WebSocketRoute("/ws-raise-custom", endpoint=websocket_raise_custom),
|
|
Mount(
|
|
"/users",
|
|
app=Router(
|
|
routes=[
|
|
Route("/", endpoint=all_users_page),
|
|
Route("/{username}", endpoint=user_page),
|
|
]
|
|
),
|
|
),
|
|
Host(
|
|
"{subdomain}.example.org",
|
|
app=Router(routes=[Route("/", endpoint=custom_subdomain)]),
|
|
),
|
|
]
|
|
|
|
|
|
def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
|
|
path = os.path.join(tmpdir, "example.txt")
|
|
with open(path, "w") as file:
|
|
file.write("<file content>")
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Mount("/static", StaticFiles(directory=tmpdir)),
|
|
]
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
|
|
response = client.get("/static/example.txt")
|
|
assert response.status_code == 200
|
|
assert response.text == "<file content>"
|
|
|
|
response = client.post("/static/example.txt")
|
|
assert response.status_code == 405
|
|
assert response.text == "Method Not Allowed"
|
|
|
|
|
|
def test_app_debug(test_client_factory: TestClientFactory) -> None:
|
|
async def homepage(request: Request) -> None:
|
|
raise RuntimeError()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/", homepage),
|
|
],
|
|
)
|
|
app.debug = True
|
|
|
|
client = test_client_factory(app, raise_server_exceptions=False)
|
|
response = client.get("/")
|
|
assert response.status_code == 500
|
|
assert "RuntimeError" in response.text
|
|
assert app.debug
|
|
|
|
|
|
def test_app_add_route(test_client_factory: TestClientFactory) -> None:
|
|
async def homepage(request: Request) -> PlainTextResponse:
|
|
return PlainTextResponse("Hello, World!")
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/", endpoint=homepage),
|
|
]
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
assert response.text == "Hello, World!"
|
|
|
|
|
|
def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
|
|
async def websocket_endpoint(session: WebSocket) -> None:
|
|
await session.accept()
|
|
await session.send_text("Hello, world!")
|
|
await session.close()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
WebSocketRoute("/ws", endpoint=websocket_endpoint),
|
|
]
|
|
)
|
|
client = test_client_factory(app)
|
|
|
|
with client.websocket_connect("/ws") as session:
|
|
text = session.receive_text()
|
|
assert text == "Hello, world!"
|
|
|
|
|
|
def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
cleanup_complete = False
|
|
|
|
def run_startup() -> None:
|
|
nonlocal startup_complete
|
|
startup_complete = True
|
|
|
|
def run_cleanup() -> None:
|
|
nonlocal cleanup_complete
|
|
cleanup_complete = True
|
|
|
|
with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"):
|
|
app = Starlette(
|
|
on_startup=[run_startup],
|
|
on_shutdown=[run_cleanup],
|
|
)
|
|
|
|
assert not startup_complete
|
|
assert not cleanup_complete
|
|
with test_client_factory(app):
|
|
assert startup_complete
|
|
assert not cleanup_complete
|
|
assert startup_complete
|
|
assert cleanup_complete
|
|
|
|
|
|
def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
cleanup_complete = False
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
|
nonlocal startup_complete, cleanup_complete
|
|
startup_complete = True
|
|
yield
|
|
cleanup_complete = True
|
|
|
|
app = Starlette(lifespan=lifespan)
|
|
|
|
assert not startup_complete
|
|
assert not cleanup_complete
|
|
with test_client_factory(app):
|
|
assert startup_complete
|
|
assert not cleanup_complete
|
|
assert startup_complete
|
|
assert cleanup_complete
|
|
|
|
|
|
deprecated_lifespan = pytest.mark.filterwarnings(
|
|
r"ignore"
|
|
r":(async )?generator function lifespans are deprecated, use an "
|
|
r"@contextlib\.asynccontextmanager function instead"
|
|
r":DeprecationWarning"
|
|
r":starlette.routing"
|
|
)
|
|
|
|
|
|
@deprecated_lifespan
|
|
def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
cleanup_complete = False
|
|
|
|
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
|
nonlocal startup_complete, cleanup_complete
|
|
startup_complete = True
|
|
yield
|
|
cleanup_complete = True
|
|
|
|
app = Starlette(lifespan=lifespan) # type: ignore
|
|
|
|
assert not startup_complete
|
|
assert not cleanup_complete
|
|
with test_client_factory(app):
|
|
assert startup_complete
|
|
assert not cleanup_complete
|
|
assert startup_complete
|
|
assert cleanup_complete
|
|
|
|
|
|
@deprecated_lifespan
|
|
def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
|
startup_complete = False
|
|
cleanup_complete = False
|
|
|
|
def lifespan(app: ASGIApp) -> Generator[None, None, None]:
|
|
nonlocal startup_complete, cleanup_complete
|
|
startup_complete = True
|
|
yield
|
|
cleanup_complete = True
|
|
|
|
app = Starlette(lifespan=lifespan) # type: ignore
|
|
|
|
assert not startup_complete
|
|
assert not cleanup_complete
|
|
with test_client_factory(app):
|
|
assert startup_complete
|
|
assert not cleanup_complete
|
|
assert startup_complete
|
|
assert cleanup_complete
|
|
|
|
|
|
def test_decorator_deprecations() -> None:
|
|
app = Starlette()
|
|
|
|
with pytest.deprecated_call(
|
|
match=("The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0.")
|
|
) as record:
|
|
app.exception_handler(500)(http_exception)
|
|
assert len(record) == 1
|
|
|
|
with pytest.deprecated_call(
|
|
match=("The `middleware` decorator is deprecated, and will be removed in version 1.0.0.")
|
|
) as record:
|
|
|
|
async def middleware(request: Request, call_next: RequestResponseEndpoint) -> None: ... # pragma: no cover
|
|
|
|
app.middleware("http")(middleware)
|
|
assert len(record) == 1
|
|
|
|
with pytest.deprecated_call(
|
|
match=("The `route` decorator is deprecated, and will be removed in version 1.0.0.")
|
|
) as record:
|
|
app.route("/")(async_homepage)
|
|
assert len(record) == 1
|
|
|
|
with pytest.deprecated_call(
|
|
match=("The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0.")
|
|
) as record:
|
|
app.websocket_route("/ws")(websocket_endpoint)
|
|
assert len(record) == 1
|
|
|
|
with pytest.deprecated_call(
|
|
match=("The `on_event` decorator is deprecated, and will be removed in version 1.0.0.")
|
|
) as record:
|
|
|
|
async def startup() -> None: ... # pragma: no cover
|
|
|
|
app.on_event("startup")(startup)
|
|
assert len(record) == 1
|
|
|
|
|
|
def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
|
|
class NoOpMiddleware:
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
await self.app(scope, receive, send)
|
|
|
|
class SimpleInitializableMiddleware:
|
|
counter = 0
|
|
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
SimpleInitializableMiddleware.counter += 1
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
await self.app(scope, receive, send)
|
|
|
|
def get_app() -> ASGIApp:
|
|
app = Starlette()
|
|
app.add_middleware(SimpleInitializableMiddleware)
|
|
app.add_middleware(NoOpMiddleware)
|
|
return app
|
|
|
|
app = get_app()
|
|
|
|
with test_client_factory(app):
|
|
pass
|
|
|
|
assert SimpleInitializableMiddleware.counter == 1
|
|
|
|
test_client_factory(app).get("/foo")
|
|
|
|
assert SimpleInitializableMiddleware.counter == 1
|
|
|
|
app = get_app()
|
|
|
|
test_client_factory(app).get("/foo")
|
|
|
|
assert SimpleInitializableMiddleware.counter == 2
|
|
|
|
|
|
def test_middleware_args(test_client_factory: TestClientFactory) -> None:
|
|
calls: list[str] = []
|
|
|
|
class MiddlewareWithArgs:
|
|
def __init__(self, app: ASGIApp, arg: str) -> None:
|
|
self.app = app
|
|
self.arg = arg
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
calls.append(self.arg)
|
|
await self.app(scope, receive, send)
|
|
|
|
app = Starlette()
|
|
app.add_middleware(MiddlewareWithArgs, "foo")
|
|
app.add_middleware(MiddlewareWithArgs, "bar")
|
|
|
|
with test_client_factory(app):
|
|
pass
|
|
|
|
assert calls == ["bar", "foo"]
|
|
|
|
|
|
def test_middleware_factory(test_client_factory: TestClientFactory) -> None:
|
|
calls: list[str] = []
|
|
|
|
def _middleware_factory(app: ASGIApp, arg: str) -> ASGIApp:
|
|
async def _app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
calls.append(arg)
|
|
await app(scope, receive, send)
|
|
|
|
return _app
|
|
|
|
def get_middleware_factory() -> Callable[[ASGIApp, str], ASGIApp]:
|
|
return _middleware_factory
|
|
|
|
app = Starlette()
|
|
app.add_middleware(_middleware_factory, arg="foo")
|
|
app.add_middleware(get_middleware_factory(), "bar")
|
|
|
|
with test_client_factory(app):
|
|
pass
|
|
|
|
assert calls == ["bar", "foo"]
|
|
|
|
|
|
def test_lifespan_app_subclass() -> None:
|
|
# This test exists to make sure that subclasses of Starlette
|
|
# (like FastAPI) are compatible with the types hints for Lifespan
|
|
|
|
class App(Starlette):
|
|
pass
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: App) -> AsyncIterator[None]: # pragma: no cover
|
|
yield
|
|
|
|
App(lifespan=lifespan)
|