From 46c942f9477936a36b77aa6dd505d1f5e8e79550 Mon Sep 17 00:00:00 2001 From: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> Date: Wed, 29 Dec 2021 17:37:15 +0530 Subject: [PATCH] Hook `UpstreamConnectionPool` lifecycle within `Threadless` (#921) * Hook connection pool lifecycle within threadless * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix test * Fix spell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- proxy/core/acceptor/threadless.py | 80 +++++++++++++++++++++++-------- proxy/core/connection/pool.py | 78 ++++++++++++++++++++---------- proxy/core/connection/server.py | 10 ++-- proxy/http/handler.py | 11 ++--- proxy/http/parser/parser.py | 2 +- proxy/http/proxy/server.py | 14 +++--- proxy/plugin/proxy_pool.py | 2 +- tests/core/test_conn_pool.py | 15 +++--- tests/core/test_connection.py | 5 +- 9 files changed, 142 insertions(+), 75 deletions(-) diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 9858712a..fa107216 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -88,6 +88,7 @@ class Threadless(ABC, Generic[T]): self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT self._total: int = 0 self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None + self._upstream_conn_filenos: Set[int] = set() if self.flags.enable_conn_pool: self._upstream_conn_pool = UpstreamConnectionPool() @@ -176,14 +177,25 @@ class Threadless(ABC, Generic[T]): data=work_id, ) self.registered_events_by_work_ids[work_id][fileno] = mask - # logger.debug( - # 'fd#{0} modified for mask#{1} by work#{2}'.format( - # fileno, mask, work_id, - # ), - # ) + logger.debug( + 'fd#{0} modified for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) # else: # logger.info( # 'fd#{0} by work#{1} not modified'.format(fileno, work_id)) + elif fileno in self._upstream_conn_filenos: + # Descriptor offered by work, but is already registered by connection pool + # Most likely because work has acquired a reusable connection. + self.selector.modify(fileno, events=mask, data=work_id) + self.registered_events_by_work_ids[work_id][fileno] = mask + self._upstream_conn_filenos.remove(fileno) + logger.debug( + 'fd#{0} borrowed with mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) # Can throw ValueError: Invalid file descriptor: -1 # # A guard within Work classes may not help here due to @@ -193,16 +205,33 @@ class Threadless(ABC, Generic[T]): # # TODO: Also remove offending work from pool to avoid spin loop. elif fileno != -1: - self.selector.register( - fileno, events=mask, - data=work_id, - ) + self.selector.register(fileno, events=mask, data=work_id) self.registered_events_by_work_ids[work_id][fileno] = mask - # logger.debug( - # 'fd#{0} registered for mask#{1} by work#{2}'.format( - # fileno, mask, work_id, - # ), - # ) + logger.debug( + 'fd#{0} registered for mask#{1} by work#{2}'.format( + fileno, mask, work_id, + ), + ) + + async def _update_conn_pool_events(self) -> None: + if not self._upstream_conn_pool: + return + assert self.selector is not None + new_conn_pool_events = await self._upstream_conn_pool.get_events() + old_conn_pool_filenos = self._upstream_conn_filenos.copy() + self._upstream_conn_filenos.clear() + new_conn_pool_filenos = set(new_conn_pool_events.keys()) + new_conn_pool_filenos.difference_update(old_conn_pool_filenos) + for fileno in new_conn_pool_filenos: + self.selector.register( + fileno, + events=new_conn_pool_events[fileno], + data=0, + ) + self._upstream_conn_filenos.add(fileno) + old_conn_pool_filenos.difference_update(self._upstream_conn_filenos) + for fileno in old_conn_pool_filenos: + self.selector.unregister(fileno) async def _update_selector(self) -> None: assert self.selector is not None @@ -215,6 +244,7 @@ class Threadless(ABC, Generic[T]): if work_id in unfinished_work_ids: continue await self._update_work_events(work_id) + await self._update_conn_pool_events() async def _selected_events(self) -> Tuple[ Dict[int, Tuple[Readables, Writables]], @@ -235,9 +265,6 @@ class Threadless(ABC, Generic[T]): """ assert self.selector is not None await self._update_selector() - events = self.selector.select( - timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT, - ) # Keys are work_id and values are 2-tuple indicating # readables & writables that work_id is interested in # and are ready for IO. @@ -248,6 +275,11 @@ class Threadless(ABC, Generic[T]): # When ``work_queue_fileno`` returns None, # always return True for the boolean value. new_work_available = True + + events = self.selector.select( + timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT, + ) + for key, mask in events: if not new_work_available and wqfileno is not None and key.fileobj == wqfileno: assert mask & selectors.EVENT_READ @@ -302,9 +334,17 @@ class Threadless(ABC, Generic[T]): assert self.loop tasks: Set['asyncio.Task[bool]'] = set() for work_id in work_by_ids: - task = self.loop.create_task( - self.works[work_id].handle_events(*work_by_ids[work_id]), - ) + if work_id == 0: + assert self._upstream_conn_pool + task = self.loop.create_task( + self._upstream_conn_pool.handle_events( + *work_by_ids[work_id], + ), + ) + else: + task = self.loop.create_task( + self.works[work_id].handle_events(*work_by_ids[work_id]), + ) task._work_id = work_id # type: ignore[attr-defined] # task.set_name(work_id) tasks.add(task) diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index a9b0585a..e51ce3fb 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -12,6 +12,7 @@ reusability """ +import socket import logging import selectors @@ -45,11 +46,19 @@ class UpstreamConnectionPool(Work[TcpServerConnection]): A separate pool is maintained for each upstream server. So internally, it's a pool of pools. - TODO: Listen for read events from the connections - to remove them from the pool when peer closes the - connection. This can also be achieved lazily by - the pool users. Example, if acquired connection - is stale, reacquire. + Internal data structure maintains references to connection objects + that pool owns or has borrowed. Borrowed connections are marked as + NOT reusable. + + For reusable connections only, pool listens for read events + to detect broken connections. This can happen if pool has opened + a connection, which was never used and eventually reaches + upstream server timeout limit. + + When a borrowed connection is returned back to the pool, + the connection is marked as reusable again. However, if + returned connection has already been closed, it is removed + from the internal data structure. TODO: Ideally, `UpstreamConnectionPool` must be shared across all cores to make SSL session cache to also work @@ -60,29 +69,25 @@ class UpstreamConnectionPool(Work[TcpServerConnection]): session cache, session ticket, abbr TLS handshake and other necessary features to make it work. - NOTE: However, for all HTTP only connections, `UpstreamConnectionPool` - can be used to save upon connection setup time and - speed-up performance of requests. + NOTE: However, currently for all HTTP only upstream connections, + `UpstreamConnectionPool` can be used to remove slow starts. """ def __init__(self) -> None: - # Pools of connection per upstream server self.connections: Dict[int, TcpServerConnection] = {} self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {} def add(self, addr: Tuple[str, int]) -> TcpServerConnection: - # Create new connection + """Creates and add a new connection to the pool.""" new_conn = TcpServerConnection(addr[0], addr[1]) new_conn.connect() - if addr not in self.pools: - self.pools[addr] = set() - self.pools[addr].add(new_conn) - self.connections[new_conn.connection.fileno()] = new_conn + self._add(new_conn) return new_conn def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]: - """Returns a connection for use with the server.""" - # Return a reusable connection if available + """Returns a reusable connection from the pool. + + If none exists, will create and return a new connection.""" if addr in self.pools: for old_conn in self.pools[addr]: if old_conn.is_reusable(): @@ -102,40 +107,63 @@ class UpstreamConnectionPool(Work[TcpServerConnection]): return True, new_conn def release(self, conn: TcpServerConnection) -> None: - """Release the connection. + """Release a previously acquired connection. If the connection has not been closed, then it will be retained in the pool for reusability. """ + assert not conn.is_reusable() if conn.closed: logger.debug( 'Removing connection#{2} from pool from upstream {0}:{1}'.format( conn.addr[0], conn.addr[1], id(conn), ), ) - self.pools[conn.addr].remove(conn) + self._remove(conn.connection.fileno()) else: logger.debug( 'Retaining connection#{2} to upstream {0}:{1}'.format( conn.addr[0], conn.addr[1], id(conn), ), ) - assert not conn.is_reusable() # Reset for reusability conn.reset() async def get_events(self) -> Dict[int, int]: + """Returns read event flag for all reusable connections in the pool.""" events = {} for connections in self.pools.values(): for conn in connections: - events[conn.connection.fileno()] = selectors.EVENT_READ + if conn.is_reusable(): + events[conn.connection.fileno()] = selectors.EVENT_READ return events async def handle_events(self, readables: Readables, _writables: Writables) -> bool: - for r in readables: + """Removes reusable connection from the pool. + + When pool is the owner of connection, we don't expect a read event from upstream + server. A read event means either upstream closed the connection or connection + has somehow reached an illegal state e.g. upstream sending data for previous + connection acquisition lifecycle.""" + for fileno in readables: if TYPE_CHECKING: - assert isinstance(r, int) - conn = self.connections[r] - self.pools[conn.addr].remove(conn) - del self.connections[r] + assert isinstance(fileno, int) + logger.debug('Upstream fd#{0} is read ready'.format(fileno)) + self._remove(fileno) return False + + def _add(self, conn: TcpServerConnection) -> None: + """Adds a new connection to internal data structure.""" + if conn.addr not in self.pools: + self.pools[conn.addr] = set() + self.pools[conn.addr].add(conn) + self.connections[conn.connection.fileno()] = conn + + def _remove(self, fileno: int) -> None: + """Remove a connection by descriptor from the internal data structure.""" + conn = self.connections[fileno] + logger.debug('Removing conn#{0} from pool'.format(id(conn))) + conn.connection.shutdown(socket.SHUT_WR) + conn.close() + self.pools[conn.addr].remove(conn) + del self.connections[fileno] diff --git a/proxy/core/connection/server.py b/proxy/core/connection/server.py index 2c2af73f..109c238e 100644 --- a/proxy/core/connection/server.py +++ b/proxy/core/connection/server.py @@ -39,11 +39,11 @@ class TcpServerConnection(TcpConnection): addr: Optional[Tuple[str, int]] = None, source_address: Optional[Tuple[str, int]] = None, ) -> None: - if self._conn is None: - self._conn = new_socket_connection( - addr or self.addr, source_address=source_address, - ) - self.closed = False + assert self._conn is None + self._conn = new_socket_connection( + addr or self.addr, source_address=source_address, + ) + self.closed = False def wrap(self, hostname: str, ca_file: Optional[str]) -> None: ctx = ssl.create_default_context( diff --git a/proxy/http/handler.py b/proxy/http/handler.py index ae6c0d66..f46f5944 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -103,7 +103,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): self.upstream_conn_pool, ) self.plugins[instance.name()] = instance - logger.debug('Handling connection %r' % self.work.connection) + logger.debug('Handling connection %s' % self.work.address) def is_inactive(self) -> bool: if not self.work.has_buffer() and \ @@ -123,9 +123,8 @@ class HttpProtocolHandler(BaseTcpServerHandler): for plugin in self.plugins.values(): plugin.on_client_connection_close() logger.debug( - 'Closing client connection %r ' - 'at address %s has buffer %s' % - (self.work.connection, self.work.address, self.work.has_buffer()), + 'Closing client connection %s has buffer %s' % + (self.work.address, self.work.has_buffer()), ) conn = self.work.connection # Unwrap if wrapped before shutdown. @@ -247,7 +246,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): async def handle_writables(self, writables: Writables) -> bool: if self.work.connection.fileno() in writables and self.work.has_buffer(): - logger.debug('Client is ready for writes, flushing buffer') + logger.debug('Client is write ready, flushing...') self.last_activity = time.time() # TODO(abhinavsingh): This hook could just reside within server recv block @@ -277,7 +276,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): async def handle_readables(self, readables: Readables) -> bool: if self.work.connection.fileno() in readables: - logger.debug('Client is ready for reads, reading') + logger.debug('Client is read ready, receiving...') self.last_activity = time.time() try: teardown = await super().handle_readables(readables) diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py index 5ed70724..197923ca 100644 --- a/proxy/http/parser/parser.py +++ b/proxy/http/parser/parser.py @@ -77,7 +77,7 @@ class HttpParser: self.total_size: int = 0 # Buffer to hold unprocessed bytes self.buffer: bytes = b'' - # Internal headers datastructure: + # Internal headers data structure: # - Keys are lower case header names. # - Values are 2-tuple containing original # header and it's value as received. diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index fdb74d71..e6f66e12 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -217,7 +217,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): self.upstream and not self.upstream.closed and \ self.upstream.has_buffer() and \ self.upstream.connection.fileno() in w: - logger.debug('Server is write ready, flushing buffer') + logger.debug('Server is write ready, flushing...') try: self.upstream.flush() except ssl.SSLWantWriteError: @@ -254,7 +254,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): and self.upstream \ and not self.upstream.closed \ and self.upstream.connection.fileno() in r: - logger.debug('Server is ready for reads, reading...') + logger.debug('Server is read ready, receiving...') try: raw = self.upstream.recv(self.flags.server_recvbuf_size) except TimeoutError as e: @@ -401,7 +401,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): pass finally: # TODO: Unwrap if wrapped before close? - self.upstream.connection.close() + self.upstream.close() except TcpConnectionUninitializedException: pass finally: @@ -587,10 +587,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): host, port = self.request.host, self.request.port if host and port: try: - logger.debug( - 'Connecting to upstream %s:%d' % - (text_(host), port), - ) # Invoke plugin.resolve_dns upstream_ip, source_addr = None, None for plugin in self.plugins.values(): @@ -599,6 +595,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): ) if upstream_ip or source_addr: break + logger.debug( + 'Connecting to upstream %s:%d' % + (text_(host), port), + ) if self.flags.enable_conn_pool: assert self.upstream_conn_pool with self.lock: diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index 641d95d6..1ae13ed1 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -121,7 +121,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin): # # Failing upstream proxies, must be removed from the pool temporarily. # A periodic health check must put them back in the pool. This can be achieved - # using a datastructure without having to spawn separate thread/process for health + # using a data structure without having to spawn separate thread/process for health # check. raise HttpProtocolException( 'Connection refused by upstream proxy {0}:{1}'.format( diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index e00436cd..ae2d2c71 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -23,9 +23,9 @@ class TestConnectionPool(unittest.TestCase): @mock.patch('proxy.core.connection.pool.TcpServerConnection') def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None: pool = UpstreamConnectionPool() - addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value + addr = mock_conn.addr mock_conn.is_reusable.side_effect = [ False, True, True, ] @@ -33,7 +33,7 @@ class TestConnectionPool(unittest.TestCase): # Acquire created, conn = pool.acquire(addr) self.assertTrue(created) - mock_tcp_server_connection.assert_called_once_with(*addr) + mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1]) self.assertEqual(conn, mock_conn) self.assertEqual(len(pool.pools[addr]), 1) self.assertTrue(conn in pool.pools[addr]) @@ -54,26 +54,25 @@ class TestConnectionPool(unittest.TestCase): self, mock_tcp_server_connection: mock.Mock, ) -> None: pool = UpstreamConnectionPool() - addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value mock_conn.closed = True - mock_conn.addr = addr + addr = mock_conn.addr # Acquire created, conn = pool.acquire(addr) self.assertTrue(created) - mock_tcp_server_connection.assert_called_once_with(*addr) + mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1]) self.assertEqual(conn, mock_conn) self.assertEqual(len(pool.pools[addr]), 1) self.assertTrue(conn in pool.pools[addr]) # Release + mock_conn.is_reusable.return_value = False pool.release(conn) self.assertEqual(len(pool.pools[addr]), 0) # Acquire created, conn = pool.acquire(addr) self.assertTrue(created) self.assertEqual(mock_tcp_server_connection.call_count, 2) - mock_conn.is_reusable.assert_not_called() class TestConnectionPoolAsync: @@ -84,10 +83,10 @@ class TestConnectionPoolAsync: 'proxy.core.connection.pool.TcpServerConnection', ) pool = UpstreamConnectionPool() - addr = ('localhost', 1234) mock_conn = mock_tcp_server_connection.return_value + addr = mock_conn.addr pool.add(addr) - mock_tcp_server_connection.assert_called_once_with(*addr) + mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1]) mock_conn.connect.assert_called_once() events = await pool.get_events() print(events) diff --git a/tests/core/test_connection.py b/tests/core/test_connection.py index 905ab56d..95bc0006 100644 --- a/tests/core/test_connection.py +++ b/tests/core/test_connection.py @@ -79,7 +79,7 @@ class TestTcpConnection(unittest.TestCase): ) @mock.patch('proxy.core.connection.server.new_socket_connection') - def testTcpServerIgnoresDoubleConnectSilently( + def testTcpServerWillNotIgnoreDoubleConnectAttemptsSilently( self, mock_new_socket_connection: mock.Mock, ) -> None: @@ -87,7 +87,8 @@ class TestTcpConnection(unittest.TestCase): str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT, ) conn.connect() - conn.connect() + with self.assertRaises(AssertionError): + conn.connect() mock_new_socket_connection.assert_called_once() @mock.patch('socket.socket')