diff --git a/plugin_examples.py b/plugin_examples.py index daff92b1..577af852 100644 --- a/plugin_examples.py +++ b/plugin_examples.py @@ -1,30 +1,38 @@ from urllib import parse as urlparse -from proxy import HttpProxyPlugin, ProxyRejectRequest +from proxy import HttpProtocolBasePlugin, ProxyRequestRejected -class RedirectToCustomServerPlugin(HttpProxyPlugin): +class RedirectToCustomServerPlugin(HttpProtocolBasePlugin): """Modifies client request to redirect all incoming requests to a fixed server address.""" def __init__(self): super(RedirectToCustomServerPlugin, self).__init__() - def handle_request(self, request): - if request.method != 'CONNECT': - request.url = urlparse.urlsplit(b'http://localhost:9999') - return request + def on_request_complete(self): + if self.request.method != 'CONNECT': + self.request.url = urlparse.urlsplit(b'http://localhost:9999') -class FilterByTargetDomainPlugin(HttpProxyPlugin): +class FilterByTargetDomainPlugin(HttpProtocolBasePlugin): """Only accepts specific requests dropping all other requests.""" def __init__(self): super(FilterByTargetDomainPlugin, self).__init__() self.allowed_domains = [b'google.com', b'www.google.com', b'google.com:443', b'www.google.com:443'] - def handle_request(self, request): + def on_request_complete(self): # TODO: Refactor internals to cleanup mess below, due to how urlparse works, hostname/path attributes # are not consistent between CONNECT and non-CONNECT requests. - if (request.method != b'CONNECT' and request.url.hostname not in self.allowed_domains) or \ - (request.method == b'CONNECT' and request.url.path not in self.allowed_domains): - raise ProxyRejectRequest(status_code=418, body='I\'m a tea pot') - return request + if (self.request.method != b'CONNECT' and self.request.url.hostname not in self.allowed_domains) or \ + (self.request.method == b'CONNECT' and self.request.url.path not in self.allowed_domains): + raise ProxyRequestRejected(status_code=418, body='I\'m a tea pot') + + +class SaveHttpResponses(HttpProtocolBasePlugin): + """Saves Http Responses locally on disk.""" + + def __init__(self): + super(SaveHttpResponses, self).__init__() + + def handle_response_chunk(self, chunk): + return chunk diff --git a/proxy.py b/proxy.py index b84fd81a..4fb8910c 100755 --- a/proxy.py +++ b/proxy.py @@ -22,6 +22,7 @@ import socket import sys import threading from collections import namedtuple +from typing import Dict, List if os.name != 'nt': import resource @@ -53,7 +54,7 @@ else: # pragma: no cover # Defaults DEFAULT_BACKLOG = 100 DEFAULT_BASIC_AUTH = None -DEFAULT_BUFFER_SIZE = 8192 +DEFAULT_BUFFER_SIZE = 1024 * 1024 DEFAULT_CLIENT_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_SERVER_RECVBUF_SIZE = DEFAULT_BUFFER_SIZE DEFAULT_IPV4_HOSTNAME = '127.0.0.1' @@ -63,10 +64,12 @@ DEFAULT_IPV4 = False DEFAULT_LOG_LEVEL = 'INFO' DEFAULT_OPEN_FILE_LIMIT = 1024 DEFAULT_PAC_FILE = None +DEFAULT_PAC_FILE_URL_PATH = '/' DEFAULT_NUM_WORKERS = 0 -DEFAULT_PLUGINS = [] +DEFAULT_PLUGINS = {} DEFAULT_LOG_FORMAT = '%(asctime)s - %(levelname)s - pid:%(process)d - %(funcName)s:%(lineno)d - %(message)s' +# Set to True if under test UNDER_TEST = False @@ -91,7 +94,7 @@ def bytes_(s, encoding='utf-8', errors='strict'): # pragma: no cover version = bytes_(__version__) -CRLF, COLON, SP = b'\r\n', b':', b' ' +CRLF, COLON, WHITESPACE = b'\r\n', b':', b' ' PROXY_AGENT_HEADER = b'Proxy-agent: proxy.py v' + version PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT = CRLF.join([ @@ -123,31 +126,31 @@ class ChunkParser(object): self.chunk = b'' # Partial chunk received self.size = None # Expected size of next following chunk - def parse(self, data): - more = True if len(data) > 0 else False + def parse(self, raw): + more = True if len(raw) > 0 else False while more: - more, data = self.process(data) + more, raw = self.process(raw) - def process(self, data): + def process(self, raw): if self.state == ChunkParser.states.WAITING_FOR_SIZE: # Consume prior chunk in buffer # in case chunk size without CRLF was received - data = self.chunk + data + raw = self.chunk + raw self.chunk = b'' # Extract following chunk data size - line, data = HttpParser.split(data) + line, raw = HttpParser.split(raw) if not line: # CRLF not received - self.chunk = data - data = b'' + self.chunk = raw + raw = b'' else: self.size = int(line, 16) self.state = ChunkParser.states.WAITING_FOR_DATA elif self.state == ChunkParser.states.WAITING_FOR_DATA: remaining = self.size - len(self.chunk) - self.chunk += data[:remaining] - data = data[remaining:] + self.chunk += raw[:remaining] + raw = raw[remaining:] if len(self.chunk) == self.size: - data = data[len(CRLF):] + raw = raw[len(CRLF):] self.body += self.chunk if self.size == 0: self.state = ChunkParser.states.COMPLETE @@ -155,7 +158,7 @@ class ChunkParser(object): self.state = ChunkParser.states.WAITING_FOR_SIZE self.chunk = b'' self.size = None - return len(data) > 0, data + return len(raw) > 0, raw class HttpParser(object): @@ -179,7 +182,7 @@ class HttpParser(object): self.type = parser_type self.state = HttpParser.states.INITIALIZED - self.raw = b'' + self.bytes = b'' self.buffer = b'' self.headers = dict() @@ -193,22 +196,38 @@ class HttpParser(object): self.chunk_parser = None + # This cleans up developer APIs as Python urlparse.urlsplit behaves differently + # for incoming proxy request and incoming web request. Web request is the one + # which is broken. + self.host = None + self.port = None + + def set_host_port(self): + if self.type == HttpParser.types.REQUEST_PARSER: + if self.method == b'CONNECT': + self.host, self.port = self.url.path.split(COLON) + elif self.url: + self.host, self.port = self.url.hostname, self.url.port \ + if self.url.port else 80 + else: + raise Exception('Invalid request\n%s' % self.bytes) + def is_chunked_encoded_response(self): return self.type == HttpParser.types.RESPONSE_PARSER and \ b'transfer-encoding' in self.headers and \ self.headers[b'transfer-encoding'][1].lower() == b'chunked' - def parse(self, data): - self.raw += data - data = self.buffer + data + def parse(self, raw): + self.bytes += raw + raw = self.buffer + raw self.buffer = b'' - more = True if len(data) > 0 else False + more = True if len(raw) > 0 else False while more: - more, data = self.process(data) - self.buffer = data + more, raw = self.process(raw) + self.buffer = raw - def process(self, data): + def process(self, raw): if self.state in (HttpParser.states.HEADERS_COMPLETE, HttpParser.states.RCVING_BODY, HttpParser.states.COMPLETE) and \ @@ -218,22 +237,22 @@ class HttpParser(object): if b'content-length' in self.headers: self.state = HttpParser.states.RCVING_BODY - self.body += data + self.body += raw if len(self.body) >= int(self.headers[b'content-length'][1]): self.state = HttpParser.states.COMPLETE elif self.is_chunked_encoded_response(): if not self.chunk_parser: self.chunk_parser = ChunkParser() - self.chunk_parser.parse(data) + self.chunk_parser.parse(raw) if self.chunk_parser.state == ChunkParser.states.COMPLETE: self.body = self.chunk_parser.body self.state = HttpParser.states.COMPLETE return False, b'' - line, data = HttpParser.split(data) + line, raw = HttpParser.split(raw) if line is False: - return line, data + return line, raw if self.state == HttpParser.states.INITIALIZED: self.process_line(line) @@ -245,7 +264,7 @@ class HttpParser(object): if self.state == HttpParser.states.LINE_RCVD and \ self.type == HttpParser.types.REQUEST_PARSER and \ self.method == b'CONNECT' and \ - data == CRLF: + raw == CRLF: self.state = HttpParser.states.COMPLETE # When raw request has ended with \r\n\r\n and no more http headers are expected @@ -254,7 +273,7 @@ class HttpParser(object): elif self.state == HttpParser.states.HEADERS_COMPLETE and \ self.type == HttpParser.types.REQUEST_PARSER and \ self.method != b'POST' and \ - self.raw.endswith(CRLF * 2): + self.bytes.endswith(CRLF * 2): self.state = HttpParser.states.COMPLETE elif self.state == HttpParser.states.HEADERS_COMPLETE and \ self.type == HttpParser.types.REQUEST_PARSER and \ @@ -262,13 +281,13 @@ class HttpParser(object): (b'content-length' not in self.headers or (b'content-length' in self.headers and int(self.headers[b'content-length'][1]) == 0)) and \ - self.raw.endswith(CRLF * 2): + self.bytes.endswith(CRLF * 2): self.state = HttpParser.states.COMPLETE - return len(data) > 0, data + return len(raw) > 0, raw - def process_line(self, data): - line = data.split(SP) + def process_line(self, raw): + line = raw.split(WHITESPACE) if self.type == HttpParser.types.REQUEST_PARSER: self.method = line[0].upper() self.url = urlparse.urlsplit(line[1]) @@ -277,17 +296,18 @@ class HttpParser(object): self.version = line[0] self.code = line[1] self.reason = b' '.join(line[2:]) + self.set_host_port() self.state = HttpParser.states.LINE_RCVD - def process_header(self, data): - if len(data) == 0: + def process_header(self, raw): + if len(raw) == 0: if self.state == HttpParser.states.RCVING_HEADERS: self.state = HttpParser.states.HEADERS_COMPLETE elif self.state == HttpParser.states.LINE_RCVD: self.state = HttpParser.states.RCVING_HEADERS else: self.state = HttpParser.states.RCVING_HEADERS - parts = data.split(COLON) + parts = raw.split(COLON) key = parts[0].strip() value = COLON.join(parts[1:]).strip() self.headers[key.lower()] = (key, value) @@ -331,13 +351,23 @@ class HttpParser(object): return k + b': ' + v @staticmethod - def split(data): - pos = data.find(CRLF) + def split(raw): + pos = raw.find(CRLF) if pos == -1: - return False, data - line = data[:pos] - data = data[pos + len(CRLF):] - return line, data + return False, raw + line = raw[:pos] + raw = raw[pos + len(CRLF):] + return line, raw + + ################################################################################### + # HttpParser was originally written to parse the incoming raw Http requests. + # Since request / response objects passed to HttpProtocolBasePlugin methods + # are also HttpParser objects, methods below were added to simplify developer API. + #################################################################################### + + def has_upstream_server(self): + """Host field SHOULD be None for incoming local WebServer requests.""" + return True if self.host is not None else False class TcpConnection(object): @@ -488,69 +518,224 @@ class ProxyRequestRejected(ProxyError): return CRLF.join(pkt) if len(pkt) > 0 else None -class HttpProxyPlugin(object): - """Base HttpProxy Plugin class.""" - - def __init__(self): - pass - - def handle_request(self, request): - """Handle client request (HttpParser). - - Return optionally modified client request (HttpParser) object. - """ - return request - - def handle_response(self, data): - """Handle data chunks as received from the server.""" - return data - - -class HttpProxyConfig(object): - """Holds various configuration values applicable to HttpProxy. +class HttpProtocolConfig(object): + """Holds various configuration values applicable to HttpProtocolHandler. This config class helps us avoid passing around bunch of key/value pairs across methods. """ def __init__(self, auth_code=DEFAULT_BASIC_AUTH, server_recvbuf_size=DEFAULT_SERVER_RECVBUF_SIZE, - client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, plugins=None): + client_recvbuf_size=DEFAULT_CLIENT_RECVBUF_SIZE, pac_file=DEFAULT_PAC_FILE, + pac_file_url_path=DEFAULT_PAC_FILE_URL_PATH, plugins=None): self.auth_code = auth_code self.server_recvbuf_size = server_recvbuf_size self.client_recvbuf_size = client_recvbuf_size self.pac_file = pac_file + self.pac_file_url_path = pac_file_url_path if plugins is None: plugins = DEFAULT_PLUGINS - self.plugins = plugins + self.plugins: Dict = plugins -class HttpProxy(threading.Thread): - """HTTP proxy implementation. +class HttpProtocolBasePlugin(object): + """Base HttpProtocolHandler Plugin class. - Accepts `Client` connection object and act as a proxy between client and server. - """ - - def __init__(self, client, config=None): - super(HttpProxy, self).__init__() - - self.start_time = self.now() - self.last_activity = self.start_time + Implement various lifecycle event methods to customize behavior.""" + def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): + self.config = config self.client = client - self.server = None - self.config = config if config else HttpProxyConfig() + self.request = request - self.request = HttpParser(HttpParser.types.REQUEST_PARSER) + def name(self) -> str: + """A unique name for your plugin. + + Defaults to name of the class. This helps plugin developers to directly + access a specific plugin by its name.""" + return self.__class__.__name__ + + def get_descriptors(self): + return [], [], [] + + def flush_to_descriptors(self, w): + pass + + def read_from_descriptors(self, r): + pass + + def on_client_data(self, raw: bytes): + return raw + + def on_request_complete(self): + """Called right after client request parser has reached COMPLETE state.""" + pass + + def handle_response_chunk(self, chunk: bytes): + """Handle data chunks as received from the server. + + Return optionally modified chunk to return back to client.""" + return chunk + + def access_log(self): + pass + + def on_client_connection_close(self): + pass + + +class HttpProxyPlugin(HttpProtocolBasePlugin): + """HttpProtocolHandler plugin which implements HttpProxy specifications.""" + + def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): + super(HttpProxyPlugin, self).__init__(config, client, request) + self.server = None self.response = HttpParser(HttpParser.types.RESPONSE_PARSER) - @staticmethod - def now(): - return datetime.datetime.utcnow() + def get_descriptors(self): + if not self.request.has_upstream_server(): + return [], [], [] - def connection_inactive_for(self): - return (self.now() - self.last_activity).seconds + r, w, x = [], [], [] + if self.server and not self.server.closed: + r.append(self.server.conn) + if self.server and not self.server.closed and self.server.has_buffer(): + w.append(self.server.conn) + return r, w, x - def is_connection_inactive(self): - return self.connection_inactive_for() > 30 + def flush_to_descriptors(self, w): + if not self.request.has_upstream_server(): + return + + if self.server and not self.server.closed and self.server.conn in w: + logger.debug('Server is ready for writes, flushing server buffer') + self.server.flush() + + def read_from_descriptors(self, r): + if not self.request.has_upstream_server(): + return + + if self.server and not self.server.closed and self.server.conn in r: + logger.debug('Server is ready for reads, reading') + raw = self.server.recv(self.config.server_recvbuf_size) + # self.last_activity = HttpProtocolHandler.now() + if not raw: + logger.debug('Server closed connection, tearing down...') + return True + # parse incoming response packet + # only for non-https requests + if not self.request.method == b'CONNECT': + self.response.parse(raw) + else: + # Only purpose of increasing memory footprint is to print response length in access log + # Not worth it? Optimize to only persist lengths? + self.response.bytes += raw + # queue raw data for client + self.client.queue(raw) + + def on_client_connection_close(self): + if not self.request.has_upstream_server(): + return + + if self.server: + logger.debug( + 'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size()) + if not self.server.closed: + self.server.close() + + def on_client_data(self, raw): + if not self.request.has_upstream_server(): + return raw + + if self.server and not self.server.closed: + self.server.queue(raw) + return None + else: + return raw + + def on_request_complete(self): + if not self.request.has_upstream_server(): + return + + self.authenticate(self.request.headers) + self.connect_upstream(self.request.host, self.request.port) + # for http connect methods (https requests) + # queue appropriate response for client + # notifying about established connection + if self.request.method == b'CONNECT': + self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) + # for general http requests, re-build request packet + # and queue for the server with appropriate headers + else: + self.server.queue(self.request.build( + del_headers=[b'proxy-authorization', b'proxy-connection', b'connection', + b'keep-alive'], + add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')] + )) + + def access_log(self): + if not self.request.has_upstream_server(): + return + + host, port = self.server.addr if self.server else (None, None) + if self.request.method == b'CONNECT': + logger.info( + '%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1], + text_(self.request.method), text_(host), + text_(port), len(self.response.bytes))) + elif self.request.method: + logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % ( + self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port, + text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason), + len(self.response.bytes))) + + def authenticate(self, headers): + if self.config.auth_code: + if b'proxy-authorization' not in headers or \ + headers[b'proxy-authorization'][1] != self.config.auth_code: + raise ProxyAuthenticationFailed() + + def connect_upstream(self, host, port): + self.server = TcpServerConnection(host, port) + try: + logger.debug('Connecting to upstream %s:%s' % (host, port)) + self.server.connect() + logger.debug('Connected to upstream %s:%s' % (host, port)) + except Exception as e: # TimeoutError, socket.gaierror + self.server.closed = True + raise ProxyConnectionFailed(host, port, repr(e)) + + +class HttpWebServerPlugin(HttpProtocolBasePlugin): + """HttpProtocolHandler plugin which handles incoming requests to local webserver.""" + + def __init__(self, config: HttpProtocolConfig, client: TcpClientConnection, request: HttpParser): + super(HttpWebServerPlugin, self).__init__(config, client, request) + + def on_request_complete(self): + if self.request.has_upstream_server(): + return + + if self.config.pac_file and \ + self.request.url.path == self.config.pac_file_url_path: + self.serve_pac_file() + else: + self.client.queue(CRLF.join([ + b'HTTP/1.1 200 OK', + b'Server: proxy.py v%s' % version, + b'Content-Length: 42', + b'Connection: Close', + CRLF, + b'https://github.com/abhinavsingh/proxy.py' + ])) + self.client.flush() + + return True + + def access_log(self): + if self.request.has_upstream_server(): + return + logger.info('%s:%s - %s %s' % (self.client.addr[0], self.client.addr[1], + text_(self.request.method), text_(self.request.build_url()))) def serve_pac_file(self): self.client.queue(PAC_FILE_RESPONSE_PREFIX) @@ -563,129 +748,103 @@ class HttpProxy(threading.Thread): self.client.queue(self.config.pac_file) self.client.flush() - def access_log(self): - host, port = self.server.addr if self.server else (None, None) - if self.request.method == b'CONNECT': - logger.info( - '%s:%s - %s %s:%s - %s bytes' % (self.client.addr[0], self.client.addr[1], - text_(self.request.method), text_(host), - text_(port), len(self.response.raw))) - elif self.request.method: - logger.info('%s:%s - %s %s:%s%s - %s %s - %s bytes' % ( - self.client.addr[0], self.client.addr[1], text_(self.request.method), text_(host), port, - text_(self.request.build_url()), text_(self.response.code), text_(self.response.reason), - len(self.response.raw))) + +class HttpProtocolHandler(threading.Thread): + """HTTP, HTTPS, HTTP2, WebSockets protocol handler. + + Accepts `Client` connection object, manages plugin invocations. + """ + + def __init__(self, client, config=None): + super(HttpProtocolHandler, self).__init__() + + self.start_time = self.now() + self.last_activity = self.start_time + + self.client = client + self.config = config if config else HttpProtocolConfig() + self.request = HttpParser(HttpParser.types.REQUEST_PARSER) + + self.plugins: Dict[str, HttpProtocolBasePlugin] = {} + for klass in self.config.plugins: + instance = klass(self.config, self.client, self.request) + self.plugins[instance.name()] = instance + + @staticmethod + def now(): + return datetime.datetime.utcnow() + + def connection_inactive_for(self): + return (self.now() - self.last_activity).seconds + + def is_connection_inactive(self): + return self.connection_inactive_for() > 30 def run_once(self): """Returns True if proxy must teardown.""" # Prepare list of descriptors - rlist, wlist, xlist = [self.client.conn], [], [] + read_desc, write_desc, err_desc = [self.client.conn], [], [] if self.client.has_buffer(): - wlist.append(self.client.conn) - if self.server and not self.server.closed: - rlist.append(self.server.conn) - if self.server and not self.server.closed and self.server.has_buffer(): - wlist.append(self.server.conn) + write_desc.append(self.client.conn) - r, w, x = select.select(rlist, wlist, xlist, 1) + for plugin in self.plugins.values(): + plugin_read_desc, plugin_write_desc, plugin_err_desc = plugin.get_descriptors() + read_desc += plugin_read_desc + write_desc += plugin_write_desc + err_desc += plugin_err_desc - # Flush buffer from ready to write sockets - if self.client.conn in w: + readable, writable, errored = select.select(read_desc, write_desc, err_desc, 1) + + # Flush buffer for ready to write sockets + if self.client.conn in writable: logger.debug('Client is ready for writes, flushing client buffer') - self.client.flush() - if self.server and not self.server.closed and self.server.conn in w: - logger.debug('Server is ready for writes, flushing server buffer') - self.server.flush() + try: + self.client.flush() + except BrokenPipeError: + logging.error('BrokenPipeError when flushing buffer for client') + return True + + for plugin in self.plugins.values(): + plugin.flush_to_descriptors(writable) # Read from ready to read sockets - if self.client.conn in r: + if self.client.conn in readable: logger.debug('Client is ready for reads, reading') - data = self.client.recv(self.config.client_recvbuf_size) + client_data = self.client.recv(self.config.client_recvbuf_size) self.last_activity = self.now() - if not data: + if not client_data: logger.debug('Client closed connection, tearing down...') return True - try: - # Once we have connection to the server - # we don't parse the http request packets - # any further, instead just pipe incoming - # data from client to server - if self.server and not self.server.closed: - self.server.queue(data) - else: + + plugin_index = 0 + plugins = list(self.plugins.values()) + while plugin_index < len(plugins) and client_data: + client_data = plugins[plugin_index].on_client_data(client_data) + plugin_index += 1 + + if client_data: + try: # Parse http request - self.request.parse(data) + self.request.parse(client_data) if self.request.state == HttpParser.states.COMPLETE: - logger.debug('Request parser is in state complete') + for plugin in self.plugins.values(): + # TODO: Cleanup by not returning True for teardown cases + plugin_response = plugin.on_request_complete() + if type(plugin_response) is bool: + return True + except ProxyError as e: # ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected + # logger.exception(e) + response = e.response(self.request) + if response: + self.client.queue(response) + # But is client also ready for writes? + self.client.flush() + raise e - if self.config.auth_code: - if b'proxy-authorization' not in self.request.headers or \ - self.request.headers[b'proxy-authorization'][1] != self.config.auth_code: - raise ProxyAuthenticationFailed() - - # Invoke HttpProxyPlugin.handle_request - for plugin in self.config.plugins: - self.request = plugin.handle_request(self.request) - - if self.request.method == b'CONNECT': - host, port = self.request.url.path.split(COLON) - elif self.request.url: - host, port = self.request.url.hostname, self.request.url.port \ - if self.request.url.port else 80 - else: - raise Exception('Invalid request\n%s' % self.request.raw) - - if host is None and self.config.pac_file: - self.serve_pac_file() - return True - - self.server = TcpServerConnection(host, port) - try: - logger.debug('Connecting to server %s:%s' % (host, port)) - self.server.connect() - logger.debug('Connected to server %s:%s' % (host, port)) - except Exception as e: # TimeoutError, socket.gaierror - self.server.closed = True - raise ProxyConnectionFailed(host, port, repr(e)) - - # for http connect methods (https requests) - # queue appropriate response for client - # notifying about established connection - if self.request.method == b'CONNECT': - self.client.queue(PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) - # for usual http requests, re-build request packet - # and queue for the server with appropriate headers - else: - self.server.queue(self.request.build( - del_headers=[b'proxy-authorization', b'proxy-connection', b'connection', - b'keep-alive'], - add_headers=[(b'Via', b'1.1 proxy.py v%s' % version), (b'Connection', b'Close')] - )) - except (ProxyAuthenticationFailed, ProxyConnectionFailed, ProxyRequestRejected) as e: - # logger.exception(e) - response = e.response(self.request) - if response: - self.client.queue(response) - # But is client also ready for writes? - self.client.flush() - raise e - if self.server and not self.server.closed and self.server.conn in r: - logger.debug('Server is ready for reads, reading') - data = self.server.recv(self.config.server_recvbuf_size) - self.last_activity = self.now() - if not data: - logger.debug('Server closed connection, tearing down...') + for plugin in self.plugins.values(): + teardown = plugin.read_from_descriptors(readable) + if teardown: return True - # parse incoming response packet - # only for non-https requests - if not self.request.method == b'CONNECT': - self.response.parse(data) - else: - # Only purpose of increasing memory footprint is to print response length in access log - # Not worth it? Optimize to only persist lengths? - self.response.raw += data - # queue data for client - self.client.queue(data) # Teardown if client buffer is empty and connection is inactive if self.client.buffer_size() == 0: @@ -706,16 +865,17 @@ class HttpProxy(threading.Thread): except Exception as e: logger.exception('Exception while handling connection %r with reason %r' % (self.client.conn, e)) finally: + for plugin in self.plugins.values(): + plugin.access_log() + self.client.close() - logger.debug( - 'Closed client connection with pending client buffer size %d bytes' % self.client.buffer_size()) - if self.server: - logger.debug( - 'Closed server connection with pending server buffer size %d bytes' % self.server.buffer_size()) - if not self.server.closed: - self.server.close() - self.access_log() - logger.debug('Closed proxy for connection %r at address %r' % (self.client.conn, self.client.addr)) + logger.debug('Closed client connection with pending ' + 'client buffer size %d bytes' % self.client.buffer_size()) + for plugin in self.plugins.values(): + plugin.on_client_connection_close() + + logger.debug('Closed proxy for connection %r ' + 'at address %r' % (self.client.conn, self.client.addr)) class TcpServer(object): @@ -797,7 +957,7 @@ class MultiCoreRequestDispatcher(TcpServer): self.workers.append(worker) def handle(self, client): - self.worker_queue.put((Worker.operations.HTTP_PROXY, client)) + self.worker_queue.put((Worker.operations.HTTP_PROTOCOL, client)) def shutdown(self): logger.info('Shutting down %d workers' % self.num_workers) @@ -815,7 +975,7 @@ class Worker(multiprocessing.Process): """ operations = namedtuple('WorkerOperations', ( - 'HTTP_PROXY', + 'HTTP_PROTOCOL', 'SHUTDOWN', ))(1, 2) @@ -828,8 +988,8 @@ class Worker(multiprocessing.Process): while True: try: op, payload = self.work_queue.get(True, 1) - if op == Worker.operations.HTTP_PROXY: - proxy = HttpProxy(payload, config=self.config) + if op == Worker.operations.HTTP_PROTOCOL: + proxy = HttpProtocolHandler(payload, config=self.config) proxy.setDaemon(True) proxy.start() elif op == Worker.operations.SHUTDOWN: @@ -852,6 +1012,22 @@ def set_open_file_limit(soft_limit): logger.debug('Open file descriptor soft limit set to %d' % soft_limit) +def load_plugins(plugins) -> List: + p = [] + plugins = plugins.split(',') + for plugin in plugins: + plugin = plugin.strip() + if plugin == '': + continue + logging.debug('Loading plugin %s', plugin) + module_name, klass_name = plugin.rsplit('.', 1) + module = importlib.import_module(module_name) + klass = getattr(module, klass_name) + logging.info('%s initialized' % klass.__name__) + p.append(klass) + return p + + def init_parser(): parser = argparse.ArgumentParser( description='proxy.py v%s' % __version__, @@ -882,12 +1058,14 @@ def init_parser(): parser.add_argument('--open-file-limit', type=int, default=DEFAULT_OPEN_FILE_LIMIT, help='Default: 1024. Maximum number of files (TCP connections) ' 'that proxy.py can open concurrently.') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, - help='Default: 8899. Server port.') parser.add_argument('--pac-file', type=str, default=DEFAULT_PAC_FILE, help='A file (Proxy Auto Configuration) or string to serve when ' 'the server receives a direct file request.') - parser.add_argument('--plugins', type=str, default=None, help='Comma separated plugins') + parser.add_argument('--pac-file-url-path', type=str, default=DEFAULT_PAC_FILE_URL_PATH, + help='Web server path to serve the PAC file.') + parser.add_argument('--plugins', type=str, default='', help='Comma separated plugins') + parser.add_argument('--port', type=int, default=DEFAULT_PORT, + help='Default: 8899. Server port.') parser.add_argument('--server-recvbuf-size', type=int, default=DEFAULT_SERVER_RECVBUF_SIZE, help='Default: 8 KB. Maximum amount of data received from the ' 'server in a single recv() operation. Bump this ' @@ -898,18 +1076,6 @@ def init_parser(): return parser -def load_plugins(plugins): - p = [] - plugins = plugins.split(',') - for plugin in plugins: - module_name, klass_name = plugin.rsplit('.', 1) - module = importlib.import_module(module_name) - klass = getattr(module, klass_name) - logging.info('%s initialized' % klass.__name__) - p.append(klass()) - return p - - def main(): if not PY3 and not UNDER_TEST: print( @@ -924,7 +1090,6 @@ def main(): parser = init_parser() args = parser.parse_args(sys.argv[1:]) - logging.basicConfig(level=getattr(logging, { 'D': 'DEBUG', @@ -934,7 +1099,6 @@ def main(): 'C': 'CRITICAL' }[args.log_level.upper()[0]]), format=args.log_format) - plugins = load_plugins(args.plugins) if args.plugins else DEFAULT_PLUGINS try: set_open_file_limit(args.open_file_limit) @@ -943,16 +1107,18 @@ def main(): if args.basic_auth: auth_code = b'Basic %s' % base64.b64encode(bytes_(args.basic_auth)) + config = HttpProtocolConfig(auth_code=auth_code, + server_recvbuf_size=args.server_recvbuf_size, + client_recvbuf_size=args.client_recvbuf_size, + pac_file=args.pac_file, + pac_file_url_path=args.pac_file_url_path) + config.plugins = load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin,%s' % args.plugins) server = MultiCoreRequestDispatcher(hostname=args.hostname, port=args.port, backlog=args.backlog, ipv4=args.ipv4, num_workers=args.num_workers, - config=HttpProxyConfig(auth_code=auth_code, - server_recvbuf_size=args.server_recvbuf_size, - client_recvbuf_size=args.client_recvbuf_size, - pac_file=args.pac_file, - plugins=plugins)) + config=config) server.run() except KeyboardInterrupt: pass diff --git a/requirements-tests-py2.7.text b/requirements-tests-py2.7.text deleted file mode 100644 index 81ba9009..00000000 --- a/requirements-tests-py2.7.text +++ /dev/null @@ -1 +0,0 @@ -mock==3.0.5 diff --git a/tests.py b/tests.py index 5d955b45..fea8acef 100644 --- a/tests.py +++ b/tests.py @@ -3,9 +3,9 @@ proxy.py ~~~~~~~~ - HTTP Proxy Server in Python. + HTTP, HTTPS, HTTP2 and WebSockets Proxy Server in Python. - :copyright: (c) 2013-2018 by Abhinav Singh. + :copyright: (c) 2013-2020 by Abhinav Singh. :license: BSD, see LICENSE for more details. """ import os @@ -21,8 +21,8 @@ from threading import Thread import proxy -# logging.basicConfig(level=logging.DEBUG, -# format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s') +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s') if sys.version_info[0] == 3: # Python3 specific imports from http.server import HTTPServer, BaseHTTPRequestHandler @@ -123,7 +123,7 @@ class TestMultiCoreRequestDispatcher(unittest.TestCase): tcp_server = None tcp_thread = None - @mock.patch.object(proxy, 'HttpProxy', side_effect=mock_tcp_proxy_side_effect) + @mock.patch.object(proxy, 'HttpProtocolHandler', side_effect=mock_tcp_proxy_side_effect) def testHttpProxyConnection(self, mock_tcp_proxy): try: self.tcp_port = get_available_port() @@ -523,6 +523,7 @@ class TestProxy(unittest.TestCase): http_server = None http_server_port = None http_server_thread = None + config = None @classmethod def setUpClass(cls): @@ -531,6 +532,8 @@ class TestProxy(unittest.TestCase): cls.http_server_thread = Thread(target=cls.http_server.serve_forever) cls.http_server_thread.setDaemon(True) cls.http_server_thread.start() + cls.config = proxy.HttpProtocolConfig() + cls.config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') @classmethod def tearDownClass(cls): @@ -541,7 +544,7 @@ class TestProxy(unittest.TestCase): def setUp(self): self._conn = MockTcpConnection() self._addr = ('127.0.0.1', 54382) - self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr)) + self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), config=self.config) @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') @@ -610,7 +613,7 @@ class TestProxy(unittest.TestCase): proxy.CRLF ])) self.proxy.run_once() - self.assertFalse(self.proxy.server is None) + self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) mock_server_connection.assert_called_once() server.connect.assert_called_once() @@ -659,9 +662,10 @@ class TestProxy(unittest.TestCase): @mock.patch('select.select') def test_proxy_authentication_failed(self, mock_select): mock_select.return_value = ([self._conn], [], []) - self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr), - config=proxy.HttpProxyConfig( - auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))) + config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), + config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', @@ -673,14 +677,15 @@ class TestProxy(unittest.TestCase): @mock.patch('select.select') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_get(self, mock_server_connection, mock_select): - client = proxy.TcpClientConnection(self._conn, self._addr) - config = proxy.HttpProxyConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) - mock_select.return_value = ([self._conn], [], []) server = mock_server_connection.return_value server.connect.return_value = True - self.proxy = proxy.HttpProxy(client, config=config) + client = proxy.TcpClientConnection(self._conn, self._addr) + config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + + self.proxy = proxy.HttpProtocolHandler(client, config=config) self.proxy.client.conn.queue(b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) self.proxy.run_once() self.assertEqual(self.proxy.request.state, proxy.HttpParser.states.INITIALIZED) @@ -731,9 +736,10 @@ class TestProxy(unittest.TestCase): server.connect.return_value = True mock_select.side_effect = [([self._conn], [], []), ([self._conn], [], []), ([], [server.conn], [])] - self.proxy = proxy.HttpProxy(proxy.TcpClientConnection(self._conn, self._addr), - config=proxy.HttpProxyConfig( - auth_code=b'Basic %s' % base64.b64encode(b'user:pass'))) + config = proxy.HttpProtocolConfig(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) + config.plugins = proxy.load_plugins('proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') + self.proxy = proxy.HttpProtocolHandler(proxy.TcpClientConnection(self._conn, self._addr), + config=config) self.proxy.client.conn.queue(proxy.CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, @@ -743,7 +749,7 @@ class TestProxy(unittest.TestCase): proxy.CRLF ])) self.proxy.run_once() - self.assertFalse(self.proxy.server is None) + self.assertFalse(self.proxy.plugins['HttpProxyPlugin'].server is None) self.assertEqual(self.proxy.client.buffer, proxy.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) mock_server_connection.assert_called_once() server.connect.assert_called_once() @@ -780,15 +786,15 @@ class TestWorker(unittest.TestCase): self.queue = multiprocessing.Queue() self.worker = proxy.Worker(self.queue) - @mock.patch('proxy.HttpProxy') + @mock.patch('proxy.HttpProtocolHandler') def test_shutdown_op(self, mock_http_proxy): self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) self.worker.run() # Worker should consume the prior shutdown operation self.assertFalse(mock_http_proxy.called) - @mock.patch('proxy.HttpProxy') + @mock.patch('proxy.HttpProtocolHandler') def test_spawns_http_proxy_threads(self, mock_http_proxy): - self.queue.put((proxy.Worker.operations.HTTP_PROXY, None)) + self.queue.put((proxy.Worker.operations.HTTP_PROTOCOL, None)) self.queue.put((proxy.Worker.operations.SHUTDOWN, None)) self.worker.run() self.assertTrue(mock_http_proxy.called)