mirror of https://github.com/encode/starlette.git
Add type hints to `test_applications.py` (#2471)
* added type annotations to test_applications.py * requested changes * Apply suggestions from code review * Apply suggestions from code review * Update tests/test_applications.py --------- Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
8c222960ba
commit
11b7ae7365
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncIterator, Callable
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator, AsyncIterator, Callable, Generator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import httpx
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
@ -11,63 +11,68 @@ from starlette.applications import Starlette
|
||||||
from starlette.endpoints import HTTPEndpoint
|
from starlette.endpoints import HTTPEndpoint
|
||||||
from starlette.exceptions import HTTPException, WebSocketException
|
from starlette.exceptions import HTTPException, WebSocketException
|
||||||
from starlette.middleware import Middleware
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.base import RequestResponseEndpoint
|
||||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse, PlainTextResponse
|
from starlette.responses import JSONResponse, PlainTextResponse
|
||||||
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
||||||
from starlette.staticfiles import StaticFiles
|
from starlette.staticfiles import StaticFiles
|
||||||
|
from starlette.testclient import TestClient
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
from starlette.websockets import WebSocket
|
from starlette.websockets import WebSocket
|
||||||
|
|
||||||
|
TestClientFactory = Callable[..., TestClient]
|
||||||
|
|
||||||
async def error_500(request, exc):
|
|
||||||
|
async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
|
||||||
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
async def method_not_allowed(request, exc):
|
async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
|
||||||
return JSONResponse({"detail": "Custom message"}, status_code=405)
|
return JSONResponse({"detail": "Custom message"}, status_code=405)
|
||||||
|
|
||||||
|
|
||||||
async def http_exception(request, exc):
|
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
|
||||||
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
||||||
|
|
||||||
|
|
||||||
def func_homepage(request):
|
def func_homepage(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Hello, world!")
|
return PlainTextResponse("Hello, world!")
|
||||||
|
|
||||||
|
|
||||||
async def async_homepage(request):
|
async def async_homepage(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Hello, world!")
|
return PlainTextResponse("Hello, world!")
|
||||||
|
|
||||||
|
|
||||||
class Homepage(HTTPEndpoint):
|
class Homepage(HTTPEndpoint):
|
||||||
def get(self, request):
|
def get(self, request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Hello, world!")
|
return PlainTextResponse("Hello, world!")
|
||||||
|
|
||||||
|
|
||||||
def all_users_page(request):
|
def all_users_page(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Hello, everyone!")
|
return PlainTextResponse("Hello, everyone!")
|
||||||
|
|
||||||
|
|
||||||
def user_page(request):
|
def user_page(request: Request) -> PlainTextResponse:
|
||||||
username = request.path_params["username"]
|
username = request.path_params["username"]
|
||||||
return PlainTextResponse(f"Hello, {username}!")
|
return PlainTextResponse(f"Hello, {username}!")
|
||||||
|
|
||||||
|
|
||||||
def custom_subdomain(request):
|
def custom_subdomain(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
|
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
|
||||||
|
|
||||||
|
|
||||||
def runtime_error(request):
|
def runtime_error(request: Request) -> None:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
|
||||||
async def websocket_endpoint(session):
|
async def websocket_endpoint(session: WebSocket) -> None:
|
||||||
await session.accept()
|
await session.accept()
|
||||||
await session.send_text("Hello, world!")
|
await session.send_text("Hello, world!")
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
async def websocket_raise_websocket(websocket: WebSocket):
|
async def websocket_raise_websocket(websocket: WebSocket) -> None:
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
|
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||||
|
|
||||||
|
@ -76,12 +81,12 @@ class CustomWSException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def websocket_raise_custom(websocket: WebSocket):
|
async def websocket_raise_custom(websocket: WebSocket) -> None:
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
raise CustomWSException()
|
raise CustomWSException()
|
||||||
|
|
||||||
|
|
||||||
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
|
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
|
||||||
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
|
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
|
||||||
|
|
||||||
|
|
||||||
|
@ -121,22 +126,22 @@ app = Starlette(
|
||||||
Mount("/users", app=users),
|
Mount("/users", app=users),
|
||||||
Host("{subdomain}.example.org", app=subdomain),
|
Host("{subdomain}.example.org", app=subdomain),
|
||||||
],
|
],
|
||||||
exception_handlers=exception_handlers,
|
exception_handlers=exception_handlers, # type: ignore
|
||||||
middleware=middleware,
|
middleware=middleware,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(test_client_factory):
|
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
|
||||||
with test_client_factory(app) as client:
|
with test_client_factory(app) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
def test_url_path_for():
|
def test_url_path_for() -> None:
|
||||||
assert app.url_path_for("func_homepage") == "/func"
|
assert app.url_path_for("func_homepage") == "/func"
|
||||||
|
|
||||||
|
|
||||||
def test_func_route(client):
|
def test_func_route(client: TestClient) -> None:
|
||||||
response = client.get("/func")
|
response = client.get("/func")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "Hello, world!"
|
assert response.text == "Hello, world!"
|
||||||
|
@ -146,31 +151,31 @@ def test_func_route(client):
|
||||||
assert response.text == ""
|
assert response.text == ""
|
||||||
|
|
||||||
|
|
||||||
def test_async_route(client):
|
def test_async_route(client: TestClient) -> None:
|
||||||
response = client.get("/async")
|
response = client.get("/async")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "Hello, world!"
|
assert response.text == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
def test_class_route(client):
|
def test_class_route(client: TestClient) -> None:
|
||||||
response = client.get("/class")
|
response = client.get("/class")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "Hello, world!"
|
assert response.text == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
def test_mounted_route(client):
|
def test_mounted_route(client: TestClient) -> None:
|
||||||
response = client.get("/users/")
|
response = client.get("/users/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "Hello, everyone!"
|
assert response.text == "Hello, everyone!"
|
||||||
|
|
||||||
|
|
||||||
def test_mounted_route_path_params(client):
|
def test_mounted_route_path_params(client: TestClient) -> None:
|
||||||
response = client.get("/users/tomchristie")
|
response = client.get("/users/tomchristie")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "Hello, tomchristie!"
|
assert response.text == "Hello, tomchristie!"
|
||||||
|
|
||||||
|
|
||||||
def test_subdomain_route(test_client_factory):
|
def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
|
||||||
client = test_client_factory(app, base_url="https://foo.example.org/")
|
client = test_client_factory(app, base_url="https://foo.example.org/")
|
||||||
|
|
||||||
response = client.get("/")
|
response = client.get("/")
|
||||||
|
@ -178,19 +183,19 @@ def test_subdomain_route(test_client_factory):
|
||||||
assert response.text == "Subdomain: foo"
|
assert response.text == "Subdomain: foo"
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_route(client):
|
def test_websocket_route(client: TestClient) -> None:
|
||||||
with client.websocket_connect("/ws") as session:
|
with client.websocket_connect("/ws") as session:
|
||||||
text = session.receive_text()
|
text = session.receive_text()
|
||||||
assert text == "Hello, world!"
|
assert text == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
def test_400(client):
|
def test_400(client: TestClient) -> None:
|
||||||
response = client.get("/404")
|
response = client.get("/404")
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
assert response.json() == {"detail": "Not Found"}
|
assert response.json() == {"detail": "Not Found"}
|
||||||
|
|
||||||
|
|
||||||
def test_405(client):
|
def test_405(client: TestClient) -> None:
|
||||||
response = client.post("/func")
|
response = client.post("/func")
|
||||||
assert response.status_code == 405
|
assert response.status_code == 405
|
||||||
assert response.json() == {"detail": "Custom message"}
|
assert response.json() == {"detail": "Custom message"}
|
||||||
|
@ -200,14 +205,14 @@ def test_405(client):
|
||||||
assert response.json() == {"detail": "Custom message"}
|
assert response.json() == {"detail": "Custom message"}
|
||||||
|
|
||||||
|
|
||||||
def test_500(test_client_factory):
|
def test_500(test_client_factory: TestClientFactory) -> None:
|
||||||
client = test_client_factory(app, raise_server_exceptions=False)
|
client = test_client_factory(app, raise_server_exceptions=False)
|
||||||
response = client.get("/500")
|
response = client.get("/500")
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
assert response.json() == {"detail": "Server Error"}
|
assert response.json() == {"detail": "Server Error"}
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_raise_websocket_exception(client):
|
def test_websocket_raise_websocket_exception(client: TestClient) -> None:
|
||||||
with client.websocket_connect("/ws-raise-websocket") as session:
|
with client.websocket_connect("/ws-raise-websocket") as session:
|
||||||
response = session.receive()
|
response = session.receive()
|
||||||
assert response == {
|
assert response == {
|
||||||
|
@ -217,7 +222,7 @@ def test_websocket_raise_websocket_exception(client):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_raise_custom_exception(client):
|
def test_websocket_raise_custom_exception(client: TestClient) -> None:
|
||||||
with client.websocket_connect("/ws-raise-custom") as session:
|
with client.websocket_connect("/ws-raise-custom") as session:
|
||||||
response = session.receive()
|
response = session.receive()
|
||||||
assert response == {
|
assert response == {
|
||||||
|
@ -227,14 +232,14 @@ def test_websocket_raise_custom_exception(client):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_middleware(test_client_factory):
|
def test_middleware(test_client_factory: TestClientFactory) -> None:
|
||||||
client = test_client_factory(app, base_url="http://incorrecthost")
|
client = test_client_factory(app, base_url="http://incorrecthost")
|
||||||
response = client.get("/func")
|
response = client.get("/func")
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.text == "Invalid host header"
|
assert response.text == "Invalid host header"
|
||||||
|
|
||||||
|
|
||||||
def test_routes():
|
def test_routes() -> None:
|
||||||
assert app.routes == [
|
assert app.routes == [
|
||||||
Route("/func", endpoint=func_homepage, methods=["GET"]),
|
Route("/func", endpoint=func_homepage, methods=["GET"]),
|
||||||
Route("/async", endpoint=async_homepage, methods=["GET"]),
|
Route("/async", endpoint=async_homepage, methods=["GET"]),
|
||||||
|
@ -259,7 +264,7 @@ def test_routes():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_app_mount(tmpdir, test_client_factory):
|
def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
|
||||||
path = os.path.join(tmpdir, "example.txt")
|
path = os.path.join(tmpdir, "example.txt")
|
||||||
with open(path, "w") as file:
|
with open(path, "w") as file:
|
||||||
file.write("<file content>")
|
file.write("<file content>")
|
||||||
|
@ -281,8 +286,8 @@ def test_app_mount(tmpdir, test_client_factory):
|
||||||
assert response.text == "Method Not Allowed"
|
assert response.text == "Method Not Allowed"
|
||||||
|
|
||||||
|
|
||||||
def test_app_debug(test_client_factory):
|
def test_app_debug(test_client_factory: TestClientFactory) -> None:
|
||||||
async def homepage(request):
|
async def homepage(request: Request) -> None:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
|
@ -299,8 +304,8 @@ def test_app_debug(test_client_factory):
|
||||||
assert app.debug
|
assert app.debug
|
||||||
|
|
||||||
|
|
||||||
def test_app_add_route(test_client_factory):
|
def test_app_add_route(test_client_factory: TestClientFactory) -> None:
|
||||||
async def homepage(request):
|
async def homepage(request: Request) -> PlainTextResponse:
|
||||||
return PlainTextResponse("Hello, World!")
|
return PlainTextResponse("Hello, World!")
|
||||||
|
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
|
@ -315,8 +320,8 @@ def test_app_add_route(test_client_factory):
|
||||||
assert response.text == "Hello, World!"
|
assert response.text == "Hello, World!"
|
||||||
|
|
||||||
|
|
||||||
def test_app_add_websocket_route(test_client_factory):
|
def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
|
||||||
async def websocket_endpoint(session):
|
async def websocket_endpoint(session: WebSocket) -> None:
|
||||||
await session.accept()
|
await session.accept()
|
||||||
await session.send_text("Hello, world!")
|
await session.send_text("Hello, world!")
|
||||||
await session.close()
|
await session.close()
|
||||||
|
@ -333,15 +338,15 @@ def test_app_add_websocket_route(test_client_factory):
|
||||||
assert text == "Hello, world!"
|
assert text == "Hello, world!"
|
||||||
|
|
||||||
|
|
||||||
def test_app_add_event_handler(test_client_factory):
|
def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
|
||||||
startup_complete = False
|
startup_complete = False
|
||||||
cleanup_complete = False
|
cleanup_complete = False
|
||||||
|
|
||||||
def run_startup():
|
def run_startup() -> None:
|
||||||
nonlocal startup_complete
|
nonlocal startup_complete
|
||||||
startup_complete = True
|
startup_complete = True
|
||||||
|
|
||||||
def run_cleanup():
|
def run_cleanup() -> None:
|
||||||
nonlocal cleanup_complete
|
nonlocal cleanup_complete
|
||||||
cleanup_complete = True
|
cleanup_complete = True
|
||||||
|
|
||||||
|
@ -362,12 +367,12 @@ def test_app_add_event_handler(test_client_factory):
|
||||||
assert cleanup_complete
|
assert cleanup_complete
|
||||||
|
|
||||||
|
|
||||||
def test_app_async_cm_lifespan(test_client_factory):
|
def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||||
startup_complete = False
|
startup_complete = False
|
||||||
cleanup_complete = False
|
cleanup_complete = False
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app):
|
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
||||||
nonlocal startup_complete, cleanup_complete
|
nonlocal startup_complete, cleanup_complete
|
||||||
startup_complete = True
|
startup_complete = True
|
||||||
yield
|
yield
|
||||||
|
@ -394,17 +399,17 @@ deprecated_lifespan = pytest.mark.filterwarnings(
|
||||||
|
|
||||||
|
|
||||||
@deprecated_lifespan
|
@deprecated_lifespan
|
||||||
def test_app_async_gen_lifespan(test_client_factory):
|
def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||||
startup_complete = False
|
startup_complete = False
|
||||||
cleanup_complete = False
|
cleanup_complete = False
|
||||||
|
|
||||||
async def lifespan(app):
|
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
||||||
nonlocal startup_complete, cleanup_complete
|
nonlocal startup_complete, cleanup_complete
|
||||||
startup_complete = True
|
startup_complete = True
|
||||||
yield
|
yield
|
||||||
cleanup_complete = True
|
cleanup_complete = True
|
||||||
|
|
||||||
app = Starlette(lifespan=lifespan)
|
app = Starlette(lifespan=lifespan) # type: ignore
|
||||||
|
|
||||||
assert not startup_complete
|
assert not startup_complete
|
||||||
assert not cleanup_complete
|
assert not cleanup_complete
|
||||||
|
@ -416,17 +421,17 @@ def test_app_async_gen_lifespan(test_client_factory):
|
||||||
|
|
||||||
|
|
||||||
@deprecated_lifespan
|
@deprecated_lifespan
|
||||||
def test_app_sync_gen_lifespan(test_client_factory):
|
def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||||
startup_complete = False
|
startup_complete = False
|
||||||
cleanup_complete = False
|
cleanup_complete = False
|
||||||
|
|
||||||
def lifespan(app):
|
def lifespan(app: ASGIApp) -> Generator[None, None, None]:
|
||||||
nonlocal startup_complete, cleanup_complete
|
nonlocal startup_complete, cleanup_complete
|
||||||
startup_complete = True
|
startup_complete = True
|
||||||
yield
|
yield
|
||||||
cleanup_complete = True
|
cleanup_complete = True
|
||||||
|
|
||||||
app = Starlette(lifespan=lifespan)
|
app = Starlette(lifespan=lifespan) # type: ignore
|
||||||
|
|
||||||
assert not startup_complete
|
assert not startup_complete
|
||||||
assert not cleanup_complete
|
assert not cleanup_complete
|
||||||
|
@ -456,7 +461,9 @@ def test_decorator_deprecations() -> None:
|
||||||
)
|
)
|
||||||
) as record:
|
) as record:
|
||||||
|
|
||||||
async def middleware(request, call_next):
|
async def middleware(
|
||||||
|
request: Request, call_next: RequestResponseEndpoint
|
||||||
|
) -> None:
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|
||||||
app.middleware("http")(middleware)
|
app.middleware("http")(middleware)
|
||||||
|
@ -487,14 +494,14 @@ def test_decorator_deprecations() -> None:
|
||||||
)
|
)
|
||||||
) as record:
|
) as record:
|
||||||
|
|
||||||
async def startup():
|
async def startup() -> None:
|
||||||
... # pragma: no cover
|
... # pragma: no cover
|
||||||
|
|
||||||
app.on_event("startup")(startup)
|
app.on_event("startup")(startup)
|
||||||
assert len(record) == 1
|
assert len(record) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
|
def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
|
||||||
class NoOpMiddleware:
|
class NoOpMiddleware:
|
||||||
def __init__(self, app: ASGIApp):
|
def __init__(self, app: ASGIApp):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
@ -536,7 +543,7 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
|
||||||
assert SimpleInitializableMiddleware.counter == 2
|
assert SimpleInitializableMiddleware.counter == 2
|
||||||
|
|
||||||
|
|
||||||
def test_lifespan_app_subclass():
|
def test_lifespan_app_subclass() -> None:
|
||||||
# This test exists to make sure that subclasses of Starlette
|
# This test exists to make sure that subclasses of Starlette
|
||||||
# (like FastAPI) are compatible with the types hints for Lifespan
|
# (like FastAPI) are compatible with the types hints for Lifespan
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue