*: Minimally type-annotate remaining tests

Streamline setup.cfg. Includes a few code changes necessary to get
tests typechecking.
This commit is contained in:
Ben Darnell 2018-09-16 16:55:56 -04:00
parent fe8d406cb8
commit bc4db0d02c
13 changed files with 89 additions and 160 deletions

102
setup.cfg
View File

@ -4,106 +4,14 @@ license_file = LICENSE
[mypy]
python_version = 3.5
[mypy-tornado.util]
[mypy-tornado.*,tornado.platform.*]
disallow_untyped_defs = True
[mypy-tornado.httputil]
disallow_untyped_defs = True
[mypy-tornado.escape]
disallow_untyped_defs = True
[mypy-tornado.concurrent]
disallow_untyped_defs = True
[mypy-tornado.gen]
disallow_untyped_defs = True
[mypy-tornado.http1connection]
disallow_untyped_defs = True
[mypy-tornado.httpserver]
disallow_untyped_defs = True
[mypy-tornado.ioloop]
disallow_untyped_defs = True
[mypy-tornado.iostream]
disallow_untyped_defs = True
[mypy-tornado.locale]
disallow_untyped_defs = True
[mypy-tornado.log]
disallow_untyped_defs = True
[mypy-tornado.netutil]
disallow_untyped_defs = True
[mypy-tornado.options]
disallow_untyped_defs = True
[mypy-tornado.platform.*]
disallow_untyped_defs = True
[mypy-tornado.tcpclient]
disallow_untyped_defs = True
[mypy-tornado.tcpserver]
disallow_untyped_defs = True
[mypy-tornado.testing]
disallow_untyped_defs = True
[mypy-tornado.auth,tornado.autoreload,tornado.curl_httpclient,tornado.httpclient,tornado.locks,tornado.process,tornado.queues,tornado.routing,tornado.simple_httpclient,tornado.template,tornado.web,tornado.websocket,tornado.wsgi]
disallow_untyped_defs = False
# It's generally too tedious to require type annotations in tests, but
# we do want to type check them as much as type inference allows.
[mypy-tornado.test.util_test]
check_untyped_defs = True
[mypy-tornado.test.httputil_test]
check_untyped_defs = True
[mypy-tornado.test.escape_test]
check_untyped_defs = True
[mypy-tornado.test.asyncio_test]
check_untyped_defs = True
[mypy-tornado.test.concurrent_test]
check_untyped_defs = True
[mypy-tornado.test.gen_test]
check_untyped_defs = True
[mypy-tornado.test.http1connection_test]
check_untyped_defs = True
[mypy-tornado.test.httpserver_test]
check_untyped_defs = True
[mypy-tornado.test.ioloop_test]
check_untyped_defs = True
[mypy-tornado.test.iostream_test]
check_untyped_defs = True
[mypy-tornado.test.locale_test]
check_untyped_defs = True
[mypy-tornado.test.log_test]
check_untyped_defs = True
[mypy-tornado.test.netutil_test]
check_untyped_defs = True
[mypy-tornado.test.options_test]
check_untyped_defs = True
[mypy-tornado.test.tcpclient_test]
check_untyped_defs = True
[mypy-tornado.test.tcpserver_test]
check_untyped_defs = True
[mypy-tornado.test.testing_test]
[mypy-tornado.test.*]
disallow_untyped_defs = False
check_untyped_defs = True

View File

@ -26,7 +26,6 @@ import datetime
import email.utils
from http.client import responses
import http.cookies
import numbers
import re
from ssl import SSLError
import time
@ -794,7 +793,7 @@ def parse_multipart_form_data(boundary: bytes, data: bytes, arguments: Dict[str,
arguments.setdefault(name, []).append(value)
def format_timestamp(ts: Union[numbers.Real, tuple, time.struct_time, datetime.datetime]) -> str:
def format_timestamp(ts: Union[int, float, tuple, time.struct_time, datetime.datetime]) -> str:
"""Formats a timestamp in the format used by HTTP.
The argument may be a numeric timestamp as returned by `time.time`,
@ -804,15 +803,15 @@ def format_timestamp(ts: Union[numbers.Real, tuple, time.struct_time, datetime.d
>>> format_timestamp(1359312200)
'Sun, 27 Jan 2013 18:43:20 GMT'
"""
if isinstance(ts, numbers.Real):
time_float = typing.cast(float, ts)
if isinstance(ts, (int, float)):
time_num = ts
elif isinstance(ts, (tuple, time.struct_time)):
time_float = calendar.timegm(ts)
time_num = calendar.timegm(ts)
elif isinstance(ts, datetime.datetime):
time_float = calendar.timegm(ts.utctimetuple())
time_num = calendar.timegm(ts.utctimetuple())
else:
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(time_float, usegmt=True)
return email.utils.formatdate(time_num, usegmt=True)
RequestStartLine = collections.namedtuple(

View File

@ -415,7 +415,8 @@ class Semaphore(_TimeoutGarbageCollector):
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
" 'with semaphore'")
__exit__ = __enter__
def __exit__(self, typ, value, traceback):
self.__enter__()
@gen.coroutine
def __aenter__(self):
@ -513,7 +514,8 @@ class Lock(object):
raise RuntimeError(
"Use Lock like 'with (yield lock)', not like 'with lock'")
__exit__ = __enter__
def __exit__(self, typ, value, tb):
self.__enter__()
@gen.coroutine
def __aenter__(self):

View File

@ -7,6 +7,7 @@ import threading
import datetime
from io import BytesIO
import time
import typing # noqa: F401
import unicodedata
import unittest
@ -161,7 +162,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = []
chunks = [] # type: typing.List[bytes]
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
@ -178,7 +179,7 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = []
chunks = [] # type: typing.List[bytes]
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
@ -209,7 +210,7 @@ Transfer-Encoding: chunked
""".replace(b"\n", b"\r\n"))
stream.close()
netutil.add_accept_handler(sock, accept_callback)
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
@ -374,7 +375,7 @@ X-XSS-Protection: 1;
""".replace(b"\n", b"\r\n"))
stream.close()
netutil.add_accept_handler(sock, accept_callback)
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.headers['X-XSS-Protection'], "1; mode=block")

View File

@ -11,6 +11,7 @@
# under the License.
from datetime import timedelta
import typing # noqa: F401
import unittest
from tornado import gen, locks
@ -21,7 +22,7 @@ from tornado.testing import gen_test, AsyncTestCase
class ConditionTest(AsyncTestCase):
def setUp(self):
super(ConditionTest, self).setUp()
self.history = []
self.history = [] # type: typing.List[typing.Union[int, str]]
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
@ -97,7 +98,7 @@ class ConditionTest(AsyncTestCase):
# Callbacks execute in the order they were registered.
self.assertEqual(
list(range(4)) + ['notify_all'],
list(range(4)) + ['notify_all'], # type: ignore
self.history)
@gen_test

View File

@ -208,7 +208,7 @@ class SubprocessTest(AsyncTestCase):
stdout=Subprocess.STREAM)
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
os.kill(subproc.pid, signal.SIGTERM) # type: ignore
try:
ret = self.wait(timeout=1.0)
except AssertionError:

View File

@ -105,7 +105,7 @@ SecondHandler = _get_named_handler("second_handler")
class CustomRouter(ReversibleRouter):
def __init__(self):
super(CustomRouter, self).__init__()
self.routes = {}
self.routes = {} # type: typing.Dict[str, typing.Any]
def add_routes(self, routes):
self.routes.update(routes)
@ -122,11 +122,12 @@ class CustomRouter(ReversibleRouter):
class CustomRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
router = CustomRouter()
class CustomApplication(Application):
def reverse_url(self, name, *args):
return router.reverse_url(name, *args)
router = CustomRouter()
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")

View File

@ -61,7 +61,8 @@ def all():
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
super(TornadoTextTestRunner, self).__init__(*args, stream=stderr, **kwargs)
kwargs['stream'] = stderr
super(TornadoTextTestRunner, self).__init__(*args, **kwargs)
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
@ -156,8 +157,10 @@ def main():
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))))
define('locale', type=str, default=None,
callback=lambda x: locale.setlocale(locale.LC_ALL, x))
def set_locale(x):
locale.setlocale(locale.LC_ALL, x)
define('locale', type=str, default=None, callback=set_locale)
log_counter = LogCounter()
add_parse_callback(
@ -167,7 +170,8 @@ def main():
# destructors) go directly to stderr instead of logging. Count
# anything written by anything but the test runner as an error.
orig_stderr = sys.stderr
sys.stderr = CountingStderr(orig_stderr)
counting_stderr = CountingStderr(orig_stderr)
sys.stderr = counting_stderr # type: ignore
import tornado.testing
kwargs = {}
@ -188,10 +192,10 @@ def main():
if (log_counter.info_count > 0 or
log_counter.warning_count > 0 or
log_counter.error_count > 0 or
sys.stderr.byte_count > 0):
counting_stderr.byte_count > 0):
logging.error("logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
log_counter.info_count, log_counter.warning_count,
log_counter.error_count, sys.stderr.byte_count)
log_counter.error_count, counting_stderr.byte_count)
sys.exit(1)

View File

@ -8,6 +8,7 @@ import re
import socket
import ssl
import sys
import typing # noqa: F401
from tornado.escape import to_unicode, utf8
from tornado import gen
@ -135,7 +136,7 @@ class RespondInPrepareHandler(RequestHandler):
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque()
self.triggers = collections.deque() # type: typing.Deque[str]
return Application([
url("/trigger", TriggerHandler, dict(queue=self.triggers,
wake_callback=self.stop)),
@ -165,8 +166,11 @@ class SimpleHTTPClientTestMixin(object):
SimpleAsyncHTTPClient(force_instance=True))
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
client1 = self.io_loop.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
client2 = io_loop2.run_sync(gen.coroutine(SimpleAsyncHTTPClient))
async def make_client():
await gen.sleep(0)
return SimpleAsyncHTTPClient()
client1 = self.io_loop.run_sync(make_client)
client2 = io_loop2.run_sync(make_client)
self.assertTrue(client1 is not client2)
def test_connection_limit(self):
@ -176,8 +180,10 @@ class SimpleHTTPClientTestMixin(object):
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
client.fetch(self.get_url("/trigger")).add_done_callback(
lambda fut, i=i: (seen.append(i), self.stop()))
def cb(fut, i=i):
seen.append(i)
self.stop()
client.fetch(self.get_url("/trigger")).add_done_callback(cb)
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
@ -273,7 +279,7 @@ class SimpleHTTPClientTestMixin(object):
@skipIfNoIPv6
def test_ipv6(self):
[sock] = bind_sockets(None, '::1', family=socket.AF_INET6)
[sock] = bind_sockets(0, '::1', family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
url = '%s://[::1]:%d/hello' % (self.get_protocol(), port)
@ -339,7 +345,7 @@ class SimpleHTTPClientTestMixin(object):
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception)
contains_errno = str(errno.WSAECONNREFUSED) in str(cm.exception) # type: ignore
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
@ -447,12 +453,12 @@ class SimpleHTTPClientTestMixin(object):
# simple_httpclient_test, but it fails with the version of libcurl
# available on travis-ci. Move it when that has been upgraded
# or we have a better framework to skip tests based on curl version.
headers = []
chunks = []
headers = [] # type: typing.List[str]
chunk_bytes = [] # type: typing.List[bytes]
self.fetch("/redirect?url=/hello",
header_callback=headers.append,
streaming_callback=chunks.append)
chunks = list(map(to_unicode, chunks))
streaming_callback=chunk_bytes.append)
chunks = list(map(to_unicode, chunk_bytes))
self.assertEqual(chunks, ['Hello world!'])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
@ -524,22 +530,22 @@ class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10)
self.assertEqual(client.max_clients, 10) # type: ignore
with closing(AsyncHTTPClient(
max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11)
self.assertEqual(client.max_clients, 11) # type: ignore
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12)
self.assertEqual(client.max_clients, 12) # type: ignore
with closing(AsyncHTTPClient(
max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13)
self.assertEqual(client.max_clients, 13) # type: ignore
with closing(AsyncHTTPClient(
max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14)
self.assertEqual(client.max_clients, 14) # type: ignore
class HTTP100ContinueTestCase(AsyncHTTPTestCase):

View File

@ -6,6 +6,8 @@ from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.util import ObjectDict
import typing # noqa: F401
class TemplateTest(unittest.TestCase):
def test_simple(self):
@ -198,9 +200,10 @@ three{%end%}
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = None
loader = None # type: typing.Optional[DictLoader]
def load_generate(path, **kwargs):
assert loader is not None
return loader.load(path).generate(**kwargs)
loader = DictLoader({

View File

@ -4,6 +4,7 @@ import platform
import socket
import sys
import textwrap
import typing # noqa: F401
import unittest
import warnings
@ -65,7 +66,7 @@ def refusing_port():
# ephemeral port number to ensure that nothing can listen on that
# port.
server_socket, port = bind_unused_port()
server_socket.setblocking(1)
server_socket.setblocking(True)
client_socket = socket.socket()
client_socket.connect(("127.0.0.1", port))
conn, client_addr = server_socket.accept()
@ -84,7 +85,7 @@ def exec_test(caller_globals, caller_locals, s):
# globals: it's all global from the perspective of code defined
# in s.
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
local_namespace = {}
local_namespace = {} # type: typing.Dict[str, typing.Any]
exec(textwrap.dedent(s), global_namespace, local_namespace)
return local_namespace

View File

@ -30,6 +30,7 @@ import logging
import os
import re
import socket
import typing # noqa: F401
import unittest
import urllib.parse
@ -73,7 +74,7 @@ class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
self._cookies = {} # type: typing.Dict[str, bytes]
if key_version is None:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
else:
@ -102,7 +103,7 @@ class SecureCookieV1Test(unittest.TestCase):
version=1)
cookie = handler._cookies['foo']
match = re.match(br'12345678\|([0-9]+)\|([0-9a-f]+)', cookie)
self.assertTrue(match)
assert match is not None
timestamp = match.group(1)
sig = match.group(2)
self.assertEqual(
@ -335,11 +336,12 @@ class CookieTest(WebTestCase):
response = self.fetch("/set_expires_days")
header = response.headers.get("Set-Cookie")
match = re.match("foo=bar; expires=(?P<expires>.+); Path=/", header)
self.assertIsNotNone(match)
assert match is not None
expires = datetime.datetime.utcnow() + datetime.timedelta(days=10)
header_expires = datetime.datetime(
*email.utils.parsedate(match.groupdict()["expires"])[:6])
parsed = email.utils.parsedate(match.groupdict()["expires"])
assert parsed is not None
header_expires = datetime.datetime(*parsed[:6])
self.assertTrue(abs((expires - header_expires).total_seconds()) < 10)
def test_set_cookie_false_flags(self):
@ -491,7 +493,7 @@ class RequestEncodingTest(WebTestCase):
class TypeCheckHandler(RequestHandler):
def prepare(self):
self.errors = {}
self.errors = {} # type: typing.Dict[str, str]
self.check_type('status', self.get_status(), int)
@ -1515,8 +1517,9 @@ class DateHeaderTest(SimpleHandlerTestCase):
def test_date_header(self):
response = self.fetch('/')
header_date = datetime.datetime(
*email.utils.parsedate(response.headers['Date'])[:6])
parsed = email.utils.parsedate(response.headers['Date'])
assert parsed is not None
header_date = datetime.datetime(*parsed[:6])
self.assertTrue(header_date - datetime.datetime.utcnow() <
datetime.timedelta(seconds=2))
@ -2085,9 +2088,9 @@ class StreamingRequestBodyTest(WebTestCase):
@gen_test
def test_streaming_body(self):
self.prepared = Future()
self.data = Future()
self.finished = Future()
self.prepared = Future() # type: Future[None]
self.data = Future() # type: Future[bytes]
self.finished = Future() # type: Future[None]
stream = self.connect(b"/stream_body", connection_close=True)
yield self.prepared
@ -2121,7 +2124,7 @@ class StreamingRequestBodyTest(WebTestCase):
@gen_test
def test_close_during_upload(self):
self.close_future = Future()
self.close_future = Future() # type: Future[None]
stream = self.connect(b"/close_detection", connection_close=False)
stream.close()
yield self.close_future
@ -2136,7 +2139,7 @@ class BaseFlowControlHandler(RequestHandler):
def initialize(self, test):
self.test = test
self.method = None
self.methods = []
self.methods = [] # type: typing.List[str]
@contextlib.contextmanager
def in_method(self, method):

View File

@ -199,7 +199,7 @@ class WebSocketBaseTestCase(AsyncHTTPTestCase):
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
@ -525,7 +525,7 @@ class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/native', NativeCoroutineOnMessageHandler,
dict(close_future=self.close_future))])
@ -546,7 +546,7 @@ class CompressionTestMixin(object):
MESSAGE = 'Hello world. Testing 123 123'
def get_app(self):
self.close_future = Future()
self.close_future = Future() # type: Future[None]
class LimitedHandler(TestWebSocketHandler):
@property
@ -677,7 +677,7 @@ class ServerPeriodicPingTest(WebSocketBaseTestCase):
def on_pong(self, data):
self.write_message("got pong")
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
], websocket_ping_interval=0.01)
@ -698,7 +698,7 @@ class ClientPeriodicPingTest(WebSocketBaseTestCase):
def on_ping(self, data):
self.write_message("got ping")
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
])
@ -719,7 +719,7 @@ class ManualPingTest(WebSocketBaseTestCase):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/', PingHandler, dict(close_future=self.close_future)),
])
@ -743,7 +743,7 @@ class ManualPingTest(WebSocketBaseTestCase):
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future()
self.close_future = Future() # type: Future[None]
return Application([
('/', EchoHandler, dict(close_future=self.close_future)),
], websocket_max_message_size=1024)