This commit is contained in:
Tom Christie 2018-10-16 16:10:34 +01:00
commit 66cce0f1c3
4 changed files with 85 additions and 61 deletions

View File

@ -11,7 +11,7 @@ export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"`
set -x
PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=starlette --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@}
${PREFIX}mypy starlette --ignore-missing-imports
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
if [ "${PYTHON_VERSION}" = '3.7' ]; then
echo "Skipping 'black' on 3.7. See issue https://github.com/ambv/black/issues/494"
else

View File

@ -1,6 +1,7 @@
import asyncio
import inspect
import typing
from concurrent.futures import ThreadPoolExecutor
from starlette.exceptions import ExceptionMiddleware
from starlette.lifespan import LifespanHandler
@ -54,6 +55,7 @@ class Starlette:
self.lifespan_handler = LifespanHandler()
self.app = self.router
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.executor = ThreadPoolExecutor()
@property
def debug(self) -> bool:
@ -118,7 +120,8 @@ class Starlette:
return decorator
def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
scope["executor"] = self.executor
if scope["type"] == "lifespan":
return self.lifespan_handler(scope)
scope["app"] = self
return self.exception_middleware(scope)

View File

@ -4,38 +4,54 @@ import io
import json
import threading
import typing
import requests
import queue
from starlette.websockets import WebSocketDisconnect
from urllib.parse import unquote, urlparse, urljoin
import requests
from starlette.websockets import WebSocketDisconnect
from starlette.types import Message, Scope, ASGIApp
# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
]
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
TimeOut = typing.Union[float, typing.Tuple[float, float]]
FileType = typing.MutableMapping[str, typing.IO]
AuthType = typing.Union[
typing.Tuple[str, str],
requests.auth.AuthBase,
typing.Callable[[requests.Request], requests.Request],
]
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
def get_all(self, key: str, default: str) -> str:
return self.getheaders(key)
class _MockOriginalResponse(object):
class _MockOriginalResponse:
"""
We have to jump through some hoops to present the response as if
it was made using urllib3.
"""
def __init__(self, headers):
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
self.msg = _HeaderDict(headers)
self.closed = False
def isclosed(self):
def isclosed(self) -> bool:
return self.closed
class _Upgrade(Exception):
def __init__(self, session):
def __init__(self, session: "WebSocketTestSession") -> None:
self.session = session
def _get_reason_phrase(status_code):
def _get_reason_phrase(status_code: int) -> str:
try:
return http.HTTPStatus(status_code).phrase
except ValueError:
@ -43,37 +59,42 @@ def _get_reason_phrase(status_code):
class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(self, app: typing.Callable, raise_server_exceptions=True) -> None:
def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
def send(self, request, *args, **kwargs):
scheme, netloc, path, params, query, fragement = urlparse(request.url)
def send( # type: ignore
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
) -> requests.Response:
scheme, netloc, path, params, query, fragement = urlparse( # type: ignore
request.url
)
if ":" in netloc:
host, port = netloc.split(":", 1)
port = int(port)
host, port_string = netloc.split(":", 1)
port = int(port_string)
else:
host = netloc
port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
# Include the 'host' header.
if "host" in request.headers:
headers = []
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
elif port == 80:
headers = [[b"host", host.encode()]]
headers = [(b"host", host.encode())]
else:
headers = [[b"host", ("%s:%d" % (host, port)).encode()]]
headers = [(b"host", ("%s:%d" % (host, port)).encode())]
# Include other request headers.
headers += [
[key.lower().encode(), value.encode()]
(key.lower().encode(), value.encode())
for key, value in request.headers.items()
]
if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
if subprotocol is None:
subprotocols = []
subprotocols = [] # type: typing.Sequence[str]
else:
subprotocols = [value.strip() for value in subprotocol.split(",")]
scope = {
@ -103,7 +124,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
"server": [host, port],
}
async def receive():
async def receive() -> Message:
body = request.body
if isinstance(body, str):
body_bytes = body.encode("utf-8") # type: bytes
@ -113,7 +134,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
body_bytes = body
return {"type": "http.request", "body": body_bytes}
async def send(message):
async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, response_complete
if message["type"] == "http.response.start":
@ -147,7 +168,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
response_started = False
response_complete = False
raw_kwargs = {"body": io.BytesIO()}
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
loop = asyncio.get_event_loop()
@ -176,12 +197,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
class WebSocketTestSession:
def __init__(self, app, scope):
def __init__(self, app: ASGIApp, scope: Scope) -> None:
self.accepted_subprotocol = None
self._loop = asyncio.new_event_loop()
self._instance = app(scope)
self._receive_queue = queue.Queue()
self._send_queue = queue.Queue()
self._receive_queue = queue.Queue() # type: queue.Queue
self._send_queue = queue.Queue() # type: queue.Queue
self._thread = threading.Thread(target=self._run)
self.send({"type": "websocket.connect"})
self._thread.start()
@ -189,10 +210,10 @@ class WebSocketTestSession:
self._raise_on_close(message)
self.accepted_subprotocol = message.get("subprotocol", None)
def __enter__(self):
def __enter__(self) -> "WebSocketTestSession":
return self
def __exit__(self, *args):
def __exit__(self, *args: typing.Any) -> None:
self.close(1000)
self._thread.join()
while not self._send_queue.empty():
@ -200,7 +221,7 @@ class WebSocketTestSession:
if isinstance(message, BaseException):
raise message
def _run(self):
def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
@ -211,49 +232,49 @@ class WebSocketTestSession:
except BaseException as exc:
self._send_queue.put(exc)
async def _asgi_receive(self):
async def _asgi_receive(self) -> Message:
return self._receive_queue.get()
async def _asgi_send(self, message):
async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
def _raise_on_close(self, message):
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
def send(self, message):
def send(self, message: Message) -> None:
self._receive_queue.put(message)
def send_text(self, data):
def send_text(self, data: str) -> None:
self.send({"type": "websocket.receive", "text": data})
def send_bytes(self, data):
def send_bytes(self, data: bytes) -> None:
self.send({"type": "websocket.receive", "bytes": data})
def send_json(self, data):
def send_json(self, data: typing.Any) -> None:
encoded = json.dumps(data).encode("utf-8")
self.send({"type": "websocket.receive", "bytes": encoded})
def close(self, code=1000):
def close(self, code: int = 1000) -> None:
self.send({"type": "websocket.disconnect", "code": code})
def receive(self):
def receive(self) -> Message:
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
return message
def receive_text(self):
def receive_text(self) -> str:
message = self.receive()
self._raise_on_close(message)
return message["text"]
def receive_bytes(self):
def receive_bytes(self) -> bytes:
message = self.receive()
self._raise_on_close(message)
return message["bytes"]
def receive_json(self):
def receive_json(self) -> typing.Any:
message = self.receive()
self._raise_on_close(message)
encoded = message["bytes"]
@ -262,7 +283,7 @@ class WebSocketTestSession:
class _TestClient(requests.Session):
def __init__(
self, app: typing.Callable, base_url: str, raise_server_exceptions=True
self, app: ASGIApp, base_url: str, raise_server_exceptions: bool = True
) -> None:
super(_TestClient, self).__init__()
adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions)
@ -277,20 +298,20 @@ class _TestClient(requests.Session):
self,
method: str,
url: str,
params=None,
data=None,
headers=None,
cookies=None,
files=None,
auth=None,
timeout=None,
allow_redirects=True,
proxies=None,
hooks=None,
stream=None,
verify=None,
cert=None,
json=None,
params: Params = None,
data: DataType = None,
headers: typing.MutableMapping[str, str] = None,
cookies: Cookies = None,
files: FileType = None,
auth: AuthType = None,
timeout: TimeOut = None,
allow_redirects: bool = None,
proxies: typing.MutableMapping[str, str] = None,
hooks: typing.Any = None,
stream: bool = None,
verify: typing.Union[bool, str] = None,
cert: typing.Union[str, typing.Tuple[str, str]] = None,
json: typing.Any = None,
) -> requests.Response:
url = urljoin(self.base_url, url)
return super().request(
@ -313,8 +334,8 @@ class _TestClient(requests.Session):
)
def websocket_connect(
self, url: str, subprotocols=None, **kwargs
) -> WebSocketTestSession:
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
@ -334,9 +355,9 @@ class _TestClient(requests.Session):
def TestClient(
app: typing.Callable,
app: ASGIApp,
base_url: str = "http://testserver",
raise_server_exceptions=True,
raise_server_exceptions: bool = True,
) -> _TestClient:
"""
We have to work around py.test discovery attempting to pick up

View File

@ -58,7 +58,7 @@ def user_page(request, username):
@app.route("/500")
def func_homepage(request):
def runtime_error(request):
raise RuntimeError()