Missing annotations added datastructures & requests modules (#60)

* fix(dataStructures): missing annotations added

* fix(requests.py): annotations added

* fix(Annotations): wrong annotations fixed on datastructures and requests modules.
This commit is contained in:
Marcos Schroh 2018-10-02 13:31:05 +02:00 committed by Tom Christie
parent 7dc07267c2
commit 7eafe567df
2 changed files with 54 additions and 51 deletions

View File

@ -1,65 +1,66 @@
import typing
from urllib.parse import parse_qsl, urlparse
from urllib.parse import parse_qsl, urlparse, ParseResult
class URL(str):
@property
def components(self):
def components(self) -> ParseResult:
if not hasattr(self, "_components"):
self._components = urlparse(self)
return self._components
@property
def scheme(self):
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self):
def netloc(self) -> str:
return self.components.netloc
@property
def path(self):
def path(self) -> str:
return self.components.path
@property
def params(self):
def params(self) -> str:
return self.components.params
@property
def query(self):
def query(self) -> str:
return self.components.query
@property
def fragment(self):
def fragment(self) -> str:
return self.components.fragment
@property
def username(self):
def username(self) -> typing.Union[None, str]:
return self.components.username
@property
def password(self):
def password(self) -> typing.Union[None, str]:
return self.components.password
@property
def hostname(self):
def hostname(self) -> typing.Union[None, str]:
return self.components.hostname
@property
def port(self):
def port(self) -> typing.Optional[int]:
return self.components.port
def replace(self, **kwargs):
def replace(self, **kwargs: typing.Any) -> "URL": # type: ignore
components = self.components._replace(**kwargs)
return URL(components.geturl())
# Type annotations for valid `__init__` values to QueryParams and Headers.
StrPairs = typing.Sequence[typing.Tuple[str, str]]
BytesPairs = typing.List[typing.Tuple[bytes, bytes]]
StrDict = typing.Mapping[str, str]
class QueryParams(typing.Mapping[str, str]):
class QueryParams(StrDict):
"""
An immutable multidict.
"""
@ -79,42 +80,42 @@ class QueryParams(typing.Mapping[str, str]):
self._dict = {k: v for k, v in reversed(items)}
self._list = items
def getlist(self, key: str) -> typing.List[str]:
def getlist(self, key: typing.Any) -> typing.List[str]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self):
def keys(self) -> typing.List[str]: # type: ignore
return [key for key, value in self._list]
def values(self):
def values(self) -> typing.List[str]: # type: ignore
return [value for key, value in self._list]
def items(self):
def items(self) -> StrPairs: # type: ignore
return list(self._list)
def get(self, key, default=None):
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key in self._dict:
return self._dict[key]
else:
return default
def __getitem__(self, key):
def __getitem__(self, key: typing.Any) -> str:
return self._dict[key]
def __contains__(self, key):
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
def __iter__(self):
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self._list)
def __len__(self):
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other):
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, QueryParams):
other = QueryParams(other)
return sorted(self._list) == sorted(other._list)
def __repr__(self):
def __repr__(self) -> str:
return "QueryParams(%s)" % repr(self._list)
@ -123,9 +124,9 @@ class Headers(typing.Mapping[str, str]):
An immutable, case-insensitive multidict.
"""
def __init__(self, raw_headers=None) -> None:
def __init__(self, raw_headers: typing.Optional[BytesPairs] = None) -> None:
if raw_headers is None:
self._list = []
self._list = [] # type: BytesPairs
else:
for header_key, header_value in raw_headers:
assert isinstance(header_key, bytes)
@ -133,19 +134,19 @@ class Headers(typing.Mapping[str, str]):
assert header_key == header_key.lower()
self._list = raw_headers
def keys(self):
def keys(self) -> typing.List[str]: # type: ignore
return [key.decode("latin-1") for key, value in self._list]
def values(self):
def values(self) -> typing.List[str]: # type: ignore
return [value.decode("latin-1") for key, value in self._list]
def items(self):
def items(self) -> StrPairs: # type: ignore
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def get(self, key: str, default: str = None):
def get(self, key: str, default: typing.Any = None) -> typing.Any:
try:
return self[key]
except KeyError:
@ -159,40 +160,40 @@ class Headers(typing.Mapping[str, str]):
if item_key == get_header_key
]
def mutablecopy(self):
def mutablecopy(self) -> "MutableHeaders":
return MutableHeaders(self._list[:])
def __getitem__(self, key: str):
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: str):
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self):
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.items())
def __len__(self):
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other):
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self):
def __repr__(self) -> str:
return "%s(%s)" % (self.__class__.__name__, repr(self.items()))
class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str):
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
@ -214,7 +215,7 @@ class MutableHeaders(Headers):
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str):
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
@ -228,7 +229,7 @@ class MutableHeaders(Headers):
for idx in reversed(pop_indexes):
del (self._list[idx])
def setdefault(self, key: str, value: str):
def setdefault(self, key: str, value: str) -> str:
"""
If the header `key` does not exist, then set it to `value`.
Returns the header value.

View File

@ -1,8 +1,10 @@
from starlette.datastructures import URL, Headers, QueryParams
import typing
import json
from collections.abc import Mapping
from urllib.parse import unquote
import json
import typing
from starlette.datastructures import URL, Headers, QueryParams
from starlette.types import Scope, Receive
class ClientDisconnect(Exception):
@ -10,19 +12,19 @@ class ClientDisconnect(Exception):
class Request(Mapping):
def __init__(self, scope, receive=None):
def __init__(self, scope: Scope, receive: Receive = None) -> None:
assert scope["type"] == "http"
self._scope = scope
self._receive = receive
self._stream_consumed = False
def __getitem__(self, key):
def __getitem__(self, key: str) -> str:
return self._scope[key]
def __iter__(self):
def __iter__(self) -> typing.Iterator[str]:
return iter(self._scope)
def __len__(self):
def __len__(self) -> int:
return len(self._scope)
@property
@ -61,7 +63,7 @@ class Request(Mapping):
self._query_params = QueryParams(query_string)
return self._query_params
async def stream(self):
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body
return
@ -82,7 +84,7 @@ class Request(Mapping):
elif message["type"] == "http.disconnect":
raise ClientDisconnect()
async def body(self):
async def body(self) -> bytes:
if not hasattr(self, "_body"):
body = b""
async for chunk in self.stream():
@ -90,7 +92,7 @@ class Request(Mapping):
self._body = body
return self._body
async def json(self):
async def json(self) -> typing.Any:
if not hasattr(self, "_json"):
body = await self.body()
self._json = json.loads(body)