mirror of https://github.com/encode/starlette.git
Merge branch 'master' of https://github.com/encode/starlette
This commit is contained in:
commit
66cce0f1c3
|
@ -11,7 +11,7 @@ export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"`
|
|||
set -x
|
||||
|
||||
PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=starlette --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@}
|
||||
${PREFIX}mypy starlette --ignore-missing-imports
|
||||
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
|
||||
if [ "${PYTHON_VERSION}" = '3.7' ]; then
|
||||
echo "Skipping 'black' on 3.7. See issue https://github.com/ambv/black/issues/494"
|
||||
else
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import typing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.lifespan import LifespanHandler
|
||||
|
@ -54,6 +55,7 @@ class Starlette:
|
|||
self.lifespan_handler = LifespanHandler()
|
||||
self.app = self.router
|
||||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
@property
|
||||
def debug(self) -> bool:
|
||||
|
@ -118,7 +120,8 @@ class Starlette:
|
|||
return decorator
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
scope["app"] = self
|
||||
scope["executor"] = self.executor
|
||||
if scope["type"] == "lifespan":
|
||||
return self.lifespan_handler(scope)
|
||||
scope["app"] = self
|
||||
return self.exception_middleware(scope)
|
||||
|
|
|
@ -4,38 +4,54 @@ import io
|
|||
import json
|
||||
import threading
|
||||
import typing
|
||||
import requests
|
||||
import queue
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from urllib.parse import unquote, urlparse, urljoin
|
||||
|
||||
import requests
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from starlette.types import Message, Scope, ASGIApp
|
||||
|
||||
|
||||
# Annotations for `Session.request()`
|
||||
Cookies = typing.Union[
|
||||
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
|
||||
]
|
||||
Params = typing.Union[bytes, typing.MutableMapping[str, str]]
|
||||
DataType = typing.Union[bytes, typing.MutableMapping[str, str], typing.IO]
|
||||
TimeOut = typing.Union[float, typing.Tuple[float, float]]
|
||||
FileType = typing.MutableMapping[str, typing.IO]
|
||||
AuthType = typing.Union[
|
||||
typing.Tuple[str, str],
|
||||
requests.auth.AuthBase,
|
||||
typing.Callable[[requests.Request], requests.Request],
|
||||
]
|
||||
|
||||
|
||||
class _HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
|
||||
def get_all(self, key, default):
|
||||
def get_all(self, key: str, default: str) -> str:
|
||||
return self.getheaders(key)
|
||||
|
||||
|
||||
class _MockOriginalResponse(object):
|
||||
class _MockOriginalResponse:
|
||||
"""
|
||||
We have to jump through some hoops to present the response as if
|
||||
it was made using urllib3.
|
||||
"""
|
||||
|
||||
def __init__(self, headers):
|
||||
def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
|
||||
self.msg = _HeaderDict(headers)
|
||||
self.closed = False
|
||||
|
||||
def isclosed(self):
|
||||
def isclosed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
|
||||
class _Upgrade(Exception):
|
||||
def __init__(self, session):
|
||||
def __init__(self, session: "WebSocketTestSession") -> None:
|
||||
self.session = session
|
||||
|
||||
|
||||
def _get_reason_phrase(status_code):
|
||||
def _get_reason_phrase(status_code: int) -> str:
|
||||
try:
|
||||
return http.HTTPStatus(status_code).phrase
|
||||
except ValueError:
|
||||
|
@ -43,37 +59,42 @@ def _get_reason_phrase(status_code):
|
|||
|
||||
|
||||
class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
||||
def __init__(self, app: typing.Callable, raise_server_exceptions=True) -> None:
|
||||
def __init__(self, app: ASGIApp, raise_server_exceptions: bool = True) -> None:
|
||||
self.app = app
|
||||
self.raise_server_exceptions = raise_server_exceptions
|
||||
|
||||
def send(self, request, *args, **kwargs):
|
||||
scheme, netloc, path, params, query, fragement = urlparse(request.url)
|
||||
def send( # type: ignore
|
||||
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> requests.Response:
|
||||
scheme, netloc, path, params, query, fragement = urlparse( # type: ignore
|
||||
request.url
|
||||
)
|
||||
|
||||
if ":" in netloc:
|
||||
host, port = netloc.split(":", 1)
|
||||
port = int(port)
|
||||
host, port_string = netloc.split(":", 1)
|
||||
port = int(port_string)
|
||||
else:
|
||||
host = netloc
|
||||
port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||||
|
||||
# Include the 'host' header.
|
||||
if "host" in request.headers:
|
||||
headers = []
|
||||
headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
|
||||
elif port == 80:
|
||||
headers = [[b"host", host.encode()]]
|
||||
headers = [(b"host", host.encode())]
|
||||
else:
|
||||
headers = [[b"host", ("%s:%d" % (host, port)).encode()]]
|
||||
headers = [(b"host", ("%s:%d" % (host, port)).encode())]
|
||||
|
||||
# Include other request headers.
|
||||
headers += [
|
||||
[key.lower().encode(), value.encode()]
|
||||
(key.lower().encode(), value.encode())
|
||||
for key, value in request.headers.items()
|
||||
]
|
||||
|
||||
if scheme in {"ws", "wss"}:
|
||||
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||||
if subprotocol is None:
|
||||
subprotocols = []
|
||||
subprotocols = [] # type: typing.Sequence[str]
|
||||
else:
|
||||
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||||
scope = {
|
||||
|
@ -103,7 +124,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
"server": [host, port],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
async def receive() -> Message:
|
||||
body = request.body
|
||||
if isinstance(body, str):
|
||||
body_bytes = body.encode("utf-8") # type: bytes
|
||||
|
@ -113,7 +134,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
body_bytes = body
|
||||
return {"type": "http.request", "body": body_bytes}
|
||||
|
||||
async def send(message):
|
||||
async def send(message: Message) -> None:
|
||||
nonlocal raw_kwargs, response_started, response_complete
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
|
@ -147,7 +168,7 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
|
||||
response_started = False
|
||||
response_complete = False
|
||||
raw_kwargs = {"body": io.BytesIO()}
|
||||
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
@ -176,12 +197,12 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
|
|||
|
||||
|
||||
class WebSocketTestSession:
|
||||
def __init__(self, app, scope):
|
||||
def __init__(self, app: ASGIApp, scope: Scope) -> None:
|
||||
self.accepted_subprotocol = None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._instance = app(scope)
|
||||
self._receive_queue = queue.Queue()
|
||||
self._send_queue = queue.Queue()
|
||||
self._receive_queue = queue.Queue() # type: queue.Queue
|
||||
self._send_queue = queue.Queue() # type: queue.Queue
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self.send({"type": "websocket.connect"})
|
||||
self._thread.start()
|
||||
|
@ -189,10 +210,10 @@ class WebSocketTestSession:
|
|||
self._raise_on_close(message)
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "WebSocketTestSession":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
self.close(1000)
|
||||
self._thread.join()
|
||||
while not self._send_queue.empty():
|
||||
|
@ -200,7 +221,7 @@ class WebSocketTestSession:
|
|||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
|
||||
def _run(self):
|
||||
def _run(self) -> None:
|
||||
"""
|
||||
The sub-thread in which the websocket session runs.
|
||||
"""
|
||||
|
@ -211,49 +232,49 @@ class WebSocketTestSession:
|
|||
except BaseException as exc:
|
||||
self._send_queue.put(exc)
|
||||
|
||||
async def _asgi_receive(self):
|
||||
async def _asgi_receive(self) -> Message:
|
||||
return self._receive_queue.get()
|
||||
|
||||
async def _asgi_send(self, message):
|
||||
async def _asgi_send(self, message: Message) -> None:
|
||||
self._send_queue.put(message)
|
||||
|
||||
def _raise_on_close(self, message):
|
||||
def _raise_on_close(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.close":
|
||||
raise WebSocketDisconnect(message.get("code", 1000))
|
||||
|
||||
def send(self, message):
|
||||
def send(self, message: Message) -> None:
|
||||
self._receive_queue.put(message)
|
||||
|
||||
def send_text(self, data):
|
||||
def send_text(self, data: str) -> None:
|
||||
self.send({"type": "websocket.receive", "text": data})
|
||||
|
||||
def send_bytes(self, data):
|
||||
def send_bytes(self, data: bytes) -> None:
|
||||
self.send({"type": "websocket.receive", "bytes": data})
|
||||
|
||||
def send_json(self, data):
|
||||
def send_json(self, data: typing.Any) -> None:
|
||||
encoded = json.dumps(data).encode("utf-8")
|
||||
self.send({"type": "websocket.receive", "bytes": encoded})
|
||||
|
||||
def close(self, code=1000):
|
||||
def close(self, code: int = 1000) -> None:
|
||||
self.send({"type": "websocket.disconnect", "code": code})
|
||||
|
||||
def receive(self):
|
||||
def receive(self) -> Message:
|
||||
message = self._send_queue.get()
|
||||
if isinstance(message, BaseException):
|
||||
raise message
|
||||
return message
|
||||
|
||||
def receive_text(self):
|
||||
def receive_text(self) -> str:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["text"]
|
||||
|
||||
def receive_bytes(self):
|
||||
def receive_bytes(self) -> bytes:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
return message["bytes"]
|
||||
|
||||
def receive_json(self):
|
||||
def receive_json(self) -> typing.Any:
|
||||
message = self.receive()
|
||||
self._raise_on_close(message)
|
||||
encoded = message["bytes"]
|
||||
|
@ -262,7 +283,7 @@ class WebSocketTestSession:
|
|||
|
||||
class _TestClient(requests.Session):
|
||||
def __init__(
|
||||
self, app: typing.Callable, base_url: str, raise_server_exceptions=True
|
||||
self, app: ASGIApp, base_url: str, raise_server_exceptions: bool = True
|
||||
) -> None:
|
||||
super(_TestClient, self).__init__()
|
||||
adapter = _ASGIAdapter(app, raise_server_exceptions=raise_server_exceptions)
|
||||
|
@ -277,20 +298,20 @@ class _TestClient(requests.Session):
|
|||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
params=None,
|
||||
data=None,
|
||||
headers=None,
|
||||
cookies=None,
|
||||
files=None,
|
||||
auth=None,
|
||||
timeout=None,
|
||||
allow_redirects=True,
|
||||
proxies=None,
|
||||
hooks=None,
|
||||
stream=None,
|
||||
verify=None,
|
||||
cert=None,
|
||||
json=None,
|
||||
params: Params = None,
|
||||
data: DataType = None,
|
||||
headers: typing.MutableMapping[str, str] = None,
|
||||
cookies: Cookies = None,
|
||||
files: FileType = None,
|
||||
auth: AuthType = None,
|
||||
timeout: TimeOut = None,
|
||||
allow_redirects: bool = None,
|
||||
proxies: typing.MutableMapping[str, str] = None,
|
||||
hooks: typing.Any = None,
|
||||
stream: bool = None,
|
||||
verify: typing.Union[bool, str] = None,
|
||||
cert: typing.Union[str, typing.Tuple[str, str]] = None,
|
||||
json: typing.Any = None,
|
||||
) -> requests.Response:
|
||||
url = urljoin(self.base_url, url)
|
||||
return super().request(
|
||||
|
@ -313,8 +334,8 @@ class _TestClient(requests.Session):
|
|||
)
|
||||
|
||||
def websocket_connect(
|
||||
self, url: str, subprotocols=None, **kwargs
|
||||
) -> WebSocketTestSession:
|
||||
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
|
||||
) -> typing.Any:
|
||||
url = urljoin("ws://testserver", url)
|
||||
headers = kwargs.get("headers", {})
|
||||
headers.setdefault("connection", "upgrade")
|
||||
|
@ -334,9 +355,9 @@ class _TestClient(requests.Session):
|
|||
|
||||
|
||||
def TestClient(
|
||||
app: typing.Callable,
|
||||
app: ASGIApp,
|
||||
base_url: str = "http://testserver",
|
||||
raise_server_exceptions=True,
|
||||
raise_server_exceptions: bool = True,
|
||||
) -> _TestClient:
|
||||
"""
|
||||
We have to work around py.test discovery attempting to pick up
|
||||
|
|
|
@ -58,7 +58,7 @@ def user_page(request, username):
|
|||
|
||||
|
||||
@app.route("/500")
|
||||
def func_homepage(request):
|
||||
def runtime_error(request):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue