Top-level notion of work not client (#695)

* Top-level notion of work not client

* Update ssl echo server example
This commit is contained in:
Abhinav Singh 2021-11-07 21:43:38 +05:30 committed by GitHub
parent d3cee32909
commit f48771fb41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 77 additions and 74 deletions

View File

@ -53,7 +53,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler):
# Drop the request if not a CONNECT request
if self.request.method != httpMethods.CONNECT:
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
)
return True
@ -66,7 +66,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler):
self.connect_upstream()
# Queue tunnel established response to client
self.client.queue(
self.work.queue(
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)

View File

@ -27,19 +27,19 @@ class EchoSSLServerHandler(BaseTcpServerHandler):
# here using wrap_socket() utility.
assert self.flags.keyfile is not None and self.flags.certfile is not None
conn = wrap_socket(
self.client.connection,
self.work.connection,
self.flags.keyfile,
self.flags.certfile,
)
conn.setblocking(False)
# Upgrade plain TcpClientConnection to SSL connection object
self.client = TcpClientConnection(
conn=conn, addr=self.client.addr,
self.work = TcpClientConnection(
conn=conn, addr=self.work.addr,
)
def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None

View File

@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
"""Sets client socket to non-blocking during initialization."""
def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)
def handle_data(self, data: memoryview) -> Optional[bool]:
# echo back to client
self.client.queue(data)
self.work.queue(data)
return None

View File

@ -25,15 +25,18 @@ class Work(ABC):
def __init__(
self,
client: TcpClientConnection,
work: TcpClientConnection,
flags: argparse.Namespace,
event_queue: Optional[EventQueue] = None,
uid: Optional[UUID] = None,
) -> None:
self.client = client
self.flags = flags
self.event_queue = event_queue
# Work uuid
self.uid: UUID = uid if uid is not None else uuid4()
self.flags = flags
# Eventing core queue
self.event_queue = event_queue
# Accept work
self.work = work
@abstractmethod
def get_events(self) -> Dict[socket.socket, int]:

View File

@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.must_flush_before_shutdown = False
logger.debug('Connection accepted from {0}'.format(self.client.addr))
logger.debug('Connection accepted from {0}'.format(self.work.addr))
@abstractmethod
def handle_data(self, data: memoryview) -> Optional[bool]:
@ -57,14 +57,14 @@ class BaseTcpServerHandler(Work):
# We always want to read from client
# Register for EVENT_READ events
if self.must_flush_before_shutdown is False:
events[self.client.connection] = selectors.EVENT_READ
events[self.work.connection] = selectors.EVENT_READ
# If there is pending buffer for client
# also register for EVENT_WRITE events
if self.client.has_buffer():
if self.client.connection in events:
events[self.client.connection] |= selectors.EVENT_WRITE
if self.work.has_buffer():
if self.work.connection in events:
events[self.work.connection] |= selectors.EVENT_WRITE
else:
events[self.client.connection] = selectors.EVENT_WRITE
events[self.work.connection] = selectors.EVENT_WRITE
return events
def handle_events(
@ -79,32 +79,32 @@ class BaseTcpServerHandler(Work):
if teardown:
logger.debug(
'Shutting down client {0} connection'.format(
self.client.addr,
self.work.addr,
),
)
return teardown
def handle_writables(self, writables: Writables) -> bool:
teardown = False
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug(
'Flushing buffer to client {0}'.format(self.client.addr),
'Flushing buffer to client {0}'.format(self.work.addr),
)
self.client.flush()
self.work.flush()
if self.must_flush_before_shutdown is True:
if not self.client.has_buffer():
if not self.work.has_buffer():
teardown = True
self.must_flush_before_shutdown = False
return teardown
def handle_readables(self, readables: Readables) -> bool:
teardown = False
if self.client.connection in readables:
data = self.client.recv(self.flags.client_recvbuf_size)
if self.work.connection in readables:
data = self.work.recv(self.flags.client_recvbuf_size)
if data is None:
logger.debug(
'Connection closed by client {0}'.format(
self.client.addr,
self.work.addr,
),
)
teardown = True
@ -113,13 +113,13 @@ class BaseTcpServerHandler(Work):
if isinstance(r, bool) and r is True:
logger.debug(
'Implementation signaled shutdown for client {0}'.format(
self.client.addr,
self.work.addr,
),
)
if self.client.has_buffer():
if self.work.has_buffer():
logger.debug(
'Client {0} has pending buffer, will be flushed before shutting down'.format(
self.client.addr,
self.work.addr,
),
)
self.must_flush_before_shutdown = True

View File

@ -43,7 +43,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler):
pass # pragma: no cover
def initialize(self) -> None:
self.client.connection.setblocking(False)
self.work.connection.setblocking(False)
def shutdown(self) -> None:
if self.upstream:
@ -87,7 +87,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler):
print('Connection closed by server')
return True
# tunnel data to client
self.client.queue(data)
self.work.queue(data)
if self.upstream and self.upstream.connection in writables:
self.upstream.flush()
return False

View File

@ -89,25 +89,25 @@ class HttpProtocolHandler(BaseTcpServerHandler):
def initialize(self) -> None:
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
conn = self._optionally_wrap_socket(self.client.connection)
conn = self._optionally_wrap_socket(self.work.connection)
conn.setblocking(False)
# Update client connection reference if connection was wrapped
if self._encryption_enabled():
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
instance: HttpProtocolHandlerPlugin = klass(
self.uid,
self.flags,
self.client,
self.work,
self.request,
self.event_queue,
)
self.plugins[instance.name()] = instance
logger.debug('Handling connection %r' % self.client.connection)
logger.debug('Handling connection %r' % self.work.connection)
def is_inactive(self) -> bool:
if not self.client.has_buffer() and \
if not self.work.has_buffer() and \
self._connection_inactive_for() > self.flags.timeout:
return True
return False
@ -127,20 +127,20 @@ class HttpProtocolHandler(BaseTcpServerHandler):
logger.debug(
'Closing client connection %r '
'at address %r has buffer %s' %
(self.client.connection, self.client.addr, self.client.has_buffer()),
(self.work.connection, self.work.addr, self.work.has_buffer()),
)
conn = self.client.connection
conn = self.work.connection
# Unwrap if wrapped before shutdown.
if self._encryption_enabled() and \
isinstance(self.client.connection, ssl.SSLSocket):
conn = self.client.connection.unwrap()
isinstance(self.work.connection, ssl.SSLSocket):
conn = self.work.connection.unwrap()
conn.shutdown(socket.SHUT_WR)
logger.debug('Client connection shutdown successful')
except OSError:
pass
finally:
self.client.connection.close()
self.work.connection.close()
logger.debug('Client connection closed')
super().shutdown()
@ -196,7 +196,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
def handle_data(self, data: memoryview) -> Optional[bool]:
if data is None:
logger.debug('Client closed connection, tearing down...')
self.client.closed = True
self.work.closed = True
return True
try:
@ -227,7 +227,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
logger.debug(
'Updated client conn to %s', upgraded_sock,
)
self.client._conn = upgraded_sock
self.work._conn = upgraded_sock
for plugin_ in self.plugins.values():
if plugin_ != plugin:
plugin_.client._conn = upgraded_sock
@ -237,12 +237,12 @@ class HttpProtocolHandler(BaseTcpServerHandler):
logger.debug('HttpProtocolException raised')
response: Optional[memoryview] = e.response(self.request)
if response:
self.client.queue(response)
self.work.queue(response)
return True
return False
def handle_writables(self, writables: Writables) -> bool:
if self.client.connection in writables and self.client.has_buffer():
if self.work.connection in writables and self.work.has_buffer():
logger.debug('Client is ready for writes, flushing buffer')
self.last_activity = time.time()
@ -250,7 +250,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
# instead of invoking when flushed to client.
#
# Invoke plugin.on_response_chunk
chunk = self.client.buffer
chunk = self.work.buffer
for plugin in self.plugins.values():
chunk = plugin.on_response_chunk(chunk)
if chunk is None:
@ -272,7 +272,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
return False
def handle_readables(self, readables: Readables) -> bool:
if self.client.connection in readables:
if self.work.connection in readables:
logger.debug('Client is ready for reads, reading')
self.last_activity = time.time()
try:
@ -290,7 +290,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
else:
logger.exception(
'Exception while receiving from %s connection %r with reason %r' %
(self.client.tag, self.client.connection, e),
(self.work.tag, self.work.connection, e),
)
return True
return False
@ -324,7 +324,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
except Exception as e:
logger.exception(
'Exception while handling connection %r' %
self.client.connection, exc_info=e,
self.work.connection, exc_info=e,
)
finally:
self.shutdown()
@ -377,24 +377,24 @@ class HttpProtocolHandler(BaseTcpServerHandler):
def _flush(self) -> None:
assert self.selector
if not self.client.has_buffer():
if not self.work.has_buffer():
return
try:
self.selector.register(
self.client.connection,
self.work.connection,
selectors.EVENT_WRITE,
)
while self.client.has_buffer():
while self.work.has_buffer():
ev: List[
Tuple[selectors.SelectorKey, int]
] = self.selector.select(timeout=1)
if len(ev) == 0:
continue
self.client.flush()
self.work.flush()
except BrokenPipeError:
pass
finally:
self.selector.unregister(self.client.connection)
self.selector.unregister(self.work.connection)
def _connection_inactive_for(self) -> float:
return time.time() - self.last_activity

View File

@ -63,9 +63,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()
@ -92,9 +92,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
self.protocol_handler._run_once()
mock_server_conn.assert_not_called()
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
self.assertEqual(
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
)
self._conn.send.assert_not_called()
@ -121,7 +121,7 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
@ -146,4 +146,4 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
self.protocol_handler._run_once()
mock_server_conn.assert_called_once()
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
self.assertEqual(self.protocol_handler.work.has_buffer(), False)

View File

@ -201,7 +201,7 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.mock_ssl_wrap.return_value,
)

View File

@ -102,7 +102,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
).upstream is not None,
)
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)
mock_server_connection.assert_called_once()
@ -111,7 +111,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
server.closed = False
parser = HttpParser(httpParserTypes.RESPONSE_PARSER)
parser.parse(self.protocol_handler.client.buffer[0].tobytes())
parser.parse(self.protocol_handler.work.buffer[0].tobytes())
self.assertEqual(parser.state, httpParserStates.COMPLETE)
assert parser.code is not None
self.assertEqual(int(parser.code), 200)
@ -199,7 +199,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
])
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
ProxyConnectionFailed.RESPONSE_PKT,
)
@ -231,7 +231,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
])
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
ProxyAuthenticationFailed.RESPONSE_PKT,
)
@ -328,7 +328,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
CRLF,
])
self.assert_tunnel_response(mock_server_connection, server)
self.protocol_handler.client.flush()
self.protocol_handler.work.flush()
self.assert_data_queued_to_server(server)
self.protocol_handler._run_once()

View File

@ -132,7 +132,7 @@ class TestWebServerPlugin(unittest.TestCase):
httpParserStates.COMPLETE,
)
self.assertEqual(
self.protocol_handler.client.buffer[0],
self.protocol_handler.work.buffer[0],
HttpWebServerPlugin.DEFAULT_404_RESPONSE,
)

View File

@ -139,7 +139,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
mock_server_conn.assert_not_called()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK, reason=b'OK',
headers={b'Content-Type': b'application/json'},
@ -215,7 +215,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
mock_server_conn.assert_not_called()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
status_code=httpStatusCodes.I_AM_A_TEAPOT,
reason=b'I\'m a tea pot',
@ -305,7 +305,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
)
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Hello from man in the middle',
@ -337,7 +337,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
status_code=httpStatusCodes.NOT_FOUND,
reason=b'Blocked',

View File

@ -170,14 +170,14 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
self.mock_server_conn.assert_called_once_with('uni.corn', 443)
self.server.connect.assert_called()
self.assertEqual(
self.protocol_handler.client.connection,
self.protocol_handler.work.connection,
self.client_ssl_connection,
)
self.assertEqual(self.server.connection, self.server_ssl_connection)
self._conn.send.assert_called_with(
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
)
self.assertFalse(self.protocol_handler.client.has_buffer())
self.assertFalse(self.protocol_handler.work.has_buffer())
def test_modify_post_data_plugin(self) -> None:
original = b'{"key": "value"}'
@ -229,7 +229,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
)
self.protocol_handler._run_once()
self.assertEqual(
self.protocol_handler.client.buffer[0].tobytes(),
self.protocol_handler.work.buffer[0].tobytes(),
build_http_response(
httpStatusCodes.OK,
reason=b'OK', body=b'Hello from man in the middle',