mirror of https://github.com/encode/starlette.git
Missing annotations added datastructures & requests modules (#60)
* fix(dataStructures): missing annotations added * fix(requests.py): annotations added * fix(Annotations): wrong annotations fixed on datastructures and requests modules.
This commit is contained in:
parent
7dc07267c2
commit
7eafe567df
|
@ -1,65 +1,66 @@
|
|||
import typing
|
||||
from urllib.parse import parse_qsl, urlparse
|
||||
from urllib.parse import parse_qsl, urlparse, ParseResult
|
||||
|
||||
|
||||
class URL(str):
|
||||
@property
|
||||
def components(self):
|
||||
def components(self) -> ParseResult:
|
||||
if not hasattr(self, "_components"):
|
||||
self._components = urlparse(self)
|
||||
return self._components
|
||||
|
||||
@property
|
||||
def scheme(self):
|
||||
def scheme(self) -> str:
|
||||
return self.components.scheme
|
||||
|
||||
@property
|
||||
def netloc(self):
|
||||
def netloc(self) -> str:
|
||||
return self.components.netloc
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
def path(self) -> str:
|
||||
return self.components.path
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
def params(self) -> str:
|
||||
return self.components.params
|
||||
|
||||
@property
|
||||
def query(self):
|
||||
def query(self) -> str:
|
||||
return self.components.query
|
||||
|
||||
@property
|
||||
def fragment(self):
|
||||
def fragment(self) -> str:
|
||||
return self.components.fragment
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
def username(self) -> typing.Union[None, str]:
|
||||
return self.components.username
|
||||
|
||||
@property
|
||||
def password(self):
|
||||
def password(self) -> typing.Union[None, str]:
|
||||
return self.components.password
|
||||
|
||||
@property
|
||||
def hostname(self):
|
||||
def hostname(self) -> typing.Union[None, str]:
|
||||
return self.components.hostname
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
def port(self) -> typing.Optional[int]:
|
||||
return self.components.port
|
||||
|
||||
def replace(self, **kwargs):
|
||||
def replace(self, **kwargs: typing.Any) -> "URL": # type: ignore
|
||||
components = self.components._replace(**kwargs)
|
||||
return URL(components.geturl())
|
||||
|
||||
|
||||
# Type annotations for valid `__init__` values to QueryParams and Headers.
|
||||
StrPairs = typing.Sequence[typing.Tuple[str, str]]
|
||||
BytesPairs = typing.List[typing.Tuple[bytes, bytes]]
|
||||
StrDict = typing.Mapping[str, str]
|
||||
|
||||
|
||||
class QueryParams(typing.Mapping[str, str]):
|
||||
class QueryParams(StrDict):
|
||||
"""
|
||||
An immutable multidict.
|
||||
"""
|
||||
|
@ -79,42 +80,42 @@ class QueryParams(typing.Mapping[str, str]):
|
|||
self._dict = {k: v for k, v in reversed(items)}
|
||||
self._list = items
|
||||
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
def getlist(self, key: typing.Any) -> typing.List[str]:
|
||||
return [item_value for item_key, item_value in self._list if item_key == key]
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> typing.List[str]: # type: ignore
|
||||
return [key for key, value in self._list]
|
||||
|
||||
def values(self):
|
||||
def values(self) -> typing.List[str]: # type: ignore
|
||||
return [value for key, value in self._list]
|
||||
|
||||
def items(self):
|
||||
def items(self) -> StrPairs: # type: ignore
|
||||
return list(self._list)
|
||||
|
||||
def get(self, key, default=None):
|
||||
def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
|
||||
if key in self._dict:
|
||||
return self._dict[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: typing.Any) -> str:
|
||||
return self._dict[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self._list)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, QueryParams):
|
||||
other = QueryParams(other)
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "QueryParams(%s)" % repr(self._list)
|
||||
|
||||
|
||||
|
@ -123,9 +124,9 @@ class Headers(typing.Mapping[str, str]):
|
|||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, raw_headers=None) -> None:
|
||||
def __init__(self, raw_headers: typing.Optional[BytesPairs] = None) -> None:
|
||||
if raw_headers is None:
|
||||
self._list = []
|
||||
self._list = [] # type: BytesPairs
|
||||
else:
|
||||
for header_key, header_value in raw_headers:
|
||||
assert isinstance(header_key, bytes)
|
||||
|
@ -133,19 +134,19 @@ class Headers(typing.Mapping[str, str]):
|
|||
assert header_key == header_key.lower()
|
||||
self._list = raw_headers
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> typing.List[str]: # type: ignore
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def values(self):
|
||||
def values(self) -> typing.List[str]: # type: ignore
|
||||
return [value.decode("latin-1") for key, value in self._list]
|
||||
|
||||
def items(self):
|
||||
def items(self) -> StrPairs: # type: ignore
|
||||
return [
|
||||
(key.decode("latin-1"), value.decode("latin-1"))
|
||||
for key, value in self._list
|
||||
]
|
||||
|
||||
def get(self, key: str, default: str = None):
|
||||
def get(self, key: str, default: typing.Any = None) -> typing.Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
|
@ -159,40 +160,40 @@ class Headers(typing.Mapping[str, str]):
|
|||
if item_key == get_header_key
|
||||
]
|
||||
|
||||
def mutablecopy(self):
|
||||
def mutablecopy(self) -> "MutableHeaders":
|
||||
return MutableHeaders(self._list[:])
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
def __getitem__(self, key: str) -> str:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return header_value.decode("latin-1")
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: str):
|
||||
def __contains__(self, key: typing.Any) -> bool:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> typing.Iterator[typing.Any]:
|
||||
return iter(self.items())
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
if not isinstance(other, Headers):
|
||||
return False
|
||||
return sorted(self._list) == sorted(other._list)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%s)" % (self.__class__.__name__, repr(self.items()))
|
||||
|
||||
|
||||
class MutableHeaders(Headers):
|
||||
def __setitem__(self, key: str, value: str):
|
||||
def __setitem__(self, key: str, value: str) -> None:
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
Retains insertion order.
|
||||
|
@ -214,7 +215,7 @@ class MutableHeaders(Headers):
|
|||
else:
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str):
|
||||
def __delitem__(self, key: str) -> None:
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
|
@ -228,7 +229,7 @@ class MutableHeaders(Headers):
|
|||
for idx in reversed(pop_indexes):
|
||||
del (self._list[idx])
|
||||
|
||||
def setdefault(self, key: str, value: str):
|
||||
def setdefault(self, key: str, value: str) -> str:
|
||||
"""
|
||||
If the header `key` does not exist, then set it to `value`.
|
||||
Returns the header value.
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from starlette.datastructures import URL, Headers, QueryParams
|
||||
import typing
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from urllib.parse import unquote
|
||||
import json
|
||||
import typing
|
||||
|
||||
from starlette.datastructures import URL, Headers, QueryParams
|
||||
from starlette.types import Scope, Receive
|
||||
|
||||
|
||||
class ClientDisconnect(Exception):
|
||||
|
@ -10,19 +12,19 @@ class ClientDisconnect(Exception):
|
|||
|
||||
|
||||
class Request(Mapping):
|
||||
def __init__(self, scope, receive=None):
|
||||
def __init__(self, scope: Scope, receive: Receive = None) -> None:
|
||||
assert scope["type"] == "http"
|
||||
self._scope = scope
|
||||
self._receive = receive
|
||||
self._stream_consumed = False
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self._scope[key]
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
return iter(self._scope)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._scope)
|
||||
|
||||
@property
|
||||
|
@ -61,7 +63,7 @@ class Request(Mapping):
|
|||
self._query_params = QueryParams(query_string)
|
||||
return self._query_params
|
||||
|
||||
async def stream(self):
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
yield self._body
|
||||
return
|
||||
|
@ -82,7 +84,7 @@ class Request(Mapping):
|
|||
elif message["type"] == "http.disconnect":
|
||||
raise ClientDisconnect()
|
||||
|
||||
async def body(self):
|
||||
async def body(self) -> bytes:
|
||||
if not hasattr(self, "_body"):
|
||||
body = b""
|
||||
async for chunk in self.stream():
|
||||
|
@ -90,7 +92,7 @@ class Request(Mapping):
|
|||
self._body = body
|
||||
return self._body
|
||||
|
||||
async def json(self):
|
||||
async def json(self) -> typing.Any:
|
||||
if not hasattr(self, "_json"):
|
||||
body = await self.body()
|
||||
self._json = json.loads(body)
|
||||
|
|
Loading…
Reference in New Issue