Add type hints to `test_applications.py` (#2471)

* added type annotations to test_applications.py

* requested changes

* Apply suggestions from code review

* Apply suggestions from code review

* Update tests/test_applications.py

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
Scirlat Danut 2024-02-06 22:30:47 +02:00 committed by GitHub
parent 8c222960ba
commit 11b7ae7365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 62 additions and 55 deletions

View File

@ -1,9 +1,9 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncIterator, Callable from pathlib import Path
from typing import AsyncGenerator, AsyncIterator, Callable, Generator
import anyio import anyio
import httpx
import pytest import pytest
from starlette import status from starlette import status
@ -11,63 +11,68 @@ from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException, WebSocketException from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
TestClientFactory = Callable[..., TestClient]
async def error_500(request, exc):
async def error_500(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500) return JSONResponse({"detail": "Server Error"}, status_code=500)
async def method_not_allowed(request, exc): async def method_not_allowed(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": "Custom message"}, status_code=405) return JSONResponse({"detail": "Custom message"}, status_code=405)
async def http_exception(request, exc): async def http_exception(request: Request, exc: HTTPException) -> JSONResponse:
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
def func_homepage(request): def func_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!") return PlainTextResponse("Hello, world!")
async def async_homepage(request): async def async_homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!") return PlainTextResponse("Hello, world!")
class Homepage(HTTPEndpoint): class Homepage(HTTPEndpoint):
def get(self, request): def get(self, request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, world!") return PlainTextResponse("Hello, world!")
def all_users_page(request): def all_users_page(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, everyone!") return PlainTextResponse("Hello, everyone!")
def user_page(request): def user_page(request: Request) -> PlainTextResponse:
username = request.path_params["username"] username = request.path_params["username"]
return PlainTextResponse(f"Hello, {username}!") return PlainTextResponse(f"Hello, {username}!")
def custom_subdomain(request): def custom_subdomain(request: Request) -> PlainTextResponse:
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"]) return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])
def runtime_error(request): def runtime_error(request: Request) -> None:
raise RuntimeError() raise RuntimeError()
async def websocket_endpoint(session): async def websocket_endpoint(session: WebSocket) -> None:
await session.accept() await session.accept()
await session.send_text("Hello, world!") await session.send_text("Hello, world!")
await session.close() await session.close()
async def websocket_raise_websocket(websocket: WebSocket): async def websocket_raise_websocket(websocket: WebSocket) -> None:
await websocket.accept() await websocket.accept()
raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA) raise WebSocketException(code=status.WS_1003_UNSUPPORTED_DATA)
@ -76,12 +81,12 @@ class CustomWSException(Exception):
pass pass
async def websocket_raise_custom(websocket: WebSocket): async def websocket_raise_custom(websocket: WebSocket) -> None:
await websocket.accept() await websocket.accept()
raise CustomWSException() raise CustomWSException()
def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException): def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> None:
anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER) anyio.from_thread.run(websocket.close, status.WS_1013_TRY_AGAIN_LATER)
@ -121,22 +126,22 @@ app = Starlette(
Mount("/users", app=users), Mount("/users", app=users),
Host("{subdomain}.example.org", app=subdomain), Host("{subdomain}.example.org", app=subdomain),
], ],
exception_handlers=exception_handlers, exception_handlers=exception_handlers, # type: ignore
middleware=middleware, middleware=middleware,
) )
@pytest.fixture @pytest.fixture
def client(test_client_factory): def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client: with test_client_factory(app) as client:
yield client yield client
def test_url_path_for(): def test_url_path_for() -> None:
assert app.url_path_for("func_homepage") == "/func" assert app.url_path_for("func_homepage") == "/func"
def test_func_route(client): def test_func_route(client: TestClient) -> None:
response = client.get("/func") response = client.get("/func")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hello, world!" assert response.text == "Hello, world!"
@ -146,31 +151,31 @@ def test_func_route(client):
assert response.text == "" assert response.text == ""
def test_async_route(client): def test_async_route(client: TestClient) -> None:
response = client.get("/async") response = client.get("/async")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hello, world!" assert response.text == "Hello, world!"
def test_class_route(client): def test_class_route(client: TestClient) -> None:
response = client.get("/class") response = client.get("/class")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hello, world!" assert response.text == "Hello, world!"
def test_mounted_route(client): def test_mounted_route(client: TestClient) -> None:
response = client.get("/users/") response = client.get("/users/")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hello, everyone!" assert response.text == "Hello, everyone!"
def test_mounted_route_path_params(client): def test_mounted_route_path_params(client: TestClient) -> None:
response = client.get("/users/tomchristie") response = client.get("/users/tomchristie")
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "Hello, tomchristie!" assert response.text == "Hello, tomchristie!"
def test_subdomain_route(test_client_factory): def test_subdomain_route(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="https://foo.example.org/") client = test_client_factory(app, base_url="https://foo.example.org/")
response = client.get("/") response = client.get("/")
@ -178,19 +183,19 @@ def test_subdomain_route(test_client_factory):
assert response.text == "Subdomain: foo" assert response.text == "Subdomain: foo"
def test_websocket_route(client): def test_websocket_route(client: TestClient) -> None:
with client.websocket_connect("/ws") as session: with client.websocket_connect("/ws") as session:
text = session.receive_text() text = session.receive_text()
assert text == "Hello, world!" assert text == "Hello, world!"
def test_400(client): def test_400(client: TestClient) -> None:
response = client.get("/404") response = client.get("/404")
assert response.status_code == 404 assert response.status_code == 404
assert response.json() == {"detail": "Not Found"} assert response.json() == {"detail": "Not Found"}
def test_405(client): def test_405(client: TestClient) -> None:
response = client.post("/func") response = client.post("/func")
assert response.status_code == 405 assert response.status_code == 405
assert response.json() == {"detail": "Custom message"} assert response.json() == {"detail": "Custom message"}
@ -200,14 +205,14 @@ def test_405(client):
assert response.json() == {"detail": "Custom message"} assert response.json() == {"detail": "Custom message"}
def test_500(test_client_factory): def test_500(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, raise_server_exceptions=False) client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/500") response = client.get("/500")
assert response.status_code == 500 assert response.status_code == 500
assert response.json() == {"detail": "Server Error"} assert response.json() == {"detail": "Server Error"}
def test_websocket_raise_websocket_exception(client): def test_websocket_raise_websocket_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-websocket") as session: with client.websocket_connect("/ws-raise-websocket") as session:
response = session.receive() response = session.receive()
assert response == { assert response == {
@ -217,7 +222,7 @@ def test_websocket_raise_websocket_exception(client):
} }
def test_websocket_raise_custom_exception(client): def test_websocket_raise_custom_exception(client: TestClient) -> None:
with client.websocket_connect("/ws-raise-custom") as session: with client.websocket_connect("/ws-raise-custom") as session:
response = session.receive() response = session.receive()
assert response == { assert response == {
@ -227,14 +232,14 @@ def test_websocket_raise_custom_exception(client):
} }
def test_middleware(test_client_factory): def test_middleware(test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app, base_url="http://incorrecthost") client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func") response = client.get("/func")
assert response.status_code == 400 assert response.status_code == 400
assert response.text == "Invalid host header" assert response.text == "Invalid host header"
def test_routes(): def test_routes() -> None:
assert app.routes == [ assert app.routes == [
Route("/func", endpoint=func_homepage, methods=["GET"]), Route("/func", endpoint=func_homepage, methods=["GET"]),
Route("/async", endpoint=async_homepage, methods=["GET"]), Route("/async", endpoint=async_homepage, methods=["GET"]),
@ -259,7 +264,7 @@ def test_routes():
] ]
def test_app_mount(tmpdir, test_client_factory): def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
path = os.path.join(tmpdir, "example.txt") path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file: with open(path, "w") as file:
file.write("<file content>") file.write("<file content>")
@ -281,8 +286,8 @@ def test_app_mount(tmpdir, test_client_factory):
assert response.text == "Method Not Allowed" assert response.text == "Method Not Allowed"
def test_app_debug(test_client_factory): def test_app_debug(test_client_factory: TestClientFactory) -> None:
async def homepage(request): async def homepage(request: Request) -> None:
raise RuntimeError() raise RuntimeError()
app = Starlette( app = Starlette(
@ -299,8 +304,8 @@ def test_app_debug(test_client_factory):
assert app.debug assert app.debug
def test_app_add_route(test_client_factory): def test_app_add_route(test_client_factory: TestClientFactory) -> None:
async def homepage(request): async def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Hello, World!") return PlainTextResponse("Hello, World!")
app = Starlette( app = Starlette(
@ -315,8 +320,8 @@ def test_app_add_route(test_client_factory):
assert response.text == "Hello, World!" assert response.text == "Hello, World!"
def test_app_add_websocket_route(test_client_factory): def test_app_add_websocket_route(test_client_factory: TestClientFactory) -> None:
async def websocket_endpoint(session): async def websocket_endpoint(session: WebSocket) -> None:
await session.accept() await session.accept()
await session.send_text("Hello, world!") await session.send_text("Hello, world!")
await session.close() await session.close()
@ -333,15 +338,15 @@ def test_app_add_websocket_route(test_client_factory):
assert text == "Hello, world!" assert text == "Hello, world!"
def test_app_add_event_handler(test_client_factory): def test_app_add_event_handler(test_client_factory: TestClientFactory) -> None:
startup_complete = False startup_complete = False
cleanup_complete = False cleanup_complete = False
def run_startup(): def run_startup() -> None:
nonlocal startup_complete nonlocal startup_complete
startup_complete = True startup_complete = True
def run_cleanup(): def run_cleanup() -> None:
nonlocal cleanup_complete nonlocal cleanup_complete
cleanup_complete = True cleanup_complete = True
@ -362,12 +367,12 @@ def test_app_add_event_handler(test_client_factory):
assert cleanup_complete assert cleanup_complete
def test_app_async_cm_lifespan(test_client_factory): def test_app_async_cm_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False startup_complete = False
cleanup_complete = False cleanup_complete = False
@asynccontextmanager @asynccontextmanager
async def lifespan(app): async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete nonlocal startup_complete, cleanup_complete
startup_complete = True startup_complete = True
yield yield
@ -394,17 +399,17 @@ deprecated_lifespan = pytest.mark.filterwarnings(
@deprecated_lifespan @deprecated_lifespan
def test_app_async_gen_lifespan(test_client_factory): def test_app_async_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False startup_complete = False
cleanup_complete = False cleanup_complete = False
async def lifespan(app): async def lifespan(app: ASGIApp) -> AsyncGenerator[None, None]:
nonlocal startup_complete, cleanup_complete nonlocal startup_complete, cleanup_complete
startup_complete = True startup_complete = True
yield yield
cleanup_complete = True cleanup_complete = True
app = Starlette(lifespan=lifespan) app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete assert not startup_complete
assert not cleanup_complete assert not cleanup_complete
@ -416,17 +421,17 @@ def test_app_async_gen_lifespan(test_client_factory):
@deprecated_lifespan @deprecated_lifespan
def test_app_sync_gen_lifespan(test_client_factory): def test_app_sync_gen_lifespan(test_client_factory: TestClientFactory) -> None:
startup_complete = False startup_complete = False
cleanup_complete = False cleanup_complete = False
def lifespan(app): def lifespan(app: ASGIApp) -> Generator[None, None, None]:
nonlocal startup_complete, cleanup_complete nonlocal startup_complete, cleanup_complete
startup_complete = True startup_complete = True
yield yield
cleanup_complete = True cleanup_complete = True
app = Starlette(lifespan=lifespan) app = Starlette(lifespan=lifespan) # type: ignore
assert not startup_complete assert not startup_complete
assert not cleanup_complete assert not cleanup_complete
@ -456,7 +461,9 @@ def test_decorator_deprecations() -> None:
) )
) as record: ) as record:
async def middleware(request, call_next): async def middleware(
request: Request, call_next: RequestResponseEndpoint
) -> None:
... # pragma: no cover ... # pragma: no cover
app.middleware("http")(middleware) app.middleware("http")(middleware)
@ -487,14 +494,14 @@ def test_decorator_deprecations() -> None:
) )
) as record: ) as record:
async def startup(): async def startup() -> None:
... # pragma: no cover ... # pragma: no cover
app.on_event("startup")(startup) app.on_event("startup")(startup)
assert len(record) == 1 assert len(record) == 1
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]): def test_middleware_stack_init(test_client_factory: TestClientFactory) -> None:
class NoOpMiddleware: class NoOpMiddleware:
def __init__(self, app: ASGIApp): def __init__(self, app: ASGIApp):
self.app = app self.app = app
@ -536,7 +543,7 @@ def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Cl
assert SimpleInitializableMiddleware.counter == 2 assert SimpleInitializableMiddleware.counter == 2
def test_lifespan_app_subclass(): def test_lifespan_app_subclass() -> None:
# This test exists to make sure that subclasses of Starlette # This test exists to make sure that subclasses of Starlette
# (like FastAPI) are compatible with the types hints for Lifespan # (like FastAPI) are compatible with the types hints for Lifespan