From 11b7ae7365b6972f8e90c4bb350e49c90cfbdca2 Mon Sep 17 00:00:00 2001 From: Scirlat Danut Date: Tue, 6 Feb 2024 22:30:47 +0200 Subject: [PATCH] Add type hints to `test_applications.py` (#2471) * added type annotations to test_applications.py * requested changes * Apply suggestions from code review * Apply suggestions from code review * Update tests/test_applications.py --------- Co-authored-by: Scirlat Danut Co-authored-by: Marcelo Trylesinski --- tests/test_applications.py | 117 ++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 55 deletions(-) diff --git a/tests/test_applications.py b/tests/test_applications.py index 6d0118b5..5b6c9d54 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,9 +1,9 @@ import os from contextlib import asynccontextmanager -from typing import AsyncIterator, Callable +from pathlib import Path +from typing import AsyncGenerator, AsyncIterator, Callable, Generator import anyio -import httpx import pytest from starlette import status @@ -11,63 +11,68 @@ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware import Middleware +from starlette.middleware.base import RequestResponseEndpoint from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket +TestClientFactory = Callable[..., TestClient] -async def error_500(request, exc): + +async def error_500(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": "Server Error"}, status_code=500) -async def method_not_allowed(request, exc): +async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": "Custom message"}, status_code=405) -async def http_exception(request, exc): +async def http_exception(request: Request, exc: HTTPException) -> JSONResponse: return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) -def func_homepage(request): +def func_homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") -async def async_homepage(request): +async def async_homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") class Homepage(HTTPEndpoint): - def get(self, request): + def get(self, request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, world!") -def all_users_page(request): +def all_users_page(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, everyone!") -def user_page(request): +def user_page(request: Request) -> PlainTextResponse: username = request.path_params["username"] return PlainTextResponse(f"Hello, {username}!") -def custom_subdomain(request): +def custom_subdomain(request: Request) -> PlainTextResponse: return PlainTextResponse("Subdomain: " + request.path_params["subdomain"]) -def runtime_error(request): +def runtime_error(request: Request) -> None: raise RuntimeError() -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() -async def websocket_raise_websocket(websocket: WebSocket): +async def websocket_raise_websocket(websocket: WebSocket) -> None: await websocket.accept() raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) @@ -76,12 +81,12 @@ class CustomWSException(Exception): pass -async def websocket_raise_custom(websocket: WebSocket): +async def websocket_raise_custom(websocket: WebSocket) -> None: await websocket.accept() raise CustomWSException() -def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException): +def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None: anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) @@ -121,22 +126,22 @@ app = Starlette( Mount("/users", app=users), Host("{subdomain}.example.org", app=subdomain), ], - exception_handlers=exception_handlers, + exception_handlers=exception_handlers, # type: ignore middleware=middleware, ) @pytest.fixture -def client(test_client_factory): +def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client -def test_url_path_for(): +def test_url_path_for() -> None: assert app.url_path_for("func_homepage") == "/func" -def test_func_route(client): +def test_func_route(client: TestClient) -> None: response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" @@ -146,31 +151,31 @@ def test_func_route(client): assert response.text == "" -def test_async_route(client): +def test_async_route(client: TestClient) -> None: response = client.get("/async") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_class_route(client): +def test_class_route(client: TestClient) -> None: response = client.get("/class") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_mounted_route(client): +def test_mounted_route(client: TestClient) -> None: response = client.get("/users/") assert response.status_code == 200 assert response.text == "Hello, everyone!" -def test_mounted_route_path_params(client): +def test_mounted_route_path_params(client: TestClient) -> None: response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_subdomain_route(test_client_factory): +def test_subdomain_route(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, base_url="https://foo.example.org/") response = client.get("/") @@ -178,19 +183,19 @@ def test_subdomain_route(test_client_factory): assert response.text == "Subdomain: foo" -def test_websocket_route(client): +def test_websocket_route(client: TestClient) -> None: with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_400(client): +def test_400(client: TestClient) -> None: response = client.get("/404") assert response.status_code == 404 assert response.json() == {"detail": "Not Found"} -def test_405(client): +def test_405(client: TestClient) -> None: response = client.post("/func") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} @@ -200,14 +205,14 @@ def test_405(client): assert response.json() == {"detail": "Custom message"} -def test_500(test_client_factory): +def test_500(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_websocket_raise_websocket_exception(client): +def test_websocket_raise_websocket_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-websocket") as session: response = session.receive() assert response == { @@ -217,7 +222,7 @@ def test_websocket_raise_websocket_exception(client): } -def test_websocket_raise_custom_exception(client): +def test_websocket_raise_custom_exception(client: TestClient) -> None: with client.websocket_connect("/ws-raise-custom") as session: response = session.receive() assert response == { @@ -227,14 +232,14 @@ def test_websocket_raise_custom_exception(client): } -def test_middleware(test_client_factory): +def test_middleware(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" -def test_routes(): +def test_routes() -> None: assert app.routes == [ Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/async", endpoint=async_homepage, methods=["GET"]), @@ -259,7 +264,7 @@ def test_routes(): ] -def test_app_mount(tmpdir, test_client_factory): +def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -281,8 +286,8 @@ def test_app_mount(tmpdir, test_client_factory): assert response.text == "Method Not Allowed" -def test_app_debug(test_client_factory): - async def homepage(request): +def test_app_debug(test_client_factory: TestClientFactory) -> None: + async def homepage(request: Request) -> None: raise RuntimeError() app = Starlette( @@ -299,8 +304,8 @@ def test_app_debug(test_client_factory): assert app.debug -def test_app_add_route(test_client_factory): - async def homepage(request): +def test_app_add_route(test_client_factory: TestClientFactory) -> None: + async def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, World!") app = Starlette( @@ -315,8 +320,8 @@ def test_app_add_route(test_client_factory): assert response.text == "Hello, World!" -def test_app_add_websocket_route(test_client_factory): - async def websocket_endpoint(session): +def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None: + async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() @@ -333,15 +338,15 @@ def test_app_add_websocket_route(test_client_factory): assert text == "Hello, world!" -def test_app_add_event_handler(test_client_factory): +def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False - def run_startup(): + def run_startup() -> None: nonlocal startup_complete startup_complete = True - def run_cleanup(): + def run_cleanup() -> None: nonlocal cleanup_complete cleanup_complete = True @@ -362,12 +367,12 @@ def test_app_add_event_handler(test_client_factory): assert cleanup_complete -def test_app_async_cm_lifespan(test_client_factory): +def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False @asynccontextmanager - async def lifespan(app): + async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield @@ -394,17 +399,17 @@ deprecated_lifespan = pytest.mark.filterwarnings( @deprecated_lifespan -def test_app_async_gen_lifespan(test_client_factory): +def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False - async def lifespan(app): + async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield cleanup_complete = True - app = Starlette(lifespan=lifespan) + app = Starlette(lifespan=lifespan) # type: ignore assert not startup_complete assert not cleanup_complete @@ -416,17 +421,17 @@ def test_app_async_gen_lifespan(test_client_factory): @deprecated_lifespan -def test_app_sync_gen_lifespan(test_client_factory): +def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None: startup_complete = False cleanup_complete = False - def lifespan(app): + def lifespan(app: ASGIApp) -> Generator[None, None, None]: nonlocal startup_complete, cleanup_complete startup_complete = True yield cleanup_complete = True - app = Starlette(lifespan=lifespan) + app = Starlette(lifespan=lifespan) # type: ignore assert not startup_complete assert not cleanup_complete @@ -456,7 +461,9 @@ def test_decorator_deprecations() -> None: ) ) as record: - async def middleware(request, call_next): + async def middleware( + request: Request, call_next: RequestResponseEndpoint + ) -> None: ... # pragma: no cover app.middleware("http")(middleware) @@ -487,14 +494,14 @@ def test_decorator_deprecations() -> None: ) ) as record: - async def startup(): + async def startup() -> None: ... # pragma: no cover app.on_event("startup")(startup) assert len(record) == 1 -def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]): +def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None: class NoOpMiddleware: def __init__(self, app: ASGIApp): self.app = app @@ -536,7 +543,7 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl assert SimpleInitializableMiddleware.counter == 2 -def test_lifespan_app_subclass(): +def test_lifespan_app_subclass() -> None: # This test exists to make sure that subclasses of Starlette # (like FastAPI) are compatible with the types hints for Lifespan