diff --git a/scripts/test b/scripts/test index e2e2a006..09f9dea5 100755 --- a/scripts/test +++ b/scripts/test @@ -11,7 +11,7 @@ export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"` set -x PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=starlette --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@} -${PREFIX}mypy starlette --ignore-missing-imports +${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs if [ "${PYTHON_VERSION}" = '3.7' ]; then echo "Skipping 'black' on 3.7. See issue https://github.com/ambv/black/issues/494" else diff --git a/starlette/applications.py b/starlette/applications.py index 7351bac4..7216d8d0 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,6 +1,7 @@ import asyncio import inspect import typing +from concurrent.futures import ThreadPoolExecutor from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler @@ -54,6 +55,7 @@ class Starlette: self.lifespan_handler = LifespanHandler() self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.executor = ThreadPoolExecutor() @property def debug(self) -> bool: @@ -118,7 +120,8 @@ class Starlette: return decorator def __call__(self, scope: Scope) -> ASGIInstance: + scope["app"] = self + scope["executor"] = self.executor if scope["type"] == "lifespan": return self.lifespan_handler(scope) - scope["app"] = self return self.exception_middleware(scope) diff --git a/starlette/testclient.py b/starlette/testclient.py index a1f75021..d7c182ba 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -4,38 +4,54 @@ import io import json import threading import typing +import requests import queue -from starlette.websockets import WebSocketDisconnect from urllib.parse import unquote, urlparse, urljoin -import requests +from starlette.websockets import WebSocketDisconnect +from starlette.types import Message, Scope, ASGIApp + + +# Annotations for `Session.request()` +Cookies = typing.Union[ + typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar +] +Params = typing.Union[bytes, typing.MutableMapping[str, str]] +DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO] +TimeOut = typing.Union[float, typing.Tuple[float, float]] +FileType = typing.MutableMapping[str, typing.IO] +AuthType = typing.Union[ + typing.Tuple[str, str], + requests.auth.AuthBase, + typing.Callable[[requests.Request], requests.Request], +] class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key, default): + def get_all(self, key: str, default: str) -> str: return self.getheaders(key) -class _MockOriginalResponse(object): +class _MockOriginalResponse: """ We have to jump through some hoops to present the response as if it was made using urllib3. """ - def __init__(self, headers): + def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None: self.msg = _HeaderDict(headers) self.closed = False - def isclosed(self): + def isclosed(self) -> bool: return self.closed class _Upgrade(Exception): - def __init__(self, session): + def __init__(self, session: "WebSocketTestSession") -> None: self.session = session -def _get_reason_phrase(status_code): +def _get_reason_phrase(status_code: int) -> str: try: return http.HTTPStatus(status_code).phrase except ValueError: @@ -43,37 +59,42 @@ def _get_reason_phrase(status_code): class _ASGIAdapter(requests.adapters.HTTPAdapter): - def __init__(self, app: typing.Callable, raise_server_exceptions=True) -> None: + def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions - def send(self, request, *args, **kwargs): - scheme, netloc, path, params, query, fragement = urlparse(request.url) + def send( # type: ignore + self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any + ) -> requests.Response: + scheme, netloc, path, params, query, fragement = urlparse( # type: ignore + request.url + ) + if ":" in netloc: - host, port = netloc.split(":", 1) - port = int(port) + host, port_string = netloc.split(":", 1) + port = int(port_string) else: host = netloc port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] # Include the 'host' header. if "host" in request.headers: - headers = [] + headers = [] # type: typing.List[typing.Tuple[bytes, bytes]] elif port == 80: - headers = [[b"host", host.encode()]] + headers = [(b"host", host.encode())] else: - headers = [[b"host", ("%s:%d" % (host, port)).encode()]] + headers = [(b"host", ("%s:%d" % (host, port)).encode())] # Include other request headers. headers += [ - [key.lower().encode(), value.encode()] + (key.lower().encode(), value.encode()) for key, value in request.headers.items() ] if scheme in {"ws", "wss"}: subprotocol = request.headers.get("sec-websocket-protocol", None) if subprotocol is None: - subprotocols = [] + subprotocols = [] # type: typing.Sequence[str] else: subprotocols = [value.strip() for value in subprotocol.split(",")] scope = { @@ -103,7 +124,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): "server": [host, port], } - async def receive(): + async def receive() -> Message: body = request.body if isinstance(body, str): body_bytes = body.encode("utf-8") # type: bytes @@ -113,7 +134,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): body_bytes = body return {"type": "http.request", "body": body_bytes} - async def send(message): + async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, response_complete if message["type"] == "http.response.start": @@ -147,7 +168,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): response_started = False response_complete = False - raw_kwargs = {"body": io.BytesIO()} + raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] loop = asyncio.get_event_loop() @@ -176,12 +197,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): class WebSocketTestSession: - def __init__(self, app, scope): + def __init__(self, app: ASGIApp, scope: Scope) -> None: self.accepted_subprotocol = None self._loop = asyncio.new_event_loop() self._instance = app(scope) - self._receive_queue = queue.Queue() - self._send_queue = queue.Queue() + self._receive_queue = queue.Queue() # type: queue.Queue + self._send_queue = queue.Queue() # type: queue.Queue self._thread = threading.Thread(target=self._run) self.send({"type": "websocket.connect"}) self._thread.start() @@ -189,10 +210,10 @@ class WebSocketTestSession: self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) - def __enter__(self): + def __enter__(self) -> "WebSocketTestSession": return self - def __exit__(self, *args): + def __exit__(self, *args: typing.Any) -> None: self.close(1000) self._thread.join() while not self._send_queue.empty(): @@ -200,7 +221,7 @@ class WebSocketTestSession: if isinstance(message, BaseException): raise message - def _run(self): + def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ @@ -211,49 +232,49 @@ class WebSocketTestSession: except BaseException as exc: self._send_queue.put(exc) - async def _asgi_receive(self): + async def _asgi_receive(self) -> Message: return self._receive_queue.get() - async def _asgi_send(self, message): + async def _asgi_send(self, message: Message) -> None: self._send_queue.put(message) - def _raise_on_close(self, message): + def _raise_on_close(self, message: Message) -> None: if message["type"] == "websocket.close": raise WebSocketDisconnect(message.get("code", 1000)) - def send(self, message): + def send(self, message: Message) -> None: self._receive_queue.put(message) - def send_text(self, data): + def send_text(self, data: str) -> None: self.send({"type": "websocket.receive", "text": data}) - def send_bytes(self, data): + def send_bytes(self, data: bytes) -> None: self.send({"type": "websocket.receive", "bytes": data}) - def send_json(self, data): + def send_json(self, data: typing.Any) -> None: encoded = json.dumps(data).encode("utf-8") self.send({"type": "websocket.receive", "bytes": encoded}) - def close(self, code=1000): + def close(self, code: int = 1000) -> None: self.send({"type": "websocket.disconnect", "code": code}) - def receive(self): + def receive(self) -> Message: message = self._send_queue.get() if isinstance(message, BaseException): raise message return message - def receive_text(self): + def receive_text(self) -> str: message = self.receive() self._raise_on_close(message) return message["text"] - def receive_bytes(self): + def receive_bytes(self) -> bytes: message = self.receive() self._raise_on_close(message) return message["bytes"] - def receive_json(self): + def receive_json(self) -> typing.Any: message = self.receive() self._raise_on_close(message) encoded = message["bytes"] @@ -262,7 +283,7 @@ class WebSocketTestSession: class _TestClient(requests.Session): def __init__( - self, app: typing.Callable, base_url: str, raise_server_exceptions=True + self, app: ASGIApp, base_url: str, raise_server_exceptions: bool = True ) -> None: super(_TestClient, self).__init__() adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions) @@ -277,20 +298,20 @@ class _TestClient(requests.Session): self, method: str, url: str, - params=None, - data=None, - headers=None, - cookies=None, - files=None, - auth=None, - timeout=None, - allow_redirects=True, - proxies=None, - hooks=None, - stream=None, - verify=None, - cert=None, - json=None, + params: Params = None, + data: DataType = None, + headers: typing.MutableMapping[str, str] = None, + cookies: Cookies = None, + files: FileType = None, + auth: AuthType = None, + timeout: TimeOut = None, + allow_redirects: bool = None, + proxies: typing.MutableMapping[str, str] = None, + hooks: typing.Any = None, + stream: bool = None, + verify: typing.Union[bool, str] = None, + cert: typing.Union[str, typing.Tuple[str, str]] = None, + json: typing.Any = None, ) -> requests.Response: url = urljoin(self.base_url, url) return super().request( @@ -313,8 +334,8 @@ class _TestClient(requests.Session): ) def websocket_connect( - self, url: str, subprotocols=None, **kwargs - ) -> WebSocketTestSession: + self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any + ) -> typing.Any: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") @@ -334,9 +355,9 @@ class _TestClient(requests.Session): def TestClient( - app: typing.Callable, + app: ASGIApp, base_url: str = "http://testserver", - raise_server_exceptions=True, + raise_server_exceptions: bool = True, ) -> _TestClient: """ We have to work around py.test discovery attempting to pick up diff --git a/tests/test_applications.py b/tests/test_applications.py index 1632a436..d91dd596 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -58,7 +58,7 @@ def user_page(request, username): @app.route("/500") -def func_homepage(request): +def runtime_error(request): raise RuntimeError()