Top-level notion of work not client (#695)
* Top-level notion of work not client * Update ssl echo server example
This commit is contained in:
parent
d3cee32909
commit
f48771fb41
|
@ -53,7 +53,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler):
|
|||
|
||||
# Drop the request if not a CONNECT request
|
||||
if self.request.method != httpMethods.CONNECT:
|
||||
self.client.queue(
|
||||
self.work.queue(
|
||||
HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME,
|
||||
)
|
||||
return True
|
||||
|
@ -66,7 +66,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler):
|
|||
self.connect_upstream()
|
||||
|
||||
# Queue tunnel established response to client
|
||||
self.client.queue(
|
||||
self.work.queue(
|
||||
HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
|
||||
)
|
||||
|
||||
|
|
|
@ -27,19 +27,19 @@ class EchoSSLServerHandler(BaseTcpServerHandler):
|
|||
# here using wrap_socket() utility.
|
||||
assert self.flags.keyfile is not None and self.flags.certfile is not None
|
||||
conn = wrap_socket(
|
||||
self.client.connection,
|
||||
self.work.connection,
|
||||
self.flags.keyfile,
|
||||
self.flags.certfile,
|
||||
)
|
||||
conn.setblocking(False)
|
||||
# Upgrade plain TcpClientConnection to SSL connection object
|
||||
self.client = TcpClientConnection(
|
||||
conn=conn, addr=self.client.addr,
|
||||
self.work = TcpClientConnection(
|
||||
conn=conn, addr=self.work.addr,
|
||||
)
|
||||
|
||||
def handle_data(self, data: memoryview) -> Optional[bool]:
|
||||
# echo back to client
|
||||
self.client.queue(data)
|
||||
self.work.queue(data)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler):
|
|||
"""Sets client socket to non-blocking during initialization."""
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.client.connection.setblocking(False)
|
||||
self.work.connection.setblocking(False)
|
||||
|
||||
def handle_data(self, data: memoryview) -> Optional[bool]:
|
||||
# echo back to client
|
||||
self.client.queue(data)
|
||||
self.work.queue(data)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -25,15 +25,18 @@ class Work(ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
client: TcpClientConnection,
|
||||
work: TcpClientConnection,
|
||||
flags: argparse.Namespace,
|
||||
event_queue: Optional[EventQueue] = None,
|
||||
uid: Optional[UUID] = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.flags = flags
|
||||
self.event_queue = event_queue
|
||||
# Work uuid
|
||||
self.uid: UUID = uid if uid is not None else uuid4()
|
||||
self.flags = flags
|
||||
# Eventing core queue
|
||||
self.event_queue = event_queue
|
||||
# Accept work
|
||||
self.work = work
|
||||
|
||||
@abstractmethod
|
||||
def get_events(self) -> Dict[socket.socket, int]:
|
||||
|
|
|
@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work):
|
|||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.must_flush_before_shutdown = False
|
||||
logger.debug('Connection accepted from {0}'.format(self.client.addr))
|
||||
logger.debug('Connection accepted from {0}'.format(self.work.addr))
|
||||
|
||||
@abstractmethod
|
||||
def handle_data(self, data: memoryview) -> Optional[bool]:
|
||||
|
@ -57,14 +57,14 @@ class BaseTcpServerHandler(Work):
|
|||
# We always want to read from client
|
||||
# Register for EVENT_READ events
|
||||
if self.must_flush_before_shutdown is False:
|
||||
events[self.client.connection] = selectors.EVENT_READ
|
||||
events[self.work.connection] = selectors.EVENT_READ
|
||||
# If there is pending buffer for client
|
||||
# also register for EVENT_WRITE events
|
||||
if self.client.has_buffer():
|
||||
if self.client.connection in events:
|
||||
events[self.client.connection] |= selectors.EVENT_WRITE
|
||||
if self.work.has_buffer():
|
||||
if self.work.connection in events:
|
||||
events[self.work.connection] |= selectors.EVENT_WRITE
|
||||
else:
|
||||
events[self.client.connection] = selectors.EVENT_WRITE
|
||||
events[self.work.connection] = selectors.EVENT_WRITE
|
||||
return events
|
||||
|
||||
def handle_events(
|
||||
|
@ -79,32 +79,32 @@ class BaseTcpServerHandler(Work):
|
|||
if teardown:
|
||||
logger.debug(
|
||||
'Shutting down client {0} connection'.format(
|
||||
self.client.addr,
|
||||
self.work.addr,
|
||||
),
|
||||
)
|
||||
return teardown
|
||||
|
||||
def handle_writables(self, writables: Writables) -> bool:
|
||||
teardown = False
|
||||
if self.client.connection in writables and self.client.has_buffer():
|
||||
if self.work.connection in writables and self.work.has_buffer():
|
||||
logger.debug(
|
||||
'Flushing buffer to client {0}'.format(self.client.addr),
|
||||
'Flushing buffer to client {0}'.format(self.work.addr),
|
||||
)
|
||||
self.client.flush()
|
||||
self.work.flush()
|
||||
if self.must_flush_before_shutdown is True:
|
||||
if not self.client.has_buffer():
|
||||
if not self.work.has_buffer():
|
||||
teardown = True
|
||||
self.must_flush_before_shutdown = False
|
||||
return teardown
|
||||
|
||||
def handle_readables(self, readables: Readables) -> bool:
|
||||
teardown = False
|
||||
if self.client.connection in readables:
|
||||
data = self.client.recv(self.flags.client_recvbuf_size)
|
||||
if self.work.connection in readables:
|
||||
data = self.work.recv(self.flags.client_recvbuf_size)
|
||||
if data is None:
|
||||
logger.debug(
|
||||
'Connection closed by client {0}'.format(
|
||||
self.client.addr,
|
||||
self.work.addr,
|
||||
),
|
||||
)
|
||||
teardown = True
|
||||
|
@ -113,13 +113,13 @@ class BaseTcpServerHandler(Work):
|
|||
if isinstance(r, bool) and r is True:
|
||||
logger.debug(
|
||||
'Implementation signaled shutdown for client {0}'.format(
|
||||
self.client.addr,
|
||||
self.work.addr,
|
||||
),
|
||||
)
|
||||
if self.client.has_buffer():
|
||||
if self.work.has_buffer():
|
||||
logger.debug(
|
||||
'Client {0} has pending buffer, will be flushed before shutting down'.format(
|
||||
self.client.addr,
|
||||
self.work.addr,
|
||||
),
|
||||
)
|
||||
self.must_flush_before_shutdown = True
|
||||
|
|
|
@ -43,7 +43,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler):
|
|||
pass # pragma: no cover
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.client.connection.setblocking(False)
|
||||
self.work.connection.setblocking(False)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.upstream:
|
||||
|
@ -87,7 +87,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler):
|
|||
print('Connection closed by server')
|
||||
return True
|
||||
# tunnel data to client
|
||||
self.client.queue(data)
|
||||
self.work.queue(data)
|
||||
if self.upstream and self.upstream.connection in writables:
|
||||
self.upstream.flush()
|
||||
return False
|
||||
|
|
|
@ -89,25 +89,25 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
|
||||
def initialize(self) -> None:
|
||||
"""Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins."""
|
||||
conn = self._optionally_wrap_socket(self.client.connection)
|
||||
conn = self._optionally_wrap_socket(self.work.connection)
|
||||
conn.setblocking(False)
|
||||
# Update client connection reference if connection was wrapped
|
||||
if self._encryption_enabled():
|
||||
self.client = TcpClientConnection(conn=conn, addr=self.client.addr)
|
||||
self.work = TcpClientConnection(conn=conn, addr=self.work.addr)
|
||||
if b'HttpProtocolHandlerPlugin' in self.flags.plugins:
|
||||
for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']:
|
||||
instance: HttpProtocolHandlerPlugin = klass(
|
||||
self.uid,
|
||||
self.flags,
|
||||
self.client,
|
||||
self.work,
|
||||
self.request,
|
||||
self.event_queue,
|
||||
)
|
||||
self.plugins[instance.name()] = instance
|
||||
logger.debug('Handling connection %r' % self.client.connection)
|
||||
logger.debug('Handling connection %r' % self.work.connection)
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
if not self.client.has_buffer() and \
|
||||
if not self.work.has_buffer() and \
|
||||
self._connection_inactive_for() > self.flags.timeout:
|
||||
return True
|
||||
return False
|
||||
|
@ -127,20 +127,20 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
logger.debug(
|
||||
'Closing client connection %r '
|
||||
'at address %r has buffer %s' %
|
||||
(self.client.connection, self.client.addr, self.client.has_buffer()),
|
||||
(self.work.connection, self.work.addr, self.work.has_buffer()),
|
||||
)
|
||||
|
||||
conn = self.client.connection
|
||||
conn = self.work.connection
|
||||
# Unwrap if wrapped before shutdown.
|
||||
if self._encryption_enabled() and \
|
||||
isinstance(self.client.connection, ssl.SSLSocket):
|
||||
conn = self.client.connection.unwrap()
|
||||
isinstance(self.work.connection, ssl.SSLSocket):
|
||||
conn = self.work.connection.unwrap()
|
||||
conn.shutdown(socket.SHUT_WR)
|
||||
logger.debug('Client connection shutdown successful')
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
self.client.connection.close()
|
||||
self.work.connection.close()
|
||||
logger.debug('Client connection closed')
|
||||
super().shutdown()
|
||||
|
||||
|
@ -196,7 +196,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
def handle_data(self, data: memoryview) -> Optional[bool]:
|
||||
if data is None:
|
||||
logger.debug('Client closed connection, tearing down...')
|
||||
self.client.closed = True
|
||||
self.work.closed = True
|
||||
return True
|
||||
|
||||
try:
|
||||
|
@ -227,7 +227,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
logger.debug(
|
||||
'Updated client conn to %s', upgraded_sock,
|
||||
)
|
||||
self.client._conn = upgraded_sock
|
||||
self.work._conn = upgraded_sock
|
||||
for plugin_ in self.plugins.values():
|
||||
if plugin_ != plugin:
|
||||
plugin_.client._conn = upgraded_sock
|
||||
|
@ -237,12 +237,12 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
logger.debug('HttpProtocolException raised')
|
||||
response: Optional[memoryview] = e.response(self.request)
|
||||
if response:
|
||||
self.client.queue(response)
|
||||
self.work.queue(response)
|
||||
return True
|
||||
return False
|
||||
|
||||
def handle_writables(self, writables: Writables) -> bool:
|
||||
if self.client.connection in writables and self.client.has_buffer():
|
||||
if self.work.connection in writables and self.work.has_buffer():
|
||||
logger.debug('Client is ready for writes, flushing buffer')
|
||||
self.last_activity = time.time()
|
||||
|
||||
|
@ -250,7 +250,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
# instead of invoking when flushed to client.
|
||||
#
|
||||
# Invoke plugin.on_response_chunk
|
||||
chunk = self.client.buffer
|
||||
chunk = self.work.buffer
|
||||
for plugin in self.plugins.values():
|
||||
chunk = plugin.on_response_chunk(chunk)
|
||||
if chunk is None:
|
||||
|
@ -272,7 +272,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
return False
|
||||
|
||||
def handle_readables(self, readables: Readables) -> bool:
|
||||
if self.client.connection in readables:
|
||||
if self.work.connection in readables:
|
||||
logger.debug('Client is ready for reads, reading')
|
||||
self.last_activity = time.time()
|
||||
try:
|
||||
|
@ -290,7 +290,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
else:
|
||||
logger.exception(
|
||||
'Exception while receiving from %s connection %r with reason %r' %
|
||||
(self.client.tag, self.client.connection, e),
|
||||
(self.work.tag, self.work.connection, e),
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
@ -324,7 +324,7 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
except Exception as e:
|
||||
logger.exception(
|
||||
'Exception while handling connection %r' %
|
||||
self.client.connection, exc_info=e,
|
||||
self.work.connection, exc_info=e,
|
||||
)
|
||||
finally:
|
||||
self.shutdown()
|
||||
|
@ -377,24 +377,24 @@ class HttpProtocolHandler(BaseTcpServerHandler):
|
|||
|
||||
def _flush(self) -> None:
|
||||
assert self.selector
|
||||
if not self.client.has_buffer():
|
||||
if not self.work.has_buffer():
|
||||
return
|
||||
try:
|
||||
self.selector.register(
|
||||
self.client.connection,
|
||||
self.work.connection,
|
||||
selectors.EVENT_WRITE,
|
||||
)
|
||||
while self.client.has_buffer():
|
||||
while self.work.has_buffer():
|
||||
ev: List[
|
||||
Tuple[selectors.SelectorKey, int]
|
||||
] = self.selector.select(timeout=1)
|
||||
if len(ev) == 0:
|
||||
continue
|
||||
self.client.flush()
|
||||
self.work.flush()
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
finally:
|
||||
self.selector.unregister(self.client.connection)
|
||||
self.selector.unregister(self.work.connection)
|
||||
|
||||
def _connection_inactive_for(self) -> float:
|
||||
return time.time() - self.last_activity
|
||||
|
|
|
@ -63,9 +63,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
|
|||
|
||||
self.protocol_handler._run_once()
|
||||
mock_server_conn.assert_not_called()
|
||||
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
|
||||
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
|
||||
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
|
||||
)
|
||||
self._conn.send.assert_not_called()
|
||||
|
||||
|
@ -92,9 +92,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
|
|||
|
||||
self.protocol_handler._run_once()
|
||||
mock_server_conn.assert_not_called()
|
||||
self.assertEqual(self.protocol_handler.client.has_buffer(), True)
|
||||
self.assertEqual(self.protocol_handler.work.has_buffer(), True)
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
|
||||
self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT,
|
||||
)
|
||||
self._conn.send.assert_not_called()
|
||||
|
||||
|
@ -121,7 +121,7 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
|
|||
|
||||
self.protocol_handler._run_once()
|
||||
mock_server_conn.assert_called_once()
|
||||
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
|
||||
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
|
||||
|
||||
@mock.patch('proxy.http.proxy.server.TcpServerConnection')
|
||||
def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None:
|
||||
|
@ -146,4 +146,4 @@ class TestHttpProxyAuthFailed(unittest.TestCase):
|
|||
|
||||
self.protocol_handler._run_once()
|
||||
mock_server_conn.assert_called_once()
|
||||
self.assertEqual(self.protocol_handler.client.has_buffer(), False)
|
||||
self.assertEqual(self.protocol_handler.work.has_buffer(), False)
|
||||
|
|
|
@ -201,7 +201,7 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(self._conn.setblocking.call_count, 2)
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.connection,
|
||||
self.protocol_handler.work.connection,
|
||||
self.mock_ssl_wrap.return_value,
|
||||
)
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
).upstream is not None,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0],
|
||||
self.protocol_handler.work.buffer[0],
|
||||
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
|
||||
)
|
||||
mock_server_connection.assert_called_once()
|
||||
|
@ -111,7 +111,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
server.closed = False
|
||||
|
||||
parser = HttpParser(httpParserTypes.RESPONSE_PARSER)
|
||||
parser.parse(self.protocol_handler.client.buffer[0].tobytes())
|
||||
parser.parse(self.protocol_handler.work.buffer[0].tobytes())
|
||||
self.assertEqual(parser.state, httpParserStates.COMPLETE)
|
||||
assert parser.code is not None
|
||||
self.assertEqual(int(parser.code), 200)
|
||||
|
@ -199,7 +199,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
])
|
||||
self.protocol_handler._run_once()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0],
|
||||
self.protocol_handler.work.buffer[0],
|
||||
ProxyConnectionFailed.RESPONSE_PKT,
|
||||
)
|
||||
|
||||
|
@ -231,7 +231,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
])
|
||||
self.protocol_handler._run_once()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0],
|
||||
self.protocol_handler.work.buffer[0],
|
||||
ProxyAuthenticationFailed.RESPONSE_PKT,
|
||||
)
|
||||
|
||||
|
@ -328,7 +328,7 @@ class TestHttpProtocolHandler(unittest.TestCase):
|
|||
CRLF,
|
||||
])
|
||||
self.assert_tunnel_response(mock_server_connection, server)
|
||||
self.protocol_handler.client.flush()
|
||||
self.protocol_handler.work.flush()
|
||||
self.assert_data_queued_to_server(server)
|
||||
|
||||
self.protocol_handler._run_once()
|
||||
|
|
|
@ -132,7 +132,7 @@ class TestWebServerPlugin(unittest.TestCase):
|
|||
httpParserStates.COMPLETE,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0],
|
||||
self.protocol_handler.work.buffer[0],
|
||||
HttpWebServerPlugin.DEFAULT_404_RESPONSE,
|
||||
)
|
||||
|
||||
|
|
|
@ -139,7 +139,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
|
|||
|
||||
mock_server_conn.assert_not_called()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0].tobytes(),
|
||||
self.protocol_handler.work.buffer[0].tobytes(),
|
||||
build_http_response(
|
||||
httpStatusCodes.OK, reason=b'OK',
|
||||
headers={b'Content-Type': b'application/json'},
|
||||
|
@ -215,7 +215,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
|
|||
|
||||
mock_server_conn.assert_not_called()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0].tobytes(),
|
||||
self.protocol_handler.work.buffer[0].tobytes(),
|
||||
build_http_response(
|
||||
status_code=httpStatusCodes.I_AM_A_TEAPOT,
|
||||
reason=b'I\'m a tea pot',
|
||||
|
@ -305,7 +305,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
|
|||
)
|
||||
self.protocol_handler._run_once()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0].tobytes(),
|
||||
self.protocol_handler.work.buffer[0].tobytes(),
|
||||
build_http_response(
|
||||
httpStatusCodes.OK,
|
||||
reason=b'OK', body=b'Hello from man in the middle',
|
||||
|
@ -337,7 +337,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase):
|
|||
self.protocol_handler._run_once()
|
||||
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0].tobytes(),
|
||||
self.protocol_handler.work.buffer[0].tobytes(),
|
||||
build_http_response(
|
||||
status_code=httpStatusCodes.NOT_FOUND,
|
||||
reason=b'Blocked',
|
||||
|
|
|
@ -170,14 +170,14 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
|
|||
self.mock_server_conn.assert_called_once_with('uni.corn', 443)
|
||||
self.server.connect.assert_called()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.connection,
|
||||
self.protocol_handler.work.connection,
|
||||
self.client_ssl_connection,
|
||||
)
|
||||
self.assertEqual(self.server.connection, self.server_ssl_connection)
|
||||
self._conn.send.assert_called_with(
|
||||
HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
|
||||
)
|
||||
self.assertFalse(self.protocol_handler.client.has_buffer())
|
||||
self.assertFalse(self.protocol_handler.work.has_buffer())
|
||||
|
||||
def test_modify_post_data_plugin(self) -> None:
|
||||
original = b'{"key": "value"}'
|
||||
|
@ -229,7 +229,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
|
|||
)
|
||||
self.protocol_handler._run_once()
|
||||
self.assertEqual(
|
||||
self.protocol_handler.client.buffer[0].tobytes(),
|
||||
self.protocol_handler.work.buffer[0].tobytes(),
|
||||
build_http_response(
|
||||
httpStatusCodes.OK,
|
||||
reason=b'OK', body=b'Hello from man in the middle',
|
||||
|
|
Loading…
Reference in New Issue