mirror of https://github.com/encode/starlette.git
Add missing type annotations (#113)
* Add missing type annotations * Type annotation tweak
This commit is contained in:
parent
e2cdd2e1ee
commit
43fb676439
|
@ -1,14 +1,17 @@
|
|||
import asyncio
|
||||
import functools
|
||||
import typing
|
||||
|
||||
|
||||
class BackgroundTask:
|
||||
def __init__(self, func, *args, **kwargs):
|
||||
def __init__(
|
||||
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __call__(self):
|
||||
async def __call__(self) -> None:
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
await asyncio.ensure_future(self.func(*self.args, **self.kwargs))
|
||||
else:
|
||||
|
|
|
@ -89,7 +89,7 @@ class URL:
|
|||
def __str__(self) -> str:
|
||||
return self._url
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%s)" % (self.__class__.__name__, repr(self._url))
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ class Headers(typing.Mapping[str, str]):
|
|||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_headers: typing.Optional[BytesPairs] = None) -> None:
|
||||
def __init__(self, raw_headers: BytesPairs = None) -> None:
|
||||
if raw_headers is None:
|
||||
self._list = [] # type: BytesPairs
|
||||
else:
|
||||
|
@ -286,7 +286,7 @@ class MutableHeaders(Headers):
|
|||
for key, val in other.items():
|
||||
self[key] = val
|
||||
|
||||
def add_vary_header(self, vary):
|
||||
def add_vary_header(self, vary: str) -> None:
|
||||
existing = self.get("vary")
|
||||
if existing is not None:
|
||||
vary = ", ".join([existing, vary])
|
||||
|
|
|
@ -5,7 +5,7 @@ from starlette.exceptions import HTTPException
|
|||
from starlette.requests import Request
|
||||
from starlette.websockets import WebSocket
|
||||
from starlette.responses import Response, PlainTextResponse
|
||||
from starlette.types import Receive, Send, Scope
|
||||
from starlette.types import Message, Receive, Send, Scope
|
||||
|
||||
|
||||
class HTTPEndpoint:
|
||||
|
@ -52,7 +52,7 @@ class WebSocketEndpoint:
|
|||
kwargs = self.scope.get("kwargs", {})
|
||||
await self.on_connect(websocket, **kwargs)
|
||||
|
||||
close_code = None
|
||||
close_code = 1000
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
@ -61,12 +61,15 @@ class WebSocketEndpoint:
|
|||
data = await self.decode(websocket, message)
|
||||
await self.on_receive(websocket, data)
|
||||
elif message["type"] == "websocket.disconnect":
|
||||
close_code = message.get("code", 1000)
|
||||
return
|
||||
close_code = int(message.get("code", 1000))
|
||||
break
|
||||
except Exception as exc:
|
||||
close_code = 1011
|
||||
raise exc from None
|
||||
finally:
|
||||
await self.on_disconnect(websocket, close_code)
|
||||
|
||||
async def decode(self, websocket, message):
|
||||
async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
|
||||
|
||||
if self.encoding == "text":
|
||||
if "text" not in message:
|
||||
|
@ -93,12 +96,12 @@ class WebSocketEndpoint:
|
|||
), f"Unsupported 'encoding' attribute {self.encoding}"
|
||||
return message["text"] if "text" in message else message["bytes"]
|
||||
|
||||
async def on_connect(self, websocket, **kwargs):
|
||||
async def on_connect(self, websocket: WebSocket, **kwargs: typing.Any) -> None:
|
||||
"""Override to handle an incoming websocket connection"""
|
||||
await websocket.accept()
|
||||
|
||||
async def on_receive(self, websocket, data):
|
||||
async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
|
||||
"""Override to handle an incoming websocket message"""
|
||||
|
||||
async def on_disconnect(self, websocket, close_code):
|
||||
async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
|
||||
"""Override to handle a disconnecting websocket"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from starlette.datastructures import Headers, MutableHeaders, URL
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.types import ASGIApp, ASGIInstance, Scope
|
||||
from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send, Message
|
||||
import functools
|
||||
import typing
|
||||
import re
|
||||
|
@ -63,7 +63,7 @@ class CORSMiddleware:
|
|||
self.simple_headers = simple_headers
|
||||
self.preflight_headers = preflight_headers
|
||||
|
||||
def __call__(self, scope: Scope):
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] == "http":
|
||||
method = scope["method"]
|
||||
headers = Headers(scope["headers"])
|
||||
|
@ -79,7 +79,7 @@ class CORSMiddleware:
|
|||
|
||||
return self.app(scope)
|
||||
|
||||
def is_allowed_origin(self, origin):
|
||||
def is_allowed_origin(self, origin: str) -> bool:
|
||||
if self.allow_all_origins:
|
||||
return True
|
||||
|
||||
|
@ -90,7 +90,7 @@ class CORSMiddleware:
|
|||
|
||||
return origin in self.allow_origins
|
||||
|
||||
def preflight_response(self, request_headers):
|
||||
def preflight_response(self, request_headers: Headers) -> ASGIInstance:
|
||||
requested_origin = request_headers["origin"]
|
||||
requested_method = request_headers["access-control-request-method"]
|
||||
requested_headers = request_headers.get("access-control-request-headers")
|
||||
|
@ -130,17 +130,20 @@ class CORSMiddleware:
|
|||
|
||||
return PlainTextResponse("OK", status_code=200, headers=headers)
|
||||
|
||||
async def simple_response(self, receive, send, scope=None, request_headers=None):
|
||||
async def simple_response(
|
||||
self, receive: Receive, send: Send, scope: Scope, request_headers: Headers
|
||||
) -> None:
|
||||
inner = self.app(scope)
|
||||
send = functools.partial(self.send, send=send, request_headers=request_headers)
|
||||
await inner(receive, send)
|
||||
|
||||
async def send(self, message, send=None, request_headers=None):
|
||||
async def send(
|
||||
self, message: Message, send: Send, request_headers: Headers
|
||||
) -> None:
|
||||
if message["type"] != "http.response.start":
|
||||
await send(message)
|
||||
return
|
||||
|
||||
print(message)
|
||||
message.setdefault("headers", [])
|
||||
headers = MutableHeaders(message["headers"])
|
||||
origin = request_headers["Origin"]
|
||||
|
@ -156,6 +159,5 @@ class CORSMiddleware:
|
|||
elif not self.allow_all_origins and self.is_allowed_origin(origin=origin):
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers.add_vary_header("Origin")
|
||||
print(headers)
|
||||
headers.update(self.simple_headers)
|
||||
await send(message)
|
||||
|
|
|
@ -107,7 +107,7 @@ class Request(Mapping):
|
|||
self._json = json.loads(body)
|
||||
return self._json
|
||||
|
||||
async def form(self):
|
||||
async def form(self) -> dict:
|
||||
if not hasattr(self, "_form"):
|
||||
assert (
|
||||
parse_options_header is not None
|
||||
|
@ -115,17 +115,17 @@ class Request(Mapping):
|
|||
content_type_header = self.headers.get("Content-Type")
|
||||
content_type, options = parse_options_header(content_type_header)
|
||||
if content_type == b"multipart/form-data":
|
||||
parser = MultiPartParser(self.headers, self.stream())
|
||||
self._form = await parser.parse()
|
||||
multipart_parser = MultiPartParser(self.headers, self.stream())
|
||||
self._form = await multipart_parser.parse()
|
||||
elif content_type == b"application/x-www-form-urlencoded":
|
||||
parser = FormParser(self.headers, self.stream())
|
||||
self._form = await parser.parse()
|
||||
from_parser = FormParser(self.headers, self.stream())
|
||||
self._form = await from_parser.parse()
|
||||
else:
|
||||
self._form = {}
|
||||
return self._form
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
if hasattr(self, "_form"):
|
||||
for item in self._form.values():
|
||||
if hasattr(item, "close"):
|
||||
await item.close()
|
||||
await item.close() # type: ignore
|
||||
|
|
Loading…
Reference in New Issue