Move `UpstreamConnectionPool` lifecycle within `Threadless` (#917)
* Tie connection pool into Threadless * Pass upstream conn pool reference to work instances * Mark upstream conn pool as optional * spellcheck * Fix unused import
This commit is contained in:
parent
ea66280827
commit
f6214a46c9
|
@ -300,6 +300,7 @@ nitpick_ignore = [
|
|||
(_py_class_role, 'unittest.case.TestCase'),
|
||||
(_py_class_role, 'unittest.result.TestResult'),
|
||||
(_py_class_role, 'UUID'),
|
||||
(_py_class_role, 'UpstreamConnectionPool'),
|
||||
(_py_class_role, 'Url'),
|
||||
(_py_class_role, 'WebsocketFrame'),
|
||||
(_py_class_role, 'Work'),
|
||||
|
|
|
@ -131,6 +131,7 @@ class ThreadlessPool:
|
|||
TcpClientConnection(conn, addr),
|
||||
flags=flags,
|
||||
event_queue=event_queue,
|
||||
upstream_conn_pool=None,
|
||||
)
|
||||
# TODO: Keep reference to threads and join during shutdown.
|
||||
# This will ensure connections are not abruptly closed on shutdown
|
||||
|
|
|
@ -25,7 +25,7 @@ from ...common.types import Readables, Writables
|
|||
from ...common.constants import DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT
|
||||
from ...common.constants import DEFAULT_WAIT_FOR_TASKS_TIMEOUT
|
||||
|
||||
from ..connection import TcpClientConnection
|
||||
from ..connection import TcpClientConnection, UpstreamConnectionPool
|
||||
from ..event import eventNames, EventQueue
|
||||
|
||||
from .work import Work
|
||||
|
@ -87,6 +87,9 @@ class Threadless(ABC, Generic[T]):
|
|||
self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT
|
||||
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
|
||||
self._total: int = 0
|
||||
self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None
|
||||
if self.flags.enable_conn_pool:
|
||||
self._upstream_conn_pool = UpstreamConnectionPool()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -134,6 +137,7 @@ class Threadless(ABC, Generic[T]):
|
|||
flags=self.flags,
|
||||
event_queue=self.event_queue,
|
||||
uid=uid,
|
||||
upstream_conn_pool=self._upstream_conn_pool,
|
||||
)
|
||||
self.works[fileno].publish_event(
|
||||
event_name=eventNames.WORK_STARTED,
|
||||
|
|
|
@ -16,11 +16,14 @@ import argparse
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import uuid4
|
||||
from typing import Optional, Dict, Any, TypeVar, Generic
|
||||
from typing import Optional, Dict, Any, TypeVar, Generic, TYPE_CHECKING
|
||||
|
||||
from ..event import eventNames, EventQueue
|
||||
from ...common.types import Readables, Writables
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..connection import UpstreamConnectionPool
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
|
@ -33,6 +36,7 @@ class Work(ABC, Generic[T]):
|
|||
flags: argparse.Namespace,
|
||||
event_queue: Optional[EventQueue] = None,
|
||||
uid: Optional[str] = None,
|
||||
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
|
||||
) -> None:
|
||||
# Work uuid
|
||||
self.uid: str = uid if uid is not None else uuid4().hex
|
||||
|
@ -41,6 +45,7 @@ class Work(ABC, Generic[T]):
|
|||
self.event_queue = event_queue
|
||||
# Accept work
|
||||
self.work = work
|
||||
self.upstream_conn_pool = upstream_conn_pool
|
||||
|
||||
@abstractmethod
|
||||
async def get_events(self) -> Dict[int, int]:
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from .connection import TcpConnection, TcpConnectionUninitializedException
|
||||
from .client import TcpClientConnection
|
||||
from .server import TcpServerConnection
|
||||
from .pool import ConnectionPool
|
||||
from .pool import UpstreamConnectionPool
|
||||
from .types import tcpConnectionTypes
|
||||
|
||||
__all__ = [
|
||||
|
@ -25,5 +25,5 @@ __all__ = [
|
|||
'TcpServerConnection',
|
||||
'TcpClientConnection',
|
||||
'tcpConnectionTypes',
|
||||
'ConnectionPool',
|
||||
'UpstreamConnectionPool',
|
||||
]
|
||||
|
|
|
@ -17,6 +17,9 @@ import logging
|
|||
from typing import Set, Dict, Tuple
|
||||
|
||||
from ...common.flag import flags
|
||||
from ...common.types import Readables, Writables
|
||||
|
||||
from ..acceptor.work import Work
|
||||
|
||||
from .server import TcpServerConnection
|
||||
|
||||
|
@ -31,10 +34,10 @@ flags.add_argument(
|
|||
)
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
class UpstreamConnectionPool(Work[TcpServerConnection]):
|
||||
"""Manages connection pool to upstream servers.
|
||||
|
||||
`ConnectionPool` avoids need to reconnect with the upstream
|
||||
`UpstreamConnectionPool` avoids need to reconnect with the upstream
|
||||
servers repeatedly when a reusable connection is available
|
||||
in the pool.
|
||||
|
||||
|
@ -47,16 +50,16 @@ class ConnectionPool:
|
|||
the pool users. Example, if acquired connection
|
||||
is stale, reacquire.
|
||||
|
||||
TODO: Ideally, ConnectionPool must be shared across
|
||||
TODO: Ideally, `UpstreamConnectionPool` must be shared across
|
||||
all cores to make SSL session cache to also work
|
||||
without additional out-of-bound synchronizations.
|
||||
|
||||
TODO: ConnectionPool currently WON'T work for
|
||||
TODO: `UpstreamConnectionPool` currently WON'T work for
|
||||
HTTPS connection. This is because of missing support for
|
||||
session cache, session ticket, abbr TLS handshake
|
||||
and other necessary features to make it work.
|
||||
|
||||
NOTE: However, for all HTTP only connections, ConnectionPool
|
||||
NOTE: However, for all HTTP only connections, `UpstreamConnectionPool`
|
||||
can be used to save upon connection setup time and
|
||||
speed-up performance of requests.
|
||||
"""
|
||||
|
@ -113,3 +116,9 @@ class ConnectionPool:
|
|||
assert not conn.is_reusable()
|
||||
# Reset for reusability
|
||||
conn.reset()
|
||||
|
||||
async def get_events(self) -> Dict[int, int]:
|
||||
return await super().get_events()
|
||||
|
||||
async def handle_events(self, readables: Readables, writables: Writables) -> bool:
|
||||
return await super().handle_events(readables, writables)
|
||||
|
|
|
@ -100,6 +100,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
self.work,
|
||||
self.request,
|
||||
self.event_queue,
|
||||
self.upstream_conn_pool,
|
||||
)
|
||||
self.plugins[instance.name()] = instance
|
||||
logger.debug('Handling connection %r' % self.work.connection)
|
||||
|
|
|
@ -12,7 +12,7 @@ import socket
|
|||
import argparse
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, List, Union, Optional
|
||||
from typing import Tuple, List, Union, Optional, TYPE_CHECKING
|
||||
|
||||
from .parser import HttpParser
|
||||
|
||||
|
@ -20,6 +20,9 @@ from ..common.types import Readables, Writables
|
|||
from ..core.event import EventQueue
|
||||
from ..core.connection import TcpClientConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..core.connection import UpstreamConnectionPool
|
||||
|
||||
|
||||
class HttpProtocolHandlerPlugin(ABC):
|
||||
"""Base HttpProtocolHandler Plugin class.
|
||||
|
@ -50,12 +53,14 @@ class HttpProtocolHandlerPlugin(ABC):
|
|||
client: TcpClientConnection,
|
||||
request: HttpParser,
|
||||
event_queue: EventQueue,
|
||||
upstream_conn_pool: Optional['UpstreamConnectionPool'] = None,
|
||||
):
|
||||
self.uid: str = uid
|
||||
self.flags: argparse.Namespace = flags
|
||||
self.client: TcpClientConnection = client
|
||||
self.request: HttpParser = request
|
||||
self.event_queue = event_queue
|
||||
self.upstream_conn_pool = upstream_conn_pool
|
||||
super().__init__()
|
||||
|
||||
def name(self) -> str:
|
||||
|
|
|
@ -44,7 +44,7 @@ from ...common.utils import text_
|
|||
from ...common.pki import gen_public_key, gen_csr, sign_csr
|
||||
|
||||
from ...core.event import eventNames
|
||||
from ...core.connection import TcpServerConnection, ConnectionPool
|
||||
from ...core.connection import TcpServerConnection
|
||||
from ...core.connection import TcpConnectionUninitializedException
|
||||
from ...common.flag import flags
|
||||
|
||||
|
@ -140,9 +140,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
# connection pool operations.
|
||||
lock = threading.Lock()
|
||||
|
||||
# Shared connection pool
|
||||
pool = ConnectionPool()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any, **kwargs: Any,
|
||||
|
@ -200,10 +197,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
|
||||
def _close_and_release(self) -> bool:
|
||||
if self.flags.enable_conn_pool:
|
||||
assert self.upstream and not self.upstream.closed
|
||||
assert self.upstream and not self.upstream.closed and self.upstream_conn_pool
|
||||
self.upstream.closed = True
|
||||
with self.lock:
|
||||
self.pool.release(self.upstream)
|
||||
self.upstream_conn_pool.release(self.upstream)
|
||||
self.upstream = None
|
||||
return True
|
||||
|
||||
|
@ -391,9 +388,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
return
|
||||
|
||||
if self.flags.enable_conn_pool:
|
||||
assert self.upstream_conn_pool
|
||||
# Release the connection for reusability
|
||||
with self.lock:
|
||||
self.pool.release(self.upstream)
|
||||
self.upstream_conn_pool.release(self.upstream)
|
||||
return
|
||||
|
||||
try:
|
||||
|
@ -589,8 +587,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
host, port = self.request.host, self.request.port
|
||||
if host and port:
|
||||
if self.flags.enable_conn_pool:
|
||||
assert self.upstream_conn_pool
|
||||
with self.lock:
|
||||
created, self.upstream = self.pool.acquire(
|
||||
created, self.upstream = self.upstream_conn_pool.acquire(
|
||||
text_(host), port,
|
||||
)
|
||||
else:
|
||||
|
@ -642,8 +641,9 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
),
|
||||
)
|
||||
if self.flags.enable_conn_pool:
|
||||
assert self.upstream_conn_pool
|
||||
with self.lock:
|
||||
self.pool.release(self.upstream)
|
||||
self.upstream_conn_pool.release(self.upstream)
|
||||
raise ProxyConnectionFailed(
|
||||
text_(host), port, repr(e),
|
||||
) from e
|
||||
|
|
|
@ -88,7 +88,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin):
|
|||
must be bootstrapped within it's own re-usable and garbage collected pool,
|
||||
to avoid establishing a new upstream proxy connection for each client request.
|
||||
|
||||
See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
|
||||
See :class:`~proxy.core.connection.pool.UpstreamConnectionPool` which is a work
|
||||
in progress for SSL cache handling.
|
||||
"""
|
||||
# We don't want to send private IP requests to remote proxies
|
||||
|
|
|
@ -101,6 +101,7 @@ class TestAcceptor(unittest.TestCase):
|
|||
mock_client.return_value,
|
||||
flags=self.flags,
|
||||
event_queue=None,
|
||||
upstream_conn_pool=None,
|
||||
)
|
||||
mock_thread.assert_called_with(
|
||||
target=self.flags.work_klass.return_value.run,
|
||||
|
|
|
@ -12,14 +12,14 @@ import unittest
|
|||
|
||||
from unittest import mock
|
||||
|
||||
from proxy.core.connection import ConnectionPool
|
||||
from proxy.core.connection import UpstreamConnectionPool
|
||||
|
||||
|
||||
class TestConnectionPool(unittest.TestCase):
|
||||
|
||||
@mock.patch('proxy.core.connection.pool.TcpServerConnection')
|
||||
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
|
||||
pool = ConnectionPool()
|
||||
pool = UpstreamConnectionPool()
|
||||
addr = ('localhost', 1234)
|
||||
# Mock
|
||||
mock_conn = mock_tcp_server_connection.return_value
|
||||
|
@ -50,7 +50,7 @@ class TestConnectionPool(unittest.TestCase):
|
|||
def test_closed_connections_are_removed_on_release(
|
||||
self, mock_tcp_server_connection: mock.Mock,
|
||||
) -> None:
|
||||
pool = ConnectionPool()
|
||||
pool = UpstreamConnectionPool()
|
||||
addr = ('localhost', 1234)
|
||||
# Mock
|
||||
mock_conn = mock_tcp_server_connection.return_value
|
||||
|
|
Loading…
Reference in New Issue