Hook `UpstreamConnectionPool` lifecycle within `Threadless` (#921)
* Hook connection pool lifecycle within threadless * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix test * Fix spell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
263c067301
commit
46c942f947
|
@ -88,6 +88,7 @@ class Threadless(ABC, Generic[T]):
|
|||
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
|
||||
self._total: int = 0
|
||||
self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None
|
||||
self._upstream_conn_filenos: Set[int] = set()
|
||||
if self.flags.enable_conn_pool:
|
||||
self._upstream_conn_pool = UpstreamConnectionPool()
|
||||
|
||||
|
@ -176,14 +177,25 @@ class Threadless(ABC, Generic[T]):
|
|||
data=work_id,
|
||||
)
|
||||
self.registered_events_by_work_ids[work_id][fileno] = mask
|
||||
# logger.debug(
|
||||
# 'fd#{0} modified for mask#{1} by work#{2}'.format(
|
||||
# fileno, mask, work_id,
|
||||
# ),
|
||||
# )
|
||||
logger.debug(
|
||||
'fd#{0} modified for mask#{1} by work#{2}'.format(
|
||||
fileno, mask, work_id,
|
||||
),
|
||||
)
|
||||
# else:
|
||||
# logger.info(
|
||||
# 'fd#{0} by work#{1} not modified'.format(fileno, work_id))
|
||||
elif fileno in self._upstream_conn_filenos:
|
||||
# Descriptor offered by work, but is already registered by connection pool
|
||||
# Most likely because work has acquired a reusable connection.
|
||||
self.selector.modify(fileno, events=mask, data=work_id)
|
||||
self.registered_events_by_work_ids[work_id][fileno] = mask
|
||||
self._upstream_conn_filenos.remove(fileno)
|
||||
logger.debug(
|
||||
'fd#{0} borrowed with mask#{1} by work#{2}'.format(
|
||||
fileno, mask, work_id,
|
||||
),
|
||||
)
|
||||
# Can throw ValueError: Invalid file descriptor: -1
|
||||
#
|
||||
# A guard within Work classes may not help here due to
|
||||
|
@ -193,16 +205,33 @@ class Threadless(ABC, Generic[T]):
|
|||
#
|
||||
# TODO: Also remove offending work from pool to avoid spin loop.
|
||||
elif fileno != -1:
|
||||
self.selector.register(
|
||||
fileno, events=mask,
|
||||
data=work_id,
|
||||
)
|
||||
self.selector.register(fileno, events=mask, data=work_id)
|
||||
self.registered_events_by_work_ids[work_id][fileno] = mask
|
||||
# logger.debug(
|
||||
# 'fd#{0} registered for mask#{1} by work#{2}'.format(
|
||||
# fileno, mask, work_id,
|
||||
# ),
|
||||
# )
|
||||
logger.debug(
|
||||
'fd#{0} registered for mask#{1} by work#{2}'.format(
|
||||
fileno, mask, work_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def _update_conn_pool_events(self) -> None:
|
||||
if not self._upstream_conn_pool:
|
||||
return
|
||||
assert self.selector is not None
|
||||
new_conn_pool_events = await self._upstream_conn_pool.get_events()
|
||||
old_conn_pool_filenos = self._upstream_conn_filenos.copy()
|
||||
self._upstream_conn_filenos.clear()
|
||||
new_conn_pool_filenos = set(new_conn_pool_events.keys())
|
||||
new_conn_pool_filenos.difference_update(old_conn_pool_filenos)
|
||||
for fileno in new_conn_pool_filenos:
|
||||
self.selector.register(
|
||||
fileno,
|
||||
events=new_conn_pool_events[fileno],
|
||||
data=0,
|
||||
)
|
||||
self._upstream_conn_filenos.add(fileno)
|
||||
old_conn_pool_filenos.difference_update(self._upstream_conn_filenos)
|
||||
for fileno in old_conn_pool_filenos:
|
||||
self.selector.unregister(fileno)
|
||||
|
||||
async def _update_selector(self) -> None:
|
||||
assert self.selector is not None
|
||||
|
@ -215,6 +244,7 @@ class Threadless(ABC, Generic[T]):
|
|||
if work_id in unfinished_work_ids:
|
||||
continue
|
||||
await self._update_work_events(work_id)
|
||||
await self._update_conn_pool_events()
|
||||
|
||||
async def _selected_events(self) -> Tuple[
|
||||
Dict[int, Tuple[Readables, Writables]],
|
||||
|
@ -235,9 +265,6 @@ class Threadless(ABC, Generic[T]):
|
|||
"""
|
||||
assert self.selector is not None
|
||||
await self._update_selector()
|
||||
events = self.selector.select(
|
||||
timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT,
|
||||
)
|
||||
# Keys are work_id and values are 2-tuple indicating
|
||||
# readables & writables that work_id is interested in
|
||||
# and are ready for IO.
|
||||
|
@ -248,6 +275,11 @@ class Threadless(ABC, Generic[T]):
|
|||
# When ``work_queue_fileno`` returns None,
|
||||
# always return True for the boolean value.
|
||||
new_work_available = True
|
||||
|
||||
events = self.selector.select(
|
||||
timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT,
|
||||
)
|
||||
|
||||
for key, mask in events:
|
||||
if not new_work_available and wqfileno is not None and key.fileobj == wqfileno:
|
||||
assert mask & selectors.EVENT_READ
|
||||
|
@ -302,9 +334,17 @@ class Threadless(ABC, Generic[T]):
|
|||
assert self.loop
|
||||
tasks: Set['asyncio.Task[bool]'] = set()
|
||||
for work_id in work_by_ids:
|
||||
task = self.loop.create_task(
|
||||
self.works[work_id].handle_events(*work_by_ids[work_id]),
|
||||
)
|
||||
if work_id == 0:
|
||||
assert self._upstream_conn_pool
|
||||
task = self.loop.create_task(
|
||||
self._upstream_conn_pool.handle_events(
|
||||
*work_by_ids[work_id],
|
||||
),
|
||||
)
|
||||
else:
|
||||
task = self.loop.create_task(
|
||||
self.works[work_id].handle_events(*work_by_ids[work_id]),
|
||||
)
|
||||
task._work_id = work_id # type: ignore[attr-defined]
|
||||
# task.set_name(work_id)
|
||||
tasks.add(task)
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
reusability
|
||||
"""
|
||||
import socket
|
||||
import logging
|
||||
import selectors
|
||||
|
||||
|
@ -45,11 +46,19 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
|
|||
A separate pool is maintained for each upstream server.
|
||||
So internally, it's a pool of pools.
|
||||
|
||||
TODO: Listen for read events from the connections
|
||||
to remove them from the pool when peer closes the
|
||||
connection. This can also be achieved lazily by
|
||||
the pool users. Example, if acquired connection
|
||||
is stale, reacquire.
|
||||
Internal data structure maintains references to connection objects
|
||||
that pool owns or has borrowed. Borrowed connections are marked as
|
||||
NOT reusable.
|
||||
|
||||
For reusable connections only, pool listens for read events
|
||||
to detect broken connections. This can happen if pool has opened
|
||||
a connection, which was never used and eventually reaches
|
||||
upstream server timeout limit.
|
||||
|
||||
When a borrowed connection is returned back to the pool,
|
||||
the connection is marked as reusable again. However, if
|
||||
returned connection has already been closed, it is removed
|
||||
from the internal data structure.
|
||||
|
||||
TODO: Ideally, `UpstreamConnectionPool` must be shared across
|
||||
all cores to make SSL session cache to also work
|
||||
|
@ -60,29 +69,25 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
|
|||
session cache, session ticket, abbr TLS handshake
|
||||
and other necessary features to make it work.
|
||||
|
||||
NOTE: However, for all HTTP only connections, `UpstreamConnectionPool`
|
||||
can be used to save upon connection setup time and
|
||||
speed-up performance of requests.
|
||||
NOTE: However, currently for all HTTP only upstream connections,
|
||||
`UpstreamConnectionPool` can be used to remove slow starts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Pools of connection per upstream server
|
||||
self.connections: Dict[int, TcpServerConnection] = {}
|
||||
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}
|
||||
|
||||
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
|
||||
# Create new connection
|
||||
"""Creates and add a new connection to the pool."""
|
||||
new_conn = TcpServerConnection(addr[0], addr[1])
|
||||
new_conn.connect()
|
||||
if addr not in self.pools:
|
||||
self.pools[addr] = set()
|
||||
self.pools[addr].add(new_conn)
|
||||
self.connections[new_conn.connection.fileno()] = new_conn
|
||||
self._add(new_conn)
|
||||
return new_conn
|
||||
|
||||
def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
|
||||
"""Returns a connection for use with the server."""
|
||||
# Return a reusable connection if available
|
||||
"""Returns a reusable connection from the pool.
|
||||
|
||||
If none exists, will create and return a new connection."""
|
||||
if addr in self.pools:
|
||||
for old_conn in self.pools[addr]:
|
||||
if old_conn.is_reusable():
|
||||
|
@ -102,40 +107,63 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
|
|||
return True, new_conn
|
||||
|
||||
def release(self, conn: TcpServerConnection) -> None:
|
||||
"""Release the connection.
|
||||
"""Release a previously acquired connection.
|
||||
|
||||
If the connection has not been closed,
|
||||
then it will be retained in the pool for reusability.
|
||||
"""
|
||||
assert not conn.is_reusable()
|
||||
if conn.closed:
|
||||
logger.debug(
|
||||
'Removing connection#{2} from pool from upstream {0}:{1}'.format(
|
||||
conn.addr[0], conn.addr[1], id(conn),
|
||||
),
|
||||
)
|
||||
self.pools[conn.addr].remove(conn)
|
||||
self._remove(conn.connection.fileno())
|
||||
else:
|
||||
logger.debug(
|
||||
'Retaining connection#{2} to upstream {0}:{1}'.format(
|
||||
conn.addr[0], conn.addr[1], id(conn),
|
||||
),
|
||||
)
|
||||
assert not conn.is_reusable()
|
||||
# Reset for reusability
|
||||
conn.reset()
|
||||
|
||||
async def get_events(self) -> Dict[int, int]:
|
||||
"""Returns read event flag for all reusable connections in the pool."""
|
||||
events = {}
|
||||
for connections in self.pools.values():
|
||||
for conn in connections:
|
||||
events[conn.connection.fileno()] = selectors.EVENT_READ
|
||||
if conn.is_reusable():
|
||||
events[conn.connection.fileno()] = selectors.EVENT_READ
|
||||
return events
|
||||
|
||||
async def handle_events(self, readables: Readables, _writables: Writables) -> bool:
|
||||
for r in readables:
|
||||
"""Removes reusable connection from the pool.
|
||||
|
||||
When pool is the owner of connection, we don't expect a read event from upstream
|
||||
server. A read event means either upstream closed the connection or connection
|
||||
has somehow reached an illegal state e.g. upstream sending data for previous
|
||||
connection acquisition lifecycle."""
|
||||
for fileno in readables:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(r, int)
|
||||
conn = self.connections[r]
|
||||
self.pools[conn.addr].remove(conn)
|
||||
del self.connections[r]
|
||||
assert isinstance(fileno, int)
|
||||
logger.debug('Upstream fd#{0} is read ready'.format(fileno))
|
||||
self._remove(fileno)
|
||||
return False
|
||||
|
||||
def _add(self, conn: TcpServerConnection) -> None:
|
||||
"""Adds a new connection to internal data structure."""
|
||||
if conn.addr not in self.pools:
|
||||
self.pools[conn.addr] = set()
|
||||
self.pools[conn.addr].add(conn)
|
||||
self.connections[conn.connection.fileno()] = conn
|
||||
|
||||
def _remove(self, fileno: int) -> None:
|
||||
"""Remove a connection by descriptor from the internal data structure."""
|
||||
conn = self.connections[fileno]
|
||||
logger.debug('Removing conn#{0} from pool'.format(id(conn)))
|
||||
conn.connection.shutdown(socket.SHUT_WR)
|
||||
conn.close()
|
||||
self.pools[conn.addr].remove(conn)
|
||||
del self.connections[fileno]
|
||||
|
|
|
@ -39,11 +39,11 @@ class TcpServerConnection(TcpConnection):
|
|||
addr: Optional[Tuple[str, int]] = None,
|
||||
source_address: Optional[Tuple[str, int]] = None,
|
||||
) -> None:
|
||||
if self._conn is None:
|
||||
self._conn = new_socket_connection(
|
||||
addr or self.addr, source_address=source_address,
|
||||
)
|
||||
self.closed = False
|
||||
assert self._conn is None
|
||||
self._conn = new_socket_connection(
|
||||
addr or self.addr, source_address=source_address,
|
||||
)
|
||||
self.closed = False
|
||||
|
||||
def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
|
||||
ctx = ssl.create_default_context(
|
||||
|
|
|
@ -103,7 +103,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
self.upstream_conn_pool,
|
||||
)
|
||||
self.plugins[instance.name()] = instance
|
||||
logger.debug('Handling connection %r' % self.work.connection)
|
||||
logger.debug('Handling connection %s' % self.work.address)
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
if not self.work.has_buffer() and \
|
||||
|
@ -123,9 +123,8 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
for plugin in self.plugins.values():
|
||||
plugin.on_client_connection_close()
|
||||
logger.debug(
|
||||
'Closing client connection %r '
|
||||
'at address %s has buffer %s' %
|
||||
(self.work.connection, self.work.address, self.work.has_buffer()),
|
||||
'Closing client connection %s has buffer %s' %
|
||||
(self.work.address, self.work.has_buffer()),
|
||||
)
|
||||
conn = self.work.connection
|
||||
# Unwrap if wrapped before shutdown.
|
||||
|
@ -247,7 +246,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
|
||||
async def handle_writables(self, writables: Writables) -> bool:
|
||||
if self.work.connection.fileno() in writables and self.work.has_buffer():
|
||||
logger.debug('Client is ready for writes, flushing buffer')
|
||||
logger.debug('Client is write ready, flushing...')
|
||||
self.last_activity = time.time()
|
||||
|
||||
# TODO(abhinavsingh): This hook could just reside within server recv block
|
||||
|
@ -277,7 +276,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
|
||||
async def handle_readables(self, readables: Readables) -> bool:
|
||||
if self.work.connection.fileno() in readables:
|
||||
logger.debug('Client is ready for reads, reading')
|
||||
logger.debug('Client is read ready, receiving...')
|
||||
self.last_activity = time.time()
|
||||
try:
|
||||
teardown = await super().handle_readables(readables)
|
||||
|
|
|
@ -77,7 +77,7 @@ class HttpParser:
|
|||
self.total_size: int = 0
|
||||
# Buffer to hold unprocessed bytes
|
||||
self.buffer: bytes = b''
|
||||
# Internal headers datastructure:
|
||||
# Internal headers data structure:
|
||||
# - Keys are lower case header names.
|
||||
# - Values are 2-tuple containing original
|
||||
# header and it's value as received.
|
||||
|
|
|
@ -217,7 +217,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
self.upstream and not self.upstream.closed and \
|
||||
self.upstream.has_buffer() and \
|
||||
self.upstream.connection.fileno() in w:
|
||||
logger.debug('Server is write ready, flushing buffer')
|
||||
logger.debug('Server is write ready, flushing...')
|
||||
try:
|
||||
self.upstream.flush()
|
||||
except ssl.SSLWantWriteError:
|
||||
|
@ -254,7 +254,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
and self.upstream \
|
||||
and not self.upstream.closed \
|
||||
and self.upstream.connection.fileno() in r:
|
||||
logger.debug('Server is ready for reads, reading...')
|
||||
logger.debug('Server is read ready, receiving...')
|
||||
try:
|
||||
raw = self.upstream.recv(self.flags.server_recvbuf_size)
|
||||
except TimeoutError as e:
|
||||
|
@ -401,7 +401,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
pass
|
||||
finally:
|
||||
# TODO: Unwrap if wrapped before close?
|
||||
self.upstream.connection.close()
|
||||
self.upstream.close()
|
||||
except TcpConnectionUninitializedException:
|
||||
pass
|
||||
finally:
|
||||
|
@ -587,10 +587,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
host, port = self.request.host, self.request.port
|
||||
if host and port:
|
||||
try:
|
||||
logger.debug(
|
||||
'Connecting to upstream %s:%d' %
|
||||
(text_(host), port),
|
||||
)
|
||||
# Invoke plugin.resolve_dns
|
||||
upstream_ip, source_addr = None, None
|
||||
for plugin in self.plugins.values():
|
||||
|
@ -599,6 +595,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
|
|||
)
|
||||
if upstream_ip or source_addr:
|
||||
break
|
||||
logger.debug(
|
||||
'Connecting to upstream %s:%d' %
|
||||
(text_(host), port),
|
||||
)
|
||||
if self.flags.enable_conn_pool:
|
||||
assert self.upstream_conn_pool
|
||||
with self.lock:
|
||||
|
|
|
@ -121,7 +121,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin):
|
|||
#
|
||||
# Failing upstream proxies, must be removed from the pool temporarily.
|
||||
# A periodic health check must put them back in the pool. This can be achieved
|
||||
# using a datastructure without having to spawn separate thread/process for health
|
||||
# using a data structure without having to spawn separate thread/process for health
|
||||
# check.
|
||||
raise HttpProtocolException(
|
||||
'Connection refused by upstream proxy {0}:{1}'.format(
|
||||
|
|
|
@ -23,9 +23,9 @@ class TestConnectionPool(unittest.TestCase):
|
|||
@mock.patch('proxy.core.connection.pool.TcpServerConnection')
|
||||
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
|
||||
pool = UpstreamConnectionPool()
|
||||
addr = ('localhost', 1234)
|
||||
# Mock
|
||||
mock_conn = mock_tcp_server_connection.return_value
|
||||
addr = mock_conn.addr
|
||||
mock_conn.is_reusable.side_effect = [
|
||||
False, True, True,
|
||||
]
|
||||
|
@ -33,7 +33,7 @@ class TestConnectionPool(unittest.TestCase):
|
|||
# Acquire
|
||||
created, conn = pool.acquire(addr)
|
||||
self.assertTrue(created)
|
||||
mock_tcp_server_connection.assert_called_once_with(*addr)
|
||||
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
|
||||
self.assertEqual(conn, mock_conn)
|
||||
self.assertEqual(len(pool.pools[addr]), 1)
|
||||
self.assertTrue(conn in pool.pools[addr])
|
||||
|
@ -54,26 +54,25 @@ class TestConnectionPool(unittest.TestCase):
|
|||
self, mock_tcp_server_connection: mock.Mock,
|
||||
) -> None:
|
||||
pool = UpstreamConnectionPool()
|
||||
addr = ('localhost', 1234)
|
||||
# Mock
|
||||
mock_conn = mock_tcp_server_connection.return_value
|
||||
mock_conn.closed = True
|
||||
mock_conn.addr = addr
|
||||
addr = mock_conn.addr
|
||||
# Acquire
|
||||
created, conn = pool.acquire(addr)
|
||||
self.assertTrue(created)
|
||||
mock_tcp_server_connection.assert_called_once_with(*addr)
|
||||
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
|
||||
self.assertEqual(conn, mock_conn)
|
||||
self.assertEqual(len(pool.pools[addr]), 1)
|
||||
self.assertTrue(conn in pool.pools[addr])
|
||||
# Release
|
||||
mock_conn.is_reusable.return_value = False
|
||||
pool.release(conn)
|
||||
self.assertEqual(len(pool.pools[addr]), 0)
|
||||
# Acquire
|
||||
created, conn = pool.acquire(addr)
|
||||
self.assertTrue(created)
|
||||
self.assertEqual(mock_tcp_server_connection.call_count, 2)
|
||||
mock_conn.is_reusable.assert_not_called()
|
||||
|
||||
|
||||
class TestConnectionPoolAsync:
|
||||
|
@ -84,10 +83,10 @@ class TestConnectionPoolAsync:
|
|||
'proxy.core.connection.pool.TcpServerConnection',
|
||||
)
|
||||
pool = UpstreamConnectionPool()
|
||||
addr = ('localhost', 1234)
|
||||
mock_conn = mock_tcp_server_connection.return_value
|
||||
addr = mock_conn.addr
|
||||
pool.add(addr)
|
||||
mock_tcp_server_connection.assert_called_once_with(*addr)
|
||||
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
|
||||
mock_conn.connect.assert_called_once()
|
||||
events = await pool.get_events()
|
||||
print(events)
|
||||
|
|
|
@ -79,7 +79,7 @@ class TestTcpConnection(unittest.TestCase):
|
|||
)
|
||||
|
||||
@mock.patch('proxy.core.connection.server.new_socket_connection')
|
||||
def testTcpServerIgnoresDoubleConnectSilently(
|
||||
def testTcpServerWillNotIgnoreDoubleConnectAttemptsSilently(
|
||||
self,
|
||||
mock_new_socket_connection: mock.Mock,
|
||||
) -> None:
|
||||
|
@ -87,7 +87,8 @@ class TestTcpConnection(unittest.TestCase):
|
|||
str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT,
|
||||
)
|
||||
conn.connect()
|
||||
conn.connect()
|
||||
with self.assertRaises(AssertionError):
|
||||
conn.connect()
|
||||
mock_new_socket_connection.assert_called_once()
|
||||
|
||||
@mock.patch('socket.socket')
|
||||
|
|
Loading…
Reference in New Issue