bpo-44011: Revert "New asyncio ssl implementation (GH-17975)" (GH-25848)

This reverts commit 5fb06edbbb and all
subsequent dependent commits.
This commit is contained in:
Pablo Galindo 2021-05-03 16:21:59 +01:00 committed by GitHub
parent 39494285e1
commit 7719953b30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 524 additions and 2477 deletions

View File

@ -273,7 +273,7 @@ async def restore(self):
class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout=None):
ssl_handshake_timeout):
self._loop = loop
self._sockets = sockets
self._active_count = 0
@ -282,7 +282,6 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
self._backlog = backlog
self._ssl_context = ssl_context
self._ssl_handshake_timeout = ssl_handshake_timeout
self._ssl_shutdown_timeout = ssl_shutdown_timeout
self._serving = False
self._serving_forever_fut = None
@ -314,8 +313,7 @@ def _start_serving(self):
sock.listen(self._backlog)
self._loop._start_serving(
self._protocol_factory, sock, self._ssl_context,
self, self._backlog, self._ssl_handshake_timeout,
self._ssl_shutdown_timeout)
self, self._backlog, self._ssl_handshake_timeout)
def get_loop(self):
return self._loop
@ -469,7 +467,6 @@ def _make_ssl_transport(
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True):
"""Create SSL transport."""
raise NotImplementedError
@ -972,7 +969,6 @@ async def create_connection(
proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
"""Connect to a TCP server.
@ -1008,10 +1004,6 @@ async def create_connection(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if happy_eyeballs_delay is not None and interleave is None:
# If using happy eyeballs, default to interleave addresses by family
interleave = 1
@ -1087,8 +1079,7 @@ async def create_connection(
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket
@ -1100,8 +1091,7 @@ async def create_connection(
async def _create_connection_transport(
self, sock, protocol_factory, ssl,
server_hostname, server_side=False,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
sock.setblocking(False)
@ -1112,8 +1102,7 @@ async def _create_connection_transport(
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(sock, protocol, waiter)
@ -1204,8 +1193,7 @@ async def _sendfile_fallback(self, transp, file, offset, count):
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
"""Upgrade transport to TLS.
Return a new transport that *protocol* should start using
@ -1228,7 +1216,6 @@ async def start_tls(self, transport, protocol, sslcontext, *,
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
call_connection_made=False)
# Pause early so that "ssl_protocol.data_received()" doesn't
@ -1427,7 +1414,6 @@ async def create_server(
reuse_address=None,
reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""Create a TCP server.
@ -1451,10 +1437,6 @@ async def create_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and ssl is None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if host is not None or port is not None:
if sock is not None:
raise ValueError(
@ -1527,8 +1509,7 @@ async def create_server(
sock.setblocking(False)
server = Server(self, sockets, protocol_factory,
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'
@ -1542,8 +1523,7 @@ async def create_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
@ -1551,14 +1531,9 @@ async def connect_accepted_socket(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
if self._debug:
# Get the socket from the transport because SSL transport closes
# the old socket and creates a new SSL socket

View File

@ -15,17 +15,10 @@
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0
# Number of seconds to wait for SSL shutdown to complete
# The default timeout mimics lingering_time
SSL_SHUTDOWN_TIMEOUT = 30.0
# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.
SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256
FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB
FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB
# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):

View File

@ -304,7 +304,6 @@ async def create_connection(
flags=0, sock=None, local_addr=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
happy_eyeballs_delay=None, interleave=None):
raise NotImplementedError
@ -314,7 +313,6 @@ async def create_server(
flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a TCP server bound to host and port.
@ -355,10 +353,6 @@ async def create_server(
will wait for completion of the SSL handshake before aborting the
connection. Default is 60s.
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown procedure
before aborting the connection. Default is 30s.
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@ -377,8 +371,7 @@ async def sendfile(self, transport, file, offset=0, count=None,
async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
"""Upgrade a transport to TLS.
Return a new transport that *protocol* should start using
@ -390,15 +383,13 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
raise NotImplementedError
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
"""A coroutine which creates a UNIX Domain Socket server.
@ -420,9 +411,6 @@ async def create_unix_server(
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 60s).
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for the SSL shutdown to finish (defaults to 30s).
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
the user should await Server.start_serving() or Server.serve_forever()
@ -433,8 +421,7 @@ async def create_unix_server(
async def connect_accepted_socket(
self, protocol_factory, sock,
*, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
"""Handle an accepted connection.
This is used by servers that accept connections outside of

View File

@ -642,13 +642,11 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
_ProactorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@ -814,8 +812,7 @@ def _write_to_self(self):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
def loop(f=None):
try:
@ -829,8 +826,7 @@ def loop(f=None):
self._make_ssl_transport(
conn, protocol, sslcontext, server_side=True,
extra={'peername': addr}, server=server,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
else:
self._make_socket_transport(
conn, protocol,

View File

@ -70,15 +70,11 @@ def _make_ssl_transport(
self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None,
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout
)
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
return ssl_protocol._app_transport
@ -150,17 +146,15 @@ def _write_to_self(self):
def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
self._add_reader(sock.fileno(), self._accept_connection,
protocol_factory, sock, sslcontext, server, backlog,
ssl_handshake_timeout, ssl_shutdown_timeout)
ssl_handshake_timeout)
def _accept_connection(
self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
# This method is only called once for each event loop tick where the
# listening socket has triggered an EVENT_READ. There may be multiple
# connections waiting for an .accept() so it is called in a loop.
@ -191,22 +185,20 @@ def _accept_connection(
self.call_later(constants.ACCEPT_RETRY_DELAY,
self._start_serving,
protocol_factory, sock, sslcontext, server,
backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
backlog, ssl_handshake_timeout)
else:
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout, ssl_shutdown_timeout)
ssl_handshake_timeout)
self.create_task(accept)
async def _accept_connection2(
self, protocol_factory, conn, extra,
sslcontext=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT):
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
protocol = None
transport = None
try:
@ -216,8 +208,7 @@ async def _accept_connection2(
transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,

File diff suppressed because it is too large Load Diff

View File

@ -229,8 +229,7 @@ async def create_unix_connection(
self, protocol_factory, path=None, *,
ssl=None, sock=None,
server_hostname=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None):
ssl_handshake_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
@ -242,9 +241,6 @@ async def create_unix_connection(
if ssl_handshake_timeout is not None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None:
if sock is not None:
@ -271,15 +267,13 @@ async def create_unix_connection(
transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_handshake_timeout=ssl_handshake_timeout)
return transport, protocol
async def create_unix_server(
self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
@ -288,10 +282,6 @@ async def create_unix_server(
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if ssl_shutdown_timeout is not None and not ssl:
raise ValueError(
'ssl_shutdown_timeout is only meaningful with ssl')
if path is not None:
if sock is not None:
raise ValueError(
@ -338,8 +328,7 @@ async def create_unix_server(
sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory,
ssl, backlog, ssl_handshake_timeout,
ssl_shutdown_timeout)
ssl, backlog, ssl_handshake_timeout)
if start_serving:
server._start_serving()
# Skip one loop iteration so that all 'loop.add_reader'

View File

@ -1437,51 +1437,44 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter,
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
ANY = mock.ANY
handshake_timeout = object()
shutdown_timeout = object()
# First try the default server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='python.org',
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
# Next try an explicit server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='perl.com',
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
# Finally try an explicit empty server_hostname.
self.loop._make_ssl_transport.reset_mock()
coro = self.loop.create_connection(
MyProto, 'python.org', 80, ssl=True,
server_hostname='',
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
transport, _ = self.loop.run_until_complete(coro)
transport.close()
self.loop._make_ssl_transport.assert_called_with(
ANY, ANY, ANY, ANY,
server_side=False,
server_hostname='',
ssl_handshake_timeout=handshake_timeout,
ssl_shutdown_timeout=shutdown_timeout)
ssl_handshake_timeout=handshake_timeout)
def test_create_connection_no_ssl_server_hostname_errors(self):
# When not using ssl, server_hostname must be None.
@ -1888,7 +1881,7 @@ def test_accept_connection_exception(self, m_log):
constants.ACCEPT_RETRY_DELAY,
# self.loop._start_serving
mock.ANY,
MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY)
MyProto, sock, None, None, mock.ANY, mock.ANY)
def test_call_coroutine(self):
with self.assertWarns(DeprecationWarning):

View File

@ -70,6 +70,44 @@ def test_make_socket_transport(self):
close_transport(transport)
@unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self):
m = mock.Mock()
self.loop._add_reader = mock.Mock()
self.loop._add_reader._is_coroutine = False
self.loop._add_writer = mock.Mock()
self.loop._remove_reader = mock.Mock()
self.loop._remove_writer = mock.Mock()
waiter = self.loop.create_future()
with test_utils.disable_logger():
transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
with self.assertRaisesRegex(RuntimeError,
r'SSL transport.*not.*initialized'):
transport.is_reading()
# execute the handshake while the logger is disabled
# to ignore SSL handshake failure
test_utils.run_briefly(self.loop)
self.assertTrue(transport.is_reading())
transport.pause_reading()
transport.pause_reading()
self.assertFalse(transport.is_reading())
transport.resume_reading()
transport.resume_reading()
self.assertTrue(transport.is_reading())
# Sanity check
class_name = transport.__class__.__name__
self.assertIn("ssl", class_name.lower())
self.assertIn("transport", class_name.lower())
transport.close()
# execute pending callbacks to close the socket transport
test_utils.run_briefly(self.loop)
@mock.patch('asyncio.selector_events.ssl', None)
@mock.patch('asyncio.sslproto.ssl', None)
def test_make_ssl_transport_without_ssl_error(self):

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,7 @@
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
from test import support
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
@ -43,13 +44,16 @@ def ssl_protocol(self, *, waiter=None, proto=None):
def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
sslobj = mock.Mock()
# emulate reading decompressed data
sslobj.read.side_effect = ssl.SSLWantReadError
if do_handshake is not None:
sslobj.do_handshake = do_handshake
ssl_proto._sslobj = sslobj
ssl_proto.connection_made(transport)
sslpipe = mock.Mock()
sslpipe.shutdown.return_value = b''
if do_handshake:
sslpipe.do_handshake.side_effect = do_handshake
else:
def mock_handshake(callback):
return []
sslpipe.do_handshake.side_effect = mock_handshake
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
ssl_proto.connection_made(transport)
return transport
def test_handshake_timeout_zero(self):
@ -71,10 +75,7 @@ def test_handshake_timeout_negative(self):
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
self.connection_made(ssl_proto)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
@ -99,10 +100,7 @@ def test_connection_lost(self):
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
self.connection_made(ssl_proto)
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
@ -112,10 +110,7 @@ def test_close_during_handshake(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
transport = self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
transport = self.connection_made(ssl_proto)
test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close()
@ -148,7 +143,7 @@ def test_data_received_after_closing(self):
transp.close()
# should not raise
self.assertIsNone(ssl_proto.buffer_updated(5))
self.assertIsNone(ssl_proto.data_received(b'data'))
def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()

View File

@ -1,2 +0,0 @@
Reimplement SSL/TLS support in asyncio, borrow the impelementation from
uvloop library.