Add missing type annotations (#113)

* Add missing type annotations

* Type annotation tweak
This commit is contained in:
Tom Christie 2018-10-16 13:20:10 +01:00 committed by GitHub
parent e2cdd2e1ee
commit 43fb676439
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 28 deletions

View File

@ -1,14 +1,17 @@
import asyncio
import functools
import typing
class BackgroundTask:
def __init__(self, func, *args, **kwargs):
def __init__(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
async def __call__(self):
async def __call__(self) -> None:
if asyncio.iscoroutinefunction(self.func):
await asyncio.ensure_future(self.func(*self.args, **self.kwargs))
else:

View File

@ -89,7 +89,7 @@ class URL:
def __str__(self) -> str:
return self._url
def __repr__(self):
def __repr__(self) -> str:
return "%s(%s)" % (self.__class__.__name__, repr(self._url))
@ -163,7 +163,7 @@ class Headers(typing.Mapping[str, str]):
An immutable, case-insensitive multidict.
"""
def __init__(self, raw_headers: typing.Optional[BytesPairs] = None) -> None:
def __init__(self, raw_headers: BytesPairs = None) -> None:
if raw_headers is None:
self._list = [] # type: BytesPairs
else:
@ -286,7 +286,7 @@ class MutableHeaders(Headers):
for key, val in other.items():
self[key] = val
def add_vary_header(self, vary):
def add_vary_header(self, vary: str) -> None:
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])

View File

@ -5,7 +5,7 @@ from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.websockets import WebSocket
from starlette.responses import Response, PlainTextResponse
from starlette.types import Receive, Send, Scope
from starlette.types import Message, Receive, Send, Scope
class HTTPEndpoint:
@ -52,7 +52,7 @@ class WebSocketEndpoint:
kwargs = self.scope.get("kwargs", {})
await self.on_connect(websocket, **kwargs)
close_code = None
close_code = 1000
try:
while True:
@ -61,12 +61,15 @@ class WebSocketEndpoint:
data = await self.decode(websocket, message)
await self.on_receive(websocket, data)
elif message["type"] == "websocket.disconnect":
close_code = message.get("code", 1000)
return
close_code = int(message.get("code", 1000))
break
except Exception as exc:
close_code = 1011
raise exc from None
finally:
await self.on_disconnect(websocket, close_code)
async def decode(self, websocket, message):
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
if self.encoding == "text":
if "text" not in message:
@ -93,12 +96,12 @@ class WebSocketEndpoint:
), f"Unsupported 'encoding' attribute {self.encoding}"
return message["text"] if "text" in message else message["bytes"]
async def on_connect(self, websocket, **kwargs):
async def on_connect(self, websocket: WebSocket, **kwargs: typing.Any) -> None:
"""Override to handle an incoming websocket connection"""
await websocket.accept()
async def on_receive(self, websocket, data):
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
"""Override to handle an incoming websocket message"""
async def on_disconnect(self, websocket, close_code):
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
"""Override to handle a disconnecting websocket"""

View File

@ -1,6 +1,6 @@
from starlette.datastructures import Headers, MutableHeaders, URL
from starlette.responses import PlainTextResponse
from starlette.types import ASGIApp, ASGIInstance, Scope
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send, Message
import functools
import typing
import re
@ -63,7 +63,7 @@ class CORSMiddleware:
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
def __call__(self, scope: Scope):
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] == "http":
method = scope["method"]
headers = Headers(scope["headers"])
@ -79,7 +79,7 @@ class CORSMiddleware:
return self.app(scope)
def is_allowed_origin(self, origin):
def is_allowed_origin(self, origin: str) -> bool:
if self.allow_all_origins:
return True
@ -90,7 +90,7 @@ class CORSMiddleware:
return origin in self.allow_origins
def preflight_response(self, request_headers):
def preflight_response(self, request_headers: Headers) -> ASGIInstance:
requested_origin = request_headers["origin"]
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get("access-control-request-headers")
@ -130,17 +130,20 @@ class CORSMiddleware:
return PlainTextResponse("OK", status_code=200, headers=headers)
async def simple_response(self, receive, send, scope=None, request_headers=None):
async def simple_response(
self, receive: Receive, send: Send, scope: Scope, request_headers: Headers
) -> None:
inner = self.app(scope)
send = functools.partial(self.send, send=send, request_headers=request_headers)
await inner(receive, send)
async def send(self, message, send=None, request_headers=None):
async def send(
self, message: Message, send: Send, request_headers: Headers
) -> None:
if message["type"] != "http.response.start":
await send(message)
return
print(message)
message.setdefault("headers", [])
headers = MutableHeaders(message["headers"])
origin = request_headers["Origin"]
@ -156,6 +159,5 @@ class CORSMiddleware:
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
headers["Access-Control-Allow-Origin"] = origin
headers.add_vary_header("Origin")
print(headers)
headers.update(self.simple_headers)
await send(message)

View File

@ -107,7 +107,7 @@ class Request(Mapping):
self._json = json.loads(body)
return self._json
async def form(self):
async def form(self) -> dict:
if not hasattr(self, "_form"):
assert (
parse_options_header is not None
@ -115,17 +115,17 @@ class Request(Mapping):
content_type_header = self.headers.get("Content-Type")
content_type, options = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
parser = MultiPartParser(self.headers, self.stream())
self._form = await parser.parse()
multipart_parser = MultiPartParser(self.headers, self.stream())
self._form = await multipart_parser.parse()
elif content_type == b"application/x-www-form-urlencoded":
parser = FormParser(self.headers, self.stream())
self._form = await parser.parse()
from_parser = FormParser(self.headers, self.stream())
self._form = await from_parser.parse()
else:
self._form = {}
return self._form
async def close(self):
async def close(self) -> None:
if hasattr(self, "_form"):
for item in self._form.values():
if hasattr(item, "close"):
await item.close()
await item.close() # type: ignore