Hook `UpstreamConnectionPool` lifecycle within `Threadless` (#921)

* Hook connection pool lifecycle within threadless

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix test

* Fix spell

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Abhinav Singh 2021-12-29 17:37:15 +05:30 committed by GitHub
parent 263c067301
commit 46c942f947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 142 additions and 75 deletions

View File

@ -88,6 +88,7 @@ class Threadless(ABC, Generic[T]):
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
self._total: int = 0
self._upstream_conn_pool: Optional[UpstreamConnectionPool] = None
self._upstream_conn_filenos: Set[int] = set()
if self.flags.enable_conn_pool:
self._upstream_conn_pool = UpstreamConnectionPool()
@ -176,14 +177,25 @@ class Threadless(ABC, Generic[T]):
data=work_id,
)
self.registered_events_by_work_ids[work_id][fileno] = mask
# logger.debug(
# 'fd#{0} modified for mask#{1} by work#{2}'.format(
# fileno, mask, work_id,
# ),
# )
logger.debug(
'fd#{0} modified for mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
# else:
# logger.info(
# 'fd#{0} by work#{1} not modified'.format(fileno, work_id))
elif fileno in self._upstream_conn_filenos:
# Descriptor offered by work, but is already registered by connection pool
# Most likely because work has acquired a reusable connection.
self.selector.modify(fileno, events=mask, data=work_id)
self.registered_events_by_work_ids[work_id][fileno] = mask
self._upstream_conn_filenos.remove(fileno)
logger.debug(
'fd#{0} borrowed with mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
# Can throw ValueError: Invalid file descriptor: -1
#
# A guard within Work classes may not help here due to
@ -193,16 +205,33 @@ class Threadless(ABC, Generic[T]):
#
# TODO: Also remove offending work from pool to avoid spin loop.
elif fileno != -1:
self.selector.register(
fileno, events=mask,
data=work_id,
)
self.selector.register(fileno, events=mask, data=work_id)
self.registered_events_by_work_ids[work_id][fileno] = mask
# logger.debug(
# 'fd#{0} registered for mask#{1} by work#{2}'.format(
# fileno, mask, work_id,
# ),
# )
logger.debug(
'fd#{0} registered for mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
async def _update_conn_pool_events(self) -> None:
if not self._upstream_conn_pool:
return
assert self.selector is not None
new_conn_pool_events = await self._upstream_conn_pool.get_events()
old_conn_pool_filenos = self._upstream_conn_filenos.copy()
self._upstream_conn_filenos.clear()
new_conn_pool_filenos = set(new_conn_pool_events.keys())
new_conn_pool_filenos.difference_update(old_conn_pool_filenos)
for fileno in new_conn_pool_filenos:
self.selector.register(
fileno,
events=new_conn_pool_events[fileno],
data=0,
)
self._upstream_conn_filenos.add(fileno)
old_conn_pool_filenos.difference_update(self._upstream_conn_filenos)
for fileno in old_conn_pool_filenos:
self.selector.unregister(fileno)
async def _update_selector(self) -> None:
assert self.selector is not None
@ -215,6 +244,7 @@ class Threadless(ABC, Generic[T]):
if work_id in unfinished_work_ids:
continue
await self._update_work_events(work_id)
await self._update_conn_pool_events()
async def _selected_events(self) -> Tuple[
Dict[int, Tuple[Readables, Writables]],
@ -235,9 +265,6 @@ class Threadless(ABC, Generic[T]):
"""
assert self.selector is not None
await self._update_selector()
events = self.selector.select(
timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT,
)
# Keys are work_id and values are 2-tuple indicating
# readables & writables that work_id is interested in
# and are ready for IO.
@ -248,6 +275,11 @@ class Threadless(ABC, Generic[T]):
# When ``work_queue_fileno`` returns None,
# always return True for the boolean value.
new_work_available = True
events = self.selector.select(
timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT,
)
for key, mask in events:
if not new_work_available and wqfileno is not None and key.fileobj == wqfileno:
assert mask & selectors.EVENT_READ
@ -302,9 +334,17 @@ class Threadless(ABC, Generic[T]):
assert self.loop
tasks: Set['asyncio.Task[bool]'] = set()
for work_id in work_by_ids:
task = self.loop.create_task(
self.works[work_id].handle_events(*work_by_ids[work_id]),
)
if work_id == 0:
assert self._upstream_conn_pool
task = self.loop.create_task(
self._upstream_conn_pool.handle_events(
*work_by_ids[work_id],
),
)
else:
task = self.loop.create_task(
self.works[work_id].handle_events(*work_by_ids[work_id]),
)
task._work_id = work_id # type: ignore[attr-defined]
# task.set_name(work_id)
tasks.add(task)

View File

@ -12,6 +12,7 @@
reusability
"""
import socket
import logging
import selectors
@ -45,11 +46,19 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
A separate pool is maintained for each upstream server.
So internally, it's a pool of pools.
TODO: Listen for read events from the connections
to remove them from the pool when peer closes the
connection. This can also be achieved lazily by
the pool users. Example, if acquired connection
is stale, reacquire.
Internal data structure maintains references to connection objects
that pool owns or has borrowed. Borrowed connections are marked as
NOT reusable.
For reusable connections only, pool listens for read events
to detect broken connections. This can happen if pool has opened
a connection, which was never used and eventually reaches
upstream server timeout limit.
When a borrowed connection is returned back to the pool,
the connection is marked as reusable again. However, if
returned connection has already been closed, it is removed
from the internal data structure.
TODO: Ideally, `UpstreamConnectionPool` must be shared across
all cores to make SSL session cache to also work
@ -60,29 +69,25 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
session cache, session ticket, abbr TLS handshake
and other necessary features to make it work.
NOTE: However, for all HTTP only connections, `UpstreamConnectionPool`
can be used to save upon connection setup time and
speed-up performance of requests.
NOTE: However, currently for all HTTP only upstream connections,
`UpstreamConnectionPool` can be used to remove slow starts.
"""
def __init__(self) -> None:
# Pools of connection per upstream server
self.connections: Dict[int, TcpServerConnection] = {}
self.pools: Dict[Tuple[str, int], Set[TcpServerConnection]] = {}
def add(self, addr: Tuple[str, int]) -> TcpServerConnection:
# Create new connection
"""Creates and add a new connection to the pool."""
new_conn = TcpServerConnection(addr[0], addr[1])
new_conn.connect()
if addr not in self.pools:
self.pools[addr] = set()
self.pools[addr].add(new_conn)
self.connections[new_conn.connection.fileno()] = new_conn
self._add(new_conn)
return new_conn
def acquire(self, addr: Tuple[str, int]) -> Tuple[bool, TcpServerConnection]:
"""Returns a connection for use with the server."""
# Return a reusable connection if available
"""Returns a reusable connection from the pool.
If none exists, will create and return a new connection."""
if addr in self.pools:
for old_conn in self.pools[addr]:
if old_conn.is_reusable():
@ -102,40 +107,63 @@ class UpstreamConnectionPool(Work[TcpServerConnection]):
return True, new_conn
def release(self, conn: TcpServerConnection) -> None:
"""Release the connection.
"""Release a previously acquired connection.
If the connection has not been closed,
then it will be retained in the pool for reusability.
"""
assert not conn.is_reusable()
if conn.closed:
logger.debug(
'Removing connection#{2} from pool from upstream {0}:{1}'.format(
conn.addr[0], conn.addr[1], id(conn),
),
)
self.pools[conn.addr].remove(conn)
self._remove(conn.connection.fileno())
else:
logger.debug(
'Retaining connection#{2} to upstream {0}:{1}'.format(
conn.addr[0], conn.addr[1], id(conn),
),
)
assert not conn.is_reusable()
# Reset for reusability
conn.reset()
async def get_events(self) -> Dict[int, int]:
"""Returns read event flag for all reusable connections in the pool."""
events = {}
for connections in self.pools.values():
for conn in connections:
events[conn.connection.fileno()] = selectors.EVENT_READ
if conn.is_reusable():
events[conn.connection.fileno()] = selectors.EVENT_READ
return events
async def handle_events(self, readables: Readables, _writables: Writables) -> bool:
for r in readables:
"""Removes reusable connection from the pool.
When pool is the owner of connection, we don't expect a read event from upstream
server. A read event means either upstream closed the connection or connection
has somehow reached an illegal state e.g. upstream sending data for previous
connection acquisition lifecycle."""
for fileno in readables:
if TYPE_CHECKING:
assert isinstance(r, int)
conn = self.connections[r]
self.pools[conn.addr].remove(conn)
del self.connections[r]
assert isinstance(fileno, int)
logger.debug('Upstream fd#{0} is read ready'.format(fileno))
self._remove(fileno)
return False
def _add(self, conn: TcpServerConnection) -> None:
"""Adds a new connection to internal data structure."""
if conn.addr not in self.pools:
self.pools[conn.addr] = set()
self.pools[conn.addr].add(conn)
self.connections[conn.connection.fileno()] = conn
def _remove(self, fileno: int) -> None:
"""Remove a connection by descriptor from the internal data structure."""
conn = self.connections[fileno]
logger.debug('Removing conn#{0} from pool'.format(id(conn)))
conn.connection.shutdown(socket.SHUT_WR)
conn.close()
self.pools[conn.addr].remove(conn)
del self.connections[fileno]

View File

@ -39,11 +39,11 @@ class TcpServerConnection(TcpConnection):
addr: Optional[Tuple[str, int]] = None,
source_address: Optional[Tuple[str, int]] = None,
) -> None:
if self._conn is None:
self._conn = new_socket_connection(
addr or self.addr, source_address=source_address,
)
self.closed = False
assert self._conn is None
self._conn = new_socket_connection(
addr or self.addr, source_address=source_address,
)
self.closed = False
def wrap(self, hostname: str, ca_file: Optional[str]) -> None:
ctx = ssl.create_default_context(

View File

@ -103,7 +103,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
self.upstream_conn_pool,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.work.connection)
logger.debug('Handling connection %s' % self.work.address)
def is_inactive(self) -> bool:
if not self.work.has_buffer() and \
@ -123,9 +123,8 @@ class HttpProtocolHandler(BaseTcpServerHandler):
for plugin in self.plugins.values():
plugin.on_client_connection_close()
logger.debug(
'Closing client connection %r '
'at address %s has buffer %s' %
(self.work.connection, self.work.address, self.work.has_buffer()),
'Closing client connection %s has buffer %s' %
(self.work.address, self.work.has_buffer()),
)
conn = self.work.connection
# Unwrap if wrapped before shutdown.
@ -247,7 +246,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
async def handle_writables(self, writables: Writables) -> bool:
if self.work.connection.fileno() in writables and self.work.has_buffer():
logger.debug('Client is ready for writes, flushing buffer')
logger.debug('Client is write ready, flushing...')
self.last_activity = time.time()
# TODO(abhinavsingh): This hook could just reside within server recv block
@ -277,7 +276,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
async def handle_readables(self, readables: Readables) -> bool:
if self.work.connection.fileno() in readables:
logger.debug('Client is ready for reads, reading')
logger.debug('Client is read ready, receiving...')
self.last_activity = time.time()
try:
teardown = await super().handle_readables(readables)

View File

@ -77,7 +77,7 @@ class HttpParser:
self.total_size: int = 0
# Buffer to hold unprocessed bytes
self.buffer: bytes = b''
# Internal headers datastructure:
# Internal headers data structure:
# - Keys are lower case header names.
# - Values are 2-tuple containing original
# header and it's value as received.

View File

@ -217,7 +217,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
self.upstream and not self.upstream.closed and \
self.upstream.has_buffer() and \
self.upstream.connection.fileno() in w:
logger.debug('Server is write ready, flushing buffer')
logger.debug('Server is write ready, flushing...')
try:
self.upstream.flush()
except ssl.SSLWantWriteError:
@ -254,7 +254,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
and self.upstream \
and not self.upstream.closed \
and self.upstream.connection.fileno() in r:
logger.debug('Server is ready for reads, reading...')
logger.debug('Server is read ready, receiving...')
try:
raw = self.upstream.recv(self.flags.server_recvbuf_size)
except TimeoutError as e:
@ -401,7 +401,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
pass
finally:
# TODO: Unwrap if wrapped before close?
self.upstream.connection.close()
self.upstream.close()
except TcpConnectionUninitializedException:
pass
finally:
@ -587,10 +587,6 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
host, port = self.request.host, self.request.port
if host and port:
try:
logger.debug(
'Connecting to upstream %s:%d' %
(text_(host), port),
)
# Invoke plugin.resolve_dns
upstream_ip, source_addr = None, None
for plugin in self.plugins.values():
@ -599,6 +595,10 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
)
if upstream_ip or source_addr:
break
logger.debug(
'Connecting to upstream %s:%d' %
(text_(host), port),
)
if self.flags.enable_conn_pool:
assert self.upstream_conn_pool
with self.lock:

View File

@ -121,7 +121,7 @@ class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin):
#
# Failing upstream proxies, must be removed from the pool temporarily.
# A periodic health check must put them back in the pool. This can be achieved
# using a datastructure without having to spawn separate thread/process for health
# using a data structure without having to spawn separate thread/process for health
# check.
raise HttpProtocolException(
'Connection refused by upstream proxy {0}:{1}'.format(

View File

@ -23,9 +23,9 @@ class TestConnectionPool(unittest.TestCase):
@mock.patch('proxy.core.connection.pool.TcpServerConnection')
def test_acquire_and_release_and_reacquire(self, mock_tcp_server_connection: mock.Mock) -> None:
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
addr = mock_conn.addr
mock_conn.is_reusable.side_effect = [
False, True, True,
]
@ -33,7 +33,7 @@ class TestConnectionPool(unittest.TestCase):
# Acquire
created, conn = pool.acquire(addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
self.assertEqual(conn, mock_conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
@ -54,26 +54,25 @@ class TestConnectionPool(unittest.TestCase):
self, mock_tcp_server_connection: mock.Mock,
) -> None:
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
# Mock
mock_conn = mock_tcp_server_connection.return_value
mock_conn.closed = True
mock_conn.addr = addr
addr = mock_conn.addr
# Acquire
created, conn = pool.acquire(addr)
self.assertTrue(created)
mock_tcp_server_connection.assert_called_once_with(*addr)
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
self.assertEqual(conn, mock_conn)
self.assertEqual(len(pool.pools[addr]), 1)
self.assertTrue(conn in pool.pools[addr])
# Release
mock_conn.is_reusable.return_value = False
pool.release(conn)
self.assertEqual(len(pool.pools[addr]), 0)
# Acquire
created, conn = pool.acquire(addr)
self.assertTrue(created)
self.assertEqual(mock_tcp_server_connection.call_count, 2)
mock_conn.is_reusable.assert_not_called()
class TestConnectionPoolAsync:
@ -84,10 +83,10 @@ class TestConnectionPoolAsync:
'proxy.core.connection.pool.TcpServerConnection',
)
pool = UpstreamConnectionPool()
addr = ('localhost', 1234)
mock_conn = mock_tcp_server_connection.return_value
addr = mock_conn.addr
pool.add(addr)
mock_tcp_server_connection.assert_called_once_with(*addr)
mock_tcp_server_connection.assert_called_once_with(addr[0], addr[1])
mock_conn.connect.assert_called_once()
events = await pool.get_events()
print(events)

View File

@ -79,7 +79,7 @@ class TestTcpConnection(unittest.TestCase):
)
@mock.patch('proxy.core.connection.server.new_socket_connection')
def testTcpServerIgnoresDoubleConnectSilently(
def testTcpServerWillNotIgnoreDoubleConnectAttemptsSilently(
self,
mock_new_socket_connection: mock.Mock,
) -> None:
@ -87,7 +87,8 @@ class TestTcpConnection(unittest.TestCase):
str(DEFAULT_IPV6_HOSTNAME), DEFAULT_PORT,
)
conn.connect()
conn.connect()
with self.assertRaises(AssertionError):
conn.connect()
mock_new_socket_connection.assert_called_once()
@mock.patch('socket.socket')