From c0d6434eb79acae879f653aa04616a6572476158 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sat, 21 Jul 2018 18:07:16 -0400 Subject: [PATCH] httputil: Type-annotate all methods --- setup.cfg | 3 + tornado/escape.py | 18 +++- tornado/httputil.py | 202 +++++++++++++++++++++++++------------------- 3 files changed, 134 insertions(+), 89 deletions(-) diff --git a/setup.cfg b/setup.cfg index da6ebfab..c5fc0248 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,6 @@ python_version = 3.5 [mypy-tornado.util] disallow_untyped_defs = True + +[mypy-tornado.httputil] +disallow_untyped_defs = True diff --git a/tornado/escape.py b/tornado/escape.py index 382133f5..17754b36 100644 --- a/tornado/escape.py +++ b/tornado/escape.py @@ -142,8 +142,22 @@ def parse_qs_bytes(qs, keep_blank_values=False, strict_parsing=False): _UTF8_TYPES = (bytes, type(None)) -def utf8(value): - # type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None] +@typing.overload +def utf8(value: bytes) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: str) -> bytes: + pass + + +@typing.overload # noqa: F811 +def utf8(value: None) -> None: + pass + + +def utf8(value): # noqa: F811 """Converts a string argument to a byte string. If the argument is already a byte string or None, it is returned unchanged. diff --git a/tornado/httputil.py b/tornado/httputil.py index 0e498235..72f1eb6b 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -25,7 +25,7 @@ import copy import datetime import email.utils from http.client import responses -import http.cookies as Cookie +import http.cookies import numbers import re from ssl import SSLError @@ -42,7 +42,13 @@ from tornado.util import ObjectDict, unicode_type # Reference it so pyflakes doesn't complain. responses -import typing # noqa: F401 +import typing +from typing import (Tuple, Iterable, List, Mapping, Iterator, Dict, Union, Optional, + Awaitable, Generator) + +if typing.TYPE_CHECKING: + from typing import Deque # noqa + import unittest # noqa # RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line @@ -61,12 +67,12 @@ class _NormalizedHeaderCache(dict): >>> normalized_headers["coNtent-TYPE"] 'Content-Type' """ - def __init__(self, size): + def __init__(self, size: int) -> None: super(_NormalizedHeaderCache, self).__init__() self.size = size - self.queue = collections.deque() + self.queue = collections.deque() # type: Deque[str] - def __missing__(self, key): + def __missing__(self, key: str) -> str: normalized = "-".join([w.capitalize() for w in key.split("-")]) self[key] = normalized self.queue.append(key) @@ -110,7 +116,19 @@ class HTTPHeaders(collections.MutableMapping): Set-Cookie: A=B Set-Cookie: C=D """ - def __init__(self, *args, **kwargs): + @typing.overload + def __init__(self, __arg: Mapping[str, List[str]]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, *args: Tuple[str, str]) -> None: + pass + + @typing.overload # noqa: F811 + def __init__(self, **kwargs: str) -> None: + pass + + def __init__(self, *args: typing.Any, **kwargs: str) -> None: # noqa: F811 self._dict = {} # type: typing.Dict[str, str] self._as_list = {} # type: typing.Dict[str, typing.List[str]] self._last_key = None @@ -125,8 +143,7 @@ class HTTPHeaders(collections.MutableMapping): # new public methods - def add(self, name, value): - # type: (str, str) -> None + def add(self, name: str, value: str) -> None: """Adds a new value for the given key.""" norm_name = _normalized_headers[name] self._last_key = norm_name @@ -137,13 +154,12 @@ class HTTPHeaders(collections.MutableMapping): else: self[norm_name] = value - def get_list(self, name): + def get_list(self, name: str) -> List[str]: """Returns all values for the given header as a list.""" norm_name = _normalized_headers[name] return self._as_list.get(norm_name, []) - def get_all(self): - # type: () -> typing.Iterable[typing.Tuple[str, str]] + def get_all(self) -> Iterable[Tuple[str, str]]: """Returns an iterable of all (name, value) pairs. If a header has multiple values, multiple pairs will be @@ -153,7 +169,7 @@ class HTTPHeaders(collections.MutableMapping): for value in values: yield (name, value) - def parse_line(self, line): + def parse_line(self, line: str) -> None: """Updates the dictionary with a single header line. >>> h = HTTPHeaders() @@ -176,7 +192,7 @@ class HTTPHeaders(collections.MutableMapping): self.add(name, value.strip()) @classmethod - def parse(cls, headers): + def parse(cls, headers: str) -> 'HTTPHeaders': """Returns a dictionary from HTTP header text. >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n") @@ -197,27 +213,26 @@ class HTTPHeaders(collections.MutableMapping): # MutableMapping abstract method implementations. - def __setitem__(self, name, value): + def __setitem__(self, name: str, value: str) -> None: norm_name = _normalized_headers[name] self._dict[norm_name] = value self._as_list[norm_name] = [value] - def __getitem__(self, name): - # type: (str) -> str + def __getitem__(self, name: str) -> str: return self._dict[_normalized_headers[name]] - def __delitem__(self, name): + def __delitem__(self, name: str) -> None: norm_name = _normalized_headers[name] del self._dict[norm_name] del self._as_list[norm_name] - def __len__(self): + def __len__(self) -> int: return len(self._dict) - def __iter__(self): + def __iter__(self) -> Iterator[typing.Any]: return iter(self._dict) - def copy(self): + def copy(self) -> 'HTTPHeaders': # defined in dict but not in MutableMapping. return HTTPHeaders(self) @@ -226,7 +241,7 @@ class HTTPHeaders(collections.MutableMapping): # the appearance that HTTPHeaders is a single container. __copy__ = copy - def __str__(self): + def __str__(self) -> str: lines = [] for name, value in self.get_all(): lines.append("%s: %s\n" % (name, value)) @@ -327,9 +342,13 @@ class HTTPServerRequest(object): .. versionchanged:: 4.0 Moved from ``tornado.httpserver.HTTPRequest``. """ - def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None, - body=None, host=None, files=None, connection=None, - start_line=None, server_connection=None): + path = None # type: str + query = None # type: str + + def __init__(self, method: str=None, uri: str=None, version: str="HTTP/1.0", + headers: HTTPHeaders=None, body: bytes=None, host: str=None, + files: Dict[str, 'HTTPFile']=None, connection: 'HTTPConnection'=None, + start_line: 'RequestStartLine'=None, server_connection: object=None) -> None: if start_line is not None: method, uri, version = start_line self.method = method @@ -351,16 +370,17 @@ class HTTPServerRequest(object): self._start_time = time.time() self._finish_time = None - self.path, sep, self.query = uri.partition('?') + if uri is not None: + self.path, sep, self.query = uri.partition('?') self.arguments = parse_qs_bytes(self.query, keep_blank_values=True) self.query_arguments = copy.deepcopy(self.arguments) - self.body_arguments = {} + self.body_arguments = {} # type: Dict[str, List[bytes]] @property - def cookies(self): - """A dictionary of Cookie.Morsel objects.""" + def cookies(self) -> Dict[str, http.cookies.Morsel]: + """A dictionary of ``http.cookies.Morsel`` objects.""" if not hasattr(self, "_cookies"): - self._cookies = Cookie.SimpleCookie() + self._cookies = http.cookies.SimpleCookie() if "Cookie" in self.headers: try: parsed = parse_cookie(self.headers["Cookie"]) @@ -377,18 +397,18 @@ class HTTPServerRequest(object): pass return self._cookies - def full_url(self): + def full_url(self) -> str: """Reconstructs the full URL for this request.""" return self.protocol + "://" + self.host + self.uri - def request_time(self): + def request_time(self) -> float: """Returns the amount of time it took for this request to execute.""" if self._finish_time is None: return time.time() - self._start_time else: return self._finish_time - self._start_time - def get_ssl_certificate(self, binary_form=False): + def get_ssl_certificate(self, binary_form: bool=False) -> Union[None, Dict, bytes]: """Returns the client's SSL certificate, if any. To use client certificates, the HTTPServer's @@ -408,12 +428,15 @@ class HTTPServerRequest(object): http://docs.python.org/library/ssl.html#sslsocket-objects """ try: - return self.connection.stream.socket.getpeercert( + if self.connection is None: + return None + # TODO: add a method to HTTPConnection for this so it can work with HTTP/2 + return self.connection.stream.socket.getpeercert( # type: ignore binary_form=binary_form) except SSLError: return None - def _parse_body(self): + def _parse_body(self) -> None: parse_body_arguments( self.headers.get("Content-Type", ""), self.body, self.body_arguments, self.files, @@ -422,7 +445,7 @@ class HTTPServerRequest(object): for k, v in self.body_arguments.items(): self.arguments.setdefault(k, []).extend(v) - def __repr__(self): + def __repr__(self) -> str: attrs = ("protocol", "host", "method", "uri", "version", "remote_ip") args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs]) return "%s(%s)" % (self.__class__.__name__, args) @@ -450,7 +473,8 @@ class HTTPServerConnectionDelegate(object): .. versionadded:: 4.0 """ - def start_request(self, server_conn, request_conn): + def start_request(self, server_conn: object, + request_conn: 'HTTPConnection')-> 'HTTPMessageDelegate': """This method is called by the server when a new request has started. :arg server_conn: is an opaque object representing the long-lived @@ -462,7 +486,7 @@ class HTTPServerConnectionDelegate(object): """ raise NotImplementedError() - def on_close(self, server_conn): + def on_close(self, server_conn: object) -> None: """This method is called when a connection has been closed. :arg server_conn: is a server connection that has previously been @@ -476,7 +500,8 @@ class HTTPMessageDelegate(object): .. versionadded:: 4.0 """ - def headers_received(self, start_line, headers): + def headers_received(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], + headers: HTTPHeaders) -> Optional[Awaitable[None]]: """Called when the HTTP headers have been received and parsed. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine` @@ -491,18 +516,18 @@ class HTTPMessageDelegate(object): """ pass - def data_received(self, chunk): + def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]: """Called when a chunk of data has been received. May return a `.Future` for flow control. """ pass - def finish(self): + def finish(self) -> None: """Called after the last chunk of data has been received.""" pass - def on_connection_close(self): + def on_connection_close(self) -> None: """Called if the connection is closed without finishing the request. If ``headers_received`` is called, either ``finish`` or @@ -516,7 +541,8 @@ class HTTPConnection(object): .. versionadded:: 4.0 """ - def write_headers(self, start_line, headers, chunk=None): + def write_headers(self, start_line: Union['RequestStartLine', 'ResponseStartLine'], + headers: HTTPHeaders, chunk: bytes=None) -> Awaitable[None]: """Write an HTTP header block. :arg start_line: a `.RequestStartLine` or `.ResponseStartLine`. @@ -524,11 +550,10 @@ class HTTPConnection(object): :arg chunk: the first (optional) chunk of data. This is an optimization so that small responses can be written in the same call as their headers. - :arg callback: a callback to be run when the write is complete. The ``version`` field of ``start_line`` is ignored. - Returns a `.Future` if no callback is given. + Returns an awaitable for flow control. .. versionchanged:: 6.0 @@ -536,11 +561,10 @@ class HTTPConnection(object): """ raise NotImplementedError() - def write(self, chunk): + def write(self, chunk: bytes) -> Awaitable[None]: """Writes a chunk of body data. - The callback will be run when the write is complete. If no callback - is given, returns a Future. + Returns an awaitable for flow control. .. versionchanged:: 6.0 @@ -548,13 +572,14 @@ class HTTPConnection(object): """ raise NotImplementedError() - def finish(self): + def finish(self) -> None: """Indicates that the last body data has been written. """ raise NotImplementedError() -def url_concat(url, args): +def url_concat(url: str, args: Union[Dict[str, str], List[Tuple[str, str]], + Tuple[Tuple[str, str], ...]]) -> str: """Concatenate url and arguments regardless of whether url has existing query parameters. @@ -605,7 +630,7 @@ class HTTPFile(ObjectDict): pass -def _parse_request_range(range_header): +def _parse_request_range(range_header: str) -> Optional[Tuple[Optional[int], Optional[int]]]: """Parses a Range header. Returns either ``None`` or tuple ``(start, end)``. @@ -654,7 +679,7 @@ def _parse_request_range(range_header): return (start, end) -def _get_content_range(start, end, total): +def _get_content_range(start: Optional[int], end: Optional[int], total: int) -> str: """Returns a suitable Content-Range header: >>> print(_get_content_range(None, 1, 4)) @@ -669,14 +694,15 @@ def _get_content_range(start, end, total): return "bytes %s-%s/%s" % (start, end, total) -def _int_or_none(val): +def _int_or_none(val: str) -> Optional[int]: val = val.strip() if val == "": return None return int(val) -def parse_body_arguments(content_type, body, arguments, files, headers=None): +def parse_body_arguments(content_type: str, body: bytes, arguments: Dict[str, List[bytes]], + files: Dict[str, HTTPFile], headers: HTTPHeaders=None) -> None: """Parses a form request body. Supports ``application/x-www-form-urlencoded`` and @@ -712,7 +738,8 @@ def parse_body_arguments(content_type, body, arguments, files, headers=None): gen_log.warning("Invalid multipart/form-data: %s", e) -def parse_multipart_form_data(boundary, data, arguments, files): +def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str, List[bytes]], + files: Dict[str, HTTPFile]) -> None: """Parses a ``multipart/form-data`` body. The ``boundary`` and ``data`` parameters are both byte strings. @@ -763,7 +790,7 @@ def parse_multipart_form_data(boundary, data, arguments, files): arguments.setdefault(name, []).append(value) -def format_timestamp(ts): +def format_timestamp(ts: Union[numbers.Real, tuple, time.struct_time, datetime.datetime]) -> str: """Formats a timestamp in the format used by HTTP. The argument may be a numeric timestamp as returned by `time.time`, @@ -774,21 +801,21 @@ def format_timestamp(ts): 'Sun, 27 Jan 2013 18:43:20 GMT' """ if isinstance(ts, numbers.Real): - pass + time_float = typing.cast(float, ts) elif isinstance(ts, (tuple, time.struct_time)): - ts = calendar.timegm(ts) + time_float = calendar.timegm(ts) elif isinstance(ts, datetime.datetime): - ts = calendar.timegm(ts.utctimetuple()) + time_float = calendar.timegm(ts.utctimetuple()) else: raise TypeError("unknown timestamp type: %r" % ts) - return email.utils.formatdate(ts, usegmt=True) + return email.utils.formatdate(time_float, usegmt=True) RequestStartLine = collections.namedtuple( 'RequestStartLine', ['method', 'path', 'version']) -def parse_request_start_line(line): +def parse_request_start_line(line: str) -> RequestStartLine: """Returns a (method, path, version) tuple for an HTTP 1.x request line. The response is a `collections.namedtuple`. @@ -812,7 +839,7 @@ ResponseStartLine = collections.namedtuple( 'ResponseStartLine', ['version', 'code', 'reason']) -def parse_response_start_line(line): +def parse_response_start_line(line: str) -> ResponseStartLine: """Returns a (version, code, reason) tuple for an HTTP 1.x response line. The response is a `collections.namedtuple`. @@ -835,7 +862,7 @@ def parse_response_start_line(line): # RFC 2231/5987 format. -def _parseparam(s): +def _parseparam(s: str) -> Generator[str, None, None]: while s[:1] == ';': s = s[1:] end = s.find(';') @@ -848,7 +875,7 @@ def _parseparam(s): s = s[end:] -def _parse_header(line): +def _parse_header(line: str) -> Tuple[str, Dict[str, str]]: r"""Parse a Content-type like header. Return the main content-type and a dictionary of options. @@ -872,18 +899,18 @@ def _parse_header(line): name = p[:i].strip().lower() value = p[i + 1:].strip() params.append((name, native_str(value))) - params = email.utils.decode_params(params) - params.pop(0) # get rid of the dummy again + decoded_params = email.utils.decode_params(params) + decoded_params.pop(0) # get rid of the dummy again pdict = {} - for name, value in params: - value = email.utils.collapse_rfc2231_value(value) + for name, decoded_value in decoded_params: + value = email.utils.collapse_rfc2231_value(decoded_value) if len(value) >= 2 and value[0] == '"' and value[-1] == '"': value = value[1:-1] pdict[name] = value return key, pdict -def _encode_header(key, pdict): +def _encode_header(key: str, pdict: Dict[str, str]) -> str: """Inverse of _parse_header. >>> _encode_header('permessage-deflate', @@ -903,7 +930,7 @@ def _encode_header(key, pdict): return '; '.join(out) -def encode_username_password(username, password): +def encode_username_password(username: Union[str, bytes], password: Union[str, bytes]) -> bytes: """Encodes a username/password pair in the format used by HTTP auth. The return value is a byte string in the form ``username:password``. @@ -918,11 +945,12 @@ def encode_username_password(username, password): def doctests(): + # type: () -> unittest.TestSuite import doctest return doctest.DocTestSuite() -def split_host_and_port(netloc): +def split_host_and_port(netloc: str) -> Tuple[str, Optional[int]]: """Returns ``(host, port)`` tuple from ``netloc``. Returned ``port`` will be ``None`` if not present. @@ -932,14 +960,14 @@ def split_host_and_port(netloc): match = re.match(r'^(.+):(\d+)$', netloc) if match: host = match.group(1) - port = int(match.group(2)) + port = int(match.group(2)) # type: Optional[int] else: host = netloc port = None return (host, port) -def qs_to_qsl(qs): +def qs_to_qsl(qs: Dict[str, List[str]]) -> Iterable[Tuple[str, str]]: """Generator converting a result of ``parse_qs`` back to name-value pairs. .. versionadded:: 5.0 @@ -954,7 +982,7 @@ _QuotePatt = re.compile(r"[\\].") _nulljoin = ''.join -def _unquote_cookie(str): +def _unquote_cookie(s: str) -> str: """Handle double quotes and escaping in cookie values. This method is copied verbatim from the Python 3.5 standard @@ -963,29 +991,29 @@ def _unquote_cookie(str): """ # If there aren't any doublequotes, # then there can't be any special characters. See RFC 2109. - if str is None or len(str) < 2: - return str - if str[0] != '"' or str[-1] != '"': - return str + if s is None or len(s) < 2: + return s + if s[0] != '"' or s[-1] != '"': + return s # We have to assume that we must decode this string. # Down to work. # Remove the "s - str = str[1:-1] + s = s[1:-1] # Check for special sequences. Examples: # \012 --> \n # \" --> " # i = 0 - n = len(str) + n = len(s) res = [] while 0 <= i < n: - o_match = _OctalPatt.search(str, i) - q_match = _QuotePatt.search(str, i) + o_match = _OctalPatt.search(s, i) + q_match = _QuotePatt.search(s, i) if not o_match and not q_match: # Neither matched - res.append(str[i:]) + res.append(s[i:]) break # else: j = k = -1 @@ -994,17 +1022,17 @@ def _unquote_cookie(str): if q_match: k = q_match.start(0) if q_match and (not o_match or k < j): # QuotePatt matched - res.append(str[i:k]) - res.append(str[k + 1]) + res.append(s[i:k]) + res.append(s[k + 1]) i = k + 2 else: # OctalPatt matched - res.append(str[i:j]) - res.append(chr(int(str[j + 1:j + 4], 8))) + res.append(s[i:j]) + res.append(chr(int(s[j + 1:j + 4], 8))) i = j + 4 return _nulljoin(res) -def parse_cookie(cookie): +def parse_cookie(cookie: str) -> Dict[str, str]: """Parse a ``Cookie`` HTTP header into a dict of name/value pairs. This function attempts to mimic browser cookie parsing behavior;