From 716f211a2b7777aeb68a318758d7b5084e458fe4 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Fri, 27 Sep 2019 09:18:58 -0700 Subject: [PATCH] Defer SSL Wrap (#100) * Always deduce address family as we have a structure * Convert TcpConnection into an ABC. * Fix tests * Rename MultiCoreRequestDispatcher to WorkerPool * Consistent naming refactor * Make WorkerPool independent of protocol config object * Fix tests * Proper fix for test_on_client_connection_called_on_teardown * TLS interception * better logging --- README.md | 30 +-- plugin_examples.py | 2 +- proxy.py | 427 +++++++++++++++++++++----------------- tests.py | 506 ++++++++++++++++----------------------------- 4 files changed, 426 insertions(+), 539 deletions(-) diff --git a/README.md b/README.md index f4d918fe..7b307b3d 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ Table of Contents * [ManInTheMiddlePlugin](#maninthemiddleplugin) * [Plugin Ordering](#plugin-ordering) * [End-to-End Encryption](#end-to-end-encryption) -* [TLS Encryption](#tls-interception) +* [TLS Interception](#tls-interception) * [Plugin Developer and Contributor Guide](#plugin-developer-and-contributor-guide) * [Everything is a plugin](#everything-is-a-plugin) * [proxy.py Internals](#proxypy-internals) @@ -230,7 +230,7 @@ Above `418 I'm a tea pot` is sent by our plugin. Verify the same by inspecting logs for `proxy.py`: ``` -2019-09-24 19:21:37,893 - ERROR - pid:50074 - handle_readables:1347 - HttpProtocolException type raised +2019-09-24 19:21:37,893 - ERROR - pid:50074 - handle_readables:1347 - ProtocolException type raised Traceback (most recent call last): ... [redacted] ... 2019-09-24 19:21:37,897 - INFO - pid:50074 - access_log:1157 - ::1:49911 - GET None:None/ - None None - 0 bytes @@ -485,38 +485,38 @@ As you might have guessed by now, in `proxy.py` everything is a plugin. Example, [FilterByUpstreamHostPlugin](#filterbyupstreamhostplugin). - We also enabled inbuilt web server using `--enable-web-server`. - Inbuilt web server implements `HttpProtocolBasePlugin` plugin. - See documentation of [HttpProtocolBasePlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L793-L850) - for available lifecycle hooks. Use `HttpProtocolBasePlugin` to add + Inbuilt web server implements `ProtocolHandlerPlugin` plugin. + See documentation of [ProtocolHandlerPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L793-L850) + for available lifecycle hooks. Use `ProtocolHandlerPlugin` to add new features for http(s) clients. Example, [HttpWebServerPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1185-L1260). - There also is a `--disable-http-proxy` flag. It disables inbuilt proxy server. Use this flag with `--enable-web-server` flag to run `proxy.py` as a programmable http(s) server. [HttpProxyPlugin](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L941-L1182) - also implements `HttpProtocolBasePlugin`. + also implements `ProtocolHandlerPlugin`. ## proxy.py Internals -- [HttpProtocolHandler](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1263-L1440) +- [ProtocolHandler](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L1263-L1440) thread is started with the accepted [TcpClientConnection](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L230-L237). -`HttpProtocolHandler` is responsible for parsing incoming client request and invoking -`HttpProtocolBasePlugin` lifecycle hooks. +`ProtocolHandler` is responsible for parsing incoming client request and invoking +`ProtocolHandlerPlugin` lifecycle hooks. -- `HttpProxyPlugin` which implements `HttpProtocolBasePlugin` also has its own plugin +- `HttpProxyPlugin` which implements `ProtocolHandlerPlugin` also has its own plugin mechanism. Its responsibility is to establish connection between client and upstream [TcpServerConnection](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L204-L227) and invoke `HttpProxyBasePlugin` lifecycle hooks. -- `HttpProtocolHandler` threads are started by [Worker](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472) +- `ProtocolHandler` threads are started by [Worker](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L424-L472) processes. - `--num-workers` `Worker` processes are started by - [MultiCoreRequestDispatcher](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L368-L421) - on start-up. `Worker` processes receives `TcpClientConnection` over a pipe from `MultiCoreRequestDispatcher`. + [WorkerPool](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L368-L421) + on start-up. `Worker` processes receives `TcpClientConnection` over a pipe from `WorkerPool`. -- `MultiCoreRequestDispatcher` implements [TcpServer](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L240-L302) - abstract class. `TcpServer` accepts `TcpClientConnection`. `MultiCoreRequestDispatcher` +- `WorkerPool` implements [TcpServer](https://github.com/abhinavsingh/proxy.py/blob/b03629fa0df1595eb4995427bc601063be7fdca9/proxy.py#L240-L302) + abstract class. `TcpServer` accepts `TcpClientConnection`. `WorkerPool` ensures full utilization of available CPU cores, for which it dispatches accepted `TcpClientConnection` to `Worker` processes in a round-robin fashion. diff --git a/plugin_examples.py b/plugin_examples.py index 02cdcfbf..e970bb7d 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -66,7 +66,7 @@ class CacheResponsesPlugin(proxy.HttpProxyBasePlugin): CACHE_DIR = tempfile.gettempdir() - def __init__(self, config: proxy.HttpProtocolConfig, client: proxy.TcpClientConnection, + def __init__(self, config: proxy.ProtocolConfig, client: proxy.TcpClientConnection, request: proxy.HttpParser) -> None: super().__init__(config, client, request) self.cache_file_path: Optional[str] = None diff --git a/proxy.py b/proxy.py index bc0af48f..df8c696f 100755 --- a/proxy.py +++ b/proxy.py @@ -150,26 +150,31 @@ HttpProtocolTypes = NamedTuple('HttpProtocolTypes', [ httpProtocolTypes = HttpProtocolTypes(1, 2) -class TcpConnection: +class TcpConnectionUninitializedException(Exception): + pass + + +class TcpConnection(ABC): """TCP server/client connection abstraction.""" def __init__(self, tag: int): - self.conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None self.buffer: bytes = b'' self.closed: bool = False self.tag: str = 'server' if tag == tcpConnectionTypes.SERVER else 'client' + @property + @abstractmethod + def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + """Must return the socket connection to use in this class.""" + raise TcpConnectionUninitializedException() + def send(self, data: bytes) -> int: """Users must handle BrokenPipeError exceptions""" - if not self.conn: - raise KeyError('conn is None') - return self.conn.send(data) + return self.connection.send(data) def recv(self, buffer_size: int = DEFAULT_BUFFER_SIZE) -> Optional[bytes]: - if not self.conn: - raise KeyError('conn is None') try: - data: bytes = self.conn.recv(buffer_size) + data: bytes = self.connection.recv(buffer_size) if len(data) > 0: logger.debug( 'received %d bytes from %s' % @@ -181,14 +186,12 @@ class TcpConnection: else: logger.exception( 'Exception while receiving from connection %s %r with reason %r' % - (self.tag, self.conn, e)) + (self.tag, self.connection, e)) return None def close(self) -> bool: - if not self.conn: - raise KeyError('conn is None') if not self.closed: - self.conn.close() + self.connection.close() self.closed = True return self.closed @@ -210,31 +213,37 @@ class TcpConnection: class TcpServerConnection(TcpConnection): - """Establishes connection to destination server.""" + """Establishes connection to upstream server.""" def __init__(self, host: str, port: int): super().__init__(tcpConnectionTypes.SERVER) self.addr: Tuple[str, int] = (host, int(port)) + self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None - def __del__(self) -> None: - if self.conn: - self.close() + @property + def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + if self._conn is None: + raise TcpConnectionUninitializedException() + return self._conn def connect(self) -> None: + if self._conn is not None: + return + try: ip = ipaddress.ip_address(text_(self.addr[0])) if ip.version == 4: - self.conn = socket.socket( + self._conn = socket.socket( socket.AF_INET, socket.SOCK_STREAM, 0) - self.conn.connect((self.addr[0], self.addr[1])) + self._conn.connect((self.addr[0], self.addr[1])) else: - self.conn = socket.socket( + self._conn = socket.socket( socket.AF_INET6, socket.SOCK_STREAM, 0) - self.conn.connect((self.addr[0], self.addr[1], 0, 0)) + self._conn.connect((self.addr[0], self.addr[1], 0, 0)) except ValueError: # Not a valid IP address, most likely its a domain name, # try to establish dual stack IPv4/IPv6 connection. - self.conn = socket.create_connection((self.addr[0], self.addr[1])) + self._conn = socket.create_connection((self.addr[0], self.addr[1])) class TcpClientConnection(TcpConnection): @@ -243,9 +252,15 @@ class TcpClientConnection(TcpConnection): def __init__(self, conn: Union[ssl.SSLSocket, socket.socket], addr: Tuple[str, int]): super().__init__(tcpConnectionTypes.CLIENT) - self.conn: Union[ssl.SSLSocket, socket.socket] = conn + self._conn: Optional[Union[ssl.SSLSocket, socket.socket]] = conn self.addr: Tuple[str, int] = addr + @property + def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + if self._conn is None: + raise TcpConnectionUninitializedException() + return self._conn + class TcpServer(ABC): """TcpServer server implementation. @@ -259,13 +274,12 @@ class TcpServer(ABC): hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME, port: int = DEFAULT_PORT, - backlog: int = DEFAULT_BACKLOG, - family: socket.AddressFamily = socket.AF_INET6): + backlog: int = DEFAULT_BACKLOG): self.port: int = port self.backlog: int = backlog self.socket: Optional[socket.socket] = None self.running: bool = False - self.family: socket.AddressFamily = family + self.family: socket.AddressFamily = socket.AF_INET6 if hostname.version == 6 else socket.AF_INET self.hostname: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = hostname @@ -315,100 +329,40 @@ class TcpServer(ABC): self.socket.close() -class HttpProtocolConfig: - """Holds various configuration values applicable to HttpProtocolHandler. - - This config class helps us avoid passing around bunch of key/value pairs across methods. - """ - - ROOT_DATA_DIR_NAME = '.proxy.py' - GENERATED_CERTS_DIR_NAME = 'certificates' - - def __init__( - self, - auth_code: Optional[bytes] = DEFAULT_BASIC_AUTH, - server_recvbuf_size: int = DEFAULT_SERVER_RECVBUF_SIZE, - client_recvbuf_size: int = DEFAULT_CLIENT_RECVBUF_SIZE, - pac_file: Optional[str] = DEFAULT_PAC_FILE, - pac_file_url_path: Optional[bytes] = DEFAULT_PAC_FILE_URL_PATH, - plugins: Optional[Dict[bytes, List[type]]] = None, - disable_headers: Optional[List[bytes]] = None, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - ca_cert_dir: Optional[str] = None, - ca_key_file: Optional[str] = None, - ca_cert_file: Optional[str] = None, - ca_signing_key_file: Optional[str] = None, - num_workers: int = 0, - hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME, - port: int = DEFAULT_PORT, - backlog: int = DEFAULT_BACKLOG) -> None: - self.auth_code = auth_code - self.server_recvbuf_size = server_recvbuf_size - self.client_recvbuf_size = client_recvbuf_size - self.pac_file = pac_file - self.pac_file_url_path = pac_file_url_path - if plugins is None: - plugins = {} - self.plugins: Dict[bytes, List[type]] = plugins - if disable_headers is None: - disable_headers = DEFAULT_DISABLE_HEADERS - self.disable_headers = disable_headers - self.certfile: Optional[str] = certfile - self.keyfile: Optional[str] = keyfile - self.ca_key_file: Optional[str] = ca_key_file - self.ca_cert_file: Optional[str] = ca_cert_file - self.ca_signing_key_file: Optional[str] = ca_signing_key_file - self.num_workers: int = num_workers - self.hostname: Union[ipaddress.IPv4Address, - ipaddress.IPv6Address] = hostname - self.port: int = port - self.backlog: int = backlog - self.family: socket.AddressFamily = socket.AF_INET if hostname.version == 4 else socket.AF_INET6 - - self.proxy_py_data_dir = os.path.join( - str(pathlib.Path.home()), self.ROOT_DATA_DIR_NAME) - os.makedirs(self.proxy_py_data_dir, exist_ok=True) - - self.ca_cert_dir: Optional[str] = ca_cert_dir - if self.ca_cert_dir is None: - self.ca_cert_dir = os.path.join( - self.proxy_py_data_dir, self.GENERATED_CERTS_DIR_NAME) - os.makedirs(self.ca_cert_dir, exist_ok=True) - - -class MultiCoreRequestDispatcher(TcpServer): - """MultiCoreRequestDispatcher. +class WorkerPool(TcpServer): + """WorkerPool. Pre-spawns worker process to utilize all cores available on the system. Accepted `TcpClientConnection` is dispatched over a queue to workers. One of the worker picks up the work and starts a new thread to handle the client request. """ - def __init__(self, config: HttpProtocolConfig) -> None: - super().__init__( - hostname=config.hostname, - port=config.port, - backlog=config.backlog, - family=config.family) + def __init__(self, hostname: Union[ipaddress.IPv4Address, + ipaddress.IPv6Address], + port: int, backlog: int, num_workers: int, + work_klass: type, **kwargs: Any) -> None: + super().__init__(hostname=hostname, port=port, backlog=backlog) + self.num_workers = num_workers + self.workers: List[Worker] = [] self.work_queues: List[Tuple[connection.Connection, connection.Connection]] = [] self.current_worker_id = 0 - self.config: HttpProtocolConfig = config + + self.work_klass = work_klass + self.kwargs = kwargs def setup(self) -> None: - for worker_id in range(self.config.num_workers): + for worker_id in range(self.num_workers): work_queue = multiprocessing.Pipe() - worker = Worker(work_queue[1], self.config) + worker = Worker(work_queue[1], self.work_klass, **self.kwargs) worker.daemon = True worker.start() self.workers.append(worker) self.work_queues.append(work_queue) - logger.info('Started %d workers' % self.config.num_workers) + logger.info('Started %d workers' % self.num_workers) def handle(self, client: TcpClientConnection) -> None: # Dispatch in round robin fashion @@ -418,15 +372,15 @@ class MultiCoreRequestDispatcher(TcpServer): self.current_worker_id) # Dispatch non-socket data first, followed by fileno using reduction work_queue[0].send((workerOperations.HTTP_PROTOCOL, client.addr)) - send_handle(work_queue[0], client.conn.fileno(), + send_handle(work_queue[0], client.connection.fileno(), self.workers[self.current_worker_id].pid) # Close parent handler client.close() self.current_worker_id += 1 - self.current_worker_id %= self.config.num_workers + self.current_worker_id %= self.num_workers def shutdown(self) -> None: - logger.info('Shutting down %d workers' % self.config.num_workers) + logger.info('Shutting down %d workers' % self.num_workers) for work_queue in self.work_queues: work_queue[0].send((workerOperations.SHUTDOWN, None)) work_queue[0].close() @@ -444,40 +398,20 @@ class Worker(multiprocessing.Process): def __init__( self, work_queue: connection.Connection, - config: HttpProtocolConfig): + work_klass: type, + **kwargs: Any): super().__init__() self.work_queue: connection.Connection = work_queue - self.config: HttpProtocolConfig = config + self.work_klass = work_klass + self.kwargs = kwargs def run_once(self) -> bool: try: op, payload = self.work_queue.recv() if op == workerOperations.HTTP_PROTOCOL: fileno = recv_handle(self.work_queue) - conn = socket.fromfd( - fileno, family=self.config.family, type=socket.SOCK_STREAM) - # TODO(abhinavsingh): Move handshake logic within - # HttpProtocolHandler or should this go under TcpServer directly? - # Rationale behind deferring ssl wrap is that plugins can custom wrap - # sockets if necessary. - if self.config.certfile and self.config.keyfile: - try: - ctx = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - ctx.verify_mode = ssl.CERT_NONE - ctx.load_cert_chain( - certfile=self.config.certfile, - keyfile=self.config.keyfile) - conn = ctx.wrap_socket(conn, server_side=True) - except OSError as e: - logger.exception( - 'OSError encountered while ssl wrapping the client socket', exc_info=e) - conn.close() - return False - proxy = HttpProtocolHandler( - TcpClientConnection(conn=conn, addr=payload), - config=self.config) + proxy = self.work_klass( + fileno=fileno, addr=payload, **self.kwargs) proxy.setDaemon(True) proxy.start() elif op == workerOperations.SHUTDOWN: @@ -701,7 +635,7 @@ class HttpParser: else: self.version = line[0] self.code = line[1] - self.reason = b' '.join(line[2:]) + self.reason = WHITESPACE.join(line[2:]) self.set_host_port() def process_header(self, raw: bytes) -> None: @@ -789,7 +723,7 @@ class HttpParser: ########################################################################## # HttpParser was originally written to parse the incoming raw Http requests. - # Since request / response objects passed to HttpProtocolBasePlugin methods + # Since request / response objects passed to ProtocolHandlerPlugin methods # are also HttpParser objects, methods below were added to simplify developer API. ########################################################################## @@ -798,18 +732,18 @@ class HttpParser: return True if self.host is not None else False -class HttpProtocolException(Exception): - """Top level HttpProtocolException exception class. +class ProtocolException(Exception): + """Top level ProtocolException exception class. All exceptions raised during execution of Http request lifecycle MUST - inherit HttpProtocolException base class. Implement response() method + inherit ProtocolException base class. Implement response() method to optionally return custom response to client.""" def response(self, request: HttpParser) -> Optional[bytes]: pass # pragma: no cover -class HttpRequestRejected(HttpProtocolException): +class HttpRequestRejected(ProtocolException): """Generic exception that can be used to reject the client requests. Connections can either be dropped/closed or optionally an @@ -829,7 +763,7 @@ class HttpRequestRejected(HttpProtocolException): if self.status_code is not None: line = b'HTTP/1.1 ' + bytes_(str(self.status_code)) if self.reason: - line += b' ' + self.reason + line += WHITESPACE + self.reason pkt.append(line) pkt.append(PROXY_AGENT_HEADER) if self.body: @@ -842,17 +776,79 @@ class HttpRequestRejected(HttpProtocolException): return CRLF.join(pkt) if len(pkt) > 0 else None -class HttpProtocolBasePlugin(ABC): - """Base HttpProtocolHandler Plugin class. +class ProtocolConfig: + """Holds various configuration values applicable to ProtocolHandler. + + This config class helps us avoid passing around bunch of key/value pairs across methods. + """ + + ROOT_DATA_DIR_NAME = '.proxy.py' + GENERATED_CERTS_DIR_NAME = 'certificates' + + def __init__( + self, + auth_code: Optional[bytes] = DEFAULT_BASIC_AUTH, + server_recvbuf_size: int = DEFAULT_SERVER_RECVBUF_SIZE, + client_recvbuf_size: int = DEFAULT_CLIENT_RECVBUF_SIZE, + pac_file: Optional[str] = DEFAULT_PAC_FILE, + pac_file_url_path: Optional[bytes] = DEFAULT_PAC_FILE_URL_PATH, + plugins: Optional[Dict[bytes, List[type]]] = None, + disable_headers: Optional[List[bytes]] = None, + certfile: Optional[str] = None, + keyfile: Optional[str] = None, + ca_cert_dir: Optional[str] = None, + ca_key_file: Optional[str] = None, + ca_cert_file: Optional[str] = None, + ca_signing_key_file: Optional[str] = None, + num_workers: int = 0, + hostname: Union[ipaddress.IPv4Address, + ipaddress.IPv6Address] = DEFAULT_IPV6_HOSTNAME, + port: int = DEFAULT_PORT, + backlog: int = DEFAULT_BACKLOG) -> None: + self.auth_code = auth_code + self.server_recvbuf_size = server_recvbuf_size + self.client_recvbuf_size = client_recvbuf_size + self.pac_file = pac_file + self.pac_file_url_path = pac_file_url_path + if plugins is None: + plugins = {} + self.plugins: Dict[bytes, List[type]] = plugins + if disable_headers is None: + disable_headers = DEFAULT_DISABLE_HEADERS + self.disable_headers = disable_headers + self.certfile: Optional[str] = certfile + self.keyfile: Optional[str] = keyfile + self.ca_key_file: Optional[str] = ca_key_file + self.ca_cert_file: Optional[str] = ca_cert_file + self.ca_signing_key_file: Optional[str] = ca_signing_key_file + self.num_workers: int = num_workers + self.hostname: Union[ipaddress.IPv4Address, + ipaddress.IPv6Address] = hostname + self.port: int = port + self.backlog: int = backlog + + self.proxy_py_data_dir = os.path.join( + str(pathlib.Path.home()), self.ROOT_DATA_DIR_NAME) + os.makedirs(self.proxy_py_data_dir, exist_ok=True) + + self.ca_cert_dir: Optional[str] = ca_cert_dir + if self.ca_cert_dir is None: + self.ca_cert_dir = os.path.join( + self.proxy_py_data_dir, self.GENERATED_CERTS_DIR_NAME) + os.makedirs(self.ca_cert_dir, exist_ok=True) + + +class ProtocolHandlerPlugin(ABC): + """Base ProtocolHandler Plugin class. Implement various lifecycle event methods to customize behavior.""" def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection, request: HttpParser): - self.config: HttpProtocolConfig = config + self.config: ProtocolConfig = config self.client: TcpClientConnection = client self.request: HttpParser = request super().__init__() @@ -902,7 +898,7 @@ class HttpProtocolBasePlugin(ABC): pass # pragma: no cover -class ProxyConnectionFailed(HttpProtocolException): +class ProxyConnectionFailed(ProtocolException): """Exception raised when HttpProxyPlugin is unable to establish connection to upstream server.""" RESPONSE_PKT = HttpParser.build_response( @@ -925,7 +921,7 @@ class ProxyConnectionFailed(HttpProtocolException): self.host, self.port, self.reason) -class ProxyAuthenticationFailed(HttpProtocolException): +class ProxyAuthenticationFailed(ProtocolException): """Exception raised when Http Proxy auth is enabled and incoming request doesn't present necessary credentials.""" @@ -947,7 +943,7 @@ class HttpProxyBasePlugin(ABC): def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection, request: HttpParser): self.config = config @@ -987,8 +983,8 @@ class HttpProxyBasePlugin(ABC): pass # pragma: no cover -class HttpProxyPlugin(HttpProtocolBasePlugin): - """HttpProtocolHandler plugin which implements HttpProxy specifications.""" +class HttpProxyPlugin(ProtocolHandlerPlugin): + """ProtocolHandler plugin which implements HttpProxy specifications.""" PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = HttpParser.build_response( 200, reason=b'Connection established' @@ -1000,7 +996,7 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection, request: HttpParser): super().__init__(config, client, request) @@ -1020,16 +1016,16 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): r: List[socket.socket] = [] w: List[socket.socket] = [] - if self.server and not self.server.closed and self.server.conn: - r.append(self.server.conn) - if self.server and not self.server.closed and self.server.has_buffer() and self.server.conn: - w.append(self.server.conn) + if self.server and not self.server.closed and self.server.connection: + r.append(self.server.connection) + if self.server and not self.server.closed and self.server.has_buffer() and self.server.connection: + w.append(self.server.connection) return r, w, [] def flush_to_descriptors(self, w: List[socket.socket]) -> bool: if self.request.has_upstream_server() and \ - self.server and not self.server.closed and self.server.conn in w: - logger.debug('Server is ready for writes, flushing server buffer') + self.server and not self.server.closed and self.server.connection in w: + logger.debug('Server is write ready, flushing buffer') try: self.server.flush() except BrokenPipeError: @@ -1040,10 +1036,10 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): def read_from_descriptors(self, r: List[socket.socket]) -> bool: if self.request.has_upstream_server( - ) and self.server and not self.server.closed and self.server.conn in r: + ) and self.server and not self.server.closed and self.server.connection in r: logger.debug('Server is ready for reads, reading') raw = self.server.recv(self.config.server_recvbuf_size) - # self.last_activity = HttpProtocolHandler.now() + # self.last_activity = ProtocolHandler.now() if not raw: logger.debug('Server closed connection, tearing down...') return True @@ -1148,21 +1144,22 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): # connection to the client. Below we handle the scenario # when client is communicating to proxy.py using http. if not (self.config.keyfile and self.config.certfile) and \ - self.server and isinstance(self.server.conn, socket.socket): - self.client.conn = ssl.wrap_socket(self.client.conn, - server_side=True, - keyfile=self.config.ca_signing_key_file, - certfile=generated_cert) + self.server and isinstance(self.server.connection, socket.socket): + self.client._conn = ssl.wrap_socket( + self.client.connection, + server_side=True, + keyfile=self.config.ca_signing_key_file, + certfile=generated_cert) # Wrap our connection to upstream server connection ctx = ssl.create_default_context( ssl.Purpose.SERVER_AUTH) ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 - self.server.conn = ctx.wrap_socket( - self.server.conn, server_hostname=text_( + self.server._conn = ctx.wrap_socket( + self.server.connection, server_hostname=text_( self.request.host)) logger.info( 'TLS interception using %s', generated_cert) - return self.client.conn + return self.client.connection # for general http requests, re-build request packet # and queue for the server with appropriate headers elif self.server: @@ -1232,7 +1229,7 @@ class HttpProxyPlugin(HttpProtocolBasePlugin): raise ProxyConnectionFailed(text_(host), port, repr(e)) from e else: logger.exception('Both host and port must exist') - raise HttpProtocolException() + raise ProtocolException() class HttpWebServerRoutePlugin(ABC): @@ -1240,7 +1237,7 @@ class HttpWebServerRoutePlugin(ABC): def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection): self.config = config self.client = client @@ -1279,7 +1276,7 @@ class HttpWebServerPacFilePlugin(HttpWebServerRoutePlugin): def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection): super().__init__(config, client) self.pac_file_response: Optional[bytes] = None @@ -1312,8 +1309,8 @@ class HttpWebServerPacFilePlugin(HttpWebServerRoutePlugin): ) -class HttpWebServerPlugin(HttpProtocolBasePlugin): - """HttpProtocolHandler plugin which handles incoming requests to local webserver.""" +class HttpWebServerPlugin(ProtocolHandlerPlugin): + """ProtocolHandler plugin which handles incoming requests to local webserver.""" DEFAULT_404_RESPONSE = HttpParser.build_response( 404, reason=b'NOT FOUND', @@ -1323,7 +1320,7 @@ class HttpWebServerPlugin(HttpProtocolBasePlugin): def __init__( self, - config: HttpProtocolConfig, + config: ProtocolConfig, client: TcpClientConnection, request: HttpParser): super().__init__(config, client, request) @@ -1392,28 +1389,62 @@ class HttpWebServerPlugin(HttpProtocolBasePlugin): return [], [], [] -class HttpProtocolHandler(threading.Thread): +class ProtocolHandler(threading.Thread): """HTTP, HTTPS, HTTP2, WebSockets protocol handler. - Accepts `Client` connection object and manages HttpProtocolBasePlugin invocations. + Accepts `Client` connection object and manages ProtocolHandlerPlugin invocations. """ - def __init__(self, client: TcpClientConnection, - config: Optional[HttpProtocolConfig] = None): + def __init__(self, fileno: int, addr: Tuple[str, int], + config: Optional[ProtocolConfig] = None): super().__init__() self.start_time: datetime.datetime = self.now() self.last_activity: datetime.datetime = self.start_time - self.client: TcpClientConnection = client - self.config: HttpProtocolConfig = config if config else HttpProtocolConfig() + self.config: ProtocolConfig = config if config else ProtocolConfig() self.request: HttpParser = HttpParser(httpParserTypes.REQUEST_PARSER) - self.plugins: Dict[str, HttpProtocolBasePlugin] = {} - if b'HttpProtocolBasePlugin' in self.config.plugins: - for klass in self.config.plugins[b'HttpProtocolBasePlugin']: + conn = self.optionally_wrap_socket(self.fromfd(fileno)) + if conn is None: + raise TcpConnectionUninitializedException() + self.client: TcpClientConnection = TcpClientConnection( + conn=conn, + addr=addr) + + self.plugins: Dict[str, ProtocolHandlerPlugin] = {} + if b'ProtocolHandlerPlugin' in self.config.plugins: + for klass in self.config.plugins[b'ProtocolHandlerPlugin']: instance = klass(self.config, self.client, self.request) self.plugins[instance.name()] = instance + def fromfd(self, fileno: int) -> socket.socket: + return socket.fromfd( + fileno, family=socket.AF_INET if self.config.hostname.version == 4 else socket.AF_INET6, + type=socket.SOCK_STREAM) + + def optionally_wrap_socket(self, conn: socket.socket) -> Optional[Union[ssl.SSLSocket, socket.socket]]: + """Attempts to wrap accepted client connection using provided certificates. + + Shutdown and closes client connection upon error. + """ + if self.config.certfile and self.config.keyfile: + try: + ctx = ssl.create_default_context( + ssl.Purpose.CLIENT_AUTH) + ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx.verify_mode = ssl.CERT_NONE + ctx.load_cert_chain( + certfile=self.config.certfile, + keyfile=self.config.keyfile) + conn = ctx.wrap_socket(conn, server_side=True) + return conn + except Exception as e: + logger.exception('Error encountered', exc_info=e) + conn.shutdown(socket.SHUT_RDWR) + conn.close() + return None + return conn + @staticmethod def now() -> datetime.datetime: return datetime.datetime.utcnow() @@ -1426,8 +1457,8 @@ class HttpProtocolHandler(threading.Thread): return self.connection_inactive_for() > 30 def handle_writables(self, writables: List[socket.socket]) -> bool: - if self.client.conn in writables: - logger.debug('Client is ready for writes, flushing client buffer') + if self.client.connection in writables: + logger.debug('Client is write, flushing buffer') try: self.client.flush() except BrokenPipeError: @@ -1437,7 +1468,7 @@ class HttpProtocolHandler(threading.Thread): return False def handle_readables(self, readables: List[socket.socket]) -> bool: - if self.client.conn in readables: + if self.client.connection in readables: logger.debug('Client is ready for reads, reading') client_data = self.client.recv(self.config.client_recvbuf_size) self.last_activity = self.now() @@ -1446,7 +1477,7 @@ class HttpProtocolHandler(threading.Thread): self.client.closed = True return True - # HttpProtocolBasePlugin.on_client_data + # ProtocolHandlerPlugin.on_client_data plugin_index = 0 plugins = list(self.plugins.values()) while plugin_index < len(plugins) and client_data: @@ -1464,19 +1495,19 @@ class HttpProtocolHandler(threading.Thread): if isinstance(upgraded_sock, ssl.SSLSocket): logger.debug( 'Updated client conn to %s', upgraded_sock) - self.client.conn = upgraded_sock + self.client._conn = upgraded_sock # Update self.client.conn references for all # plugins for plugin_ in self.plugins.values(): if plugin_ != plugin: - plugin_.client.conn = upgraded_sock + plugin_.client._conn = upgraded_sock logger.debug( 'Upgraded client conn for plugin %s', str(plugin_)) elif isinstance(upgraded_sock, bool) and upgraded_sock: return True - except HttpProtocolException as e: + except ProtocolException as e: logger.exception( - 'HttpProtocolException type raised', exc_info=e) + 'ProtocolException type raised', exc_info=e) response = e.response(self.request) if response: self.client.queue(response) @@ -1490,13 +1521,13 @@ class HttpProtocolHandler(threading.Thread): def run_once(self) -> bool: """Returns True if proxy must teardown.""" # Prepare list of descriptors - read_desc: List[socket.socket] = [self.client.conn] + read_desc: List[socket.socket] = [self.client.connection] write_desc: List[socket.socket] = [] err_desc: List[socket.socket] = [] if self.client.has_buffer(): - write_desc.append(self.client.conn) + write_desc.append(self.client.connection) - # HttpProtocolBasePlugin.get_descriptors + # ProtocolHandlerPlugin.get_descriptors for plugin in self.plugins.values(): plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors() read_desc += plugin_read_desc @@ -1539,7 +1570,7 @@ class HttpProtocolHandler(threading.Thread): return False def run(self) -> None: - logger.debug('Proxying connection %r' % self.client.conn) + logger.debug('Proxying connection %r' % self.client.connection) try: while True: teardown = self.run_once() @@ -1550,7 +1581,7 @@ class HttpProtocolHandler(threading.Thread): except Exception as e: logger.exception( 'Exception while handling connection %r with reason %r' % - (self.client.conn, e)) + (self.client.connection, e)) finally: # Invoke plugin.access_log for plugin in self.plugins.values(): @@ -1558,7 +1589,7 @@ class HttpProtocolHandler(threading.Thread): if not self.client.closed: try: - self.client.conn.shutdown(socket.SHUT_RDWR) + self.client.connection.shutdown(socket.SHUT_RDWR) self.client.close() except OSError: pass @@ -1570,7 +1601,7 @@ class HttpProtocolHandler(threading.Thread): logger.debug( 'Closed proxy for connection %r ' 'at address %r with pending client buffer size %d bytes' % - (self.client.conn, self.client.addr, self.client.buffer_size())) + (self.client.connection, self.client.addr, self.client.buffer_size())) def is_py3() -> bool: @@ -1595,7 +1626,7 @@ def load_plugins(plugins: bytes) -> Dict[bytes, List[type]]: """Accepts a comma separated list of Python modules and returns a list of respective Python classes.""" p: Dict[bytes, List[type]] = { - b'HttpProtocolBasePlugin': [], + b'ProtocolHandlerPlugin': [], b'HttpProxyBasePlugin': [], b'HttpWebServerRoutePlugin': [], } @@ -1820,7 +1851,7 @@ def main(input_args: List[str]) -> None: if args.basic_auth: auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth)) - config = HttpProtocolConfig( + config = ProtocolConfig( auth_code=auth_code, server_recvbuf_size=args.server_recvbuf_size, client_recvbuf_size=args.client_recvbuf_size, @@ -1853,7 +1884,13 @@ def main(input_args: List[str]) -> None: '%s%s' % (default_plugins, args.plugins))) - server = MultiCoreRequestDispatcher(config=config) + server = WorkerPool( + hostname=config.hostname, + port=config.port, + backlog=config.backlog, + num_workers=config.num_workers, + work_klass=ProtocolHandler, + config=config) if args.pid_file: with open(args.pid_file, 'wb') as pid_file: pid_file.write(bytes_(str(os.getpid()))) diff --git a/tests.py b/tests.py index 4c08d514..d84ebffe 100644 --- a/tests.py +++ b/tests.py @@ -20,7 +20,7 @@ import unittest from contextlib import closing from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Thread -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union from unittest import mock import proxy @@ -46,21 +46,32 @@ def get_available_port() -> int: class TestTcpConnection(unittest.TestCase): + class TcpConnectionToTest(proxy.TcpConnection): + + def __init__(self, conn: Optional[Union[ssl.SSLSocket, socket.socket]] = None, + tag: int = proxy.tcpConnectionTypes.CLIENT) -> None: + super().__init__(tag) + self._conn = conn + + @property + def connection(self) -> Union[ssl.SSLSocket, socket.socket]: + if self._conn is None: + raise proxy.TcpConnectionUninitializedException() + return self._conn + def testThrowsKeyErrorIfNoConn(self) -> None: - self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT) - self.conn.conn = None - with self.assertRaises(KeyError): + self.conn = TestTcpConnection.TcpConnectionToTest() + with self.assertRaises(proxy.TcpConnectionUninitializedException): self.conn.send(b'dummy') - with self.assertRaises(KeyError): + with self.assertRaises(proxy.TcpConnectionUninitializedException): self.conn.recv() - with self.assertRaises(KeyError): + with self.assertRaises(proxy.TcpConnectionUninitializedException): self.conn.close() def testHandlesIOError(self) -> None: - self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT) _conn = mock.MagicMock() _conn.recv.side_effect = IOError() - self.conn.conn = _conn + self.conn = TestTcpConnection.TcpConnectionToTest(_conn) with mock.patch('proxy.logger') as mock_logger: self.conn.recv() mock_logger.exception.assert_called() @@ -68,12 +79,11 @@ class TestTcpConnection(unittest.TestCase): 'Exception while receiving from connection')) def testHandlesConnReset(self) -> None: - self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT) _conn = mock.MagicMock() e = IOError() e.errno = errno.ECONNRESET _conn.recv.side_effect = e - self.conn.conn = _conn + self.conn = TestTcpConnection.TcpConnectionToTest(_conn) with mock.patch('proxy.logger') as mock_logger: self.conn.recv() mock_logger.exception.assert_not_called() @@ -81,58 +91,46 @@ class TestTcpConnection(unittest.TestCase): self.assertEqual(mock_logger.debug.call_args[0][0], '%r' % e) def testClosesIfNotClosed(self) -> None: - self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT) _conn = mock.MagicMock() - self.conn.conn = _conn + self.conn = TestTcpConnection.TcpConnectionToTest(_conn) self.conn.close() _conn.close.assert_called() self.assertTrue(self.conn.closed) def testNoOpIfAlreadyClosed(self) -> None: - self.conn = proxy.TcpConnection(proxy.TcpConnectionTypes.CLIENT) _conn = mock.MagicMock() - self.conn.conn = _conn + self.conn = TestTcpConnection.TcpConnectionToTest(_conn) self.conn.closed = True self.conn.close() _conn.close.assert_not_called() self.assertTrue(self.conn.closed) - @mock.patch('socket.socket') - def testTcpServerClosesConnOnGC(self, mock_socket: mock.Mock) -> None: - conn = mock.MagicMock() - mock_socket.return_value = conn - self.conn = proxy.TcpServerConnection( - str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT) - self.conn.connect() - del self.conn - conn.close.assert_called() - @mock.patch('socket.socket') def testTcpServerEstablishesIPv6Connection(self, mock_socket: mock.Mock) -> None: - self.conn = proxy.TcpServerConnection( + conn = proxy.TcpServerConnection( str(proxy.DEFAULT_IPV6_HOSTNAME), proxy.DEFAULT_PORT) - self.conn.connect() + conn.connect() mock_socket.assert_called() mock_socket.return_value.connect.assert_called_with( (str(proxy.DEFAULT_IPV6_HOSTNAME), proxy.DEFAULT_PORT, 0, 0)) @mock.patch('socket.socket') def testTcpServerEstablishesIPv4Connection(self, mock_socket: mock.Mock) -> None: - self.conn = proxy.TcpServerConnection( + conn = proxy.TcpServerConnection( str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT) - self.conn.connect() + conn.connect() mock_socket.assert_called() mock_socket.return_value.connect.assert_called_with( (str(proxy.DEFAULT_IPV4_HOSTNAME), proxy.DEFAULT_PORT)) -class BasicTcpServer(proxy.TcpServer): +class TcpServerUnderTest(proxy.TcpServer): def handle(self, client: proxy.TcpClientConnection) -> None: data = client.recv(proxy.DEFAULT_BUFFER_SIZE) if data != b'HELLO': raise ValueError('Expected HELLO') - client.conn.sendall(b'WORLD') + client.connection.sendall(b'WORLD') client.close() def setup(self) -> None: @@ -148,8 +146,8 @@ class TestTcpServerIntegration(unittest.TestCase): ipv4_port: Optional[int] = None ipv6_port: Optional[int] = None - ipv4_server: Optional[BasicTcpServer] = None - ipv6_server: Optional[BasicTcpServer] = None + ipv4_server: Optional[TcpServerUnderTest] = None + ipv6_server: Optional[TcpServerUnderTest] = None ipv4_thread: Optional[Thread] = None ipv6_thread: Optional[Thread] = None @@ -157,14 +155,12 @@ class TestTcpServerIntegration(unittest.TestCase): def setUpClass(cls) -> None: cls.ipv4_port = get_available_port() cls.ipv6_port = get_available_port() - cls.ipv4_server = BasicTcpServer( + cls.ipv4_server = TcpServerUnderTest( hostname=proxy.DEFAULT_IPV4_HOSTNAME, - port=cls.ipv4_port, - family=socket.AF_INET) - cls.ipv6_server = BasicTcpServer( + port=cls.ipv4_port) + cls.ipv6_server = TcpServerUnderTest( hostname=proxy.DEFAULT_IPV6_HOSTNAME, - port=cls.ipv6_port, - family=socket.AF_INET6) + port=cls.ipv6_port) cls.ipv4_thread = Thread(target=cls.ipv4_server.run) cls.ipv6_thread = Thread(target=cls.ipv6_server.run) cls.ipv4_thread.setDaemon(True) @@ -221,7 +217,7 @@ class TestTcpServer(unittest.TestCase): def testAcceptSSLErrorsSilentlyIgnored(self, mock_socket: mock.Mock, mock_select: mock.Mock) -> None: mock_socket.accept.side_effect = ssl.SSLError() mock_select.return_value = ([mock_socket], [], []) - server = BasicTcpServer(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=1234) + server = TcpServerUnderTest(hostname=proxy.DEFAULT_IPV6_HOSTNAME, port=1234) server.socket = mock_socket with mock.patch('proxy.logger') as mock_logger: server.run_once() @@ -229,71 +225,6 @@ class TestTcpServer(unittest.TestCase): self.assertTrue(mock_logger.exception.call_args[0][0], 'SSLError encountered') -class MockHttpProxy: - - def __init__(self, client: proxy.TcpClientConnection, **kwargs) -> None: # type: ignore - self.client = client - self.kwargs = kwargs - - def setDaemon(self, _val: bool) -> None: - pass - - def start(self) -> None: - self.client.conn.sendall(proxy.CRLF.join( - [b'HTTP/1.1 200 OK', proxy.CRLF])) - self.client.conn.shutdown(socket.SHUT_RDWR) - self.client.conn.close() - - -def mock_tcp_proxy_side_effect(client: proxy.TcpClientConnection, **kwargs) -> MockHttpProxy: # type: ignore - return MockHttpProxy(client, **kwargs) - - -@unittest.skipIf(os.getenv('TESTING_ON_TRAVIS', 0), - 'Opening sockets not allowed on Travis') -@unittest.skipIf(os.getenv('GITHUB_WORKFLOW', 0), - 'This test fails on GitHub Windows environment') -class TestMultiCoreRequestDispatcherIntegration(unittest.TestCase): - tcp_port = None - tcp_server = None - tcp_thread = None - - @mock.patch.object( - proxy, # type: ignore - 'HttpProtocolHandler', - side_effect=mock_tcp_proxy_side_effect) - def testHttpProxyConnection(self, _mock_tcp_proxy): - try: - self.tcp_port = get_available_port() - self.tcp_server = proxy.MultiCoreRequestDispatcher( - proxy.HttpProtocolConfig( - hostname=proxy.DEFAULT_IPV4_HOSTNAME, - port=self.tcp_port, - num_workers=1)) - self.tcp_thread = Thread(target=self.tcp_server.run) - self.tcp_thread.setDaemon(True) - self.tcp_thread.start() - - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) as sock: - try: - sock.connect( - (str(proxy.DEFAULT_IPV4_HOSTNAME), self.tcp_port)) - sock.send(proxy.HttpParser.build_request( - b'GET', b'http://httpbin.org/get', b'HTTP/1.1', - headers={b'Host': b'httpbin.org'}) - ) - data = sock.recv(proxy.DEFAULT_BUFFER_SIZE) - self.assertEqual(data, proxy.CRLF.join( - [b'HTTP/1.1 200 OK', proxy.CRLF])) - break - except ConnectionRefusedError: - time.sleep(0.1) - finally: - self.tcp_server.stop() - self.tcp_thread.join() - - class TestChunkParser(unittest.TestCase): def setUp(self) -> None: @@ -770,61 +701,31 @@ class TestHttpParser(unittest.TestCase): self.assertTrue(k in dictionary) -# TODO: Replace MockTcpSocket with mock.MagicMock instances. -class MockTcpSocket(socket.socket): - - def __init__(self, buf: bytes = b'') -> None: - self.buffer = buf - self.received = b'' - self.closed = False - super().__init__() - - def recv(self, b: int = 8192, flags: Optional[int] = None) -> bytes: - data = self.buffer[:b] - self.buffer = self.buffer[b:] - return data - - def send(self, data: bytes, flags: Optional[int] = None) -> int: - self.received += data - return len(data) - - def queue(self, data: bytes) -> None: - self.buffer += data - - def close(self) -> None: - self.closed = True - super().close() - - def shutdown(self, _how: int) -> None: - pass - - -class HTTPRequestHandler(BaseHTTPRequestHandler): - - def do_GET(self) -> None: - self.send_response(200) - # TODO(abhinavsingh): Proxy should work just fine even without - # content-length header - self.send_header('content-length', '2') - self.end_headers() - self.wfile.write(b'OK') - - class TestHttpProtocolHandler(unittest.TestCase): http_server = None http_server_port = None http_server_thread = None config = None + class HTTPRequestHandler(BaseHTTPRequestHandler): + + def do_GET(self) -> None: + self.send_response(200) + # TODO(abhinavsingh): Proxy should work just fine even without + # content-length header + self.send_header('content-length', '2') + self.end_headers() + self.wfile.write(b'OK') + @classmethod def setUpClass(cls) -> None: cls.http_server_port = get_available_port() cls.http_server = HTTPServer( - ('127.0.0.1', cls.http_server_port), HTTPRequestHandler) + ('127.0.0.1', cls.http_server_port), TestHttpProtocolHandler.HTTPRequestHandler) cls.http_server_thread = Thread(target=cls.http_server.serve_forever) cls.http_server_thread.setDaemon(True) cls.http_server_thread.start() - cls.config = proxy.HttpProtocolConfig() + cls.config = proxy.ProtocolConfig() cls.config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') @@ -836,15 +737,12 @@ class TestHttpProtocolHandler(unittest.TestCase): if cls.http_server_thread: cls.http_server_thread.join() - def setUp(self) -> None: - self._conn = MockTcpSocket() + @mock.patch('socket.fromfd') + def setUp(self, mock_fromfd: mock.Mock) -> None: + self.fileno = 10 self._addr = ('127.0.0.1', 54382) - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=self.config) - - def tearDown(self) -> None: - self._conn.close() + self._conn = mock_fromfd.return_value + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=self.config) @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') @@ -852,13 +750,14 @@ class TestHttpProtocolHandler(unittest.TestCase): server = mock_server_connection.return_value server.connect.return_value = True mock_select.side_effect = [ - ([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] + ([self._conn], [], []), + ([self._conn], [], []), + ([], [server.connection], [])] # Send request line assert self.http_server_port is not None - self.proxy.client.conn.queue( # type: ignore - (b'GET http://localhost:%d HTTP/1.1' % - self.http_server_port) + proxy.CRLF) + self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' % + self.http_server_port) + proxy.CRLF self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -869,20 +768,20 @@ class TestHttpProtocolHandler(unittest.TestCase): # Send headers and blank line, thus completing HTTP request assert self.http_server_port is not None - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'User-Agent: proxy.py/%s' % proxy.version, b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', proxy.CRLF - ])) + ]) self.assert_data_queued(mock_server_connection, server) self.proxy.run_once() server.flush.assert_called_once() def assert_tunnel_response(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: self.proxy.run_once() - self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) # type: ignore + self.assertTrue(self.proxy.plugins['HttpProxyPlugin'].server is not None) # type: ignore self.assertEqual( self.proxy.client.buffer, proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) @@ -903,17 +802,21 @@ class TestHttpProtocolHandler(unittest.TestCase): server = mock_server_connection.return_value server.connect.return_value = True server.has_buffer.side_effect = [False, False, False, True] - mock_select.side_effect = [([self._conn], [], []), ([], [self._conn], []), - ([self._conn], [], []), ([], [server.conn], [])] + mock_select.side_effect = [ + ([self._conn], [], []), # client read ready + ([], [self._conn], []), # client write ready + ([self._conn], [], []), # client read ready + ([], [server.connection], []) # server write ready + ] assert self.http_server_port is not None - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % proxy.version, b'Proxy-Connection: Keep-Alive', proxy.CRLF - ])) + ]) self.assert_tunnel_response(mock_server_connection, server) # Dispatch tunnel established response to client @@ -927,105 +830,111 @@ class TestHttpProtocolHandler(unittest.TestCase): @mock.patch('select.select') def test_proxy_connection_failed(self, mock_select: mock.Mock) -> None: mock_select.return_value = ([self._conn], [], []) - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'GET http://unknown.domain HTTP/1.1', b'Host: unknown.domain', proxy.CRLF - ])) + ]) self.proxy.run_once() - self.assertEqual( - self.proxy.client.conn.received, # type: ignore - proxy.ProxyConnectionFailed.RESPONSE_PKT) + received = self._conn.send.call_args[0][0] + self.assertEqual(received, proxy.ProxyConnectionFailed.RESPONSE_PKT) + @mock.patch('socket.fromfd') @mock.patch('select.select') - def test_proxy_authentication_failed(self, mock_select: mock.Mock) -> None: + def test_proxy_authentication_failed(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None: + self._conn = mock_fromfd.return_value mock_select.return_value = ([self._conn], [], []) - config = proxy.HttpProtocolConfig( + config = proxy.ProtocolConfig( auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=config) - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config) + self._conn.recv.return_value = proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', proxy.CRLF - ])) + ]) self.proxy.run_once() self.assertEqual( - self.proxy.client.conn.received, # type: ignore + self._conn.send.call_args[0][0], proxy.ProxyAuthenticationFailed.RESPONSE_PKT) + @mock.patch('socket.fromfd') @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_get( - self, mock_server_connection: mock.Mock, mock_select: mock.Mock) -> None: + self, mock_server_connection: mock.Mock, + mock_select: mock.Mock, + mock_fromfd: mock.Mock) -> None: + self._conn = mock_fromfd.return_value mock_select.return_value = ([self._conn], [], []) server = mock_server_connection.return_value server.connect.return_value = True - client = proxy.TcpClientConnection(self._conn, self._addr) - config = proxy.HttpProtocolConfig( + config = proxy.ProtocolConfig( auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler(client, config=config) + self.proxy = proxy.ProtocolHandler(self.fileno, addr=self._addr, config=config) assert self.http_server_port is not None - self.proxy.client.conn.queue( # type: ignore - b'GET http://localhost:%d HTTP/1.1' % - self.http_server_port) + + self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port self.proxy.run_once() self.assertEqual( self.proxy.request.state, proxy.httpParserStates.INITIALIZED) - self.proxy.client.conn.queue(proxy.CRLF) # type: ignore + self._conn.recv.return_value = proxy.CRLF self.proxy.run_once() self.assertEqual( self.proxy.request.state, proxy.httpParserStates.LINE_RCVD) assert self.http_server_port is not None - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'User-Agent: proxy.py/%s' % proxy.version, b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', proxy.CRLF - ])) + ]) self.assert_data_queued(mock_server_connection, server) + @mock.patch('socket.fromfd') @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_tunnel( - self, mock_server_connection: mock.Mock, mock_select: mock.Mock) -> None: + self, mock_server_connection: mock.Mock, + mock_select: mock.Mock, + mock_fromfd: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True - mock_select.side_effect = [ - ([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] - config = proxy.HttpProtocolConfig( + self._conn = mock_fromfd.return_value + mock_select.side_effect = [ + ([self._conn], [], []), ([self._conn], [], []), ([], [server.connection], [])] + + config = proxy.ProtocolConfig( auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=config) + + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config) + assert self.http_server_port is not None - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % proxy.version, b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', proxy.CRLF - ])) + ]) self.assert_tunnel_response(mock_server_connection, server) self.proxy.client.flush() self.assert_data_queued_to_server(server) @@ -1033,73 +942,72 @@ class TestHttpProtocolHandler(unittest.TestCase): self.proxy.run_once() server.flush.assert_called_once() + @mock.patch('socket.fromfd') @mock.patch('select.select') - def test_pac_file_served_from_disk(self, mock_select: mock.Mock) -> None: + def test_pac_file_served_from_disk(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None: + pac_file = 'proxy.pac' + self._conn = mock_fromfd.return_value mock_select.return_value = [self._conn], [], [] - config = proxy.HttpProtocolConfig(pac_file='proxy.pac') - self.init_and_make_pac_file_request(config) + self.init_and_make_pac_file_request(pac_file) self.proxy.run_once() self.assertEqual( self.proxy.request.state, proxy.httpParserStates.COMPLETE) - with open('proxy.pac', 'rb') as pac_file: - self.assertEqual( - self._conn.received, - proxy.HttpParser.build_response( - 200, reason=b'OK', headers={ - b'Content-Type': b'application/x-ns-proxy-autoconfig', - b'Connection': b'close' - }, body=pac_file.read() - )) - - @mock.patch('select.select') - def test_pac_file_served_from_buffer(self, mock_select: mock.Mock) -> None: - pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' - mock_select.return_value = [self._conn], [], [] - config = proxy.HttpProtocolConfig(pac_file=proxy.text_(pac_file_content)) - self.init_and_make_pac_file_request(config) - self.proxy.run_once() - self.assertEqual( - self.proxy.request.state, - proxy.httpParserStates.COMPLETE) - self.assertEqual( - self._conn.received, - proxy.HttpParser.build_response( + with open('proxy.pac', 'rb') as f: + self._conn.send.called_once_with(proxy.HttpParser.build_response( 200, reason=b'OK', headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', b'Connection': b'close' - }, body=pac_file_content + }, body=f.read() )) + @mock.patch('socket.fromfd') @mock.patch('select.select') - def test_default_web_server_returns_404(self, mock_select: mock.Mock) -> None: + def test_pac_file_served_from_buffer(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None: + self._conn = mock_fromfd.return_value + pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' mock_select.return_value = [self._conn], [], [] - config = proxy.HttpProtocolConfig() + self.init_and_make_pac_file_request(proxy.text_(pac_file_content)) + self.proxy.run_once() + self.assertEqual( + self.proxy.request.state, + proxy.httpParserStates.COMPLETE) + self._conn.send.called_once_with(proxy.HttpParser.build_response( + 200, reason=b'OK', headers={ + b'Content-Type': b'application/x-ns-proxy-autoconfig', + b'Connection': b'close' + }, body=pac_file_content + )) + + @mock.patch('socket.fromfd') + @mock.patch('select.select') + def test_default_web_server_returns_404(self, mock_select: mock.Mock, mock_fromfd: mock.Mock) -> None: + self._conn = mock_fromfd.return_value + mock_select.return_value = [self._conn], [], [] + config = proxy.ProtocolConfig() config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=config) - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config) + self._conn.recv.return_value = proxy.CRLF.join([ b'GET /hello HTTP/1.1', proxy.CRLF, proxy.CRLF - ])) + ]) self.proxy.run_once() self.assertEqual( self.proxy.request.state, proxy.httpParserStates.COMPLETE) self.assertEqual( - self._conn.received, + self._conn.send.call_args[0][0], proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) - def test_on_client_connection_called_on_teardown(self) -> None: - config = proxy.HttpProtocolConfig() + @mock.patch('socket.fromfd') + def test_on_client_connection_called_on_teardown(self, mock_fromfd: mock.Mock) -> None: + config = proxy.ProtocolConfig() plugin = mock.MagicMock() - config.plugins = {b'HttpProtocolBasePlugin': [plugin]} - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=config) + config.plugins = {b'ProtocolHandlerPlugin': [plugin]} + self._conn = mock_fromfd.return_value + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config) plugin.assert_called() with mock.patch.object(self.proxy, 'run_once') as mock_run_once: mock_run_once.return_value = True @@ -1108,17 +1016,15 @@ class TestHttpProtocolHandler(unittest.TestCase): plugin.return_value.access_log.assert_called() plugin.return_value.on_client_connection_close.assert_called() - def init_and_make_pac_file_request(self, config: proxy.HttpProtocolConfig) -> None: + def init_and_make_pac_file_request(self, pac_file: str) -> None: + config = proxy.ProtocolConfig(pac_file=pac_file) config.plugins = proxy.load_plugins( b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,proxy.HttpWebServerPacFilePlugin') - self.proxy = proxy.HttpProtocolHandler( - proxy.TcpClientConnection( - self._conn, self._addr), config=config) - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self.proxy = proxy.ProtocolHandler(self.fileno, self._addr, config=config) + self._conn.recv.return_value = proxy.CRLF.join([ b'GET / HTTP/1.1', proxy.CRLF, - proxy.CRLF - ])) + ]) def assert_data_queued(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: self.proxy.run_once() @@ -1140,14 +1046,14 @@ class TestHttpProtocolHandler(unittest.TestCase): def assert_data_queued_to_server(self, server: mock.Mock) -> None: assert self.http_server_port is not None - self.assertEqual(self.proxy.client.buffer_size(), 0) + self.assertEqual(self._conn.send.call_args[0][0], proxy.HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - self.proxy.client.conn.queue(proxy.CRLF.join([ # type: ignore + self._conn.recv.return_value = proxy.CRLF.join([ b'GET / HTTP/1.1', b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % proxy.version, proxy.CRLF - ])) + ]) self.proxy.run_once() server.queue.assert_called_once_with(proxy.CRLF.join([ b'GET / HTTP/1.1', @@ -1160,33 +1066,26 @@ class TestHttpProtocolHandler(unittest.TestCase): class TestWorker(unittest.TestCase): - def setUp(self) -> None: + @mock.patch('proxy.ProtocolHandler') + def setUp(self, mock_protocol_handler: mock.Mock) -> None: self.pipe = multiprocessing.Pipe() - self.worker = proxy.Worker(self.pipe[1], proxy.HttpProtocolConfig()) + self.worker = proxy.Worker(self.pipe[1], mock_protocol_handler, config=proxy.ProtocolConfig()) + self.mock_protocol_handler = mock_protocol_handler - @mock.patch('proxy.HttpProtocolHandler') + @mock.patch('proxy.ProtocolHandler') def test_shutdown_op(self, mock_http_proxy: mock.Mock) -> None: self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None)) self.worker.run() self.assertFalse(mock_http_proxy.called) - @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') - @mock.patch('proxy.HttpProtocolHandler') - def test_spawns_http_proxy_threads( - self, - mock_http_proxy: mock.Mock, - mock_recv_handle: mock.Mock, - mock_fromfd: mock.Mock) -> None: + def test_spawns_http_proxy_threads(self, mock_recv_handle: mock.Mock) -> None: fileno = 10 mock_recv_handle.return_value = fileno self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None)) self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None)) self.worker.run() - self.assertTrue(mock_fromfd.called) - mock_fromfd.assert_called_with( - fileno, family=socket.AF_INET6, type=socket.SOCK_STREAM) - self.assertTrue(mock_http_proxy.called) + self.assertTrue(self.mock_protocol_handler.called) def test_handles_work_queue_recv_connection_refused(self) -> None: with mock.patch.object(self.worker.work_queue, 'recv') as mock_recv: @@ -1194,61 +1093,6 @@ class TestWorker(unittest.TestCase): self.assertFalse(self.worker.run_once()) # doesn't teardown -class TestWorkerSSLWrap(unittest.TestCase): - - CERTFILE = 'my-https-cert.pem' - KEYFILE = 'my-https-key.pem' - - def setUp(self) -> None: - self.pipe = multiprocessing.Pipe() - self.worker = proxy.Worker(self.pipe[1], - proxy.HttpProtocolConfig(certfile=self.CERTFILE, keyfile=self.KEYFILE)) - - @mock.patch('ssl.create_default_context') - @mock.patch('socket.fromfd') - @mock.patch('proxy.recv_handle') - @mock.patch('proxy.HttpProtocolHandler') - def test_worker_performs_ssl_wrap( - self, - mock_http_proxy: mock.Mock, - mock_recv_handle: mock.Mock, - mock_fromfd: mock.Mock, - mock_context: mock.Mock) -> None: - fileno = 10 - mock_recv_handle.return_value = fileno - self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None)) - self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None)) - self.worker.run() - mock_fromfd.assert_called() - mock_context.assert_called_with(ssl.Purpose.CLIENT_AUTH) - mock_context.return_value.load_cert_chain.assert_called_with(certfile=self.CERTFILE, keyfile=self.KEYFILE) - mock_http_proxy.assert_called() - - @mock.patch('ssl.create_default_context') - @mock.patch('socket.fromfd') - @mock.patch('proxy.recv_handle') - @mock.patch('proxy.HttpProtocolHandler') - def test_client_conn_closed_on_os_error( - self, - mock_http_proxy: mock.Mock, - mock_recv_handle: mock.Mock, - mock_fromfd: mock.Mock, - mock_context: mock.Mock) -> None: - fileno = 10 - mock_recv_handle.return_value = fileno - mock_context.return_value.wrap_socket.side_effect = OSError() - self.pipe[0].send((proxy.workerOperations.HTTP_PROTOCOL, None)) - self.pipe[0].send((proxy.workerOperations.SHUTDOWN, None)) - with mock.patch('proxy.logger') as mock_logger: - self.worker.run() - mock_logger.exception.assert_called() - self.assertEqual(mock_logger.exception.call_args[0][0], - 'OSError encountered while ssl wrapping the client socket') - mock_fromfd.assert_called() - mock_fromfd.return_value.close.assert_called() - mock_http_proxy.assert_not_called() - - class TestHttpRequestRejected(unittest.TestCase): def setUp(self) -> None: @@ -1282,7 +1126,7 @@ class TestHttpRequestRejected(unittest.TestCase): class TestMain(unittest.TestCase): @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') @mock.patch('proxy.logging.basicConfig') def test_log_file_setup( self, @@ -1305,7 +1149,7 @@ class TestMain(unittest.TestCase): @mock.patch('os.path.exists') @mock.patch('builtins.open') @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') @unittest.skipIf( True, # type: ignore 'This test passes while development on Intellij but fails via CLI :(') @@ -1327,9 +1171,9 @@ class TestMain(unittest.TestCase): mock_exists.assert_called_with(pid_file) mock_remove.assert_called_with(pid_file) - @mock.patch('proxy.HttpProtocolConfig') + @mock.patch('proxy.ProtocolConfig') @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') def test_main( self, mock_multicore_dispatcher: mock.Mock, @@ -1337,8 +1181,14 @@ class TestMain(unittest.TestCase): mock_config: mock.Mock) -> None: proxy.main(['--basic-auth', 'user:pass']) self.assertTrue(mock_set_open_file_limit.called) + config = mock_config.return_value mock_multicore_dispatcher.assert_called_with( - config=mock_config.return_value) + hostname=config.hostname, + port=config.port, + backlog=config.backlog, + num_workers=config.num_workers, + work_klass=proxy.ProtocolHandler, + config=config) mock_config.assert_called_with( auth_code=b'Basic dXNlcjpwYXNz', client_recvbuf_size=proxy.DEFAULT_CLIENT_RECVBUF_SIZE, @@ -1359,9 +1209,9 @@ class TestMain(unittest.TestCase): ) @mock.patch('builtins.print') - @mock.patch('proxy.HttpProtocolConfig') + @mock.patch('proxy.ProtocolConfig') @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') def test_main_version( self, mock_multicore_dispatcher: mock.Mock, @@ -1376,9 +1226,9 @@ class TestMain(unittest.TestCase): mock_config.assert_not_called() @mock.patch('builtins.print') - @mock.patch('proxy.HttpProtocolConfig') + @mock.patch('proxy.ProtocolConfig') @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') @mock.patch('proxy.is_py3') def test_main_py3_runs( self, @@ -1396,9 +1246,9 @@ class TestMain(unittest.TestCase): mock_config.assert_called() @mock.patch('builtins.print') - @mock.patch('proxy.HttpProtocolConfig') + @mock.patch('proxy.ProtocolConfig') @mock.patch('proxy.set_open_file_limit') - @mock.patch('proxy.MultiCoreRequestDispatcher') + @mock.patch('proxy.WorkerPool') @mock.patch('proxy.is_py3') @unittest.skipIf( True, # type: ignore