httputil: Type-annotate all methods

This commit is contained in:
Ben Darnell 2018-07-21 18:07:16 -04:00
parent 27a726068b
commit c0d6434eb7
3 changed files with 134 additions and 89 deletions

View File

@ -3,3 +3,6 @@ python_version = 3.5
[mypy-tornado.util]
disallow_untyped_defs = True
[mypy-tornado.httputil]
disallow_untyped_defs = True

View File

@ -142,8 +142,22 @@ def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False):
_UTF8_TYPES = (bytes, type(None))
def utf8(value):
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
@typing.overload
def utf8(value: bytes) -> bytes:
pass
@typing.overload # noqa: F811
def utf8(value: str) -> bytes:
pass
@typing.overload # noqa: F811
def utf8(value: None) -> None:
pass
def utf8(value): # noqa: F811
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.

View File

@ -25,7 +25,7 @@ import copy
import datetime
import email.utils
from http.client import responses
import http.cookies as Cookie
import http.cookies
import numbers
import re
from ssl import SSLError
@ -42,7 +42,13 @@ from tornado.util import ObjectDict, unicode_type
# Reference it so pyflakes doesn't complain.
responses
import typing # noqa: F401
import typing
from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional,
Awaitable, Generator)
if typing.TYPE_CHECKING:
from typing import Deque # noqa
import unittest # noqa
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
@ -61,12 +67,12 @@ class _NormalizedHeaderCache(dict):
>>> normalized_headers["coNtent-TYPE"]
'Content-Type'
"""
def __init__(self, size):
def __init__(self, size: int) -> None:
super(_NormalizedHeaderCache, self).__init__()
self.size = size
self.queue = collections.deque()
self.queue = collections.deque() # type: Deque[str]
def __missing__(self, key):
def __missing__(self, key: str) -> str:
normalized = "-".join([w.capitalize() for w in key.split("-")])
self[key] = normalized
self.queue.append(key)
@ -110,7 +116,19 @@ class HTTPHeaders(collections.MutableMapping):
Set-Cookie: A=B
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
@typing.overload
def __init__(self, __arg: Mapping[str, List[str]]) -> None:
pass
@typing.overload # noqa: F811
def __init__(self, *args: Tuple[str, str]) -> None:
pass
@typing.overload # noqa: F811
def __init__(self, **kwargs: str) -> None:
pass
def __init__(self, *args: typing.Any, **kwargs: str) -> None: # noqa: F811
self._dict = {} # type: typing.Dict[str, str]
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
self._last_key = None
@ -125,8 +143,7 @@ class HTTPHeaders(collections.MutableMapping):
# new public methods
def add(self, name, value):
# type: (str, str) -> None
def add(self, name: str, value: str) -> None:
"""Adds a new value for the given key."""
norm_name = _normalized_headers[name]
self._last_key = norm_name
@ -137,13 +154,12 @@ class HTTPHeaders(collections.MutableMapping):
else:
self[norm_name] = value
def get_list(self, name):
def get_list(self, name: str) -> List[str]:
"""Returns all values for the given header as a list."""
norm_name = _normalized_headers[name]
return self._as_list.get(norm_name, [])
def get_all(self):
# type: () -> typing.Iterable[typing.Tuple[str, str]]
def get_all(self) -> Iterable[Tuple[str, str]]:
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
@ -153,7 +169,7 @@ class HTTPHeaders(collections.MutableMapping):
for value in values:
yield (name, value)
def parse_line(self, line):
def parse_line(self, line: str) -> None:
"""Updates the dictionary with a single header line.
>>> h = HTTPHeaders()
@ -176,7 +192,7 @@ class HTTPHeaders(collections.MutableMapping):
self.add(name, value.strip())
@classmethod
def parse(cls, headers):
def parse(cls, headers: str) -> 'HTTPHeaders':
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
@ -197,27 +213,26 @@ class HTTPHeaders(collections.MutableMapping):
# MutableMapping abstract method implementations.
def __setitem__(self, name, value):
def __setitem__(self, name: str, value: str) -> None:
norm_name = _normalized_headers[name]
self._dict[norm_name] = value
self._as_list[norm_name] = [value]
def __getitem__(self, name):
# type: (str) -> str
def __getitem__(self, name: str) -> str:
return self._dict[_normalized_headers[name]]
def __delitem__(self, name):
def __delitem__(self, name: str) -> None:
norm_name = _normalized_headers[name]
del self._dict[norm_name]
del self._as_list[norm_name]
def __len__(self):
def __len__(self) -> int:
return len(self._dict)
def __iter__(self):
def __iter__(self) -> Iterator[typing.Any]:
return iter(self._dict)
def copy(self):
def copy(self) -> 'HTTPHeaders':
# defined in dict but not in MutableMapping.
return HTTPHeaders(self)
@ -226,7 +241,7 @@ class HTTPHeaders(collections.MutableMapping):
# the appearance that HTTPHeaders is a single container.
__copy__ = copy
def __str__(self):
def __str__(self) -> str:
lines = []
for name, value in self.get_all():
lines.append("%s: %s\n" % (name, value))
@ -327,9 +342,13 @@ class HTTPServerRequest(object):
.. versionchanged:: 4.0
Moved from ``tornado.httpserver.HTTPRequest``.
"""
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
body=None, host=None, files=None, connection=None,
start_line=None, server_connection=None):
path = None # type: str
query = None # type: str
def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0",
headers: HTTPHeaders=None, body: bytes=None, host: str=None,
files: Dict[str, 'HTTPFile']=None, connection: 'HTTPConnection'=None,
start_line: 'RequestStartLine'=None, server_connection: object=None) -> None:
if start_line is not None:
method, uri, version = start_line
self.method = method
@ -351,16 +370,17 @@ class HTTPServerRequest(object):
self._start_time = time.time()
self._finish_time = None
self.path, sep, self.query = uri.partition('?')
if uri is not None:
self.path, sep, self.query = uri.partition('?')
self.arguments = parse_qs_bytes(self.query, keep_blank_values=True)
self.query_arguments = copy.deepcopy(self.arguments)
self.body_arguments = {}
self.body_arguments = {} # type: Dict[str, List[bytes]]
@property
def cookies(self):
"""A dictionary of Cookie.Morsel objects."""
def cookies(self) -> Dict[str, http.cookies.Morsel]:
"""A dictionary of ``http.cookies.Morsel`` objects."""
if not hasattr(self, "_cookies"):
self._cookies = Cookie.SimpleCookie()
self._cookies = http.cookies.SimpleCookie()
if "Cookie" in self.headers:
try:
parsed = parse_cookie(self.headers["Cookie"])
@ -377,18 +397,18 @@ class HTTPServerRequest(object):
pass
return self._cookies
def full_url(self):
def full_url(self) -> str:
"""Reconstructs the full URL for this request."""
return self.protocol + "://" + self.host + self.uri
def request_time(self):
def request_time(self) -> float:
"""Returns the amount of time it took for this request to execute."""
if self._finish_time is None:
return time.time() - self._start_time
else:
return self._finish_time - self._start_time
def get_ssl_certificate(self, binary_form=False):
def get_ssl_certificate(self, binary_form: bool=False) -> Union[None, Dict, bytes]:
"""Returns the client's SSL certificate, if any.
To use client certificates, the HTTPServer's
@ -408,12 +428,15 @@ class HTTPServerRequest(object):
http://docs.python.org/library/ssl.html#sslsocket-objects
"""
try:
return self.connection.stream.socket.getpeercert(
if self.connection is None:
return None
# TODO: add a method to HTTPConnection for this so it can work with HTTP/2
return self.connection.stream.socket.getpeercert( # type: ignore
binary_form=binary_form)
except SSLError:
return None
def _parse_body(self):
def _parse_body(self) -> None:
parse_body_arguments(
self.headers.get("Content-Type", ""), self.body,
self.body_arguments, self.files,
@ -422,7 +445,7 @@ class HTTPServerRequest(object):
for k, v in self.body_arguments.items():
self.arguments.setdefault(k, []).extend(v)
def __repr__(self):
def __repr__(self) -> str:
attrs = ("protocol", "host", "method", "uri", "version", "remote_ip")
args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
return "%s(%s)" % (self.__class__.__name__, args)
@ -450,7 +473,8 @@ class HTTPServerConnectionDelegate(object):
.. versionadded:: 4.0
"""
def start_request(self, server_conn, request_conn):
def start_request(self, server_conn: object,
request_conn: 'HTTPConnection')-> 'HTTPMessageDelegate':
"""This method is called by the server when a new request has started.
:arg server_conn: is an opaque object representing the long-lived
@ -462,7 +486,7 @@ class HTTPServerConnectionDelegate(object):
"""
raise NotImplementedError()
def on_close(self, server_conn):
def on_close(self, server_conn: object) -> None:
"""This method is called when a connection has been closed.
:arg server_conn: is a server connection that has previously been
@ -476,7 +500,8 @@ class HTTPMessageDelegate(object):
.. versionadded:: 4.0
"""
def headers_received(self, start_line, headers):
def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
headers: HTTPHeaders) -> Optional[Awaitable[None]]:
"""Called when the HTTP headers have been received and parsed.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`
@ -491,18 +516,18 @@ class HTTPMessageDelegate(object):
"""
pass
def data_received(self, chunk):
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
"""Called when a chunk of data has been received.
May return a `.Future` for flow control.
"""
pass
def finish(self):
def finish(self) -> None:
"""Called after the last chunk of data has been received."""
pass
def on_connection_close(self):
def on_connection_close(self) -> None:
"""Called if the connection is closed without finishing the request.
If ``headers_received`` is called, either ``finish`` or
@ -516,7 +541,8 @@ class HTTPConnection(object):
.. versionadded:: 4.0
"""
def write_headers(self, start_line, headers, chunk=None):
def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'],
headers: HTTPHeaders, chunk: bytes=None) -> Awaitable[None]:
"""Write an HTTP header block.
:arg start_line: a `.RequestStartLine` or `.ResponseStartLine`.
@ -524,11 +550,10 @@ class HTTPConnection(object):
:arg chunk: the first (optional) chunk of data. This is an optimization
so that small responses can be written in the same call as their
headers.
:arg callback: a callback to be run when the write is complete.
The ``version`` field of ``start_line`` is ignored.
Returns a `.Future` if no callback is given.
Returns an awaitable for flow control.
.. versionchanged:: 6.0
@ -536,11 +561,10 @@ class HTTPConnection(object):
"""
raise NotImplementedError()
def write(self, chunk):
def write(self, chunk: bytes) -> Awaitable[None]:
"""Writes a chunk of body data.
The callback will be run when the write is complete. If no callback
is given, returns a Future.
Returns an awaitable for flow control.
.. versionchanged:: 6.0
@ -548,13 +572,14 @@ class HTTPConnection(object):
"""
raise NotImplementedError()
def finish(self):
def finish(self) -> None:
"""Indicates that the last body data has been written.
"""
raise NotImplementedError()
def url_concat(url, args):
def url_concat(url: str, args: Union[Dict[str, str], List[Tuple[str, str]],
Tuple[Tuple[str, str], ...]]) -> str:
"""Concatenate url and arguments regardless of whether
url has existing query parameters.
@ -605,7 +630,7 @@ class HTTPFile(ObjectDict):
pass
def _parse_request_range(range_header):
def _parse_request_range(range_header: str) -> Optional[Tuple[Optional[int], Optional[int]]]:
"""Parses a Range header.
Returns either ``None`` or tuple ``(start, end)``.
@ -654,7 +679,7 @@ def _parse_request_range(range_header):
return (start, end)
def _get_content_range(start, end, total):
def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str:
"""Returns a suitable Content-Range header:
>>> print(_get_content_range(None, 1, 4))
@ -669,14 +694,15 @@ def _get_content_range(start, end, total):
return "bytes %s-%s/%s" % (start, end, total)
def _int_or_none(val):
def _int_or_none(val: str) -> Optional[int]:
val = val.strip()
if val == "":
return None
return int(val)
def parse_body_arguments(content_type, body, arguments, files, headers=None):
def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, List[bytes]],
files: Dict[str, HTTPFile], headers: HTTPHeaders=None) -> None:
"""Parses a form request body.
Supports ``application/x-www-form-urlencoded`` and
@ -712,7 +738,8 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None):
gen_log.warning("Invalid multipart/form-data: %s", e)
def parse_multipart_form_data(boundary, data, arguments, files):
def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, List[bytes]],
files: Dict[str, HTTPFile]) -> None:
"""Parses a ``multipart/form-data`` body.
The ``boundary`` and ``data`` parameters are both byte strings.
@ -763,7 +790,7 @@ def parse_multipart_form_data(boundary, data, arguments, files):
arguments.setdefault(name, []).append(value)
def format_timestamp(ts):
def format_timestamp(ts: Union[numbers.Real, tuple, time.struct_time, datetime.datetime]) -> str:
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
@ -774,21 +801,21 @@ def format_timestamp(ts):
'Sun, 27 Jan 2013 18:43:20 GMT'
"""
if isinstance(ts, numbers.Real):
pass
time_float = typing.cast(float, ts)
elif isinstance(ts, (tuple, time.struct_time)):
ts = calendar.timegm(ts)
time_float = calendar.timegm(ts)
elif isinstance(ts, datetime.datetime):
ts = calendar.timegm(ts.utctimetuple())
time_float = calendar.timegm(ts.utctimetuple())
else:
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(ts, usegmt=True)
return email.utils.formatdate(time_float, usegmt=True)
RequestStartLine = collections.namedtuple(
'RequestStartLine', ['method', 'path', 'version'])
def parse_request_start_line(line):
def parse_request_start_line(line: str) -> RequestStartLine:
"""Returns a (method, path, version) tuple for an HTTP 1.x request line.
The response is a `collections.namedtuple`.
@ -812,7 +839,7 @@ ResponseStartLine = collections.namedtuple(
'ResponseStartLine', ['version', 'code', 'reason'])
def parse_response_start_line(line):
def parse_response_start_line(line: str) -> ResponseStartLine:
"""Returns a (version, code, reason) tuple for an HTTP 1.x response line.
The response is a `collections.namedtuple`.
@ -835,7 +862,7 @@ def parse_response_start_line(line):
# RFC 2231/5987 format.
def _parseparam(s):
def _parseparam(s: str) -> Generator[str, None, None]:
while s[:1] == ';':
s = s[1:]
end = s.find(';')
@ -848,7 +875,7 @@ def _parseparam(s):
s = s[end:]
def _parse_header(line):
def _parse_header(line: str) -> Tuple[str, Dict[str, str]]:
r"""Parse a Content-type like header.
Return the main content-type and a dictionary of options.
@ -872,18 +899,18 @@ def _parse_header(line):
name = p[:i].strip().lower()
value = p[i + 1:].strip()
params.append((name, native_str(value)))
params = email.utils.decode_params(params)
params.pop(0) # get rid of the dummy again
decoded_params = email.utils.decode_params(params)
decoded_params.pop(0) # get rid of the dummy again
pdict = {}
for name, value in params:
value = email.utils.collapse_rfc2231_value(value)
for name, decoded_value in decoded_params:
value = email.utils.collapse_rfc2231_value(decoded_value)
if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
value = value[1:-1]
pdict[name] = value
return key, pdict
def _encode_header(key, pdict):
def _encode_header(key: str, pdict: Dict[str, str]) -> str:
"""Inverse of _parse_header.
>>> _encode_header('permessage-deflate',
@ -903,7 +930,7 @@ def _encode_header(key, pdict):
return '; '.join(out)
def encode_username_password(username, password):
def encode_username_password(username: Union[str, bytes], password: Union[str, bytes]) -> bytes:
"""Encodes a username/password pair in the format used by HTTP auth.
The return value is a byte string in the form ``username:password``.
@ -918,11 +945,12 @@ def encode_username_password(username, password):
def doctests():
# type: () -> unittest.TestSuite
import doctest
return doctest.DocTestSuite()
def split_host_and_port(netloc):
def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]:
"""Returns ``(host, port)`` tuple from ``netloc``.
Returned ``port`` will be ``None`` if not present.
@ -932,14 +960,14 @@ def split_host_and_port(netloc):
match = re.match(r'^(.+):(\d+)$', netloc)
if match:
host = match.group(1)
port = int(match.group(2))
port = int(match.group(2)) # type: Optional[int]
else:
host = netloc
port = None
return (host, port)
def qs_to_qsl(qs):
def qs_to_qsl(qs: Dict[str, List[str]]) -> Iterable[Tuple[str, str]]:
"""Generator converting a result of ``parse_qs`` back to name-value pairs.
.. versionadded:: 5.0
@ -954,7 +982,7 @@ _QuotePatt = re.compile(r"[\\].")
_nulljoin = ''.join
def _unquote_cookie(str):
def _unquote_cookie(s: str) -> str:
"""Handle double quotes and escaping in cookie values.
This method is copied verbatim from the Python 3.5 standard
@ -963,29 +991,29 @@ def _unquote_cookie(str):
"""
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if str is None or len(str) < 2:
return str
if str[0] != '"' or str[-1] != '"':
return str
if s is None or len(s) < 2:
return s
if s[0] != '"' or s[-1] != '"':
return s
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
str = str[1:-1]
s = s[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
n = len(str)
n = len(s)
res = []
while 0 <= i < n:
o_match = _OctalPatt.search(str, i)
q_match = _QuotePatt.search(str, i)
o_match = _OctalPatt.search(s, i)
q_match = _QuotePatt.search(s, i)
if not o_match and not q_match: # Neither matched
res.append(str[i:])
res.append(s[i:])
break
# else:
j = k = -1
@ -994,17 +1022,17 @@ def _unquote_cookie(str):
if q_match:
k = q_match.start(0)
if q_match and (not o_match or k < j): # QuotePatt matched
res.append(str[i:k])
res.append(str[k + 1])
res.append(s[i:k])
res.append(s[k + 1])
i = k + 2
else: # OctalPatt matched
res.append(str[i:j])
res.append(chr(int(str[j + 1:j + 4], 8)))
res.append(s[i:j])
res.append(chr(int(s[j + 1:j + 4], 8)))
i = j + 4
return _nulljoin(res)
def parse_cookie(cookie):
def parse_cookie(cookie: str) -> Dict[str, str]:
"""Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
This function attempts to mimic browser cookie parsing behavior;