Move `UpstreamConnectionPool` lifecycle within `Threadless` (#917)

* Tie connection pool into Threadless

* Pass upstream conn pool reference to work instances

* Mark upstream conn pool as optional

* spellcheck

* Fix unused import
This commit is contained in:
Abhinav Singh 2021-12-28 13:51:20 +05:30 committed by GitHub
parent ea66280827
commit f6214a46c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 50 additions and 23 deletions

View File

@ -300,6 +300,7 @@ nitpick_ignore = [
(_py_class_role, 'unittest.case.TestCase'),
(_py_class_role, 'unittest.result.TestResult'),
(_py_class_role, 'UUID'),
(_py_class_role, 'UpstreamConnectionPool'),
(_py_class_role, 'Url'),
(_py_class_role, 'WebsocketFrame'),
(_py_class_role, 'Work'),

View File

@ -131,6 +131,7 @@ class ThreadlessPool:
TcpClientConnection(conn, addr),
flags=flags,
event_queue=event_queue,
upstream_conn_pool=None,
)
# TODO: Keep reference to threads and join during shutdown.
# This will ensure connections are not abruptly closed on shutdown

View File

@ -25,7 +25,7 @@ from ...common.types import Readables, Writables
from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT
from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT
from ..connection import TcpClientConnection
from ..connection import TcpClientConnection, UpstreamConnectionPool
from ..event import eventNames, EventQueue
from .work import Work
@ -87,6 +87,9 @@ class Threadless(ABC, Generic[T]):
self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
self._total: int = 0
self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None
if self.flags.enable_conn_pool:
self._upstream_conn_pool = UpstreamConnectionPool()
@property
@abstractmethod
@ -134,6 +137,7 @@ class Threadless(ABC, Generic[T]):
flags=self.flags,
event_queue=self.event_queue,
uid=uid,
upstream_conn_pool=self._upstream_conn_pool,
)
self.works[fileno].publish_event(
event_name=eventNames.WORK_STARTED,

View File

@ -16,11 +16,14 @@ import argparse
from abc import ABC, abstractmethod
from uuid import uuid4
from typing import Optional, Dict, Any, TypeVar, Generic
from typing import Optional, Dict, Any, TypeVar, Generic, TYPE_CHECKING
from ..event import eventNames, EventQueue
from ...common.types import Readables, Writables
if TYPE_CHECKING:
from ..connection import UpstreamConnectionPool
T = TypeVar('T')
@ -33,6 +36,7 @@ class Work(ABC, Generic[T]):
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[str] = None,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
) -> None:
# Work uuid
self.uid: str = uid if uid is not None else uuid4().hex
@ -41,6 +45,7 @@ class Work(ABC, Generic[T]):
self.event_queue = event_queue
# Accept work
self.work = work
self.upstream_conn_pool = upstream_conn_pool
@abstractmethod
async def get_events(self) -> Dict[int, int]:

View File

@ -16,7 +16,7 @@
from .connection import TcpConnection, TcpConnectionUninitializedException
from .client import TcpClientConnection
from .server import TcpServerConnection
from .pool import ConnectionPool
from .pool import UpstreamConnectionPool
from .types import tcpConnectionTypes
__all__ = [
@ -25,5 +25,5 @@ __all__ = [
'TcpServerConnection',
'TcpClientConnection',
'tcpConnectionTypes',
'ConnectionPool',
'UpstreamConnectionPool',
]

View File

@ -17,6 +17,9 @@ import logging
from typing import Set, Dict, Tuple
from ...common.flag import flags
from ...common.types import Readables, Writables
from ..acceptor.work import Work
from .server import TcpServerConnection
@ -31,10 +34,10 @@ flags.add_argument(
)
class ConnectionPool:
class UpstreamConnectionPool(Work[TcpServerConnection]):
"""Manages connection pool to upstream servers.
`ConnectionPool` avoids need to reconnect with the upstream
`UpstreamConnectionPool` avoids need to reconnect with the upstream
servers repeatedly when a reusable connection is available
in the pool.
@ -47,16 +50,16 @@ class ConnectionPool:
the pool users. Example, if acquired connection
is stale, reacquire.
TODO: Ideally, ConnectionPool must be shared across
TODO: Ideally, `UpstreamConnectionPool` must be shared across
all cores to make SSL session cache to also work
without additional out-of-bound synchronizations.
TODO: ConnectionPool currently WON'T work for
TODO: `UpstreamConnectionPool` currently WON'T work for
HTTPS connection. This is because of missing support for
session cache, session ticket, abbr TLS handshake
and other necessary features to make it work.
NOTE: However, for all HTTP only connections, ConnectionPool
NOTE: However, for all HTTP only connections, `UpstreamConnectionPool`
can be used to save upon connection setup time and
speed-up performance of requests.
"""
@ -113,3 +116,9 @@ class ConnectionPool:
assert not conn.is_reusable()
# Reset for reusability
conn.reset()
async def get_events(self) -> Dict[int, int]:
return await super().get_events()
async def handle_events(self, readables: Readables, writables: Writables) -> bool:
return await super().handle_events(readables, writables)

View File

@ -100,6 +100,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
self.work,
self.request,
self.event_queue,
self.upstream_conn_pool,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.work.connection)

View File

@ -12,7 +12,7 @@ import socket
import argparse
from abc import ABC, abstractmethod
from typing import Tuple, List, Union, Optional
from typing import Tuple, List, Union, Optional, TYPE_CHECKING
from .parser import HttpParser
@ -20,6 +20,9 @@ from ..common.types import Readables, Writables
from ..core.event import EventQueue
from ..core.connection import TcpClientConnection
if TYPE_CHECKING:
from ..core.connection import UpstreamConnectionPool
class HttpProtocolHandlerPlugin(ABC):
"""Base HttpProtocolHandler Plugin class.
@ -50,12 +53,14 @@ class HttpProtocolHandlerPlugin(ABC):
client: TcpClientConnection,
request: HttpParser,
event_queue: EventQueue,
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
):
self.uid: str = uid
self.flags: argparse.Namespace = flags
self.client: TcpClientConnection = client
self.request: HttpParser = request
self.event_queue = event_queue
self.upstream_conn_pool = upstream_conn_pool
super().__init__()
def name(self) -> str:

View File

@ -44,7 +44,7 @@ from ...common.utils import text_
from ...common.pki import gen_public_key, gen_csr, sign_csr
from ...core.event import eventNames
from ...core.connection import TcpServerConnection, ConnectionPool
from ...core.connection import TcpServerConnection
from ...core.connection import TcpConnectionUninitializedException
from ...common.flag import flags
@ -140,9 +140,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
# connection pool operations.
lock = threading.Lock()
# Shared connection pool
pool = ConnectionPool()
def __init__(
self,
*args: Any, **kwargs: Any,
@ -200,10 +197,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
def _close_and_release(self) -> bool:
if self.flags.enable_conn_pool:
assert self.upstream and not self.upstream.closed
assert self.upstream and not self.upstream.closed and self.upstream_conn_pool
self.upstream.closed = True
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
self.upstream = None
return True
@ -391,9 +388,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
return
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
# Release the connection for reusability
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
return
try:
@ -589,8 +587,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
host, port = self.request.host, self.request.port
if host and port:
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
created, self.upstream = self.pool.acquire(
created, self.upstream = self.upstream_conn_pool.acquire(
text_(host), port,
)
else:
@ -642,8 +641,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
),
)
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:
self.pool.release(self.upstream)
self.upstream_conn_pool.release(self.upstream)
raise ProxyConnectionFailed(
text_(host), port, repr(e),
) from e

View File

@ -88,7 +88,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin):
must be bootstrapped within it's own re-usable and garbage collected pool,
to avoid establishing a new upstream proxy connection for each client request.
See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
See :class:`~proxy.core.connection.pool.UpstreamConnectionPool` which is a work
in progress for SSL cache handling.
"""
# We don't want to send private IP requests to remote proxies

View File

@ -101,6 +101,7 @@ class TestAcceptor(unittest.TestCase):
mock_client.return_value,
flags=self.flags,
event_queue=None,
upstream_conn_pool=None,
)
mock_thread.assert_called_with(
target=self.flags.work_klass.return_value.run,

View File

@ -12,14 +12,14 @@ import unittest
from unittest import mock
from proxy.core.connection import ConnectionPool
from proxy.core.connection import UpstreamConnectionPool
class TestConnectionPool(unittest.TestCase):
@mock.patch('proxy.core.connection.pool.TcpServerConnection')
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
pool = ConnectionPool()
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
@ -50,7 +50,7 @@ class TestConnectionPool(unittest.TestCase):
def test_closed_connections_are_removed_on_release(
self, mock_tcp_server_connection: mock.Mock,
) -> None:
pool = ConnectionPool()
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value