diff --git a/examples/https_connect_tunnel.py b/examples/https_connect_tunnel.py index ea2e49bc..64307242 100644 --- a/examples/https_connect_tunnel.py +++ b/examples/https_connect_tunnel.py @@ -53,7 +53,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler): # Drop the request if not a CONNECT request if self.request.method != httpMethods.CONNECT: - self.client.queue( + self.work.queue( HttpsConnectTunnelHandler.PROXY_TUNNEL_UNSUPPORTED_SCHEME, ) return True @@ -66,7 +66,7 @@ class HttpsConnectTunnelHandler(BaseTcpTunnelHandler): self.connect_upstream() # Queue tunnel established response to client - self.client.queue( + self.work.queue( HttpsConnectTunnelHandler.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) diff --git a/examples/ssl_echo_server.py b/examples/ssl_echo_server.py index 5c46ef63..7a7804f3 100644 --- a/examples/ssl_echo_server.py +++ b/examples/ssl_echo_server.py @@ -27,19 +27,19 @@ class EchoSSLServerHandler(BaseTcpServerHandler): # here using wrap_socket() utility. assert self.flags.keyfile is not None and self.flags.certfile is not None conn = wrap_socket( - self.client.connection, + self.work.connection, self.flags.keyfile, self.flags.certfile, ) conn.setblocking(False) # Upgrade plain TcpClientConnection to SSL connection object - self.client = TcpClientConnection( - conn=conn, addr=self.client.addr, + self.work = TcpClientConnection( + conn=conn, addr=self.work.addr, ) def handle_data(self, data: memoryview) -> Optional[bool]: # echo back to client - self.client.queue(data) + self.work.queue(data) return None diff --git a/examples/tcp_echo_server.py b/examples/tcp_echo_server.py index 38d194cf..9e0e8f7e 100644 --- a/examples/tcp_echo_server.py +++ b/examples/tcp_echo_server.py @@ -20,11 +20,11 @@ class EchoServerHandler(BaseTcpServerHandler): """Sets client socket to non-blocking during initialization.""" def initialize(self) -> None: - self.client.connection.setblocking(False) + self.work.connection.setblocking(False) def handle_data(self, data: memoryview) -> Optional[bool]: # echo back to client - self.client.queue(data) + self.work.queue(data) return None diff --git a/proxy/core/acceptor/work.py b/proxy/core/acceptor/work.py index 5556a319..52d5103d 100644 --- a/proxy/core/acceptor/work.py +++ b/proxy/core/acceptor/work.py @@ -25,15 +25,18 @@ class Work(ABC): def __init__( self, - client: TcpClientConnection, + work: TcpClientConnection, flags: argparse.Namespace, event_queue: Optional[EventQueue] = None, uid: Optional[UUID] = None, ) -> None: - self.client = client - self.flags = flags - self.event_queue = event_queue + # Work uuid self.uid: UUID = uid if uid is not None else uuid4() + self.flags = flags + # Eventing core queue + self.event_queue = event_queue + # Accept work + self.work = work @abstractmethod def get_events(self) -> Dict[socket.socket, int]: diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index b0b44500..3141c5fb 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -45,7 +45,7 @@ class BaseTcpServerHandler(Work): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.must_flush_before_shutdown = False - logger.debug('Connection accepted from {0}'.format(self.client.addr)) + logger.debug('Connection accepted from {0}'.format(self.work.addr)) @abstractmethod def handle_data(self, data: memoryview) -> Optional[bool]: @@ -57,14 +57,14 @@ class BaseTcpServerHandler(Work): # We always want to read from client # Register for EVENT_READ events if self.must_flush_before_shutdown is False: - events[self.client.connection] = selectors.EVENT_READ + events[self.work.connection] = selectors.EVENT_READ # If there is pending buffer for client # also register for EVENT_WRITE events - if self.client.has_buffer(): - if self.client.connection in events: - events[self.client.connection] |= selectors.EVENT_WRITE + if self.work.has_buffer(): + if self.work.connection in events: + events[self.work.connection] |= selectors.EVENT_WRITE else: - events[self.client.connection] = selectors.EVENT_WRITE + events[self.work.connection] = selectors.EVENT_WRITE return events def handle_events( @@ -79,32 +79,32 @@ class BaseTcpServerHandler(Work): if teardown: logger.debug( 'Shutting down client {0} connection'.format( - self.client.addr, + self.work.addr, ), ) return teardown def handle_writables(self, writables: Writables) -> bool: teardown = False - if self.client.connection in writables and self.client.has_buffer(): + if self.work.connection in writables and self.work.has_buffer(): logger.debug( - 'Flushing buffer to client {0}'.format(self.client.addr), + 'Flushing buffer to client {0}'.format(self.work.addr), ) - self.client.flush() + self.work.flush() if self.must_flush_before_shutdown is True: - if not self.client.has_buffer(): + if not self.work.has_buffer(): teardown = True self.must_flush_before_shutdown = False return teardown def handle_readables(self, readables: Readables) -> bool: teardown = False - if self.client.connection in readables: - data = self.client.recv(self.flags.client_recvbuf_size) + if self.work.connection in readables: + data = self.work.recv(self.flags.client_recvbuf_size) if data is None: logger.debug( 'Connection closed by client {0}'.format( - self.client.addr, + self.work.addr, ), ) teardown = True @@ -113,13 +113,13 @@ class BaseTcpServerHandler(Work): if isinstance(r, bool) and r is True: logger.debug( 'Implementation signaled shutdown for client {0}'.format( - self.client.addr, + self.work.addr, ), ) - if self.client.has_buffer(): + if self.work.has_buffer(): logger.debug( 'Client {0} has pending buffer, will be flushed before shutting down'.format( - self.client.addr, + self.work.addr, ), ) self.must_flush_before_shutdown = True diff --git a/proxy/core/base/tcp_tunnel.py b/proxy/core/base/tcp_tunnel.py index 94e058c1..2d87a156 100644 --- a/proxy/core/base/tcp_tunnel.py +++ b/proxy/core/base/tcp_tunnel.py @@ -43,7 +43,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler): pass # pragma: no cover def initialize(self) -> None: - self.client.connection.setblocking(False) + self.work.connection.setblocking(False) def shutdown(self) -> None: if self.upstream: @@ -87,7 +87,7 @@ class BaseTcpTunnelHandler(BaseTcpServerHandler): print('Connection closed by server') return True # tunnel data to client - self.client.queue(data) + self.work.queue(data) if self.upstream and self.upstream.connection in writables: self.upstream.flush() return False diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 5fe54527..ae1b0d8f 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -89,25 +89,25 @@ class HttpProtocolHandler(BaseTcpServerHandler): def initialize(self) -> None: """Optionally upgrades connection to HTTPS, set conn in non-blocking mode and initializes plugins.""" - conn = self._optionally_wrap_socket(self.client.connection) + conn = self._optionally_wrap_socket(self.work.connection) conn.setblocking(False) # Update client connection reference if connection was wrapped if self._encryption_enabled(): - self.client = TcpClientConnection(conn=conn, addr=self.client.addr) + self.work = TcpClientConnection(conn=conn, addr=self.work.addr) if b'HttpProtocolHandlerPlugin' in self.flags.plugins: for klass in self.flags.plugins[b'HttpProtocolHandlerPlugin']: instance: HttpProtocolHandlerPlugin = klass( self.uid, self.flags, - self.client, + self.work, self.request, self.event_queue, ) self.plugins[instance.name()] = instance - logger.debug('Handling connection %r' % self.client.connection) + logger.debug('Handling connection %r' % self.work.connection) def is_inactive(self) -> bool: - if not self.client.has_buffer() and \ + if not self.work.has_buffer() and \ self._connection_inactive_for() > self.flags.timeout: return True return False @@ -127,20 +127,20 @@ class HttpProtocolHandler(BaseTcpServerHandler): logger.debug( 'Closing client connection %r ' 'at address %r has buffer %s' % - (self.client.connection, self.client.addr, self.client.has_buffer()), + (self.work.connection, self.work.addr, self.work.has_buffer()), ) - conn = self.client.connection + conn = self.work.connection # Unwrap if wrapped before shutdown. if self._encryption_enabled() and \ - isinstance(self.client.connection, ssl.SSLSocket): - conn = self.client.connection.unwrap() + isinstance(self.work.connection, ssl.SSLSocket): + conn = self.work.connection.unwrap() conn.shutdown(socket.SHUT_WR) logger.debug('Client connection shutdown successful') except OSError: pass finally: - self.client.connection.close() + self.work.connection.close() logger.debug('Client connection closed') super().shutdown() @@ -196,7 +196,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): def handle_data(self, data: memoryview) -> Optional[bool]: if data is None: logger.debug('Client closed connection, tearing down...') - self.client.closed = True + self.work.closed = True return True try: @@ -227,7 +227,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): logger.debug( 'Updated client conn to %s', upgraded_sock, ) - self.client._conn = upgraded_sock + self.work._conn = upgraded_sock for plugin_ in self.plugins.values(): if plugin_ != plugin: plugin_.client._conn = upgraded_sock @@ -237,12 +237,12 @@ class HttpProtocolHandler(BaseTcpServerHandler): logger.debug('HttpProtocolException raised') response: Optional[memoryview] = e.response(self.request) if response: - self.client.queue(response) + self.work.queue(response) return True return False def handle_writables(self, writables: Writables) -> bool: - if self.client.connection in writables and self.client.has_buffer(): + if self.work.connection in writables and self.work.has_buffer(): logger.debug('Client is ready for writes, flushing buffer') self.last_activity = time.time() @@ -250,7 +250,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): # instead of invoking when flushed to client. # # Invoke plugin.on_response_chunk - chunk = self.client.buffer + chunk = self.work.buffer for plugin in self.plugins.values(): chunk = plugin.on_response_chunk(chunk) if chunk is None: @@ -272,7 +272,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): return False def handle_readables(self, readables: Readables) -> bool: - if self.client.connection in readables: + if self.work.connection in readables: logger.debug('Client is ready for reads, reading') self.last_activity = time.time() try: @@ -290,7 +290,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): else: logger.exception( 'Exception while receiving from %s connection %r with reason %r' % - (self.client.tag, self.client.connection, e), + (self.work.tag, self.work.connection, e), ) return True return False @@ -324,7 +324,7 @@ class HttpProtocolHandler(BaseTcpServerHandler): except Exception as e: logger.exception( 'Exception while handling connection %r' % - self.client.connection, exc_info=e, + self.work.connection, exc_info=e, ) finally: self.shutdown() @@ -377,24 +377,24 @@ class HttpProtocolHandler(BaseTcpServerHandler): def _flush(self) -> None: assert self.selector - if not self.client.has_buffer(): + if not self.work.has_buffer(): return try: self.selector.register( - self.client.connection, + self.work.connection, selectors.EVENT_WRITE, ) - while self.client.has_buffer(): + while self.work.has_buffer(): ev: List[ Tuple[selectors.SelectorKey, int] ] = self.selector.select(timeout=1) if len(ev) == 0: continue - self.client.flush() + self.work.flush() except BrokenPipeError: pass finally: - self.selector.unregister(self.client.connection) + self.selector.unregister(self.work.connection) def _connection_inactive_for(self) -> float: return time.time() - self.last_activity diff --git a/tests/http/exceptions/test_http_proxy_auth_failed.py b/tests/http/exceptions/test_http_proxy_auth_failed.py index b24a87d6..312106fb 100644 --- a/tests/http/exceptions/test_http_proxy_auth_failed.py +++ b/tests/http/exceptions/test_http_proxy_auth_failed.py @@ -63,9 +63,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase): self.protocol_handler._run_once() mock_server_conn.assert_not_called() - self.assertEqual(self.protocol_handler.client.has_buffer(), True) + self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( - self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, + self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, ) self._conn.send.assert_not_called() @@ -92,9 +92,9 @@ class TestHttpProxyAuthFailed(unittest.TestCase): self.protocol_handler._run_once() mock_server_conn.assert_not_called() - self.assertEqual(self.protocol_handler.client.has_buffer(), True) + self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( - self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, + self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, ) self._conn.send.assert_not_called() @@ -121,7 +121,7 @@ class TestHttpProxyAuthFailed(unittest.TestCase): self.protocol_handler._run_once() mock_server_conn.assert_called_once() - self.assertEqual(self.protocol_handler.client.has_buffer(), False) + self.assertEqual(self.protocol_handler.work.has_buffer(), False) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proxy_auth_works_with_mixed_case_basic_string(self, mock_server_conn: mock.Mock) -> None: @@ -146,4 +146,4 @@ class TestHttpProxyAuthFailed(unittest.TestCase): self.protocol_handler._run_once() mock_server_conn.assert_called_once() - self.assertEqual(self.protocol_handler.client.has_buffer(), False) + self.assertEqual(self.protocol_handler.work.has_buffer(), False) diff --git a/tests/http/test_http_proxy_tls_interception.py b/tests/http/test_http_proxy_tls_interception.py index 1c030685..1d6e7df0 100644 --- a/tests/http/test_http_proxy_tls_interception.py +++ b/tests/http/test_http_proxy_tls_interception.py @@ -201,7 +201,7 @@ class TestHttpProxyTlsInterception(unittest.TestCase): ) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( - self.protocol_handler.client.connection, + self.protocol_handler.work.connection, self.mock_ssl_wrap.return_value, ) diff --git a/tests/http/test_protocol_handler.py b/tests/http/test_protocol_handler.py index d312b8a4..bb228c09 100644 --- a/tests/http/test_protocol_handler.py +++ b/tests/http/test_protocol_handler.py @@ -102,7 +102,7 @@ class TestHttpProtocolHandler(unittest.TestCase): ).upstream is not None, ) self.assertEqual( - self.protocol_handler.client.buffer[0], + self.protocol_handler.work.buffer[0], HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) mock_server_connection.assert_called_once() @@ -111,7 +111,7 @@ class TestHttpProtocolHandler(unittest.TestCase): server.closed = False parser = HttpParser(httpParserTypes.RESPONSE_PARSER) - parser.parse(self.protocol_handler.client.buffer[0].tobytes()) + parser.parse(self.protocol_handler.work.buffer[0].tobytes()) self.assertEqual(parser.state, httpParserStates.COMPLETE) assert parser.code is not None self.assertEqual(int(parser.code), 200) @@ -199,7 +199,7 @@ class TestHttpProtocolHandler(unittest.TestCase): ]) self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.client.buffer[0], + self.protocol_handler.work.buffer[0], ProxyConnectionFailed.RESPONSE_PKT, ) @@ -231,7 +231,7 @@ class TestHttpProtocolHandler(unittest.TestCase): ]) self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.client.buffer[0], + self.protocol_handler.work.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT, ) @@ -328,7 +328,7 @@ class TestHttpProtocolHandler(unittest.TestCase): CRLF, ]) self.assert_tunnel_response(mock_server_connection, server) - self.protocol_handler.client.flush() + self.protocol_handler.work.flush() self.assert_data_queued_to_server(server) self.protocol_handler._run_once() diff --git a/tests/http/test_web_server.py b/tests/http/test_web_server.py index af1eadad..37df71e7 100644 --- a/tests/http/test_web_server.py +++ b/tests/http/test_web_server.py @@ -132,7 +132,7 @@ class TestWebServerPlugin(unittest.TestCase): httpParserStates.COMPLETE, ) self.assertEqual( - self.protocol_handler.client.buffer[0], + self.protocol_handler.work.buffer[0], HttpWebServerPlugin.DEFAULT_404_RESPONSE, ) diff --git a/tests/plugin/test_http_proxy_plugins.py b/tests/plugin/test_http_proxy_plugins.py index aab0cb30..99138f1c 100644 --- a/tests/plugin/test_http_proxy_plugins.py +++ b/tests/plugin/test_http_proxy_plugins.py @@ -139,7 +139,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase): mock_server_conn.assert_not_called() self.assertEqual( - self.protocol_handler.client.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, reason=b'OK', headers={b'Content-Type': b'application/json'}, @@ -215,7 +215,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase): mock_server_conn.assert_not_called() self.assertEqual( - self.protocol_handler.client.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0].tobytes(), build_http_response( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', @@ -305,7 +305,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase): ) self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.client.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Hello from man in the middle', @@ -337,7 +337,7 @@ class TestHttpProxyPluginExamples(unittest.TestCase): self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.client.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0].tobytes(), build_http_response( status_code=httpStatusCodes.NOT_FOUND, reason=b'Blocked', diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 1f276c45..48d7d1cd 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -170,14 +170,14 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual( - self.protocol_handler.client.connection, + self.protocol_handler.work.connection, self.client_ssl_connection, ) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) - self.assertFalse(self.protocol_handler.client.has_buffer()) + self.assertFalse(self.protocol_handler.work.has_buffer()) def test_modify_post_data_plugin(self) -> None: original = b'{"key": "value"}' @@ -229,7 +229,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): ) self.protocol_handler._run_once() self.assertEqual( - self.protocol_handler.client.buffer[0].tobytes(), + self.protocol_handler.work.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Hello from man in the middle',