mirror of https://github.com/encode/starlette.git
Add type hints to `test_applications.py` (#2471)
* added type annotations to test_applications.py * requested changes * Apply suggestions from code review * Apply suggestions from code review * Update tests/test_applications.py --------- Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
8c222960ba
commit
11b7ae7365
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator, Callable
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, AsyncIterator, Callable, Generator
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from starlette import status
|
||||
|
@ -11,63 +11,68 @@ from starlette.applications import Starlette
|
|||
from starlette.endpoints import HTTPEndpoint
|
||||
from starlette.exceptions import HTTPException, WebSocketException
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, PlainTextResponse
|
||||
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
TestClientFactory = Callable[..., TestClient]
|
||||
|
||||
async def error_500(request, exc):
|
||||
|
||||
async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
||||
|
||||
|
||||
async def method_not_allowed(request, exc):
|
||||
async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse({"detail": "Custom message"}, status_code=405)
|
||||
|
||||
|
||||
async def http_exception(request, exc):
|
||||
async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
||||
|
||||
|
||||
def func_homepage(request):
|
||||
def func_homepage(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Hello, world!")
|
||||
|
||||
|
||||
async def async_homepage(request):
|
||||
async def async_homepage(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Hello, world!")
|
||||
|
||||
|
||||
class Homepage(HTTPEndpoint):
|
||||
def get(self, request):
|
||||
def get(self, request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Hello, world!")
|
||||
|
||||
|
||||
def all_users_page(request):
|
||||
def all_users_page(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Hello, everyone!")
|
||||
|
||||
|
||||
def user_page(request):
|
||||
def user_page(request: Request) -> PlainTextResponse:
|
||||
username = request.path_params["username"]
|
||||
return PlainTextResponse(f"Hello, {username}!")
|
||||
|
||||
|
||||
def custom_subdomain(request):
|
||||
def custom_subdomain(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
|
||||
|
||||
|
||||
def runtime_error(request):
|
||||
def runtime_error(request: Request) -> None:
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
async def websocket_endpoint(session):
|
||||
async def websocket_endpoint(session: WebSocket) -> None:
|
||||
await session.accept()
|
||||
await session.send_text("Hello, world!")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def websocket_raise_websocket(websocket: WebSocket):
|
||||
async def websocket_raise_websocket(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
|
||||
|
||||
|
@ -76,12 +81,12 @@ class CustomWSException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
async def websocket_raise_custom(websocket: WebSocket):
|
||||
async def websocket_raise_custom(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
raise CustomWSException()
|
||||
|
||||
|
||||
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException):
|
||||
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
|
||||
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
|
||||
|
||||
|
||||
|
@ -121,22 +126,22 @@ app = Starlette(
|
|||
Mount("/users", app=users),
|
||||
Host("{subdomain}.example.org", app=subdomain),
|
||||
],
|
||||
exception_handlers=exception_handlers,
|
||||
exception_handlers=exception_handlers, # type: ignore
|
||||
middleware=middleware,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_client_factory):
|
||||
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
|
||||
with test_client_factory(app) as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_url_path_for():
|
||||
def test_url_path_for() -> None:
|
||||
assert app.url_path_for("func_homepage") == "/func"
|
||||
|
||||
|
||||
def test_func_route(client):
|
||||
def test_func_route(client: TestClient) -> None:
|
||||
response = client.get("/func")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
@ -146,31 +151,31 @@ def test_func_route(client):
|
|||
assert response.text == ""
|
||||
|
||||
|
||||
def test_async_route(client):
|
||||
def test_async_route(client: TestClient) -> None:
|
||||
response = client.get("/async")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_class_route(client):
|
||||
def test_class_route(client: TestClient) -> None:
|
||||
response = client.get("/class")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world!"
|
||||
|
||||
|
||||
def test_mounted_route(client):
|
||||
def test_mounted_route(client: TestClient) -> None:
|
||||
response = client.get("/users/")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, everyone!"
|
||||
|
||||
|
||||
def test_mounted_route_path_params(client):
|
||||
def test_mounted_route_path_params(client: TestClient) -> None:
|
||||
response = client.get("/users/tomchristie")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, tomchristie!"
|
||||
|
||||
|
||||
def test_subdomain_route(test_client_factory):
|
||||
def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
|
||||
client = test_client_factory(app, base_url="https://foo.example.org/")
|
||||
|
||||
response = client.get("/")
|
||||
|
@ -178,19 +183,19 @@ def test_subdomain_route(test_client_factory):
|
|||
assert response.text == "Subdomain: foo"
|
||||
|
||||
|
||||
def test_websocket_route(client):
|
||||
def test_websocket_route(client: TestClient) -> None:
|
||||
with client.websocket_connect("/ws") as session:
|
||||
text = session.receive_text()
|
||||
assert text == "Hello, world!"
|
||||
|
||||
|
||||
def test_400(client):
|
||||
def test_400(client: TestClient) -> None:
|
||||
response = client.get("/404")
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": "Not Found"}
|
||||
|
||||
|
||||
def test_405(client):
|
||||
def test_405(client: TestClient) -> None:
|
||||
response = client.post("/func")
|
||||
assert response.status_code == 405
|
||||
assert response.json() == {"detail": "Custom message"}
|
||||
|
@ -200,14 +205,14 @@ def test_405(client):
|
|||
assert response.json() == {"detail": "Custom message"}
|
||||
|
||||
|
||||
def test_500(test_client_factory):
|
||||
def test_500(test_client_factory: TestClientFactory) -> None:
|
||||
client = test_client_factory(app, raise_server_exceptions=False)
|
||||
response = client.get("/500")
|
||||
assert response.status_code == 500
|
||||
assert response.json() == {"detail": "Server Error"}
|
||||
|
||||
|
||||
def test_websocket_raise_websocket_exception(client):
|
||||
def test_websocket_raise_websocket_exception(client: TestClient) -> None:
|
||||
with client.websocket_connect("/ws-raise-websocket") as session:
|
||||
response = session.receive()
|
||||
assert response == {
|
||||
|
@ -217,7 +222,7 @@ def test_websocket_raise_websocket_exception(client):
|
|||
}
|
||||
|
||||
|
||||
def test_websocket_raise_custom_exception(client):
|
||||
def test_websocket_raise_custom_exception(client: TestClient) -> None:
|
||||
with client.websocket_connect("/ws-raise-custom") as session:
|
||||
response = session.receive()
|
||||
assert response == {
|
||||
|
@ -227,14 +232,14 @@ def test_websocket_raise_custom_exception(client):
|
|||
}
|
||||
|
||||
|
||||
def test_middleware(test_client_factory):
|
||||
def test_middleware(test_client_factory: TestClientFactory) -> None:
|
||||
client = test_client_factory(app, base_url="http://incorrecthost")
|
||||
response = client.get("/func")
|
||||
assert response.status_code == 400
|
||||
assert response.text == "Invalid host header"
|
||||
|
||||
|
||||
def test_routes():
|
||||
def test_routes() -> None:
|
||||
assert app.routes == [
|
||||
Route("/func", endpoint=func_homepage, methods=["GET"]),
|
||||
Route("/async", endpoint=async_homepage, methods=["GET"]),
|
||||
|
@ -259,7 +264,7 @@ def test_routes():
|
|||
]
|
||||
|
||||
|
||||
def test_app_mount(tmpdir, test_client_factory):
|
||||
def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
|
||||
path = os.path.join(tmpdir, "example.txt")
|
||||
with open(path, "w") as file:
|
||||
file.write("<file content>")
|
||||
|
@ -281,8 +286,8 @@ def test_app_mount(tmpdir, test_client_factory):
|
|||
assert response.text == "Method Not Allowed"
|
||||
|
||||
|
||||
def test_app_debug(test_client_factory):
|
||||
async def homepage(request):
|
||||
def test_app_debug(test_client_factory: TestClientFactory) -> None:
|
||||
async def homepage(request: Request) -> None:
|
||||
raise RuntimeError()
|
||||
|
||||
app = Starlette(
|
||||
|
@ -299,8 +304,8 @@ def test_app_debug(test_client_factory):
|
|||
assert app.debug
|
||||
|
||||
|
||||
def test_app_add_route(test_client_factory):
|
||||
async def homepage(request):
|
||||
def test_app_add_route(test_client_factory: TestClientFactory) -> None:
|
||||
async def homepage(request: Request) -> PlainTextResponse:
|
||||
return PlainTextResponse("Hello, World!")
|
||||
|
||||
app = Starlette(
|
||||
|
@ -315,8 +320,8 @@ def test_app_add_route(test_client_factory):
|
|||
assert response.text == "Hello, World!"
|
||||
|
||||
|
||||
def test_app_add_websocket_route(test_client_factory):
|
||||
async def websocket_endpoint(session):
|
||||
def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
|
||||
async def websocket_endpoint(session: WebSocket) -> None:
|
||||
await session.accept()
|
||||
await session.send_text("Hello, world!")
|
||||
await session.close()
|
||||
|
@ -333,15 +338,15 @@ def test_app_add_websocket_route(test_client_factory):
|
|||
assert text == "Hello, world!"
|
||||
|
||||
|
||||
def test_app_add_event_handler(test_client_factory):
|
||||
def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
def run_startup():
|
||||
def run_startup() -> None:
|
||||
nonlocal startup_complete
|
||||
startup_complete = True
|
||||
|
||||
def run_cleanup():
|
||||
def run_cleanup() -> None:
|
||||
nonlocal cleanup_complete
|
||||
cleanup_complete = True
|
||||
|
||||
|
@ -362,12 +367,12 @@ def test_app_add_event_handler(test_client_factory):
|
|||
assert cleanup_complete
|
||||
|
||||
|
||||
def test_app_async_cm_lifespan(test_client_factory):
|
||||
def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app):
|
||||
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
||||
nonlocal startup_complete, cleanup_complete
|
||||
startup_complete = True
|
||||
yield
|
||||
|
@ -394,17 +399,17 @@ deprecated_lifespan = pytest.mark.filterwarnings(
|
|||
|
||||
|
||||
@deprecated_lifespan
|
||||
def test_app_async_gen_lifespan(test_client_factory):
|
||||
def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
async def lifespan(app):
|
||||
async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
|
||||
nonlocal startup_complete, cleanup_complete
|
||||
startup_complete = True
|
||||
yield
|
||||
cleanup_complete = True
|
||||
|
||||
app = Starlette(lifespan=lifespan)
|
||||
app = Starlette(lifespan=lifespan) # type: ignore
|
||||
|
||||
assert not startup_complete
|
||||
assert not cleanup_complete
|
||||
|
@ -416,17 +421,17 @@ def test_app_async_gen_lifespan(test_client_factory):
|
|||
|
||||
|
||||
@deprecated_lifespan
|
||||
def test_app_sync_gen_lifespan(test_client_factory):
|
||||
def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
|
||||
startup_complete = False
|
||||
cleanup_complete = False
|
||||
|
||||
def lifespan(app):
|
||||
def lifespan(app: ASGIApp) -> Generator[None, None, None]:
|
||||
nonlocal startup_complete, cleanup_complete
|
||||
startup_complete = True
|
||||
yield
|
||||
cleanup_complete = True
|
||||
|
||||
app = Starlette(lifespan=lifespan)
|
||||
app = Starlette(lifespan=lifespan) # type: ignore
|
||||
|
||||
assert not startup_complete
|
||||
assert not cleanup_complete
|
||||
|
@ -456,7 +461,9 @@ def test_decorator_deprecations() -> None:
|
|||
)
|
||||
) as record:
|
||||
|
||||
async def middleware(request, call_next):
|
||||
async def middleware(
|
||||
request: Request, call_next: RequestResponseEndpoint
|
||||
) -> None:
|
||||
... # pragma: no cover
|
||||
|
||||
app.middleware("http")(middleware)
|
||||
|
@ -487,14 +494,14 @@ def test_decorator_deprecations() -> None:
|
|||
)
|
||||
) as record:
|
||||
|
||||
async def startup():
|
||||
async def startup() -> None:
|
||||
... # pragma: no cover
|
||||
|
||||
app.on_event("startup")(startup)
|
||||
assert len(record) == 1
|
||||
|
||||
|
||||
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
|
||||
def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
|
||||
class NoOpMiddleware:
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
@ -536,7 +543,7 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
|
|||
assert SimpleInitializableMiddleware.counter == 2
|
||||
|
||||
|
||||
def test_lifespan_app_subclass():
|
||||
def test_lifespan_app_subclass() -> None:
|
||||
# This test exists to make sure that subclasses of Starlette
|
||||
# (like FastAPI) are compatible with the types hints for Lifespan
|
||||
|
||||
|
|
Loading…
Reference in New Issue