From f6214a46c9822e47acbf0bd1dc94afc02036eb26 Mon Sep 17 00:00:00 2001 From: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> Date: Tue, 28 Dec 2021 13:51:20 +0530 Subject: [PATCH] Move `UpstreamConnectionPool` lifecycle within `Threadless` (#917) * Tie connection pool into Threadless * Pass upstream conn pool reference to work instances * Mark upstream conn pool as optional * spellcheck * Fix unused import --- docs/conf.py | 1 + proxy/core/acceptor/executors.py | 1 + proxy/core/acceptor/threadless.py | 6 +++++- proxy/core/acceptor/work.py | 7 ++++++- proxy/core/connection/__init__.py | 4 ++-- proxy/core/connection/pool.py | 19 ++++++++++++++----- proxy/http/handler.py | 1 + proxy/http/plugin.py | 7 ++++++- proxy/http/proxy/server.py | 18 +++++++++--------- proxy/plugin/proxy_pool.py | 2 +- tests/core/test_acceptor.py | 1 + tests/core/test_conn_pool.py | 6 +++--- 12 files changed, 50 insertions(+), 23 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8172d57f..402e263b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -300,6 +300,7 @@ nitpick_ignore = [ (_py_class_role, 'unittest.case.TestCase'), (_py_class_role, 'unittest.result.TestResult'), (_py_class_role, 'UUID'), + (_py_class_role, 'UpstreamConnectionPool'), (_py_class_role, 'Url'), (_py_class_role, 'WebsocketFrame'), (_py_class_role, 'Work'), diff --git a/proxy/core/acceptor/executors.py b/proxy/core/acceptor/executors.py index 95127626..d0c5a912 100644 --- a/proxy/core/acceptor/executors.py +++ b/proxy/core/acceptor/executors.py @@ -131,6 +131,7 @@ class ThreadlessPool: TcpClientConnection(conn, addr), flags=flags, event_queue=event_queue, + upstream_conn_pool=None, ) # TODO: Keep reference to threads and join during shutdown. # This will ensure connections are not abruptly closed on shutdown diff --git a/proxy/core/acceptor/threadless.py b/proxy/core/acceptor/threadless.py index 5140c934..9858712a 100644 --- a/proxy/core/acceptor/threadless.py +++ b/proxy/core/acceptor/threadless.py @@ -25,7 +25,7 @@ from ...common.types import Readables, Writables from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT -from ..connection import TcpClientConnection +from ..connection import TcpClientConnection, UpstreamConnectionPool from ..event import eventNames, EventQueue from .work import Work @@ -87,6 +87,9 @@ class Threadless(ABC, Generic[T]): self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT self._total: int = 0 + self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None + if self.flags.enable_conn_pool: + self._upstream_conn_pool = UpstreamConnectionPool() @property @abstractmethod @@ -134,6 +137,7 @@ class Threadless(ABC, Generic[T]): flags=self.flags, event_queue=self.event_queue, uid=uid, + upstream_conn_pool=self._upstream_conn_pool, ) self.works[fileno].publish_event( event_name=eventNames.WORK_STARTED, diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py index 5a7ba072..37f0bf59 100644 --- a/proxy/core/acceptor/work.py +++ b/proxy/core/acceptor/work.py @@ -16,11 +16,14 @@ import argparse from abc import ABC, abstractmethod from uuid import uuid4 -from typing import Optional, Dict, Any, TypeVar, Generic +from typing import Optional, Dict, Any, TypeVar, Generic, TYPE_CHECKING from ..event import eventNames, EventQueue from ...common.types import Readables, Writables +if TYPE_CHECKING: + from ..connection import UpstreamConnectionPool + T = TypeVar('T') @@ -33,6 +36,7 @@ class Work(ABC, Generic[T]): flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, uid: Optional[str] = None, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, ) -> None: # Work uuid self.uid: str = uid if uid is not None else uuid4().hex @@ -41,6 +45,7 @@ class Work(ABC, Generic[T]): self.event_queue = event_queue # Accept work self.work = work + self.upstream_conn_pool = upstream_conn_pool @abstractmethod async def get_events(self) -> Dict[int, int]: diff --git a/proxy/core/connection/__init__.py b/proxy/core/connection/__init__.py index 952ee08f..58d100a8 100644 --- a/proxy/core/connection/__init__.py +++ b/proxy/core/connection/__init__.py @@ -16,7 +16,7 @@ from .connection import TcpConnection, TcpConnectionUninitializedException from .client import TcpClientConnection from .server import TcpServerConnection -from .pool import ConnectionPool +from .pool import UpstreamConnectionPool from .types import tcpConnectionTypes __all__ = [ @@ -25,5 +25,5 @@ __all__ = [ 'TcpServerConnection', 'TcpClientConnection', 'tcpConnectionTypes', - 'ConnectionPool', + 'UpstreamConnectionPool', ] diff --git a/proxy/core/connection/pool.py b/proxy/core/connection/pool.py index 16cd5096..5f92066b 100644 --- a/proxy/core/connection/pool.py +++ b/proxy/core/connection/pool.py @@ -17,6 +17,9 @@ import logging from typing import Set, Dict, Tuple from ...common.flag import flags +from ...common.types import Readables, Writables + +from ..acceptor.work import Work from .server import TcpServerConnection @@ -31,10 +34,10 @@ flags.add_argument( ) -class ConnectionPool: +class UpstreamConnectionPool(Work[TcpServerConnection]): """Manages connection pool to upstream servers. - `ConnectionPool` avoids need to reconnect with the upstream + `UpstreamConnectionPool` avoids need to reconnect with the upstream servers repeatedly when a reusable connection is available in the pool. @@ -47,16 +50,16 @@ class ConnectionPool: the pool users. Example, if acquired connection is stale, reacquire. - TODO: Ideally, ConnectionPool must be shared across + TODO: Ideally, `UpstreamConnectionPool` must be shared across all cores to make SSL session cache to also work without additional out-of-bound synchronizations. - TODO: ConnectionPool currently WON'T work for + TODO: `UpstreamConnectionPool` currently WON'T work for HTTPS connection. This is because of missing support for session cache, session ticket, abbr TLS handshake and other necessary features to make it work. - NOTE: However, for all HTTP only connections, ConnectionPool + NOTE: However, for all HTTP only connections, `UpstreamConnectionPool` can be used to save upon connection setup time and speed-up performance of requests. """ @@ -113,3 +116,9 @@ class ConnectionPool: assert not conn.is_reusable() # Reset for reusability conn.reset() + + async def get_events(self) -> Dict[int, int]: + return await super().get_events() + + async def handle_events(self, readables: Readables, writables: Writables) -> bool: + return await super().handle_events(readables, writables) diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 38429df7..ae6c0d66 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -100,6 +100,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): self.work, self.request, self.event_queue, + self.upstream_conn_pool, ) self.plugins[instance.name()] = instance logger.debug('Handling connection %r' % self.work.connection) diff --git a/proxy/http/plugin.py b/proxy/http/plugin.py index eafcd053..0180e5f9 100644 --- a/proxy/http/plugin.py +++ b/proxy/http/plugin.py @@ -12,7 +12,7 @@ import socket import argparse from abc import ABC, abstractmethod -from typing import Tuple, List, Union, Optional +from typing import Tuple, List, Union, Optional, TYPE_CHECKING from .parser import HttpParser @@ -20,6 +20,9 @@ from ..common.types import Readables, Writables from ..core.event import EventQueue from ..core.connection import TcpClientConnection +if TYPE_CHECKING: + from ..core.connection import UpstreamConnectionPool + class HttpProtocolHandlerPlugin(ABC): """Base HttpProtocolHandler Plugin class. @@ -50,12 +53,14 @@ class HttpProtocolHandlerPlugin(ABC): client: TcpClientConnection, request: HttpParser, event_queue: EventQueue, + upstream_conn_pool: Optional['UpstreamConnectionPool'] = None, ): self.uid: str = uid self.flags: argparse.Namespace = flags self.client: TcpClientConnection = client self.request: HttpParser = request self.event_queue = event_queue + self.upstream_conn_pool = upstream_conn_pool super().__init__() def name(self) -> str: diff --git a/proxy/http/proxy/server.py b/proxy/http/proxy/server.py index 55604abf..228d8ecc 100644 --- a/proxy/http/proxy/server.py +++ b/proxy/http/proxy/server.py @@ -44,7 +44,7 @@ from ...common.utils import text_ from ...common.pki import gen_public_key, gen_csr, sign_csr from ...core.event import eventNames -from ...core.connection import TcpServerConnection, ConnectionPool +from ...core.connection import TcpServerConnection from ...core.connection import TcpConnectionUninitializedException from ...common.flag import flags @@ -140,9 +140,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): # connection pool operations. lock = threading.Lock() - # Shared connection pool - pool = ConnectionPool() - def __init__( self, *args: Any, **kwargs: Any, @@ -200,10 +197,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): def _close_and_release(self) -> bool: if self.flags.enable_conn_pool: - assert self.upstream and not self.upstream.closed + assert self.upstream and not self.upstream.closed and self.upstream_conn_pool self.upstream.closed = True with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) self.upstream = None return True @@ -391,9 +388,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): return if self.flags.enable_conn_pool: + assert self.upstream_conn_pool # Release the connection for reusability with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) return try: @@ -589,8 +587,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): host, port = self.request.host, self.request.port if host and port: if self.flags.enable_conn_pool: + assert self.upstream_conn_pool with self.lock: - created, self.upstream = self.pool.acquire( + created, self.upstream = self.upstream_conn_pool.acquire( text_(host), port, ) else: @@ -642,8 +641,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin): ), ) if self.flags.enable_conn_pool: + assert self.upstream_conn_pool with self.lock: - self.pool.release(self.upstream) + self.upstream_conn_pool.release(self.upstream) raise ProxyConnectionFailed( text_(host), port, repr(e), ) from e diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index cfc80178..641d95d6 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -88,7 +88,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin): must be bootstrapped within it's own re-usable and garbage collected pool, to avoid establishing a new upstream proxy connection for each client request. - See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work + See :class:`~proxy.core.connection.pool.UpstreamConnectionPool` which is a work in progress for SSL cache handling. """ # We don't want to send private IP requests to remote proxies diff --git a/tests/core/test_acceptor.py b/tests/core/test_acceptor.py index 2a4d0898..89bbce46 100644 --- a/tests/core/test_acceptor.py +++ b/tests/core/test_acceptor.py @@ -101,6 +101,7 @@ class TestAcceptor(unittest.TestCase): mock_client.return_value, flags=self.flags, event_queue=None, + upstream_conn_pool=None, ) mock_thread.assert_called_with( target=self.flags.work_klass.return_value.run, diff --git a/tests/core/test_conn_pool.py b/tests/core/test_conn_pool.py index 3eaad052..db3de3d7 100644 --- a/tests/core/test_conn_pool.py +++ b/tests/core/test_conn_pool.py @@ -12,14 +12,14 @@ import unittest from unittest import mock -from proxy.core.connection import ConnectionPool +from proxy.core.connection import UpstreamConnectionPool class TestConnectionPool(unittest.TestCase): @mock.patch('proxy.core.connection.pool.TcpServerConnection') def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None: - pool = ConnectionPool() + pool = UpstreamConnectionPool() addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value @@ -50,7 +50,7 @@ class TestConnectionPool(unittest.TestCase): def test_closed_connections_are_removed_on_release( self, mock_tcp_server_connection: mock.Mock, ) -> None: - pool = ConnectionPool() + pool = UpstreamConnectionPool() addr = ('localhost', 1234) # Mock mock_conn = mock_tcp_server_connection.return_value