diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index e54ee309e42..f789635e0f8 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -273,7 +273,7 @@ async def restore(self): class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, - ssl_handshake_timeout, ssl_shutdown_timeout=None): + ssl_handshake_timeout): self._loop = loop self._sockets = sockets self._active_count = 0 @@ -282,7 +282,6 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, self._backlog = backlog self._ssl_context = ssl_context self._ssl_handshake_timeout = ssl_handshake_timeout - self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None @@ -314,8 +313,7 @@ def _start_serving(self): sock.listen(self._backlog) self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, - self, self._backlog, self._ssl_handshake_timeout, - self._ssl_shutdown_timeout) + self, self._backlog, self._ssl_handshake_timeout) def get_loop(self): return self._loop @@ -469,7 +467,6 @@ def _make_ssl_transport( *, server_side=False, server_hostname=None, extra=None, server=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -972,7 +969,6 @@ async def create_connection( proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. @@ -1008,10 +1004,6 @@ async def create_connection( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') - if ssl_shutdown_timeout is not None and not ssl: - raise ValueError( - 'ssl_shutdown_timeout is only meaningful with ssl') - if happy_eyeballs_delay is not None and interleave is None: # If using happy eyeballs, default to interleave addresses by family interleave = 1 @@ -1087,8 +1079,7 @@ async def create_connection( transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1100,8 +1091,7 @@ async def create_connection( async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): sock.setblocking(False) @@ -1112,8 +1102,7 @@ async def _create_connection_transport( transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, server_side=server_side, server_hostname=server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -1204,8 +1193,7 @@ async def _sendfile_fallback(self, transp, file, offset, count): async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): """Upgrade transport to TLS. Return a new transport that *protocol* should start using @@ -1228,7 +1216,6 @@ async def start_tls(self, transport, protocol, sslcontext, *, self, protocol, sslcontext, waiter, server_side, server_hostname, ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) # Pause early so that "ssl_protocol.data_received()" doesn't @@ -1427,7 +1414,6 @@ async def create_server( reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, start_serving=True): """Create a TCP server. @@ -1451,10 +1437,6 @@ async def create_server( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') - if ssl_shutdown_timeout is not None and ssl is None: - raise ValueError( - 'ssl_shutdown_timeout is only meaningful with ssl') - if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1527,8 +1509,7 @@ async def create_server( sock.setblocking(False) server = Server(self, sockets, protocol_factory, - ssl, backlog, ssl_handshake_timeout, - ssl_shutdown_timeout) + ssl, backlog, ssl_handshake_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' @@ -1542,8 +1523,7 @@ async def create_server( async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') @@ -1551,14 +1531,9 @@ async def connect_accepted_socket( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') - if ssl_shutdown_timeout is not None and not ssl: - raise ValueError( - 'ssl_shutdown_timeout is only meaningful with ssl') - transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index f171ead28fe..33feed60e55 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -15,17 +15,10 @@ # The default timeout matches that of Nginx. SSL_HANDSHAKE_TIMEOUT = 60.0 -# Number of seconds to wait for SSL shutdown to complete -# The default timeout mimics lingering_time -SSL_SHUTDOWN_TIMEOUT = 30.0 - # Used in sendfile fallback code. We use fallback for platforms # that don't support sendfile, or for TLS connections. SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 -FLOW_CONTROL_HIGH_WATER_SSL_READ = 256 # KiB -FLOW_CONTROL_HIGH_WATER_SSL_WRITE = 512 # KiB - # The enum should be here to break circular dependencies between # base_events and sslproto class _SendfileMode(enum.Enum): diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index d5254fa5e7e..b966ad26bf4 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -304,7 +304,6 @@ async def create_connection( flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, happy_eyeballs_delay=None, interleave=None): raise NotImplementedError @@ -314,7 +313,6 @@ async def create_server( flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, start_serving=True): """A coroutine which creates a TCP server bound to host and port. @@ -355,10 +353,6 @@ async def create_server( will wait for completion of the SSL handshake before aborting the connection. Default is 60s. - ssl_shutdown_timeout is the time in seconds that an SSL server - will wait for completion of the SSL shutdown procedure - before aborting the connection. Default is 30s. - start_serving set to True (default) causes the created server to start accepting connections immediately. When set to False, the user should await Server.start_serving() or Server.serve_forever() @@ -377,8 +371,7 @@ async def sendfile(self, transport, file, offset=0, count=None, async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): """Upgrade a transport to TLS. Return a new transport that *protocol* should start using @@ -390,15 +383,13 @@ async def create_unix_connection( self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): raise NotImplementedError async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, start_serving=True): """A coroutine which creates a UNIX Domain Socket server. @@ -420,9 +411,6 @@ async def create_unix_server( ssl_handshake_timeout is the time in seconds that an SSL server will wait for the SSL handshake to complete (defaults to 60s). - ssl_shutdown_timeout is the time in seconds that an SSL server - will wait for the SSL shutdown to finish (defaults to 30s). - start_serving set to True (default) causes the created server to start accepting connections immediately. When set to False, the user should await Server.start_serving() or Server.serve_forever() @@ -433,8 +421,7 @@ async def create_unix_server( async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): """Handle an accepted connection. This is used by servers that accept connections outside of diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 10852afe2b4..45c11ee4b48 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -642,13 +642,11 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) _ProactorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport @@ -814,8 +812,7 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): def loop(f=None): try: @@ -829,8 +826,7 @@ def loop(f=None): self._make_ssl_transport( conn, protocol, sslcontext, server_side=True, extra={'peername': addr}, server=server, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) else: self._make_socket_transport( conn, protocol, diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 63ab15f30fb..59cb6b1babe 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -70,15 +70,11 @@ def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, - ): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_protocol = sslproto.SSLProtocol( - self, protocol, sslcontext, waiter, - server_side, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout - ) + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) _SelectorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport @@ -150,17 +146,15 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): self._add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, sslcontext, server, backlog, - ssl_handshake_timeout, ssl_shutdown_timeout) + ssl_handshake_timeout) def _accept_connection( self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -191,22 +185,20 @@ def _accept_connection( self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, protocol_factory, sock, sslcontext, server, - backlog, ssl_handshake_timeout, - ssl_shutdown_timeout) + backlog, ssl_handshake_timeout) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} accept = self._accept_connection2( protocol_factory, conn, extra, sslcontext, server, - ssl_handshake_timeout, ssl_shutdown_timeout) + ssl_handshake_timeout) self.create_task(accept) async def _accept_connection2( self, protocol_factory, conn, extra, sslcontext=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): protocol = None transport = None try: @@ -216,8 +208,7 @@ async def _accept_connection2( transport = self._make_ssl_transport( conn, protocol, sslcontext, waiter=waiter, server_side=True, extra=extra, server=server, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 79734ab63d2..cad25b26539 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -1,5 +1,4 @@ import collections -import enum import warnings try: import ssl @@ -7,38 +6,10 @@ ssl = None from . import constants -from . import exceptions from . import protocols from . import transports from .log import logger -if ssl is not None: - SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError) - - -class SSLProtocolState(enum.Enum): - UNWRAPPED = "UNWRAPPED" - DO_HANDSHAKE = "DO_HANDSHAKE" - WRAPPED = "WRAPPED" - FLUSHING = "FLUSHING" - SHUTDOWN = "SHUTDOWN" - - -class AppProtocolState(enum.Enum): - # This tracks the state of app protocol (https://git.io/fj59P): - # - # INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST - # - # * cm: connection_made() - # * dr: data_received() - # * er: eof_received() - # * cl: connection_lost() - - STATE_INIT = "STATE_INIT" - STATE_CON_MADE = "STATE_CON_MADE" - STATE_EOF = "STATE_EOF" - STATE_CON_LOST = "STATE_CON_LOST" - def _create_transport_context(server_side, server_hostname): if server_side: @@ -54,35 +25,269 @@ def _create_transport_context(server_side, server_hostname): return sslcontext -def add_flowcontrol_defaults(high, low, kb): - if high is None: - if low is None: - hi = kb * 1024 - else: - lo = low - hi = 4 * lo - else: - hi = high - if low is None: - lo = hi // 4 - else: - lo = low +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" - if not hi >= lo >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (hi, lo)) - return hi, lo +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + do_handshake(). To shutdown SSL again, call unwrap(). + """ + + max_size = 256 * 1024 # Buffer size passed to read() + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the ssl.SSLContext to use. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = _UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + self._handshake_cb = None + self._shutdown_cb = None + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal ssl.SSLObject instance. + + Return None if the pipe is not wrapped. + """ + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """ + Whether a security layer is currently in effect. + + Return False during handshake. + """ + return self._state == _WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != _UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = _DO_HANDSHAKE + self._handshake_cb = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == _UNWRAPPED: + raise RuntimeError('no security layer present') + if self._state == _SHUTDOWN: + raise RuntimeError('shutdown in progress') + assert self._state in (_WRAPPED, _DO_HANDSHAKE) + self._state = _SHUTDOWN + self._shutdown_cb = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling shutdown(). + """ + if self._state == _UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + if data: + appdata = [data] + else: + appdata = [] + return ([], appdata) + + self._need_ssldata = False + if data: + self._incoming.write(data) + + ssldata = [] + appdata = [] + try: + if self._state == _DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = _WRAPPED + if self._handshake_cb: + self._handshake_cb(None) + if only_handshake: + return (ssldata, appdata) + # Handshake done: execute the wrapped block + + if self._state == _WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.max_size) + appdata.append(chunk) + if not chunk: # close_notify + break + + elif self._state == _SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = _UNWRAPPED + if self._shutdown_cb: + self._shutdown_cb() + + elif self._state == _UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as exc: + exc_errno = getattr(exc, 'errno', None) + if exc_errno not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == _DO_HANDSHAKE and self._handshake_cb: + self._handshake_cb(exc) + raise + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the id() must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + assert 0 <= offset <= len(data) + if self._state == _UNWRAPPED: + # pass through data in unwrapped mode + if offset < len(data): + ssldata = [data[offset:]] + else: + ssldata = [] + return (ssldata, len(data)) + + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as exc: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + exc_errno = getattr(exc, 'errno', None) + if exc.reason == 'PROTOCOL_IS_SHUTDOWN': + exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ + if exc_errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) + + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): - _start_tls_compatible = True _sendfile_compatible = constants._SendfileMode.FALLBACK def __init__(self, loop, ssl_protocol): self._loop = loop + # SSLProtocol instance self._ssl_protocol = ssl_protocol self._closed = False @@ -110,15 +315,16 @@ def close(self): self._closed = True self._ssl_protocol._start_shutdown() - def __del__(self, _warnings=warnings): + def __del__(self, _warn=warnings.warn): if not self._closed: - self._closed = True - _warnings.warn( - "unclosed transport ", ResourceWarning) + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() def is_reading(self): - return not self._ssl_protocol._app_reading_paused + tr = self._ssl_protocol._transport + if tr is None: + raise RuntimeError('SSL transport has not been initialized yet') + return tr.is_reading() def pause_reading(self): """Pause the receiving end. @@ -126,7 +332,7 @@ def pause_reading(self): No data will be passed to the protocol's data_received() method until resume_reading() is called. """ - self._ssl_protocol._pause_reading() + self._ssl_protocol._transport.pause_reading() def resume_reading(self): """Resume the receiving end. @@ -134,7 +340,7 @@ def resume_reading(self): Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._resume_reading() + self._ssl_protocol._transport.resume_reading() def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -155,51 +361,16 @@ def set_write_buffer_limits(self, high=None, low=None): reduces opportunities for doing I/O and computation concurrently. """ - self._ssl_protocol._set_write_buffer_limits(high, low) - self._ssl_protocol._control_app_writing() - - def get_write_buffer_limits(self): - return (self._ssl_protocol._outgoing_low_water, - self._ssl_protocol._outgoing_high_water) + self._ssl_protocol._transport.set_write_buffer_limits(high, low) def get_write_buffer_size(self): - """Return the current size of the write buffers.""" - return self._ssl_protocol._get_write_buffer_size() - - def set_read_buffer_limits(self, high=None, low=None): - """Set the high- and low-water limits for read flow control. - - These two values control when to call the upstream transport's - pause_reading() and resume_reading() methods. If specified, - the low-water limit must be less than or equal to the - high-water limit. Neither value can be negative. - - The defaults are implementation-specific. If only the - high-water limit is given, the low-water limit defaults to an - implementation-specific value less than or equal to the - high-water limit. Setting high to zero forces low to zero as - well, and causes pause_reading() to be called whenever the - buffer becomes non-empty. Setting low to zero causes - resume_reading() to be called only once the buffer is empty. - Use of zero for either limit is generally sub-optimal as it - reduces opportunities for doing I/O and computation - concurrently. - """ - self._ssl_protocol._set_read_buffer_limits(high, low) - self._ssl_protocol._control_ssl_reading() - - def get_read_buffer_limits(self): - return (self._ssl_protocol._incoming_low_water, - self._ssl_protocol._incoming_high_water) - - def get_read_buffer_size(self): - """Return the current size of the read buffer.""" - return self._ssl_protocol._get_read_buffer_size() + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() @property def _protocol_paused(self): # Required for sendfile fallback pause_writing/resume_writing logic - return self._ssl_protocol._app_writing_paused + return self._ssl_protocol._transport._protocol_paused def write(self, data): """Write some data bytes to the transport. @@ -212,22 +383,7 @@ def write(self, data): f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata((data,)) - - def writelines(self, list_of_data): - """Write a list (or any iterable) of data bytes to the transport. - - The default implementation concatenates the arguments and - calls write() on the result. - """ - self._ssl_protocol._write_appdata(list_of_data) - - def write_eof(self): - """Close the write end after flushing buffered data. - - This raises :exc:`NotImplementedError` right now. - """ - raise NotImplementedError + self._ssl_protocol._write_appdata(data) def can_write_eof(self): """Return True if this transport supports write_eof(), False if not.""" @@ -240,36 +396,23 @@ def abort(self): The protocol's connection_lost() method will (eventually) be called with None as its argument. """ - self._closed = True self._ssl_protocol._abort() - - def _force_close(self, exc): self._closed = True - self._ssl_protocol._abort(exc) - - def _test__append_write_backlog(self, data): - # for test only - self._ssl_protocol._write_backlog.append(data) - self._ssl_protocol._write_buffer_size += len(data) -class SSLProtocol(protocols.BufferedProtocol): - max_size = 256 * 1024 # Buffer size passed to read() +class SSLProtocol(protocols.Protocol): + """SSL protocol. - _handshake_start_time = None - _handshake_timeout_handle = None - _shutdown_timeout_handle = None + Implementation of SSL on top of a socket using incoming and outgoing + buffers which are ssl.MemoryBIO objects. + """ def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): if ssl is None: - raise RuntimeError("stdlib ssl module not available") - - self._ssl_buffer = bytearray(self.max_size) - self._ssl_buffer_view = memoryview(self._ssl_buffer) + raise RuntimeError('stdlib ssl module not available') if ssl_handshake_timeout is None: ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT @@ -277,12 +420,6 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, raise ValueError( f"ssl_handshake_timeout should be a positive number, " f"got {ssl_handshake_timeout}") - if ssl_shutdown_timeout is None: - ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT - elif ssl_shutdown_timeout <= 0: - raise ValueError( - f"ssl_shutdown_timeout should be a positive number, " - f"got {ssl_shutdown_timeout}") if not sslcontext: sslcontext = _create_transport_context( @@ -305,54 +442,21 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._waiter = waiter self._loop = loop self._set_app_protocol(app_protocol) - self._app_transport = None - self._app_transport_created = False + self._app_transport = _SSLProtocolTransport(self._loop, self) + # _SSLPipe instance (None until the connection is made) + self._sslpipe = None + self._session_established = False + self._in_handshake = False + self._in_shutdown = False # transport, ex: SelectorSocketTransport self._transport = None + self._call_connection_made = call_connection_made self._ssl_handshake_timeout = ssl_handshake_timeout - self._ssl_shutdown_timeout = ssl_shutdown_timeout - # SSL and state machine - self._incoming = ssl.MemoryBIO() - self._outgoing = ssl.MemoryBIO() - self._state = SSLProtocolState.UNWRAPPED - self._conn_lost = 0 # Set when connection_lost called - if call_connection_made: - self._app_state = AppProtocolState.STATE_INIT - else: - self._app_state = AppProtocolState.STATE_CON_MADE - self._sslobj = self._sslcontext.wrap_bio( - self._incoming, self._outgoing, - server_side=self._server_side, - server_hostname=self._server_hostname) - - # Flow Control - - self._ssl_writing_paused = False - - self._app_reading_paused = False - - self._ssl_reading_paused = False - self._incoming_high_water = 0 - self._incoming_low_water = 0 - self._set_read_buffer_limits() - self._eof_received = False - - self._app_writing_paused = False - self._outgoing_high_water = 0 - self._outgoing_low_water = 0 - self._set_write_buffer_limits() - self._get_app_transport() def _set_app_protocol(self, app_protocol): self._app_protocol = app_protocol - # Make fast hasattr check first - if (hasattr(app_protocol, 'get_buffer') and - isinstance(app_protocol, protocols.BufferedProtocol)): - self._app_protocol_get_buffer = app_protocol.get_buffer - self._app_protocol_buffer_updated = app_protocol.buffer_updated - self._app_protocol_is_buffer = True - else: - self._app_protocol_is_buffer = False + self._app_protocol_is_buffer = \ + isinstance(app_protocol, protocols.BufferedProtocol) def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -364,20 +468,15 @@ def _wakeup_waiter(self, exc=None): self._waiter.set_result(None) self._waiter = None - def _get_app_transport(self): - if self._app_transport is None: - if self._app_transport_created: - raise RuntimeError('Creating _SSLProtocolTransport twice') - self._app_transport = _SSLProtocolTransport(self._loop, self) - self._app_transport_created = True - return self._app_transport - def connection_made(self, transport): """Called when the low-level connection is made. Start the SSL handshake. """ self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) self._start_handshake() def connection_lost(self, exc): @@ -387,58 +486,72 @@ def connection_lost(self, exc): meaning a regular EOF is received or the connection was aborted or closed). """ - self._write_backlog.clear() - self._outgoing.read() - self._conn_lost += 1 - - # Just mark the app transport as closed so that its __dealloc__ - # doesn't complain. - if self._app_transport is not None: - self._app_transport._closed = True - - if self._state != SSLProtocolState.DO_HANDSHAKE: - if ( - self._app_state == AppProtocolState.STATE_CON_MADE or - self._app_state == AppProtocolState.STATE_EOF - ): - self._app_state = AppProtocolState.STATE_CON_LOST - self._loop.call_soon(self._app_protocol.connection_lost, exc) - self._set_state(SSLProtocolState.UNWRAPPED) + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + else: + # Most likely an exception occurred while in SSL handshake. + # Just mark the app transport as closed so that its __del__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True self._transport = None self._app_transport = None - self._app_protocol = None - self._wakeup_waiter(exc) - - if self._shutdown_timeout_handle: - self._shutdown_timeout_handle.cancel() - self._shutdown_timeout_handle = None - if self._handshake_timeout_handle: + if getattr(self, '_handshake_timeout_handle', None): self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + self._wakeup_waiter(exc) + self._app_protocol = None + self._sslpipe = None - def get_buffer(self, n): - want = n - if want <= 0 or want > self.max_size: - want = self.max_size - if len(self._ssl_buffer) < want: - self._ssl_buffer = bytearray(want) - self._ssl_buffer_view = memoryview(self._ssl_buffer) - return self._ssl_buffer_view + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() - def buffer_updated(self, nbytes): - self._incoming.write(self._ssl_buffer_view[:nbytes]) + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._do_handshake() + def data_received(self, data): + """Called when some SSL data is received. - elif self._state == SSLProtocolState.WRAPPED: - self._do_read() + The argument is a bytes object. + """ + if self._sslpipe is None: + # transport closing, sslpipe is destroyed + return - elif self._state == SSLProtocolState.FLUSHING: - self._do_flush() + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + self._fatal_error(e, 'SSL error in data received') + return - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() + for chunk in ssldata: + self._transport.write(chunk) + + for chunk in appdata: + if chunk: + try: + if self._app_protocol_is_buffer: + protocols._feed_data_to_buffered_proto( + self._app_protocol, chunk) + else: + self._app_protocol.data_received(chunk) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as ex: + self._fatal_error( + ex, 'application protocol failed to receive SSL data') + return + else: + self._start_shutdown() + break def eof_received(self): """Called when the other end of the low-level stream @@ -448,32 +561,19 @@ def eof_received(self): will close itself. If it returns a true value, closing the transport is up to the protocol. """ - self._eof_received = True try: if self._loop.get_debug(): logger.debug("%r received EOF", self) - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._on_handshake_complete(ConnectionResetError) + self._wakeup_waiter(ConnectionResetError) - elif self._state == SSLProtocolState.WRAPPED: - self._set_state(SSLProtocolState.FLUSHING) - if self._app_reading_paused: - return True - else: - self._do_flush() - - elif self._state == SSLProtocolState.FLUSHING: - self._do_write() - self._set_state(SSLProtocolState.SHUTDOWN) - self._do_shutdown() - - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() - - except Exception: + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: self._transport.close() - raise def _get_extra_info(self, name, default=None): if name in self._extra: @@ -483,45 +583,19 @@ def _get_extra_info(self, name, default=None): else: return default - def _set_state(self, new_state): - allowed = False - - if new_state == SSLProtocolState.UNWRAPPED: - allowed = True - - elif ( - self._state == SSLProtocolState.UNWRAPPED and - new_state == SSLProtocolState.DO_HANDSHAKE - ): - allowed = True - - elif ( - self._state == SSLProtocolState.DO_HANDSHAKE and - new_state == SSLProtocolState.WRAPPED - ): - allowed = True - - elif ( - self._state == SSLProtocolState.WRAPPED and - new_state == SSLProtocolState.FLUSHING - ): - allowed = True - - elif ( - self._state == SSLProtocolState.FLUSHING and - new_state == SSLProtocolState.SHUTDOWN - ): - allowed = True - - if allowed: - self._state = new_state - + def _start_shutdown(self): + if self._in_shutdown: + return + if self._in_handshake: + self._abort() else: - raise RuntimeError( - 'cannot switch state from {} to {}'.format( - self._state, new_state)) + self._in_shutdown = True + self._write_appdata(b'') - # Handshake flow + def _write_appdata(self, data): + self._write_backlog.append((data, 0)) + self._write_buffer_size += len(data) + self._process_write_backlog() def _start_handshake(self): if self._loop.get_debug(): @@ -529,18 +603,17 @@ def _start_handshake(self): self._handshake_start_time = self._loop.time() else: self._handshake_start_time = None - - self._set_state(SSLProtocolState.DO_HANDSHAKE) - - # start handshake timeout count down + self._in_handshake = True + # (b'', 1) is a special value in _process_write_backlog() to do + # the SSL handshake + self._write_backlog.append((b'', 1)) self._handshake_timeout_handle = \ self._loop.call_later(self._ssl_handshake_timeout, - lambda: self._check_handshake_timeout()) - - self._do_handshake() + self._check_handshake_timeout) + self._process_write_backlog() def _check_handshake_timeout(self): - if self._state == SSLProtocolState.DO_HANDSHAKE: + if self._in_handshake is True: msg = ( f"SSL handshake is taking longer than " f"{self._ssl_handshake_timeout} seconds: " @@ -548,37 +621,24 @@ def _check_handshake_timeout(self): ) self._fatal_error(ConnectionAbortedError(msg)) - def _do_handshake(self): - try: - self._sslobj.do_handshake() - except SSLAgainErrors: - self._process_outgoing() - except ssl.SSLError as exc: - self._on_handshake_complete(exc) - else: - self._on_handshake_complete(None) - def _on_handshake_complete(self, handshake_exc): - if self._handshake_timeout_handle is not None: - self._handshake_timeout_handle.cancel() - self._handshake_timeout_handle = None + self._in_handshake = False + self._handshake_timeout_handle.cancel() - sslobj = self._sslobj + sslobj = self._sslpipe.ssl_object try: - if handshake_exc is None: - self._set_state(SSLProtocolState.WRAPPED) - else: + if handshake_exc is not None: raise handshake_exc peercert = sslobj.getpeercert() - except Exception as exc: - self._set_state(SSLProtocolState.UNWRAPPED) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if isinstance(exc, ssl.CertificateError): msg = 'SSL handshake failed on verifying the certificate' else: msg = 'SSL handshake failed' self._fatal_error(exc, msg) - self._wakeup_waiter(exc) return if self._loop.get_debug(): @@ -589,330 +649,85 @@ def _on_handshake_complete(self, handshake_exc): self._extra.update(peercert=peercert, cipher=sslobj.cipher(), compression=sslobj.compression(), - ssl_object=sslobj) - if self._app_state == AppProtocolState.STATE_INIT: - self._app_state = AppProtocolState.STATE_CON_MADE - self._app_protocol.connection_made(self._get_app_transport()) + ssl_object=sslobj, + ) + if self._call_connection_made: + self._app_protocol.connection_made(self._app_transport) self._wakeup_waiter() - self._do_read() + self._session_established = True + # In case transport.write() was already called. Don't call + # immediately _process_write_backlog(), but schedule it: + # _on_handshake_complete() can be called indirectly from + # _process_write_backlog(), and _process_write_backlog() is not + # reentrant. + self._loop.call_soon(self._process_write_backlog) - # Shutdown flow - - def _start_shutdown(self): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN, - SSLProtocolState.UNWRAPPED - ) - ): - return - if self._app_transport is not None: - self._app_transport._closed = True - if self._state == SSLProtocolState.DO_HANDSHAKE: - self._abort() - else: - self._set_state(SSLProtocolState.FLUSHING) - self._shutdown_timeout_handle = self._loop.call_later( - self._ssl_shutdown_timeout, - lambda: self._check_shutdown_timeout() - ) - self._do_flush() - - def _check_shutdown_timeout(self): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN - ) - ): - self._transport._force_close( - exceptions.TimeoutError('SSL shutdown timed out')) - - def _do_flush(self): - self._do_read() - self._set_state(SSLProtocolState.SHUTDOWN) - self._do_shutdown() - - def _do_shutdown(self): - try: - if not self._eof_received: - self._sslobj.unwrap() - except SSLAgainErrors: - self._process_outgoing() - except ssl.SSLError as exc: - self._on_shutdown_complete(exc) - else: - self._process_outgoing() - self._call_eof_received() - self._on_shutdown_complete(None) - - def _on_shutdown_complete(self, shutdown_exc): - if self._shutdown_timeout_handle is not None: - self._shutdown_timeout_handle.cancel() - self._shutdown_timeout_handle = None - - if shutdown_exc: - self._fatal_error(shutdown_exc) - else: - self._loop.call_soon(self._transport.close) - - def _abort(self): - self._set_state(SSLProtocolState.UNWRAPPED) - if self._transport is not None: - self._transport.abort() - - # Outgoing flow - - def _write_appdata(self, list_of_data): - if ( - self._state in ( - SSLProtocolState.FLUSHING, - SSLProtocolState.SHUTDOWN, - SSLProtocolState.UNWRAPPED - ) - ): - if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - logger.warning('SSL connection is closed') - self._conn_lost += 1 + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None or self._sslpipe is None: return - for data in list_of_data: - self._write_backlog.append(data) - self._write_buffer_size += len(data) - try: - if self._state == SSLProtocolState.WRAPPED: - self._do_write() - - except Exception as ex: - self._fatal_error(ex, 'Fatal error on SSL protocol') - - def _do_write(self): - try: - while self._write_backlog: - data = self._write_backlog[0] - count = self._sslobj.write(data) - data_len = len(data) - if count < data_len: - self._write_backlog[0] = data[count:] - self._write_buffer_size -= count + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata = self._sslpipe.do_handshake( + self._on_handshake_complete) + offset = 1 else: - del self._write_backlog[0] - self._write_buffer_size -= data_len - except SSLAgainErrors: - pass - self._process_outgoing() + ssldata = self._sslpipe.shutdown(self._finalize) + offset = 1 - def _process_outgoing(self): - if not self._ssl_writing_paused: - data = self._outgoing.read() - if len(data): - self._transport.write(data) - self._control_app_writing() + for chunk in ssldata: + self._transport.write(chunk) - # Incoming flow - - def _do_read(self): - if ( - self._state not in ( - SSLProtocolState.WRAPPED, - SSLProtocolState.FLUSHING, - ) - ): - return - try: - if not self._app_reading_paused: - if self._app_protocol_is_buffer: - self._do_read__buffered() - else: - self._do_read__copied() - if self._write_backlog: - self._do_write() - else: - self._process_outgoing() - self._control_ssl_reading() - except Exception as ex: - self._fatal_error(ex, 'Fatal error on SSL protocol') - - def _do_read__buffered(self): - offset = 0 - count = 1 - - buf = self._app_protocol_get_buffer(self._get_read_buffer_size()) - wants = len(buf) - - try: - count = self._sslobj.read(wants, buf) - - if count > 0: - offset = count - while offset < wants: - count = self._sslobj.read(wants - offset, buf[offset:]) - if count > 0: - offset += count - else: - break - else: - self._loop.call_soon(lambda: self._do_read()) - except SSLAgainErrors: - pass - if offset > 0: - self._app_protocol_buffer_updated(offset) - if not count: - # close_notify - self._call_eof_received() - self._start_shutdown() - - def _do_read__copied(self): - chunk = b'1' - zero = True - one = False - - try: - while True: - chunk = self._sslobj.read(self.max_size) - if not chunk: + if offset < len(data): + self._write_backlog[0] = (data, offset) + # A short write means that a write is blocked on a read + # We need to enable reading if it is paused! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() break - if zero: - zero = False - one = True - first = chunk - elif one: - one = False - data = [first, chunk] - else: - data.append(chunk) - except SSLAgainErrors: - pass - if one: - self._app_protocol.data_received(first) - elif not zero: - self._app_protocol.data_received(b''.join(data)) - if not chunk: - # close_notify - self._call_eof_received() - self._start_shutdown() - def _call_eof_received(self): - try: - if self._app_state == AppProtocolState.STATE_CON_MADE: - self._app_state = AppProtocolState.STATE_EOF - keep_open = self._app_protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - except (KeyboardInterrupt, SystemExit): + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= len(data) + except (SystemExit, KeyboardInterrupt): raise - except BaseException as ex: - self._fatal_error(ex, 'Error calling eof_received()') - - # Flow control for writes from APP socket - - def _control_app_writing(self): - size = self._get_write_buffer_size() - if size >= self._outgoing_high_water and not self._app_writing_paused: - self._app_writing_paused = True - try: - self._app_protocol.pause_writing() - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.pause_writing() failed', - 'exception': exc, - 'transport': self._app_transport, - 'protocol': self, - }) - elif size <= self._outgoing_low_water and self._app_writing_paused: - self._app_writing_paused = False - try: - self._app_protocol.resume_writing() - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.resume_writing() failed', - 'exception': exc, - 'transport': self._app_transport, - 'protocol': self, - }) - - def _get_write_buffer_size(self): - return self._outgoing.pending + self._write_buffer_size - - def _set_write_buffer_limits(self, high=None, low=None): - high, low = add_flowcontrol_defaults( - high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE) - self._outgoing_high_water = high - self._outgoing_low_water = low - - # Flow control for reads to APP socket - - def _pause_reading(self): - self._app_reading_paused = True - - def _resume_reading(self): - if self._app_reading_paused: - self._app_reading_paused = False - - def resume(): - if self._state == SSLProtocolState.WRAPPED: - self._do_read() - elif self._state == SSLProtocolState.FLUSHING: - self._do_flush() - elif self._state == SSLProtocolState.SHUTDOWN: - self._do_shutdown() - self._loop.call_soon(resume) - - # Flow control for reads from SSL socket - - def _control_ssl_reading(self): - size = self._get_read_buffer_size() - if size >= self._incoming_high_water and not self._ssl_reading_paused: - self._ssl_reading_paused = True - self._transport.pause_reading() - elif size <= self._incoming_low_water and self._ssl_reading_paused: - self._ssl_reading_paused = False - self._transport.resume_reading() - - def _set_read_buffer_limits(self, high=None, low=None): - high, low = add_flowcontrol_defaults( - high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ) - self._incoming_high_water = high - self._incoming_low_water = low - - def _get_read_buffer_size(self): - return self._incoming.pending - - # Flow control for writes to SSL socket - - def pause_writing(self): - """Called when the low-level transport's buffer goes over - the high-water mark. - """ - assert not self._ssl_writing_paused - self._ssl_writing_paused = True - - def resume_writing(self): - """Called when the low-level transport's buffer drains below - the low-water mark. - """ - assert self._ssl_writing_paused - self._ssl_writing_paused = False - self._process_outgoing() + except BaseException as exc: + if self._in_handshake: + # Exceptions will be re-raised in _on_handshake_complete. + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') def _fatal_error(self, exc, message='Fatal error on transport'): - if self._transport: - self._transport._force_close(exc) - if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) - elif not isinstance(exc, exceptions.CancelledError): + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, 'transport': self._transport, 'protocol': self, }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + self._sslpipe = None + + if self._transport is not None: + self._transport.close() + + def _abort(self): + try: + if self._transport is not None: + self._transport.abort() + finally: + self._finalize() diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 181e1885152..a55b3a375fa 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -229,8 +229,7 @@ async def create_unix_connection( self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_handshake_timeout=None): assert server_hostname is None or isinstance(server_hostname, str) if ssl: if server_hostname is None: @@ -242,9 +241,6 @@ async def create_unix_connection( if ssl_handshake_timeout is not None: raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') - if ssl_shutdown_timeout is not None: - raise ValueError( - 'ssl_shutdown_timeout is only meaningful with ssl') if path is not None: if sock is not None: @@ -271,15 +267,13 @@ async def create_unix_connection( transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, server_hostname, - ssl_handshake_timeout=ssl_handshake_timeout, - ssl_shutdown_timeout=ssl_shutdown_timeout) + ssl_handshake_timeout=ssl_handshake_timeout) return transport, protocol async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None, start_serving=True): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -288,10 +282,6 @@ async def create_unix_server( raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') - if ssl_shutdown_timeout is not None and not ssl: - raise ValueError( - 'ssl_shutdown_timeout is only meaningful with ssl') - if path is not None: if sock is not None: raise ValueError( @@ -338,8 +328,7 @@ async def create_unix_server( sock.setblocking(False) server = base_events.Server(self, [sock], protocol_factory, - ssl, backlog, ssl_handshake_timeout, - ssl_shutdown_timeout) + ssl, backlog, ssl_handshake_timeout) if start_serving: server._start_serving() # Skip one loop iteration so that all 'loop.add_reader' diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index be5ea1e3c73..5691d4250ac 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1437,51 +1437,44 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = mock.ANY handshake_timeout = object() - shutdown_timeout = object() # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='python.org', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='perl.com', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection( MyProto, 'python.org', 80, ssl=True, server_hostname='', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) transport, _ = self.loop.run_until_complete(coro) transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, server_hostname='', - ssl_handshake_timeout=handshake_timeout, - ssl_shutdown_timeout=shutdown_timeout) + ssl_handshake_timeout=handshake_timeout) def test_create_connection_no_ssl_server_hostname_errors(self): # When not using ssl, server_hostname must be None. @@ -1888,7 +1881,7 @@ def test_accept_connection_exception(self, m_log): constants.ACCEPT_RETRY_DELAY, # self.loop._start_serving mock.ANY, - MyProto, sock, None, None, mock.ANY, mock.ANY, mock.ANY) + MyProto, sock, None, None, mock.ANY, mock.ANY) def test_call_coroutine(self): with self.assertWarns(DeprecationWarning): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 349e4f2dca0..1613c753c26 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -70,6 +70,44 @@ def test_make_socket_transport(self): close_transport(transport) + @unittest.skipIf(ssl is None, 'No ssl module') + def test_make_ssl_transport(self): + m = mock.Mock() + self.loop._add_reader = mock.Mock() + self.loop._add_reader._is_coroutine = False + self.loop._add_writer = mock.Mock() + self.loop._remove_reader = mock.Mock() + self.loop._remove_writer = mock.Mock() + waiter = self.loop.create_future() + with test_utils.disable_logger(): + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) + + with self.assertRaisesRegex(RuntimeError, + r'SSL transport.*not.*initialized'): + transport.is_reading() + + # execute the handshake while the logger is disabled + # to ignore SSL handshake failure + test_utils.run_briefly(self.loop) + + self.assertTrue(transport.is_reading()) + transport.pause_reading() + transport.pause_reading() + self.assertFalse(transport.is_reading()) + transport.resume_reading() + transport.resume_reading() + self.assertTrue(transport.is_reading()) + + # Sanity check + class_name = transport.__class__.__name__ + self.assertIn("ssl", class_name.lower()) + self.assertIn("transport", class_name.lower()) + + transport.close() + # execute pending callbacks to close the socket transport + test_utils.run_briefly(self.loop) + @mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): diff --git a/Lib/test/test_asyncio/test_ssl.py b/Lib/test/test_asyncio/test_ssl.py deleted file mode 100644 index 9cdd281221c..00000000000 --- a/Lib/test/test_asyncio/test_ssl.py +++ /dev/null @@ -1,1723 +0,0 @@ -import asyncio -import asyncio.sslproto -import contextlib -import gc -import logging -import select -import socket -import tempfile -import threading -import time -import weakref -import unittest - -try: - import ssl -except ImportError: - ssl = None - -from test import support -from test.test_asyncio import utils as test_utils - - -def tearDownModule(): - asyncio.set_event_loop_policy(None) - - -class MyBaseProto(asyncio.Protocol): - connected = None - done = None - - def __init__(self, loop=None): - self.transport = None - self.state = 'INITIAL' - self.nbytes = 0 - if loop is not None: - self.connected = asyncio.Future(loop=loop) - self.done = asyncio.Future(loop=loop) - - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' - if self.connected: - self.connected.set_result(None) - - def data_received(self, data): - assert self.state == 'CONNECTED', self.state - self.nbytes += len(data) - - def eof_received(self): - assert self.state == 'CONNECTED', self.state - self.state = 'EOF' - - def connection_lost(self, exc): - assert self.state in ('CONNECTED', 'EOF'), self.state - self.state = 'CLOSED' - if self.done: - self.done.set_result(None) - - -@unittest.skipIf(ssl is None, 'No ssl module') -class TestSSL(test_utils.TestCase): - - PAYLOAD_SIZE = 1024 * 100 - TIMEOUT = 60 - - def setUp(self): - super().setUp() - self.loop = asyncio.new_event_loop() - self.set_event_loop(self.loop) - self.addCleanup(self.loop.close) - - def tearDown(self): - # just in case if we have transport close callbacks - if not self.loop.is_closed(): - test_utils.run_briefly(self.loop) - - self.doCleanups() - support.gc_collect() - super().tearDown() - - def tcp_server(self, server_prog, *, - family=socket.AF_INET, - addr=None, - timeout=5, - backlog=1, - max_clients=10): - - if addr is None: - if family == getattr(socket, "AF_UNIX", None): - with tempfile.NamedTemporaryFile() as tmp: - addr = tmp.name - else: - addr = ('127.0.0.1', 0) - - sock = socket.socket(family, socket.SOCK_STREAM) - - if timeout is None: - raise RuntimeError('timeout is required') - if timeout <= 0: - raise RuntimeError('only blocking sockets are supported') - sock.settimeout(timeout) - - try: - sock.bind(addr) - sock.listen(backlog) - except OSError as ex: - sock.close() - raise ex - - return TestThreadedServer( - self, sock, server_prog, timeout, max_clients) - - def tcp_client(self, client_prog, - family=socket.AF_INET, - timeout=10): - - sock = socket.socket(family, socket.SOCK_STREAM) - - if timeout is None: - raise RuntimeError('timeout is required') - if timeout <= 0: - raise RuntimeError('only blocking sockets are supported') - sock.settimeout(timeout) - - return TestThreadedClient( - self, sock, client_prog, timeout) - - def unix_server(self, *args, **kwargs): - return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) - - def unix_client(self, *args, **kwargs): - return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) - - def _create_server_ssl_context(self, certfile, keyfile=None): - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.load_cert_chain(certfile, keyfile) - return sslcontext - - def _create_client_ssl_context(self, *, disable_verify=True): - sslcontext = ssl.create_default_context() - sslcontext.check_hostname = False - if disable_verify: - sslcontext.verify_mode = ssl.CERT_NONE - return sslcontext - - @contextlib.contextmanager - def _silence_eof_received_warning(self): - # TODO This warning has to be fixed in asyncio. - logger = logging.getLogger('asyncio') - filter = logging.Filter('has no effect when using ssl') - logger.addFilter(filter) - try: - yield - finally: - logger.removeFilter(filter) - - def _abort_socket_test(self, ex): - try: - self.loop.stop() - finally: - self.fail(ex) - - def new_loop(self): - return asyncio.new_event_loop() - - def new_policy(self): - return asyncio.DefaultEventLoopPolicy() - - async def wait_closed(self, obj): - if not isinstance(obj, asyncio.StreamWriter): - return - try: - await obj.wait_closed() - except (BrokenPipeError, ConnectionError): - pass - - def test_create_server_ssl_1(self): - CNT = 0 # number of clients that were successful - TOTAL_CNT = 25 # total number of clients that test will create - TIMEOUT = 60.0 # timeout for this test - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - - clients = [] - - async def handle_client(reader, writer): - nonlocal CNT - - data = await reader.readexactly(len(A_DATA)) - self.assertEqual(data, A_DATA) - writer.write(b'OK') - - data = await reader.readexactly(len(B_DATA)) - self.assertEqual(data, B_DATA) - writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) - - await writer.drain() - writer.close() - - CNT += 1 - - async def test_client(addr): - fut = asyncio.Future() - - def prog(sock): - try: - sock.starttls(client_sslctx) - sock.connect(addr) - sock.send(A_DATA) - - data = sock.recv_all(2) - self.assertEqual(data, b'OK') - - sock.send(B_DATA) - data = sock.recv_all(4) - self.assertEqual(data, b'SPAM') - - sock.close() - - except Exception as ex: - self.loop.call_soon_threadsafe(fut.set_exception, ex) - else: - self.loop.call_soon_threadsafe(fut.set_result, None) - - client = self.tcp_client(prog) - client.start() - clients.append(client) - - await fut - - async def start_server(): - extras = {} - extras = dict(ssl_handshake_timeout=40.0) - - srv = await asyncio.start_server( - handle_client, - '127.0.0.1', 0, - family=socket.AF_INET, - ssl=sslctx, - **extras) - - try: - srv_socks = srv.sockets - self.assertTrue(srv_socks) - - addr = srv_socks[0].getsockname() - - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(test_client(addr)) - - await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) - - finally: - self.loop.call_soon(srv.close) - await srv.wait_closed() - - with self._silence_eof_received_warning(): - self.loop.run_until_complete(start_server()) - - self.assertEqual(CNT, TOTAL_CNT) - - for client in clients: - client.stop() - - def test_create_connection_ssl_1(self): - self.loop.set_exception_handler(None) - - CNT = 0 - TOTAL_CNT = 25 - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - - def server(sock): - sock.starttls( - sslctx, - server_side=True) - - data = sock.recv_all(len(A_DATA)) - self.assertEqual(data, A_DATA) - sock.send(b'OK') - - data = sock.recv_all(len(B_DATA)) - self.assertEqual(data, B_DATA) - sock.send(b'SPAM') - - sock.close() - - async def client(addr): - extras = {} - extras = dict(ssl_handshake_timeout=40.0) - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - **extras) - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - writer.write(B_DATA) - self.assertEqual(await reader.readexactly(4), b'SPAM') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - - async def client_sock(addr): - sock = socket.socket() - sock.connect(addr) - reader, writer = await asyncio.open_connection( - sock=sock, - ssl=client_sslctx, - server_hostname='') - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - writer.write(B_DATA) - self.assertEqual(await reader.readexactly(4), b'SPAM') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - sock.close() - - def run(coro): - nonlocal CNT - CNT = 0 - - async def _gather(*tasks): - # trampoline - return await asyncio.gather(*tasks) - - with self.tcp_server(server, - max_clients=TOTAL_CNT, - backlog=TOTAL_CNT) as srv: - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(coro(srv.addr)) - - self.loop.run_until_complete(_gather(*tasks)) - - self.assertEqual(CNT, TOTAL_CNT) - - with self._silence_eof_received_warning(): - run(client) - - with self._silence_eof_received_warning(): - run(client_sock) - - def test_create_connection_ssl_slow_handshake(self): - client_sslctx = self._create_client_ssl_context() - - # silence error logger - self.loop.set_exception_handler(lambda *args: None) - - def server(sock): - try: - sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - pass - finally: - sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=1.0) - writer.close() - await self.wait_closed(writer) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaisesRegex( - ConnectionAbortedError, - r'SSL handshake.*is taking longer'): - - self.loop.run_until_complete(client(srv.addr)) - - def test_create_connection_ssl_failed_certificate(self): - # silence error logger - self.loop.set_exception_handler(lambda *args: None) - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context(disable_verify=False) - - def server(sock): - try: - sock.starttls( - sslctx, - server_side=True) - sock.connect() - except (ssl.SSLError, OSError): - pass - finally: - sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=1.0) - writer.close() - await self.wait_closed(writer) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(ssl.SSLCertVerificationError): - self.loop.run_until_complete(client(srv.addr)) - - def test_ssl_handshake_timeout(self): - # bpo-29970: Check that a connection is aborted if handshake is not - # completed in timeout period, instead of remaining open indefinitely - client_sslctx = test_utils.simple_client_sslcontext() - - # silence error logger - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - - server_side_aborted = False - - def server(sock): - nonlocal server_side_aborted - try: - sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - server_side_aborted = True - finally: - sock.close() - - async def client(addr): - await asyncio.wait_for( - self.loop.create_connection( - asyncio.Protocol, - *addr, - ssl=client_sslctx, - server_hostname='', - ssl_handshake_timeout=10.0), - 0.5) - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(asyncio.TimeoutError): - self.loop.run_until_complete(client(srv.addr)) - - self.assertTrue(server_side_aborted) - - # Python issue #23197: cancelling a handshake must not raise an - # exception or log an error, even if the handshake failed - self.assertEqual(messages, []) - - def test_ssl_handshake_connection_lost(self): - # #246: make sure that no connection_lost() is called before - # connection_made() is called first - - client_sslctx = test_utils.simple_client_sslcontext() - - # silence error logger - self.loop.set_exception_handler(lambda loop, ctx: None) - - connection_made_called = False - connection_lost_called = False - - def server(sock): - sock.recv(1024) - # break the connection during handshake - sock.close() - - class ClientProto(asyncio.Protocol): - def connection_made(self, transport): - nonlocal connection_made_called - connection_made_called = True - - def connection_lost(self, exc): - nonlocal connection_lost_called - connection_lost_called = True - - async def client(addr): - await self.loop.create_connection( - ClientProto, - *addr, - ssl=client_sslctx, - server_hostname=''), - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - with self.assertRaises(ConnectionResetError): - self.loop.run_until_complete(client(srv.addr)) - - if connection_lost_called: - if connection_made_called: - self.fail("unexpected call to connection_lost()") - else: - self.fail("unexpected call to connection_lost() without" - "calling connection_made()") - elif connection_made_called: - self.fail("unexpected call to connection_made()") - - def test_ssl_connect_accepted_socket(self): - proto = ssl.PROTOCOL_TLS_SERVER - server_context = ssl.SSLContext(proto) - server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) - if hasattr(server_context, 'check_hostname'): - server_context.check_hostname = False - server_context.verify_mode = ssl.CERT_NONE - - client_context = ssl.SSLContext(proto) - if hasattr(server_context, 'check_hostname'): - client_context.check_hostname = False - client_context.verify_mode = ssl.CERT_NONE - - def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): - loop = self.loop - - class MyProto(MyBaseProto): - - def connection_lost(self, exc): - super().connection_lost(exc) - loop.call_soon(loop.stop) - - def data_received(self, data): - super().data_received(data) - self.transport.write(expected_response) - - lsock = socket.socket(socket.AF_INET) - lsock.bind(('127.0.0.1', 0)) - lsock.listen(1) - addr = lsock.getsockname() - - message = b'test data' - response = None - expected_response = b'roger' - - def client(): - nonlocal response - try: - csock = socket.socket(socket.AF_INET) - if client_ssl is not None: - csock = client_ssl.wrap_socket(csock) - csock.connect(addr) - csock.sendall(message) - response = csock.recv(99) - csock.close() - except Exception as exc: - print( - "Failure in client thread in test_connect_accepted_socket", - exc) - - thread = threading.Thread(target=client, daemon=True) - thread.start() - - conn, _ = lsock.accept() - proto = MyProto(loop=loop) - proto.loop = loop - - extras = {} - if server_ssl: - extras = dict(ssl_handshake_timeout=10.0) - - f = loop.create_task( - loop.connect_accepted_socket( - (lambda: proto), conn, ssl=server_ssl, - **extras)) - loop.run_forever() - conn.close() - lsock.close() - - thread.join(1) - self.assertFalse(thread.is_alive()) - self.assertEqual(proto.state, 'CLOSED') - self.assertEqual(proto.nbytes, len(message)) - self.assertEqual(response, expected_response) - tr, _ = f.result() - - if server_ssl: - self.assertIn('SSL', tr.__class__.__name__) - - tr.close() - # let it close - self.loop.run_until_complete(asyncio.sleep(0.1)) - - def test_start_tls_client_corrupted_ssl(self): - self.loop.set_exception_handler(lambda loop, ctx: None) - - sslctx = test_utils.simple_server_sslcontext() - client_sslctx = test_utils.simple_client_sslcontext() - - def server(sock): - orig_sock = sock.dup() - try: - sock.starttls( - sslctx, - server_side=True) - sock.sendall(b'A\n') - sock.recv_all(1) - orig_sock.send(b'please corrupt the SSL connection') - except ssl.SSLError: - pass - finally: - sock.close() - orig_sock.close() - - async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - - self.assertEqual(await reader.readline(), b'A\n') - writer.write(b'B') - with self.assertRaises(ssl.SSLError): - await reader.readline() - writer.close() - try: - await self.wait_closed(writer) - except ssl.SSLError: - pass - return 'OK' - - with self.tcp_server(server, - max_clients=1, - backlog=1) as srv: - - res = self.loop.run_until_complete(client(srv.addr)) - - self.assertEqual(res, 'OK') - - def test_start_tls_client_reg_proto_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr) - - tr.write(HELLO_MSG) - new_tr = await self.loop.start_tls(tr, proto, client_context) - - self.assertEqual(await on_data, b'O') - new_tr.write(HELLO_MSG) - await on_eof - - new_tr.close() - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - def test_create_connection_memory_leak(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_context = self._create_client_ssl_context() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - # XXX: We assume user stores the transport in protocol - proto.tr = tr - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr, - ssl=client_context) - - self.assertEqual(await on_data, b'O') - tr.write(HELLO_MSG) - await on_eof - - tr.close() - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - # No garbage is left for SSL client from loop.create_connection, even - # if user stores the SSLTransport in corresponding protocol instance - client_context = weakref.ref(client_context) - self.assertIsNone(client_context()) - - def test_start_tls_client_buf_proto_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - client_con_made_calls = 0 - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(server_context, server_side=True) - - sock.sendall(b'O') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.sendall(b'2') - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.unwrap() - sock.close() - - class ClientProtoFirst(asyncio.BufferedProtocol): - def __init__(self, on_data): - self.on_data = on_data - self.buf = bytearray(1) - - def connection_made(self, tr): - nonlocal client_con_made_calls - client_con_made_calls += 1 - - def get_buffer(self, sizehint): - return self.buf - - def buffer_updated(self, nsize): - assert nsize == 1 - self.on_data.set_result(bytes(self.buf[:nsize])) - - def eof_received(self): - pass - - class ClientProtoSecond(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(self, tr): - nonlocal client_con_made_calls - client_con_made_calls += 1 - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data1 = self.loop.create_future() - on_data2 = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProtoFirst(on_data1), *addr) - - tr.write(HELLO_MSG) - new_tr = await self.loop.start_tls(tr, proto, client_context) - - self.assertEqual(await on_data1, b'O') - new_tr.write(HELLO_MSG) - - new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) - self.assertEqual(await on_data2, b'2') - new_tr.write(HELLO_MSG) - await on_eof - - new_tr.close() - - # connection_made() should be called only once -- when - # we establish connection for the first time. Start TLS - # doesn't call connection_made() on application protocols. - self.assertEqual(client_con_made_calls, 1) - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), - timeout=self.TIMEOUT)) - - def test_start_tls_slow_client_cancel(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - client_context = test_utils.simple_client_sslcontext() - server_waits_on_handshake = self.loop.create_future() - - def serve(sock): - sock.settimeout(self.TIMEOUT) - - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - try: - self.loop.call_soon_threadsafe( - server_waits_on_handshake.set_result, None) - data = sock.recv_all(1024 * 1024) - except ConnectionAbortedError: - pass - finally: - sock.close() - - class ClientProto(asyncio.Protocol): - def __init__(self, on_data, on_eof): - self.on_data = on_data - self.on_eof = on_eof - self.con_made_cnt = 0 - - def connection_made(proto, tr): - proto.con_made_cnt += 1 - # Ensure connection_made gets called only once. - self.assertEqual(proto.con_made_cnt, 1) - - def data_received(self, data): - self.on_data.set_result(data) - - def eof_received(self): - self.on_eof.set_result(True) - - async def client(addr): - await asyncio.sleep(0.5) - - on_data = self.loop.create_future() - on_eof = self.loop.create_future() - - tr, proto = await self.loop.create_connection( - lambda: ClientProto(on_data, on_eof), *addr) - - tr.write(HELLO_MSG) - - await server_waits_on_handshake - - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for( - self.loop.start_tls(tr, proto, client_context), - 0.5) - - with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: - self.loop.run_until_complete( - asyncio.wait_for(client(srv.addr), timeout=10)) - - def test_start_tls_server_1(self): - HELLO_MSG = b'1' * self.PAYLOAD_SIZE - - server_context = test_utils.simple_server_sslcontext() - client_context = test_utils.simple_client_sslcontext() - - def client(sock, addr): - sock.settimeout(self.TIMEOUT) - - sock.connect(addr) - data = sock.recv_all(len(HELLO_MSG)) - self.assertEqual(len(data), len(HELLO_MSG)) - - sock.starttls(client_context) - sock.sendall(HELLO_MSG) - - sock.unwrap() - sock.close() - - class ServerProto(asyncio.Protocol): - def __init__(self, on_con, on_eof, on_con_lost): - self.on_con = on_con - self.on_eof = on_eof - self.on_con_lost = on_con_lost - self.data = b'' - - def connection_made(self, tr): - self.on_con.set_result(tr) - - def data_received(self, data): - self.data += data - - def eof_received(self): - self.on_eof.set_result(1) - - def connection_lost(self, exc): - if exc is None: - self.on_con_lost.set_result(None) - else: - self.on_con_lost.set_exception(exc) - - async def main(proto, on_con, on_eof, on_con_lost): - tr = await on_con - tr.write(HELLO_MSG) - - self.assertEqual(proto.data, b'') - - new_tr = await self.loop.start_tls( - tr, proto, server_context, - server_side=True, - ssl_handshake_timeout=self.TIMEOUT) - - await on_eof - await on_con_lost - self.assertEqual(proto.data, HELLO_MSG) - new_tr.close() - - async def run_main(): - on_con = self.loop.create_future() - on_eof = self.loop.create_future() - on_con_lost = self.loop.create_future() - proto = ServerProto(on_con, on_eof, on_con_lost) - - server = await self.loop.create_server( - lambda: proto, '127.0.0.1', 0) - addr = server.sockets[0].getsockname() - - with self.tcp_client(lambda sock: client(sock, addr), - timeout=self.TIMEOUT): - await asyncio.wait_for( - main(proto, on_con, on_eof, on_con_lost), - timeout=self.TIMEOUT) - - server.close() - await server.wait_closed() - - self.loop.run_until_complete(run_main()) - - def test_create_server_ssl_over_ssl(self): - CNT = 0 # number of clients that were successful - TOTAL_CNT = 25 # total number of clients that test will create - TIMEOUT = 10.0 # timeout for this test - - A_DATA = b'A' * 1024 * 1024 - B_DATA = b'B' * 1024 * 1024 - - sslctx_1 = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx_1 = self._create_client_ssl_context() - sslctx_2 = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx_2 = self._create_client_ssl_context() - - clients = [] - - async def handle_client(reader, writer): - nonlocal CNT - - data = await reader.readexactly(len(A_DATA)) - self.assertEqual(data, A_DATA) - writer.write(b'OK') - - data = await reader.readexactly(len(B_DATA)) - self.assertEqual(data, B_DATA) - writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) - - await writer.drain() - writer.close() - - CNT += 1 - - class ServerProtocol(asyncio.StreamReaderProtocol): - def connection_made(self, transport): - super_ = super() - transport.pause_reading() - fut = self._loop.create_task(self._loop.start_tls( - transport, self, sslctx_2, server_side=True)) - - def cb(_): - try: - tr = fut.result() - except Exception as ex: - super_.connection_lost(ex) - else: - super_.connection_made(tr) - fut.add_done_callback(cb) - - def server_protocol_factory(): - reader = asyncio.StreamReader() - protocol = ServerProtocol(reader, handle_client) - return protocol - - async def test_client(addr): - fut = asyncio.Future() - - def prog(sock): - try: - sock.connect(addr) - sock.starttls(client_sslctx_1) - - # because wrap_socket() doesn't work correctly on - # SSLSocket, we have to do the 2nd level SSL manually - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) - - def do(func, *args): - while True: - try: - rv = func(*args) - break - except ssl.SSLWantReadError: - if outgoing.pending: - sock.send(outgoing.read()) - incoming.write(sock.recv(65536)) - if outgoing.pending: - sock.send(outgoing.read()) - return rv - - do(sslobj.do_handshake) - - do(sslobj.write, A_DATA) - data = do(sslobj.read, 2) - self.assertEqual(data, b'OK') - - do(sslobj.write, B_DATA) - data = b'' - while True: - chunk = do(sslobj.read, 4) - if not chunk: - break - data += chunk - self.assertEqual(data, b'SPAM') - - do(sslobj.unwrap) - sock.close() - - except Exception as ex: - self.loop.call_soon_threadsafe(fut.set_exception, ex) - sock.close() - else: - self.loop.call_soon_threadsafe(fut.set_result, None) - - client = self.tcp_client(prog) - client.start() - clients.append(client) - - await fut - - async def start_server(): - extras = {} - - srv = await self.loop.create_server( - server_protocol_factory, - '127.0.0.1', 0, - family=socket.AF_INET, - ssl=sslctx_1, - **extras) - - try: - srv_socks = srv.sockets - self.assertTrue(srv_socks) - - addr = srv_socks[0].getsockname() - - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(test_client(addr)) - - await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) - - finally: - self.loop.call_soon(srv.close) - await srv.wait_closed() - - with self._silence_eof_received_warning(): - self.loop.run_until_complete(start_server()) - - self.assertEqual(CNT, TOTAL_CNT) - - for client in clients: - client.stop() - - def test_shutdown_cleanly(self): - CNT = 0 - TOTAL_CNT = 25 - - A_DATA = b'A' * 1024 * 1024 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx = self._create_client_ssl_context() - - def server(sock): - sock.starttls( - sslctx, - server_side=True) - - data = sock.recv_all(len(A_DATA)) - self.assertEqual(data, A_DATA) - sock.send(b'OK') - - sock.unwrap() - - sock.close() - - async def client(addr): - extras = {} - extras = dict(ssl_handshake_timeout=10.0) - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='', - **extras) - - writer.write(A_DATA) - self.assertEqual(await reader.readexactly(2), b'OK') - - self.assertEqual(await reader.read(), b'') - - nonlocal CNT - CNT += 1 - - writer.close() - await self.wait_closed(writer) - - def run(coro): - nonlocal CNT - CNT = 0 - - async def _gather(*tasks): - return await asyncio.gather(*tasks) - - with self.tcp_server(server, - max_clients=TOTAL_CNT, - backlog=TOTAL_CNT) as srv: - tasks = [] - for _ in range(TOTAL_CNT): - tasks.append(coro(srv.addr)) - - self.loop.run_until_complete( - _gather(*tasks)) - - self.assertEqual(CNT, TOTAL_CNT) - - with self._silence_eof_received_warning(): - run(client) - - def test_flush_before_shutdown(self): - CHUNK = 1024 * 128 - SIZE = 32 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, test_utils.ONLYKEY) - client_sslctx = self._create_client_ssl_context() - if hasattr(ssl, 'OP_NO_TLSv1_3'): - client_sslctx.options |= ssl.OP_NO_TLSv1_3 - - future = None - - def server(sock): - sock.starttls(sslctx, server_side=True) - self.assertEqual(sock.recv_all(4), b'ping') - sock.send(b'pong') - time.sleep(0.5) # hopefully stuck the TCP buffer - data = sock.recv_all(CHUNK * SIZE) - self.assertEqual(len(data), CHUNK * SIZE) - sock.close() - - def run(meth): - def wrapper(sock): - try: - meth(sock) - except Exception as ex: - self.loop.call_soon_threadsafe(future.set_exception, ex) - else: - self.loop.call_soon_threadsafe(future.set_result, None) - return wrapper - - async def client(addr): - nonlocal future - future = self.loop.create_future() - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - sslprotocol = writer.transport._ssl_protocol - writer.write(b'ping') - data = await reader.readexactly(4) - self.assertEqual(data, b'pong') - - sslprotocol.pause_writing() - for _ in range(SIZE): - writer.write(b'x' * CHUNK) - - writer.close() - sslprotocol.resume_writing() - - await self.wait_closed(writer) - try: - data = await reader.read() - self.assertEqual(data, b'') - except ConnectionResetError: - pass - await future - - with self.tcp_server(run(server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - def test_remote_shutdown_receives_trailing_data(self): - CHUNK = 1024 * 128 - SIZE = 32 - - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - client_sslctx = self._create_client_ssl_context() - future = None - - def server(sock): - incoming = ssl.MemoryBIO() - outgoing = ssl.MemoryBIO() - sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) - - while True: - try: - sslobj.do_handshake() - except ssl.SSLWantReadError: - if outgoing.pending: - sock.send(outgoing.read()) - incoming.write(sock.recv(16384)) - else: - if outgoing.pending: - sock.send(outgoing.read()) - break - - while True: - try: - data = sslobj.read(4) - except ssl.SSLWantReadError: - incoming.write(sock.recv(16384)) - else: - break - - self.assertEqual(data, b'ping') - sslobj.write(b'pong') - sock.send(outgoing.read()) - - time.sleep(0.2) # wait for the peer to fill its backlog - - # send close_notify but don't wait for response - with self.assertRaises(ssl.SSLWantReadError): - sslobj.unwrap() - sock.send(outgoing.read()) - - # should receive all data - data_len = 0 - while True: - try: - chunk = len(sslobj.read(16384)) - data_len += chunk - except ssl.SSLWantReadError: - incoming.write(sock.recv(16384)) - except ssl.SSLZeroReturnError: - break - - self.assertEqual(data_len, CHUNK * SIZE) - - # verify that close_notify is received - sslobj.unwrap() - - sock.close() - - def eof_server(sock): - sock.starttls(sslctx, server_side=True) - self.assertEqual(sock.recv_all(4), b'ping') - sock.send(b'pong') - - time.sleep(0.2) # wait for the peer to fill its backlog - - # send EOF - sock.shutdown(socket.SHUT_WR) - - # should receive all data - data = sock.recv_all(CHUNK * SIZE) - self.assertEqual(len(data), CHUNK * SIZE) - - sock.close() - - async def client(addr): - nonlocal future - future = self.loop.create_future() - - reader, writer = await asyncio.open_connection( - *addr, - ssl=client_sslctx, - server_hostname='') - writer.write(b'ping') - data = await reader.readexactly(4) - self.assertEqual(data, b'pong') - - # fill write backlog in a hacky way - renegotiation won't help - for _ in range(SIZE): - writer.transport._test__append_write_backlog(b'x' * CHUNK) - - try: - data = await reader.read() - self.assertEqual(data, b'') - except (BrokenPipeError, ConnectionResetError): - pass - - await future - - writer.close() - await self.wait_closed(writer) - - def run(meth): - def wrapper(sock): - try: - meth(sock) - except Exception as ex: - self.loop.call_soon_threadsafe(future.set_exception, ex) - else: - self.loop.call_soon_threadsafe(future.set_result, None) - return wrapper - - with self.tcp_server(run(server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - with self.tcp_server(run(eof_server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - - def test_connect_timeout_warning(self): - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - addr = s.getsockname() - - async def test(): - try: - await asyncio.wait_for( - self.loop.create_connection(asyncio.Protocol, - *addr, ssl=True), - 0.1) - except (ConnectionRefusedError, asyncio.TimeoutError): - pass - else: - self.fail('TimeoutError is not raised') - - with s: - try: - with self.assertWarns(ResourceWarning) as cm: - self.loop.run_until_complete(test()) - gc.collect() - gc.collect() - gc.collect() - except AssertionError as e: - self.assertEqual(str(e), 'ResourceWarning not triggered') - else: - self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) - - def test_handshake_timeout_handler_leak(self): - s = socket.socket(socket.AF_INET) - s.bind(('127.0.0.1', 0)) - s.listen(1) - addr = s.getsockname() - - async def test(ctx): - try: - await asyncio.wait_for( - self.loop.create_connection(asyncio.Protocol, *addr, - ssl=ctx), - 0.1) - except (ConnectionRefusedError, asyncio.TimeoutError): - pass - else: - self.fail('TimeoutError is not raised') - - with s: - ctx = ssl.create_default_context() - self.loop.run_until_complete(test(ctx)) - ctx = weakref.ref(ctx) - - # SSLProtocol should be DECREF to 0 - self.assertIsNone(ctx()) - - def test_shutdown_timeout_handler_leak(self): - loop = self.loop - - def server(sock): - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - sock = sslctx.wrap_socket(sock, server_side=True) - sock.recv(32) - sock.close() - - class Protocol(asyncio.Protocol): - def __init__(self): - self.fut = asyncio.Future(loop=loop) - - def connection_lost(self, exc): - self.fut.set_result(None) - - async def client(addr, ctx): - tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) - tr.close() - await pr.fut - - with self.tcp_server(server) as srv: - ctx = self._create_client_ssl_context() - loop.run_until_complete(client(srv.addr, ctx)) - ctx = weakref.ref(ctx) - - # asyncio has no shutdown timeout, but it ends up with a circular - # reference loop - not ideal (introduces gc glitches), but at least - # not leaking - gc.collect() - gc.collect() - gc.collect() - - # SSLProtocol should be DECREF to 0 - self.assertIsNone(ctx()) - - def test_shutdown_timeout_handler_not_set(self): - loop = self.loop - eof = asyncio.Event() - extra = None - - def server(sock): - sslctx = self._create_server_ssl_context( - test_utils.ONLYCERT, - test_utils.ONLYKEY - ) - sock = sslctx.wrap_socket(sock, server_side=True) - sock.send(b'hello') - assert sock.recv(1024) == b'world' - sock.send(b'extra bytes') - # sending EOF here - sock.shutdown(socket.SHUT_WR) - loop.call_soon_threadsafe(eof.set) - # make sure we have enough time to reproduce the issue - assert sock.recv(1024) == b'' - sock.close() - - class Protocol(asyncio.Protocol): - def __init__(self): - self.fut = asyncio.Future(loop=loop) - self.transport = None - - def connection_made(self, transport): - self.transport = transport - - def data_received(self, data): - if data == b'hello': - self.transport.write(b'world') - # pause reading would make incoming data stay in the sslobj - self.transport.pause_reading() - else: - nonlocal extra - extra = data - - def connection_lost(self, exc): - if exc is None: - self.fut.set_result(None) - else: - self.fut.set_exception(exc) - - async def client(addr): - ctx = self._create_client_ssl_context() - tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) - await eof.wait() - tr.resume_reading() - await pr.fut - tr.close() - assert extra == b'extra bytes' - - with self.tcp_server(server) as srv: - loop.run_until_complete(client(srv.addr)) - - -############################################################################### -# Socket Testing Utilities -############################################################################### - - -class TestSocketWrapper: - - def __init__(self, sock): - self.__sock = sock - - def recv_all(self, n): - buf = b'' - while len(buf) < n: - data = self.recv(n - len(buf)) - if data == b'': - raise ConnectionAbortedError - buf += data - return buf - - def starttls(self, ssl_context, *, - server_side=False, - server_hostname=None, - do_handshake_on_connect=True): - - assert isinstance(ssl_context, ssl.SSLContext) - - ssl_sock = ssl_context.wrap_socket( - self.__sock, server_side=server_side, - server_hostname=server_hostname, - do_handshake_on_connect=do_handshake_on_connect) - - if server_side: - ssl_sock.do_handshake() - - self.__sock.close() - self.__sock = ssl_sock - - def __getattr__(self, name): - return getattr(self.__sock, name) - - def __repr__(self): - return '<{} {!r}>'.format(type(self).__name__, self.__sock) - - -class SocketThread(threading.Thread): - - def stop(self): - self._active = False - self.join() - - def __enter__(self): - self.start() - return self - - def __exit__(self, *exc): - self.stop() - - -class TestThreadedClient(SocketThread): - - def __init__(self, test, sock, prog, timeout): - threading.Thread.__init__(self, None, None, 'test-client') - self.daemon = True - - self._timeout = timeout - self._sock = sock - self._active = True - self._prog = prog - self._test = test - - def run(self): - try: - self._prog(TestSocketWrapper(self._sock)) - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._test._abort_socket_test(ex) - - -class TestThreadedServer(SocketThread): - - def __init__(self, test, sock, prog, timeout, max_clients): - threading.Thread.__init__(self, None, None, 'test-server') - self.daemon = True - - self._clients = 0 - self._finished_clients = 0 - self._max_clients = max_clients - self._timeout = timeout - self._sock = sock - self._active = True - - self._prog = prog - - self._s1, self._s2 = socket.socketpair() - self._s1.setblocking(False) - - self._test = test - - def stop(self): - try: - if self._s2 and self._s2.fileno() != -1: - try: - self._s2.send(b'stop') - except OSError: - pass - finally: - super().stop() - - def run(self): - try: - with self._sock: - self._sock.setblocking(0) - self._run() - finally: - self._s1.close() - self._s2.close() - - def _run(self): - while self._active: - if self._clients >= self._max_clients: - return - - r, w, x = select.select( - [self._sock, self._s1], [], [], self._timeout) - - if self._s1 in r: - return - - if self._sock in r: - try: - conn, addr = self._sock.accept() - except BlockingIOError: - continue - except socket.timeout: - if not self._active: - return - else: - raise - else: - self._clients += 1 - conn.settimeout(self._timeout) - try: - with conn: - self._handle_client(conn) - except (KeyboardInterrupt, SystemExit): - raise - except BaseException as ex: - self._active = False - try: - raise - finally: - self._test._abort_socket_test(ex) - - def _handle_client(self, sock): - self._prog(TestSocketWrapper(sock)) - - @property - def addr(self): - return self._sock.getsockname() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 79a81bd8c39..e87863eb712 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -15,6 +15,7 @@ from asyncio import log from asyncio import protocols from asyncio import sslproto +from test import support from test.test_asyncio import utils as test_utils from test.test_asyncio import functional as func_tests @@ -43,13 +44,16 @@ def ssl_protocol(self, *, waiter=None, proto=None): def connection_made(self, ssl_proto, *, do_handshake=None): transport = mock.Mock() - sslobj = mock.Mock() - # emulate reading decompressed data - sslobj.read.side_effect = ssl.SSLWantReadError - if do_handshake is not None: - sslobj.do_handshake = do_handshake - ssl_proto._sslobj = sslobj - ssl_proto.connection_made(transport) + sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' + if do_handshake: + sslpipe.do_handshake.side_effect = do_handshake + else: + def mock_handshake(callback): + return [] + sslpipe.do_handshake.side_effect = mock_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) return transport def test_handshake_timeout_zero(self): @@ -71,10 +75,7 @@ def test_handshake_timeout_negative(self): def test_eof_received_waiter(self): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + self.connection_made(ssl_proto) ssl_proto.eof_received() test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionResetError) @@ -99,10 +100,7 @@ def test_connection_lost(self): # yield from waiter hang if lost_connection was called. waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + self.connection_made(ssl_proto) ssl_proto.connection_lost(ConnectionAbortedError) test_utils.run_briefly(self.loop) self.assertIsInstance(waiter.exception(), ConnectionAbortedError) @@ -112,10 +110,7 @@ def test_close_during_handshake(self): waiter = self.loop.create_future() ssl_proto = self.ssl_protocol(waiter=waiter) - transport = self.connection_made( - ssl_proto, - do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) - ) + transport = self.connection_made(ssl_proto) test_utils.run_briefly(self.loop) ssl_proto._app_transport.close() @@ -148,7 +143,7 @@ def test_data_received_after_closing(self): transp.close() # should not raise - self.assertIsNone(ssl_proto.buffer_updated(5)) + self.assertIsNone(ssl_proto.data_received(b'data')) def test_write_after_closing(self): ssl_proto = self.ssl_protocol() diff --git a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst b/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst deleted file mode 100644 index e2b5a9e395e..00000000000 --- a/Misc/NEWS.d/next/Library/2021-05-02-23-44-21.bpo-44011.hd8iUO.rst +++ /dev/null @@ -1,2 +0,0 @@ -Reimplement SSL/TLS support in asyncio, borrow the impelementation from -uvloop library.