From 43fb6764396d10eed82028466f1071c472d5f7cf Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 16 Oct 2018 13:20:10 +0100 Subject: [PATCH] Add missing type annotations (#113) * Add missing type annotations * Type annotation tweak --- starlette/background.py | 7 +++++-- starlette/datastructures.py | 6 +++--- starlette/endpoints.py | 19 +++++++++++-------- starlette/middleware/cors.py | 18 ++++++++++-------- starlette/requests.py | 14 +++++++------- 5 files changed, 36 insertions(+), 28 deletions(-) diff --git a/starlette/background.py b/starlette/background.py index 98387350..a666cc61 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,14 +1,17 @@ import asyncio import functools +import typing class BackgroundTask: - def __init__(self, func, *args, **kwargs): + def __init__( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> None: self.func = func self.args = args self.kwargs = kwargs - async def __call__(self): + async def __call__(self) -> None: if asyncio.iscoroutinefunction(self.func): await asyncio.ensure_future(self.func(*self.args, **self.kwargs)) else: diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 7f3aef4c..e5278603 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -89,7 +89,7 @@ class URL: def __str__(self) -> str: return self._url - def __repr__(self): + def __repr__(self) -> str: return "%s(%s)" % (self.__class__.__name__, repr(self._url)) @@ -163,7 +163,7 @@ class Headers(typing.Mapping[str, str]): An immutable, case-insensitive multidict. """ - def __init__(self, raw_headers: typing.Optional[BytesPairs] = None) -> None: + def __init__(self, raw_headers: BytesPairs = None) -> None: if raw_headers is None: self._list = [] # type: BytesPairs else: @@ -286,7 +286,7 @@ class MutableHeaders(Headers): for key, val in other.items(): self[key] = val - def add_vary_header(self, vary): + def add_vary_header(self, vary: str) -> None: existing = self.get("vary") if existing is not None: vary = ", ".join([existing, vary]) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 098d09f6..efb9d335 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -5,7 +5,7 @@ from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.websockets import WebSocket from starlette.responses import Response, PlainTextResponse -from starlette.types import Receive, Send, Scope +from starlette.types import Message, Receive, Send, Scope class HTTPEndpoint: @@ -52,7 +52,7 @@ class WebSocketEndpoint: kwargs = self.scope.get("kwargs", {}) await self.on_connect(websocket, **kwargs) - close_code = None + close_code = 1000 try: while True: @@ -61,12 +61,15 @@ class WebSocketEndpoint: data = await self.decode(websocket, message) await self.on_receive(websocket, data) elif message["type"] == "websocket.disconnect": - close_code = message.get("code", 1000) - return + close_code = int(message.get("code", 1000)) + break + except Exception as exc: + close_code = 1011 + raise exc from None finally: await self.on_disconnect(websocket, close_code) - async def decode(self, websocket, message): + async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: if self.encoding == "text": if "text" not in message: @@ -93,12 +96,12 @@ class WebSocketEndpoint: ), f"Unsupported 'encoding' attribute {self.encoding}" return message["text"] if "text" in message else message["bytes"] - async def on_connect(self, websocket, **kwargs): + async def on_connect(self, websocket: WebSocket, **kwargs: typing.Any) -> None: """Override to handle an incoming websocket connection""" await websocket.accept() - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: """Override to handle an incoming websocket message""" - async def on_disconnect(self, websocket, close_code): + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: """Override to handle a disconnecting websocket""" diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 8b0abeae..8820ec9c 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -1,6 +1,6 @@ from starlette.datastructures import Headers, MutableHeaders, URL from starlette.responses import PlainTextResponse -from starlette.types import ASGIApp, ASGIInstance, Scope +from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send, Message import functools import typing import re @@ -63,7 +63,7 @@ class CORSMiddleware: self.simple_headers = simple_headers self.preflight_headers = preflight_headers - def __call__(self, scope: Scope): + def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] == "http": method = scope["method"] headers = Headers(scope["headers"]) @@ -79,7 +79,7 @@ class CORSMiddleware: return self.app(scope) - def is_allowed_origin(self, origin): + def is_allowed_origin(self, origin: str) -> bool: if self.allow_all_origins: return True @@ -90,7 +90,7 @@ class CORSMiddleware: return origin in self.allow_origins - def preflight_response(self, request_headers): + def preflight_response(self, request_headers: Headers) -> ASGIInstance: requested_origin = request_headers["origin"] requested_method = request_headers["access-control-request-method"] requested_headers = request_headers.get("access-control-request-headers") @@ -130,17 +130,20 @@ class CORSMiddleware: return PlainTextResponse("OK", status_code=200, headers=headers) - async def simple_response(self, receive, send, scope=None, request_headers=None): + async def simple_response( + self, receive: Receive, send: Send, scope: Scope, request_headers: Headers + ) -> None: inner = self.app(scope) send = functools.partial(self.send, send=send, request_headers=request_headers) await inner(receive, send) - async def send(self, message, send=None, request_headers=None): + async def send( + self, message: Message, send: Send, request_headers: Headers + ) -> None: if message["type"] != "http.response.start": await send(message) return - print(message) message.setdefault("headers", []) headers = MutableHeaders(message["headers"]) origin = request_headers["Origin"] @@ -156,6 +159,5 @@ class CORSMiddleware: elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): headers["Access-Control-Allow-Origin"] = origin headers.add_vary_header("Origin") - print(headers) headers.update(self.simple_headers) await send(message) diff --git a/starlette/requests.py b/starlette/requests.py index c721844e..c4133273 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -107,7 +107,7 @@ class Request(Mapping): self._json = json.loads(body) return self._json - async def form(self): + async def form(self) -> dict: if not hasattr(self, "_form"): assert ( parse_options_header is not None @@ -115,17 +115,17 @@ class Request(Mapping): content_type_header = self.headers.get("Content-Type") content_type, options = parse_options_header(content_type_header) if content_type == b"multipart/form-data": - parser = MultiPartParser(self.headers, self.stream()) - self._form = await parser.parse() + multipart_parser = MultiPartParser(self.headers, self.stream()) + self._form = await multipart_parser.parse() elif content_type == b"application/x-www-form-urlencoded": - parser = FormParser(self.headers, self.stream()) - self._form = await parser.parse() + from_parser = FormParser(self.headers, self.stream()) + self._form = await from_parser.parse() else: self._form = {} return self._form - async def close(self): + async def close(self) -> None: if hasattr(self, "_form"): for item in self._form.values(): if hasattr(item, "close"): - await item.close() + await item.close() # type: ignore